scmidas.model#

class scmidas.model.Decoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#

Bases: Module

Decoder class for multi-modal data with shared and modality-specific decoding layers.

Parameters:
  • dims_x – Dict[str, list] Output dimensions for each modality.

  • dims_h – Dict[str, list] Hidden dimensions for each modality.

  • dim_z – int Latent dimension size.

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • out_trans – str Output activation function (e.g., ‘relu’).

  • drop – float Dropout rate.

  • kwargs – Dict[str, Any] Additional modality-specific configurations.

forward(latent_data: Tensor) Dict[str, Tensor][source]#

Forward pass for the decoder.

Parameters:

latent_data – torch.Tensor Latent variable input tensor of shape (batch_size, dim_z).

Returns:

Decoded outputs for each modality.

Return type:

Dict[str, torch.Tensor]

class scmidas.model.Discriminator(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#

Bases: Module

Discriminator class for multi-modal latent variables.

Parameters:
  • dims_x – Dict[str, list] Input dimensions for each modality.

  • dims_s – Dict[str, int] Dimensions of the classes for each modality.

  • kwargs – Dict[str, Any] Additional configurations, such as hidden layer sizes, dropout rate, and normalization type.

calculate_loss(predictions: Dict[str, Tensor], targets: Dict[str, Tensor]) Tensor[source]#

Calculate cross-entropy loss for all modalities.

Parameters:
  • predictions – Dict[str, torch.Tensor] Dictionary of predicted logits for each modality.

  • targets – Dict[str, torch.Tensor] Dictionary of ground truth labels for each modality.

Returns:

Total normalized loss.

Return type:

torch.Tensor

forward(latent_inputs: Dict[str, Tensor]) Dict[str, Tensor][source]#

Forward pass for the discriminator.

Parameters:

latent_inputs – Dict[str, torch.Tensor] Dictionary of latent inputs for each modality, where keys are modality names and values are tensors of shape (batch_size, dim_c).

Returns:

Dictionary of logits for each modality, where keys are modality names and values are tensors of shape (batch_size, dims_s[modality]).

Return type:

Dict[str, torch.Tensor]

class scmidas.model.Encoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#

Bases: Module

Encoder class for multi-modal data with modality-specific pre-processing, encoding, and shared encoding layers.

Parameters:
  • dims_x – Dict[str, list] Input dimensions for each modality (e.g, {‘rna’:[1000], ‘adt’:[100]}).

  • dims_h – Dict[str, list] Hidden dimensions for each modality after pre-encoding (e.g, {‘rna’:256, ‘adt’:256}).

  • dim_z – int Latent dimension size (e.g, 32).

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • out_trans – str Output activation function (e.g., ‘mish’).

  • drop – float Dropout rate.

  • kwargs – Dict[str, Any] Additional modality-specific configurations.

Notes

By default, RNA and ADT data are log1p-transformed in the encoder and will be exponentiated after decoding. To skip this step, modify the configuration file. See parameter ‘trsf_before_enc_’.

forward(data: Dict[str, Tensor], mask: Dict[str, Tensor]) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#

Forward pass for the encoder.

Parameters:
  • data – Dict[str, torch.Tensor] Input data for each modality.

  • mask – Dict[str, torch.Tensor] Masks for each modality.

Returns:

  • z_x_muDict[str, torch.Tensor]

    Mean values for latent space for each modality.

  • z_x_logvarDict[str, torch.Tensor]

    Log-variance values for latent space for each modality.

Return type:

Tuple

class scmidas.model.MIDAS(mdata: MuData | None = None, *, save_model_path: str = './saved_models/scmidas', configs: Dict[str, Any] | None = None, batch_size: int = 256, n_save: int = 500, sampler_type: str = 'auto', viz_umap_tb: bool = False, transform: Dict[str, str] | None = None)[source]#

Bases: LightningModule

MIDAS processes mosaic single-cell data into imputed and batch-corrected data for multimodal analysis.

net#

VAE Variational Autoencoder for multi-modal data encoding and decoding.

dsc#

Discriminator Discriminator for distinguishing latent variables across batches.

configs#

Dict[str, Any] Model and training configurations dynamically set as attributes.

automatic_optimization#

bool Controls whether optimization is automatic or manually defined. Always True.

static calc_consistency_loss(z_uni: Dict[str, Tensor]) Tensor[source]#

Calculate the consistency loss for unified latent variables across modalities.

Parameters:

z_uni – Dict[str, torch.Tensor] Dictionary of unified latent variables for each modality, where each value is a tensor of shape (batch_size x latent_dim).

Returns:

Consistency loss computed as the variance of the unified latent variables.

Return type:

torch.Tensor

static calc_dsc_loss(pred: Dict[str, Tensor], true: Dict[str, Tensor]) Tensor[source]#

Calculate the discriminator loss using cross-entropy.

Parameters:
  • pred – Dict[str, torch.Tensor] Predicted logits for each modality.

  • true – Dict[str, torch.Tensor] Ground truth labels for each modality.

Returns:

Computed discriminator loss.

Return type:

torch.Tensor

static calc_kld_loss(mu: Tensor, logvar: Tensor) Tensor[source]#

Calculate the KLD loss for a single latent space.

Parameters:
  • mu – torch.Tensor Mean of the latent variable distribution (batch_size x latent_dim).

  • logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x latent_dim).

