Source code for scmidas.api

"""High-level conveniences on top of the MIDAS class."""
from __future__ import annotations

import logging
from typing import Any, Optional

import mudata as mu

from .config import load_config
from .model import MIDAS

logger = logging.getLogger(__name__)


# Tuned for the bundled quickstart dataset (PBMC mosaic, 1200 cells, 500
# HVGs + 224 ADT). On a single mid-range GPU these defaults converge in
# roughly one minute and produce a clean lineage-separated UMAP. They
# are NOT general-purpose — for full datasets, fall back to the
# paper defaults via ``MIDAS.configure_data_from_mdata(...).train()``.
_QUICKSTART_DEFAULTS: dict[str, Any] = {
    'batch_size': 128,
    'max_epochs': 65,
    'lr_net': 3e-4,
    'lr_dsc': 3e-4,
}


[docs] def integrate( mdata: mu.MuData, *, batch_key: str = 'batch', max_epochs: Optional[int] = None, batch_size: Optional[int] = None, accelerator: str = 'auto', devices: Any = 1, strategy: str = 'auto', save_model_path: str = './saved_models/scmidas', seed: Optional[int] = 42, key_added: str = 'X_midas', **kwargs: Any, ) -> MIDAS: """One-call MIDAS pipeline for users who want a sensible default. Equivalent to:: scmidas.MIDAS.setup_mudata(mdata, batch_key=batch_key) model = scmidas.MIDAS(mdata, configs=..., batch_size=..., ...) model.train(max_epochs=..., accelerator=..., ...) mdata.obsm[key_added] = model.get_latent_representation() .. warning:: The default training hyperparameters (``batch_size=128``, ``max_epochs=65``, ``lr=3e-4``) are tuned for the **toy quickstart dataset** (1600 cells). They are **not appropriate for real analyses** — for full datasets pass your own ``max_epochs`` (typically 1000-2000) and consider letting ``batch_size`` default back to 256. Parameters: mdata : MuData Multi-modal single-cell data. batch_key : str Column in each modality's ``.obs`` that identifies the source batch. max_epochs : int, optional Training epochs. Default 65 (quickstart-tuned). For real data, override with 1000-2000. batch_size : int, optional Mini-batch size. Default 128 (quickstart-tuned). For real data, 256 is a more typical choice. accelerator, devices, strategy Forwarded to ``lightning.Trainer``. Default ``'auto'`` picks GPU if available. save_model_path : str Where to write checkpoints during training. seed : int, optional If not None, calls ``lightning.seed_everything(seed)`` before setup, so the run is reproducible. key_added : str Key under which the biological latent ``z_c`` is written to ``mdata.obsm``. Defaults to ``'X_midas'`` so that ``sc.pp.neighbors(mdata, use_rep='X_midas')`` works without further arguments. **kwargs Additional keyword arguments forwarded to ``MIDAS(...)``. Returns: MIDAS: A trained MIDAS model. The biological latent has already been written to ``mdata.obsm[key_added]``. """ if seed is not None: import lightning as L L.seed_everything(seed, verbose=False) configs = load_config() configs['lr_net'] = _QUICKSTART_DEFAULTS['lr_net'] configs['lr_dsc'] = _QUICKSTART_DEFAULTS['lr_dsc'] bsz = batch_size if batch_size is not None else _QUICKSTART_DEFAULTS['batch_size'] eps = max_epochs if max_epochs is not None else _QUICKSTART_DEFAULTS['max_epochs'] logger.info( 'scmidas.integrate(): toy-tuned defaults — ' 'batch_size=%d, max_epochs=%d, lr=%g. ' 'For real datasets, override max_epochs (e.g. 2000) ' 'and consider batch_size=256.', bsz, eps, _QUICKSTART_DEFAULTS['lr_net'], ) MIDAS.setup_mudata(mdata, batch_key=batch_key) model = MIDAS( mdata, configs=configs, batch_size=bsz, save_model_path=save_model_path, **kwargs, ) model.train( max_epochs=eps, accelerator=accelerator, devices=devices, strategy=strategy, ) mdata.obsm[key_added] = model.get_latent_representation() return model