Returns:

KLD loss for the latent space, normalized by batch size.

Return type:

torch.Tensor

static calc_kld_z_loss(dim_c: int, dim_u: int, lam_kld_c: float, lam_kld_u: float, mu: Tensor, logvar: Tensor) Tensor[source]#

Calculate the Kullback-Leibler Divergence (KLD) loss for latent variables z.

Parameters:
  • dim_c – int Dimension of the biological latent space.

  • dim_u – int Dimension of the technical latent space.

  • lam_kld_c – float Weight for KLD loss of the biological latent space.

  • lam_kld_u – float Weight for KLD loss of the technical latent space.

  • mu – torch.Tensor Mean of the latent variable distribution (batch_size x (dim_c + dim_u)).

  • logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x (dim_c + dim_u)).

Returns:

Weighted sum of KLD losses for the biological and technical latent spaces.

Return type:

torch.Tensor

static calc_recon_loss(x: Dict[str, Tensor], s: Tensor, e: Dict[str, Tensor], x_r_pre: Dict[str, Tensor], s_r_pre: Dict[str, Tensor], dist: Dict[str, str], lam: Dict[str, float]) Tuple[float, Dict[Tensor, Tensor]][source]#

Calculate the reconstruction loss for input data and predicted outputs.

Parameters:
  • x – Dict[str, torch.Tensor] Original input data for each modality (x^m).

  • s – torch.Tensor Ground truth batch labels.

  • e – Dict[str, torch.Tensor] Mask.

  • x_r_pre – Dict[str, torch.Tensor] Reconstructed predictions for each modality (x_r^m).

  • s_r_pre – Dict[str, torch.Tensor] Reconstructed predictions for batch labels.

  • dist – Dict[str, str] Dictionary specifying the distribution type for each modality’s decoder.

  • lam – Dict[str, float] Dictionary containing reconstruction loss weights for each modality and for s.

Returns:

  • total_losstorch.Tensor

    Total reconstruction loss, normalized by batch size.

  • lossesDict[str, torch.Tensor]

    Dictionary containing reconstruction losses for each modality and for batch labels.

Return type:

Tuple

classmethod configure_data(configs: dict, datalist: List[Dataset], dims_x: Dict[str, list], dims_s: Dict[str, int], s_joint: List[Dict[str, int]], combs: List[List[str]], batch_size: int = 256, n_save: int = 500, save_model_path: str = './saved_models/', sampler_type: str = 'auto', viz_umap_tb=False, batch_names=None) MIDAS[source]#

Configure the data and model parameters for training.

Parameters:
  • configs – dict, Configurations of the model.

  • datalist – List[Dataset] List of datasets to be used for training.

  • dims_x – Dict[str, list] Dictionary specifying the dimensions of input features for each modality.

  • dims_s – Dict[str, int] Dimensions of the classes for each modality.

  • s_joint – List[Dict[str, int]] Modality ID for each batch.

  • combs – List[List[str]] Combinations of modalities.

  • batch_size – int, optional Size of each training batch, by default 256.

  • n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.

  • save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.

  • sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.

  • viz_umap_tb – bool, optional Whether to visualize UMAP embeddings in TensorBoard, by default False.

  • batch_names – list, optional List of batch names, by default None.

Returns:

Returns MIDAS instance.

Return type:

class ‘MIDAS’

classmethod configure_data_from_dir(configs: Dict[str, Any], dir_path: str, format: str = 'mtx', transform: Dict[str, str] = None, sampler_type: str = 'auto', viz_umap_tb: bool = False, save_model_path: str = './saved_models/', n_save: int = 500, **kwargs: Dict[str, Any]) MIDAS[source]#

Configure data from a directory and apply optional transformations.

Parameters:
  • configs – Dict[str, Any] Configurations of the model.

  • dir_path – str Path to the directory containing data files.

  • transform – Dict[str, str], optional A dictionary specifying transformations to apply to specific modalities. Example: {‘atac’: ‘binarize’} Default is None, which uses the default transformation settings.

  • sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.

  • viz_umap_tb – bool, optional Whether to visualize UMAP embeddings in TensorBoard, by default False.

  • save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.

  • n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.

  • kwargs – Dict[str, Any] Additional parameters passed to configure_data().

Returns:

Returns the configured class instance.

Return type:

class ‘MIDAS’

Examples

>>> from scmidas.model import MIDAS
>>> from scmidas.config import load_config
>>> configs = load_config()
>>> dir_path = 'XXX'
>>> transform = {'atac': 'binarize'}
>>> model = MIDAS.configure_data_from_dir(configs, dir_path, transform)
classmethod configure_data_from_mdata(configs: Dict[str, Any], mdata: MuData, dims_x: Dict[str, list], batch_key: str = 'batch', transform: Dict[str, str] | None = None, sampler_type: str = 'auto', viz_umap_tb: bool = False, save_model_path: str = './saved_models/', n_save: int = 500, **kwargs: Any) MIDAS[source]#

Configure the MIDAS model directly from a MuData object.

This method processes the MuData input to extract data, masks, and batch information, initializes the datasets, and sets up the model configuration.

Parameters:
  • configs – Dict[str, Any] Configurations of the model.

  • mdata – MuData The input MuData object containing multi-modal single-cell data. It is expected to contain AnnData objects for different modalities (e.g., RNA, ATAC).

  • dims_x – Dict[str, list] A dictionary specifying the input feature dimensions for each modality. Keys are modality names, and values are lists of dimensions (e.g., {‘rna’: [2000]}).

  • transform – Optional[Dict[str, str]], default=None A dictionary specifying specific transformations to apply to each modality. Example: {‘atac’: ‘binarize’}. If None, default transformations are used.

  • sampler_type – str, default=’auto’ Strategy for data sampling. Use ‘ddp’ for Distributed Data Parallel training, or ‘auto’ for standard training.

  • viz_umap_tb – bool, default=False If True, enables UMAP visualization logs in TensorBoard during the training process.

  • save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.

  • n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.

  • **kwargs – Any Additional keyword arguments passed to the underlying configure_data method (e.g., batch_size, num_workers).

Returns:

An initialized instance of the MIDAS class, ready for training or inference.

Return type:

MIDAS

configure_optimizers() List[Optimizer][source]#

Configure optimizers for the MIDAS model.

Returns:

List of optimizers for the network and discriminator.

Return type:

List[torch.optim.Optimizer]

static get_datasets_from_adata(data: List[Dict[str, AnnData]], mask: List[Dict[str, str]], batch_names: List[str], transform: Dict[str, str] = None)[source]#

Configure data from a CSV input.

Parameters:
  • data – List[Dict[str, str]] List of data dictionaries, where keys are modalities and values are adata object.

  • mask – List[Dict[str, str]] List of mask dictionaries, where keys are modalities and values are mask values.

  • batch_name – List[str] List of batch names.

  • transform – Optional[Dict[str, str]] Transformations to apply to specific modalities.

  • format – str File type of the input data, default is ‘vec’. [‘vec’, ‘mtx’, ‘csv’]

Returns:

  • datasetsList[MultiModalDataset]

    List of initialized MultiModalDataset objects.

  • dims_sDict[str, int]

    Dimensions for batch correction for each modality.

  • s_jointList[Dict[str, int]]

    Modality indices for each batch.

  • combsList[List[str]]

    List of modality combinations for each batch.

Return type:

Tuple

static get_datasets_from_dir(data: List[Dict[str, str]], mask: List[Dict[str, str]], batch_names: List[str], transform: Dict[str, str] = None, format: str = 'mtx')[source]#

Configure data from directory.

Parameters:
  • data – List[Dict[str, str]] List of data dictionaries, where keys are modalities and values are file paths.

  • mask – List[Dict[str, str]] List of mask dictionaries, where keys are modalities and values are mask file paths.

  • batch_name – List[str] List of batch names.

  • transform – Optional[Dict[str, str]] Transformations to apply to specific modalities.

  • format – str File type of the input data, default is ‘vec’. [‘vec’, ‘mtx’, ‘csv’]

Returns:

  • datasetsList[MultiModalDataset]

    List of initialized MultiModalDataset objects.

  • dims_sDict[str, int]

    Dimensions for batch correction for each modality.

  • s_jointList[Dict[str, int]]

    Modality indices for each batch.

  • combsList[List[str]]

    List of modality combinations for each batch.

Return type:

Tuple

get_emb_umap(pred_dir: str = None, pred_format: str = None, save_dir: str = None, drop_c_umap: bool = False, drop_u_umap: bool = False, color_by: str = 'batch', n_obs: int = None, verbose=True, **kwargs) Tuple[List[AnnData], List[Figure]][source]#

Generate UMAP visualizations for biological and technical latent embeddings.

This function loads predicted latent representations and computes UMAP embeddings for visualization. Two types of embeddings are supported:

  1. Biological embedding (z_c)

  2. Technical embedding (z_u)

For large datasets, the function can optionally subsample observations to accelerate UMAP computation.

Parameters:
  • pred_dir – str, optional Directory containing predicted results generated by predict() or predict_streaming(). If None, predictions will be generated on-the-fly using self.predict().

  • pred_format – {“npy”, “csv”}, optional File format of saved prediction files when loading from disk. Only used when pred_dir is provided.

  • save_dir – str, optional Directory to save the generated UMAP figures. If None, figures will not be written to disk.

  • drop_c_umap – bool, default=False Whether to skip UMAP visualization for the biological embedding (z_c).

  • drop_u_umap – bool, default=False Whether to skip UMAP visualization for the technical embedding (z_u).

  • color_by

    str, default=”batch” Column name in adata.obs used to color cells in UMAP plots. Common options include:

    • ”batch” : batch label

    • ”s_joint” : subset or dataset identifier

    • any other metadata column stored in adata.obs

  • n_obs – int, optional Number of observations to randomly subsample before computing UMAP. Useful for large datasets to speed up visualization. If None, all observations will be used.

  • verbose – bool, default=True Whether to display progress bars and logging information.

  • **kwargs – Dict[str, Any] Additional keyword arguments passed to scanpy.pl.umap().

Returns:

List[AnnData]

List of AnnData objects containing the computed UMAP embeddings.

all_figuresList[matplotlib.figure.Figure]

List of generated UMAP figure objects.

Return type:

all_adata

Notes

  • UMAP is computed using scanpy.pp.neighbors() followed by

scanpy.tl.umap(). - The biological embedding (z_c) captures biological variation, while the technical embedding (z_u) reflects batch or technical effects. - For very large datasets (e.g., >1M cells), it is recommended to set n_obs (e.g., 20,000) to reduce computation time.

Examples

Generate UMAP from saved predictions:

>>> model.get_emb_umap(pred_dir="./predictions")

Generate UMAP with subsampling and custom coloring:

>>> model.get_emb_umap(
...     pred_dir="./predictions",
...     n_obs=20000,
...     color_by="batch"
... )

Generate UMAP and save figures:

>>> model.get_emb_umap(
...     pred_dir="./predictions",
...     save_dir="./figs"
... )
get_imputed_values(mdata: MuData | None = None, *, modality: str = 'rna', verbose: bool = False) ndarray[source]#

Return imputed expression values for a single modality.

For mosaic data, this fills in cells that originally lacked the modality with model-inferred values. Cells that already had real observations also get the model’s reconstruction (useful for denoising).

Parameters:
  • mdata – MuData, optional MuData to align against. Defaults to the MuData used at construction time.

  • modality – str Modality name (e.g. 'rna', 'adt', 'atac'). Must be one of self.mods.

  • verbose – bool Whether to show prediction progress bars.

Returns:

Array of shape (mdata.n_obs, n_features[modality]), aligned to mdata.obs_names. Cells absent from training data yield NaN rows (a warning is logged).

Return type:

np.ndarray

static get_info_from_dir(dir_path: str, format: str)[source]#

Extract data, mask, and feature dimensions from a directory of vectors.

Parameters:
  • dir_path – str Path to the directory containing data and mask files.

  • format – str Support ‘mtx’, ‘csv’, and ‘vec’.

Returns:

  • dataList[Dict[str, str]]

    List of dictionaries where keys are modalities and values are file paths.

  • maskList[Dict[str, str]]

    List of dictionaries where keys are modalities and values are mask file paths.

  • dims_xDict[str, list]

    Dictionary containing feature dimensions for each modality.

Return type:

Tuple

Notes

The directory should be organized as:

dataset/
    feat/
        # Dimensions of each modality: {mod1=[...], mod2=[...]}.
        # Split the data into chunks if the length of the list is greater than 1.
        # For instance, you can split the ATAC data by chromosomes.
        feat_dims.toml
    batch_0/
        mask/mod1.csv
        mask/mod2.csv
        vec/mod1/0000.csv # the first sample
        vec/mod1/0001.csv # the second sample
        ....
        vec/mod2/0000.csv
        vec/mod2/0001.csv
        ....
    batch_1/
        mask/mod1.csv
        mask/mod2.csv
        vec/mod1/0000.csv
        vec/mod1/0001.csv
        ....
        vec/mod2/0000.csv
        vec/mod2/0001.csv
    ....

or like:

dataset/
    feat/
        # Dimensions of each modality: {mod1=[...], mod2=[...]}.
        # Split the data into chunks if the length of the list is greater than 1.
        # For instance, you can split the ATAC data by chromosomes.
        feat_dims.toml
    batch_0/
        mask/mod1.csv
        mask/mod2.csv
        mat/mod1.mtx (.csv)
        mat/mod2.mtx (.csv)
        ....
    batch_1/
        mask/mod1.csv
        mask/mod2.csv
        mat/mod1.mtx (.csv)
        mat/mod2.mtx (.csv)
    ....
static get_info_from_mdata(mdata, batch_key='batch')[source]#
get_latent_representation(mdata: MuData | None = None, *, kind: str = 'c', verbose: bool = False) ndarray[source]#

Return the joint latent representation aligned to mdata.obs_names.

This is the convenience wrapper around predict() that most users want: it runs the joint latent forward pass, stitches the per-batch outputs, and reorders the result so each row matches mdata.obs_names — ready to drop straight into mdata.obsm['X_midas'].

Parameters:
  • mdata – MuData, optional MuData to align against. Defaults to the MuData used at construction time. Must include the same cells as the registered MuData (same obs_names).

  • kind

    {‘c’, ‘u’, ‘joint’} Which latent to return:

    • 'c' (default): biological latent z_c of shape (n_obs, dim_c). Recommended for clustering / visualization.

    • 'u': technical latent z_u of shape (n_obs, dim_u).

    • 'joint': concatenation [z_c, z_u].

  • verbose – bool Whether to show prediction progress bars.

Returns:

Array of shape (mdata.n_obs, dim) aligned to mdata.obs_names. Cells absent from training data yield NaN rows (a warning is logged).

Return type:

np.ndarray

classmethod load(dir_path: str, mdata: MuData, **kwargs: Any) MIDAS[source]#

Load a MIDAS model previously saved via save().

If mdata has not been registered (no _scmidas in mdata.uns), this method auto-registers it using the batch_key saved alongside the model.

Parameters:
  • dir_path – str Directory written by save().

  • mdata – MuData Multi-modal data the model was trained on (or a query dataset with the same modality structure).

  • **kwargs – Any Forwarded to MIDAS (e.g. batch_size, save_model_path).

Returns:

A model with the saved weights loaded.

Return type:

MIDAS

load_checkpoint(checkpoint_path: str, start_epoch: int = 0, **kwargs)[source]#

Load model and optimizer states from a checkpoint file.

Parameters:
  • checkpoint_path – str Path to the checkpoint file containing saved model and optimizer states.

  • start_epoch – int Indicate how many epoch the model has been trained.

  • kwargs – Dict[str, Any] Additional configurations for torch.load().

Raises:

AssertionError – If the provided checkpoint path does not exist.

log_losses(recon_loss: Tensor, kld_loss, consistency_loss: Tensor, loss_net: Tensor, loss_dsc: Tensor, recon_dict: Dict[str, Tensor])[source]#

Log losses for monitoring and debugging during training.

Parameters:
  • recon_loss – torch.Tensor Reconstruction loss.

  • kld_loss – torch.Tensor KLD loss.

  • consistency_loss – torch.Tensor Consistency loss.

  • recon_dict – Dict[str, torch.Tensor] Per-modality reconstruction losses.

  • loss_net – torch.Tensor Total VAE loss.

  • loss_dsc – torch.Tensor Discriminator loss.

on_train_end()[source]#

Save the final model checkpoint at the end of training.

on_train_epoch_end()[source]#

Save a model checkpoint at the end of each training epoch with a meaningful filename.

predict(return_in_memory: bool = True, save_dir: str | None = None, save_format: str = 'npy', joint_latent: bool = True, mod_latent: bool = False, impute: bool = False, batch_correct: bool = False, translate: bool = False, input: bool = False, verbose: bool = True)[source]#

Run model inference in a streaming manner.

This method supports three prediction modes:

  1. Return predictions in memory (recommended for small or medium datasets).

  2. Stream predictions to disk per mini-batch (recommended for large datasets).

  3. Perform both simultaneously.

Notes

  • If return_in_memory=False, prediction tensors will not accumulate in RAM,

making this method suitable for very large datasets. - If save_dir is provided, predictions are written incrementally to disk per mini-batch. - If batch_correct=True, a second pass over the dataset is performed:

Pass 1:

Compute joint latent representations and estimate the technical centroid using online statistics.

Pass 2:

Reconstruct data with batch correction and stream the corrected outputs.

  • If translate=True, mod_latent will be automatically set to True.

Parameters:
  • return_in_memory – bool, default=True Whether to keep predictions in memory and return them as a nested dictionary. Set to False for large datasets to avoid OOM.

  • save_dir – str or None, default=None Output directory for streaming prediction results. If None, predictions are not saved to disk.

  • save_format

    {“npy”, “csv”}, default=”npy” File format used when save_dir is provided.

    • ”npy” : NumPy binary format (recommended; fast and compact).

    • ”csv” : CSV text format (not recommended for large arrays).

  • joint_latent

    bool, default=True Whether to compute the joint latent representation conditioned on all observed modalities.

    Stored as:
    • z[“joint”] (raw latent)

    • or postprocessed into z_c[“joint”] and z_u[“joint”].

  • mod_latent

    bool, default=False Whether to compute latent representations conditioned on each individual modality.

    For each modality m, a single-modality forward pass is performed and stored as:

    z[m]

  • impute

    bool, default=False Whether to generate imputed data (x_impt) from the joint latent representation.

    Stored as:

    x_impt[modality]

  • batch_correct

    bool, default=False Whether to estimate a technical centroid and perform batch-effect correction on reconstructed data.

    Stored as:

    x_bc[modality]

  • translate

    bool, default=False Whether to perform cross-modality translation.

    For each available input modality subset, missing modalities are generated and stored as:

    x_trans[“<input_mods>_to_<target_mod>”]

  • input

    bool, default=False Whether to include the original input data and masks in the output.

    Stored as:

    x[modality] mask[modality] (if available)

  • verbose – bool, default=True Whether to display progress bars (tqdm) and logging messages.

Returns:

dict or None

Return type:

output

Raises:
  • ValueError – If both return_in_memory=False and save_dir=None, or if an unsupported save_format is specified.

  • KeyError – If required prediction fields are missing.

static print_info(mask: List[Dict[str, str]], datalist: List[Dataset], batch_names: List[str])[source]#

Print summary of mask density and dataset information.

Parameters:
  • mask – List[Dict[str, str]] List of mask.

  • datalist – List[Dataset] List of datasets.

  • batch_name – List[str] List of batch names.

save(dir_path: str, *, overwrite: bool = False) None[source]#

Save this MIDAS model to a directory.

Writes two files:

  • model.pt — VAE / Discriminator weights and optimizer states.

  • setup.json — minimal data-setup metadata so the model can be reloaded on a fresh process.

This is the recommended user-facing save API (v0.3+); pair with MIDAS.load(). The legacy save_checkpoint() / load_checkpoint() flat-file API still works.

Parameters:
  • dir_path – str Directory path. Created if missing.

  • overwrite – bool Replace dir_path if it already exists.

save_checkpoint(checkpoint_path: str)[source]#

Save the current model and optimizer states to a checkpoint file.

Parameters:

checkpoint_path – str Path to save the checkpoint file.

Raises:

ValueError – If checkpoint_path is an invalid or empty string.

static setup_mudata(mdata: MuData, batch_key: str = 'batch', dims_x: Dict[str, list] | None = None) None[source]#

Register a MuData object for use with MIDAS.

Stores the data-setup metadata (batch key, modality names, feature dimensions, batch list) under mdata.uns['_scmidas']. After this call, the MuData can be passed directly to MIDAS.

Parameters:
  • mdata – MuData Multi-modal data. Each modality (mdata[m]) must contain batch_key in its .obs.

  • batch_key – str Column in each modality’s .obs that identifies the source batch.

  • dims_x – dict, optional Per-modality feature dimensions, e.g. {'rna': [4045], 'adt': [224], 'atac': [2897, 2007, ...]}. Provide this for ATAC data that should be encoded by chromosome chunks (the published architecture). If omitted, each modality is treated as a single chunk equal to its full feature count.

Notes

Mutates mdata in-place; returns None.

train(**kwargs)[source]#

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Parameters:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

Returns:

self

Return type:

Module

train_dataloader() DataLoader[source]#

Create a DataLoader for training, using the appropriate sampler.

Returns:

Configured DataLoader instance for training.

Return type:

DataLoader

train_discriminator(c_all: Dict[str, Tensor], targets: Dict[str, Tensor])[source]#

Train the discriminator with modality-specific latent representations.

Parameters:
  • c_all – Dict[str, torch.Tensor] Dictionary of latent representations for each modality.

  • targets – Dict[str, torch.Tensor] Ground truth batch labels for each modality.

training_step(batch: Dict[str, Dict[str, Tensor]], batch_idx: int) Tensor[source]#

Executes a single training step for MIDAS.

Parameters:
  • batch – Dict[str, Dict[str, torch.Tensor]] Input batch containing modality data, batch indices, and masks.

  • batch_idx – int Index of the current training batch.

Returns:

Total VAE loss for the current batch.

Return type:

torch.Tensor

static update_model(loss: Tensor, model: Module, optimizer: Optimizer, grad_clip: int = -1)[source]#

Update model parameters using backpropagation.

Parameters:
  • loss – torch.Tensor Computed loss for backpropagation.

  • model – torch.nn.Module Model to update.

  • optimizer – torch.optim.Optimizer Optimizer for parameter updates.

  • grad_clip – int True to allow clipping gradient.

class scmidas.model.S_Decoder(n_batches: int, dims_dec_s: List[int], dim_u: int, norm: str, drop: float)[source]#

Bases: Module

Decoder for reconstructing batch ID.

Parameters:
  • n_batches – int Number of distinct batches.

  • dims_dec_s – List[int] List of dimensions for hidden layers in the decoder.

  • dim_u – int Latent dimension size for the input (e.g, 2).

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • drop – float Dropout rate.

forward(data: Tensor) Tensor[source]#

Forward pass for S_Decoder.

Parameters:

data – torch.Tensor Latent input tensor of shape (batch_size, dim_u).

Returns:

Reconstructed tensor of shape (batch_size, n_batches).

Return type:

torch.Tensor

class scmidas.model.S_Encoder(n_batches: int, dims_enc_s: List[int], dim_z: int, norm: str, drop: float)[source]#

Bases: Module

Encoder for batch ID latent variables.

Parameters:
  • n_batches – int Number of distinct batches.

  • dims_enc_s – List[int] List of dimensions for hidden layers in the encoder.

  • dim_z – int Latent dimension size for the latent.

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • drop – float Dropout rate.

forward(data: Tensor) Tensor[source]#

Forward pass for S_Encoder.

Parameters:

data – torch.Tensor Input tensor of shape (batch_size, 1), containing batch indices.

Returns:

Encoded tensor of shape (batch_size, dim_z * 2).

Return type:

torch.Tensor

class scmidas.model.VAE(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#

Bases: Module

Variational Autoencoder (VAE) for multi-modal data, supporting batch correction and sampling from distributions.

Parameters:
  • dims_x – Dict[str, list] Input dimensions for each modality, e.g {‘rna’=[1000], ‘adt’=[100], ‘atac’=[10,10,10]}.

  • dims_s – Dict[str, int] Dimensions of the classes for each modality.

  • kwargs – Dict[str, Any] Additional configurations for encoders, decoders, and other modules.

encode_batch(s: Tensor) Tuple[list, list] | None[source]#

Encode batch IDs latent variables.

Parameters:

s – torch.Tensor Batch IDs.

Returns:

  • z_s_muList[torch.Tensor]

    Mean of batch IDs latent variables.

  • z_s_logvarList[torch.Tensor]

    Log-variance of batch IDs latent variables.

Return type:

Optional[Tuple[list, list]]

forward(data: Dict[str, Tensor]) Tuple[Dict[str, Tensor], Tensor | None, Tensor, Tensor, Tensor, Tensor, Tensor, Dict[str, Tensor], Dict[str, Tensor]][source]#

Forward pass for the VAE.

Parameters:

data – Dict[str, torch.Tensor] Input data dictionary containing: - ‘x’: Dict[str, torch.Tensor], modality-specific input data. - ‘e’: Dict[str, torch.Tensor], modality-specific masks. - ‘s’ (optional): torch.Tensor, dimensions of the output classes for each modality.

Returns:

  • x_r_preDict[str, torch.Tensor]

    Reconstructed modality-specific data.

  • s_r_preOptional[torch.Tensor]

    If ‘s’ is provided, return reconstructed batch indices. If ‘s’ is not provided, return None.

  • z_mutorch.Tensor

    Mean of the combined latent variables.

  • z_logvartorch.Tensor

    Log-variance of the combined latent variables.

  • ztorch.Tensor

    Sampled latent variables.

  • ctorch.Tensor

    Biological information variables.

  • utorch.Tensor

    Technical noise variables.

  • z_uniDict[str, torch.Tensor]

    Unified latent variables for each modality.

  • c_allDict[str, torch.Tensor]

    Modality-specific Biological information variables.

Return type:

Tuple

gen_real_data(x_r_pre: Dict[str, Tensor], sampling: bool = True) Dict[str, Tensor][source]#

Generate real data from reconstructed data.

Parameters:
  • x_r_pre – Dict[str, torch.Tensor] Dictionary of reconstructed data tensors for each modality.

  • sampling – bool, optional Whether to sample the output (default: True).

Returns:

Generated real data for each modality.

Return type:

Dict[str, torch.Tensor]

generate_unified_latent(z_x_mu: Dict[str, Tensor], z_x_logvar: Dict[str, Tensor], z_s_mu: List[Tensor], z_s_logvar: List[Tensor], c: Tensor) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#

Generate unified latent variables and modality-specific representations.

Parameters:
  • z_x_mu – Dict[str, torch.Tensor] Means of modality-specific latent variables.

  • z_x_logvar – Dict[str, torch.Tensor] Log-variances of modality-specific latent variables.

  • z_s_mu – List[torch.Tensor] Mean of the batch-ID latent variables.

  • z_s_logvar – List[torch.Tensor] Log-variance of the batch-ID latent variables.

  • c – torch.Tensor Biological information.

Returns:

  • z_uniDict[str, torch.Tensor]:

    Collection of latent variables for the unimodal inputs.

  • c_allDict[str, torch.Tensor]:

    Collection of biological information for the unimodal and joint inputs.

Return type:

Tuple

get_dim_h() Dict[str, List[int]][source]#

Compute hidden dimensions for each modality.

Returns:

A dictionary containing the hidden dimensions for each modality.

Return type:

Dict[str, List[int]]

static poe(mus: List[Tensor], logvars: List[Tensor]) Tuple[Tensor, Tensor][source]#

Product of Experts (PoE) for combining Gaussian distributions.

Parameters:
  • mus – list of torch.Tensor List of mean tensors for each Gaussian.

  • logvars – list of torch.Tensor List of log-variance tensors for each Gaussian.

Returns:

  • combined_mean: torch.Tensor

    Mean of the combined Gaussian distribution.

  • combined_logvar: torch.Tensor

    Log-variance of the combined Gaussian distribution.

Return type:

Tuple

static sample(name: str, data: Tensor, sampling: bool) Tensor[source]#

Map a sampling function based on the distribution name.

Parameters:
  • name – str Name of the distribution.

  • data – torch.Tensor Input data tensor.

  • sampling – bool Whether to apply sampling.

Returns:

torch.Tensor

Sampled or original data tensor.

static sample_gaussian(mu: Tensor, logvar: Tensor) Tensor[source]#

Sample from a Gaussian distribution using the reparameterization trick.

Parameters:
  • mu – torch.Tensor Mean of the Gaussian distribution.

  • logvar – torch.Tensor Log-variance of the Gaussian distribution.

Returns:

torch.Tensor

Sampled tensor.

sample_latent(z_mu: Tensor, z_logvar: Tensor) Tensor[source]#

Sample latent variables from a Gaussian distribution.

Parameters:
  • z_mu – torch.Tensor Mean of the latent variables of shape (batch_size, latent_dim).

  • z_logvar – torch.Tensor Log-variance of the latent variables of shape (batch_size, latent_dim).

Returns:

Sampled latent variables of shape (batch_size, latent_dim).

Return type:

torch.Tensor