scmidas.model#
- class scmidas.model.Decoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#
Bases:
ModuleDecoder class for multi-modal data with shared and modality-specific decoding layers.
- Parameters:
dims_x – Dict[str, list] Output dimensions for each modality.
dims_h – Dict[str, list] Hidden dimensions for each modality.
dim_z – int Latent dimension size.
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
out_trans – str Output activation function (e.g., ‘relu’).
drop – float Dropout rate.
kwargs – Dict[str, Any] Additional modality-specific configurations.
- class scmidas.model.Discriminator(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#
Bases:
ModuleDiscriminator class for multi-modal latent variables.
- Parameters:
dims_x – Dict[str, list] Input dimensions for each modality.
dims_s – Dict[str, int] Dimensions of the classes for each modality.
kwargs – Dict[str, Any] Additional configurations, such as hidden layer sizes, dropout rate, and normalization type.
- calculate_loss(predictions: Dict[str, Tensor], targets: Dict[str, Tensor]) Tensor[source]#
Calculate cross-entropy loss for all modalities.
- Parameters:
predictions – Dict[str, torch.Tensor] Dictionary of predicted logits for each modality.
targets – Dict[str, torch.Tensor] Dictionary of ground truth labels for each modality.
- Returns:
Total normalized loss.
- Return type:
torch.Tensor
- forward(latent_inputs: Dict[str, Tensor]) Dict[str, Tensor][source]#
Forward pass for the discriminator.
- Parameters:
latent_inputs – Dict[str, torch.Tensor] Dictionary of latent inputs for each modality, where keys are modality names and values are tensors of shape (batch_size, dim_c).
- Returns:
Dictionary of logits for each modality, where keys are modality names and values are tensors of shape (batch_size, dims_s[modality]).
- Return type:
Dict[str, torch.Tensor]
- class scmidas.model.Encoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#
Bases:
ModuleEncoder class for multi-modal data with modality-specific pre-processing, encoding, and shared encoding layers.
- Parameters:
dims_x – Dict[str, list] Input dimensions for each modality (e.g, {‘rna’:[1000], ‘adt’:[100]}).
dims_h – Dict[str, list] Hidden dimensions for each modality after pre-encoding (e.g, {‘rna’:256, ‘adt’:256}).
dim_z – int Latent dimension size (e.g, 32).
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
out_trans – str Output activation function (e.g., ‘mish’).
drop – float Dropout rate.
kwargs – Dict[str, Any] Additional modality-specific configurations.
Notes
By default, RNA and ADT data are log1p-transformed in the encoder and will be exponentiated after decoding. To skip this step, modify the configuration file. See parameter ‘trsf_before_enc_’.
- forward(data: Dict[str, Tensor], mask: Dict[str, Tensor]) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#
Forward pass for the encoder.
- Parameters:
data – Dict[str, torch.Tensor] Input data for each modality.
mask – Dict[str, torch.Tensor] Masks for each modality.
- Returns:
- z_x_muDict[str, torch.Tensor]
Mean values for latent space for each modality.
- z_x_logvarDict[str, torch.Tensor]
Log-variance values for latent space for each modality.
- Return type:
Tuple
- class scmidas.model.MIDAS(mdata: MuData | None = None, *, save_model_path: str = './saved_models/scmidas', configs: Dict[str, Any] | None = None, batch_size: int = 256, n_save: int = 500, sampler_type: str = 'auto', viz_umap_tb: bool = False, transform: Dict[str, str] | None = None)[source]#
Bases:
LightningModuleMIDAS processes mosaic single-cell data into imputed and batch-corrected data for multimodal analysis.
- net#
VAE Variational Autoencoder for multi-modal data encoding and decoding.
- dsc#
Discriminator Discriminator for distinguishing latent variables across batches.
- configs#
Dict[str, Any] Model and training configurations dynamically set as attributes.
- automatic_optimization#
bool Controls whether optimization is automatic or manually defined. Always True.
- static calc_consistency_loss(z_uni: Dict[str, Tensor]) Tensor[source]#
Calculate the consistency loss for unified latent variables across modalities.
- Parameters:
z_uni – Dict[str, torch.Tensor] Dictionary of unified latent variables for each modality, where each value is a tensor of shape (batch_size x latent_dim).
- Returns:
Consistency loss computed as the variance of the unified latent variables.
- Return type:
torch.Tensor
- static calc_dsc_loss(pred: Dict[str, Tensor], true: Dict[str, Tensor]) Tensor[source]#
Calculate the discriminator loss using cross-entropy.
- Parameters:
pred – Dict[str, torch.Tensor] Predicted logits for each modality.
true – Dict[str, torch.Tensor] Ground truth labels for each modality.
- Returns:
Computed discriminator loss.
- Return type:
torch.Tensor
- static calc_kld_loss(mu: Tensor, logvar: Tensor) Tensor[source]#
Calculate the KLD loss for a single latent space.
- Parameters:
mu – torch.Tensor Mean of the latent variable distribution (batch_size x latent_dim).
logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x latent_dim).
- Returns:
KLD loss for the latent space, normalized by batch size.
- Return type:
torch.Tensor
- static calc_kld_z_loss(dim_c: int, dim_u: int, lam_kld_c: float, lam_kld_u: float, mu: Tensor, logvar: Tensor) Tensor[source]#
Calculate the Kullback-Leibler Divergence (KLD) loss for latent variables z.
- Parameters:
dim_c – int Dimension of the biological latent space.
dim_u – int Dimension of the technical latent space.
lam_kld_c – float Weight for KLD loss of the biological latent space.
lam_kld_u – float Weight for KLD loss of the technical latent space.
mu – torch.Tensor Mean of the latent variable distribution (batch_size x (dim_c + dim_u)).
logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x (dim_c + dim_u)).
- Returns:
Weighted sum of KLD losses for the biological and technical latent spaces.
- Return type:
torch.Tensor
- static calc_recon_loss(x: Dict[str, Tensor], s: Tensor, e: Dict[str, Tensor], x_r_pre: Dict[str, Tensor], s_r_pre: Dict[str, Tensor], dist: Dict[str, str], lam: Dict[str, float]) Tuple[float, Dict[Tensor, Tensor]][source]#
Calculate the reconstruction loss for input data and predicted outputs.
- Parameters:
x – Dict[str, torch.Tensor] Original input data for each modality (x^m).
s – torch.Tensor Ground truth batch labels.
e – Dict[str, torch.Tensor] Mask.
x_r_pre – Dict[str, torch.Tensor] Reconstructed predictions for each modality (x_r^m).
s_r_pre – Dict[str, torch.Tensor] Reconstructed predictions for batch labels.
dist – Dict[str, str] Dictionary specifying the distribution type for each modality’s decoder.
lam – Dict[str, float] Dictionary containing reconstruction loss weights for each modality and for s.
- Returns:
- total_losstorch.Tensor
Total reconstruction loss, normalized by batch size.
- lossesDict[str, torch.Tensor]
Dictionary containing reconstruction losses for each modality and for batch labels.
- Return type:
Tuple
- classmethod configure_data(configs: dict, datalist: List[Dataset], dims_x: Dict[str, list], dims_s: Dict[str, int], s_joint: List[Dict[str, int]], combs: List[List[str]], batch_size: int = 256, n_save: int = 500, save_model_path: str = './saved_models/', sampler_type: str = 'auto', viz_umap_tb=False, batch_names=None) MIDAS[source]#
Configure the data and model parameters for training.
- Parameters:
configs – dict, Configurations of the model.
datalist – List[Dataset] List of datasets to be used for training.
dims_x – Dict[str, list] Dictionary specifying the dimensions of input features for each modality.
dims_s – Dict[str, int] Dimensions of the classes for each modality.
s_joint – List[Dict[str, int]] Modality ID for each batch.
combs – List[List[str]] Combinations of modalities.
batch_size – int, optional Size of each training batch, by default 256.
n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.
save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.
sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.
viz_umap_tb – bool, optional Whether to visualize UMAP embeddings in TensorBoard, by default False.
batch_names – list, optional List of batch names, by default None.
- Returns:
Returns MIDAS instance.
- Return type:
class ‘MIDAS’
- classmethod configure_data_from_dir(configs: Dict[str, Any], dir_path: str, format: str = 'mtx', transform: Dict[str, str] = None, sampler_type: str = 'auto', viz_umap_tb: bool = False, save_model_path: str = './saved_models/', n_save: int = 500, **kwargs: Dict[str, Any]) MIDAS[source]#
Configure data from a directory and apply optional transformations.
- Parameters:
configs – Dict[str, Any] Configurations of the model.
dir_path – str Path to the directory containing data files.
transform – Dict[str, str], optional A dictionary specifying transformations to apply to specific modalities. Example: {‘atac’: ‘binarize’} Default is None, which uses the default transformation settings.
sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.
viz_umap_tb – bool, optional Whether to visualize UMAP embeddings in TensorBoard, by default False.
save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.
n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.
kwargs – Dict[str, Any] Additional parameters passed to configure_data().
- Returns:
Returns the configured class instance.
- Return type:
class ‘MIDAS’
Examples
>>> from scmidas.model import MIDAS >>> from scmidas.config import load_config >>> configs = load_config() >>> dir_path = 'XXX' >>> transform = {'atac': 'binarize'} >>> model = MIDAS.configure_data_from_dir(configs, dir_path, transform)
- classmethod configure_data_from_mdata(configs: Dict[str, Any], mdata: MuData, dims_x: Dict[str, list], batch_key: str = 'batch', transform: Dict[str, str] | None = None, sampler_type: str = 'auto', viz_umap_tb: bool = False, save_model_path: str = './saved_models/', n_save: int = 500, **kwargs: Any) MIDAS[source]#
Configure the MIDAS model directly from a MuData object.
This method processes the MuData input to extract data, masks, and batch information, initializes the datasets, and sets up the model configuration.
- Parameters:
configs – Dict[str, Any] Configurations of the model.
mdata – MuData The input MuData object containing multi-modal single-cell data. It is expected to contain AnnData objects for different modalities (e.g., RNA, ATAC).
dims_x – Dict[str, list] A dictionary specifying the input feature dimensions for each modality. Keys are modality names, and values are lists of dimensions (e.g., {‘rna’: [2000]}).
transform – Optional[Dict[str, str]], default=None A dictionary specifying specific transformations to apply to each modality. Example: {‘atac’: ‘binarize’}. If None, default transformations are used.
sampler_type – str, default=’auto’ Strategy for data sampling. Use ‘ddp’ for Distributed Data Parallel training, or ‘auto’ for standard training.
viz_umap_tb – bool, default=False If True, enables UMAP visualization logs in TensorBoard during the training process.
save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.
n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.
**kwargs – Any Additional keyword arguments passed to the underlying configure_data method (e.g., batch_size, num_workers).
- Returns:
An initialized instance of the MIDAS class, ready for training or inference.
- Return type:
- configure_optimizers() List[Optimizer][source]#
Configure optimizers for the MIDAS model.
- Returns:
List of optimizers for the network and discriminator.
- Return type:
List[torch.optim.Optimizer]
- static get_datasets_from_adata(data: List[Dict[str, AnnData]], mask: List[Dict[str, str]], batch_names: List[str], transform: Dict[str, str] = None)[source]#
Configure data from a CSV input.
- Parameters:
data – List[Dict[str, str]] List of data dictionaries, where keys are modalities and values are adata object.
mask – List[Dict[str, str]] List of mask dictionaries, where keys are modalities and values are mask values.
batch_name – List[str] List of batch names.
transform – Optional[Dict[str, str]] Transformations to apply to specific modalities.
format – str File type of the input data, default is ‘vec’. [‘vec’, ‘mtx’, ‘csv’]
- Returns:
- datasetsList[MultiModalDataset]
List of initialized MultiModalDataset objects.
- dims_sDict[str, int]
Dimensions for batch correction for each modality.
- s_jointList[Dict[str, int]]
Modality indices for each batch.
- combsList[List[str]]
List of modality combinations for each batch.
- Return type:
Tuple
- static get_datasets_from_dir(data: List[Dict[str, str]], mask: List[Dict[str, str]], batch_names: List[str], transform: Dict[str, str] = None, format: str = 'mtx')[source]#
Configure data from directory.
- Parameters:
data – List[Dict[str, str]] List of data dictionaries, where keys are modalities and values are file paths.
mask – List[Dict[str, str]] List of mask dictionaries, where keys are modalities and values are mask file paths.
batch_name – List[str] List of batch names.
transform – Optional[Dict[str, str]] Transformations to apply to specific modalities.
format – str File type of the input data, default is ‘vec’. [‘vec’, ‘mtx’, ‘csv’]
- Returns:
- datasetsList[MultiModalDataset]
List of initialized MultiModalDataset objects.
- dims_sDict[str, int]
Dimensions for batch correction for each modality.
- s_jointList[Dict[str, int]]
Modality indices for each batch.
- combsList[List[str]]
List of modality combinations for each batch.
- Return type:
Tuple
- get_emb_umap(pred_dir: str = None, pred_format: str = None, save_dir: str = None, drop_c_umap: bool = False, drop_u_umap: bool = False, color_by: str = 'batch', n_obs: int = None, verbose=True, **kwargs) Tuple[List[AnnData], List[Figure]][source]#
Generate UMAP visualizations for biological and technical latent embeddings.
This function loads predicted latent representations and computes UMAP embeddings for visualization. Two types of embeddings are supported:
Biological embedding (z_c)
Technical embedding (z_u)
For large datasets, the function can optionally subsample observations to accelerate UMAP computation.
- Parameters:
pred_dir – str, optional Directory containing predicted results generated by predict() or predict_streaming(). If None, predictions will be generated on-the-fly using self.predict().
pred_format – {“npy”, “csv”}, optional File format of saved prediction files when loading from disk. Only used when pred_dir is provided.
save_dir – str, optional Directory to save the generated UMAP figures. If None, figures will not be written to disk.
drop_c_umap – bool, default=False Whether to skip UMAP visualization for the biological embedding (z_c).
drop_u_umap – bool, default=False Whether to skip UMAP visualization for the technical embedding (z_u).
color_by –
str, default=”batch” Column name in adata.obs used to color cells in UMAP plots. Common options include:
”batch” : batch label
”s_joint” : subset or dataset identifier
any other metadata column stored in adata.obs
n_obs – int, optional Number of observations to randomly subsample before computing UMAP. Useful for large datasets to speed up visualization. If None, all observations will be used.
verbose – bool, default=True Whether to display progress bars and logging information.
**kwargs – Dict[str, Any] Additional keyword arguments passed to scanpy.pl.umap().
- Returns:
- List[AnnData]
List of AnnData objects containing the computed UMAP embeddings.
- all_figuresList[matplotlib.figure.Figure]
List of generated UMAP figure objects.
- Return type:
all_adata
Notes
UMAP is computed using scanpy.pp.neighbors() followed by
scanpy.tl.umap(). - The biological embedding (z_c) captures biological variation, while the technical embedding (z_u) reflects batch or technical effects. - For very large datasets (e.g., >1M cells), it is recommended to set n_obs (e.g., 20,000) to reduce computation time.
Examples
Generate UMAP from saved predictions:
>>> model.get_emb_umap(pred_dir="./predictions")
Generate UMAP with subsampling and custom coloring:
>>> model.get_emb_umap( ... pred_dir="./predictions", ... n_obs=20000, ... color_by="batch" ... )
Generate UMAP and save figures:
>>> model.get_emb_umap( ... pred_dir="./predictions", ... save_dir="./figs" ... )
- get_imputed_values(mdata: MuData | None = None, *, modality: str = 'rna', verbose: bool = False) ndarray[source]#
Return imputed expression values for a single modality.
For mosaic data, this fills in cells that originally lacked the modality with model-inferred values. Cells that already had real observations also get the model’s reconstruction (useful for denoising).
- Parameters:
mdata – MuData, optional MuData to align against. Defaults to the MuData used at construction time.
modality – str Modality name (e.g.
'rna','adt','atac'). Must be one ofself.mods.verbose – bool Whether to show prediction progress bars.
- Returns:
Array of shape
(mdata.n_obs, n_features[modality]), aligned tomdata.obs_names. Cells absent from training data yield NaN rows (a warning is logged).- Return type:
np.ndarray
- static get_info_from_dir(dir_path: str, format: str)[source]#
Extract data, mask, and feature dimensions from a directory of vectors.
- Parameters:
dir_path – str Path to the directory containing data and mask files.
format – str Support ‘mtx’, ‘csv’, and ‘vec’.
- Returns:
- dataList[Dict[str, str]]
List of dictionaries where keys are modalities and values are file paths.
- maskList[Dict[str, str]]
List of dictionaries where keys are modalities and values are mask file paths.
- dims_xDict[str, list]
Dictionary containing feature dimensions for each modality.
- Return type:
Tuple
Notes
The directory should be organized as:
dataset/ feat/ # Dimensions of each modality: {mod1=[...], mod2=[...]}. # Split the data into chunks if the length of the list is greater than 1. # For instance, you can split the ATAC data by chromosomes. feat_dims.toml batch_0/ mask/mod1.csv mask/mod2.csv vec/mod1/0000.csv # the first sample vec/mod1/0001.csv # the second sample .... vec/mod2/0000.csv vec/mod2/0001.csv .... batch_1/ mask/mod1.csv mask/mod2.csv vec/mod1/0000.csv vec/mod1/0001.csv .... vec/mod2/0000.csv vec/mod2/0001.csv ....
or like:
dataset/ feat/ # Dimensions of each modality: {mod1=[...], mod2=[...]}. # Split the data into chunks if the length of the list is greater than 1. # For instance, you can split the ATAC data by chromosomes. feat_dims.toml batch_0/ mask/mod1.csv mask/mod2.csv mat/mod1.mtx (.csv) mat/mod2.mtx (.csv) .... batch_1/ mask/mod1.csv mask/mod2.csv mat/mod1.mtx (.csv) mat/mod2.mtx (.csv) ....
- get_latent_representation(mdata: MuData | None = None, *, kind: str = 'c', verbose: bool = False) ndarray[source]#
Return the joint latent representation aligned to
mdata.obs_names.This is the convenience wrapper around
predict()that most users want: it runs the joint latent forward pass, stitches the per-batch outputs, and reorders the result so each row matchesmdata.obs_names— ready to drop straight intomdata.obsm['X_midas'].- Parameters:
mdata – MuData, optional MuData to align against. Defaults to the MuData used at construction time. Must include the same cells as the registered MuData (same
obs_names).kind –
{‘c’, ‘u’, ‘joint’} Which latent to return:
'c'(default): biological latentz_cof shape(n_obs, dim_c). Recommended for clustering / visualization.'u': technical latentz_uof shape(n_obs, dim_u).'joint': concatenation[z_c, z_u].
verbose – bool Whether to show prediction progress bars.
- Returns:
Array of shape
(mdata.n_obs, dim)aligned tomdata.obs_names. Cells absent from training data yield NaN rows (a warning is logged).- Return type:
np.ndarray
- classmethod load(dir_path: str, mdata: MuData, **kwargs: Any) MIDAS[source]#
Load a MIDAS model previously saved via
save().If
mdatahas not been registered (no_scmidasinmdata.uns), this method auto-registers it using thebatch_keysaved alongside the model.- Parameters:
- Returns:
A model with the saved weights loaded.
- Return type:
- load_checkpoint(checkpoint_path: str, start_epoch: int = 0, **kwargs)[source]#
Load model and optimizer states from a checkpoint file.
- Parameters:
checkpoint_path – str Path to the checkpoint file containing saved model and optimizer states.
start_epoch – int Indicate how many epoch the model has been trained.
kwargs – Dict[str, Any] Additional configurations for torch.load().
- Raises:
AssertionError – If the provided checkpoint path does not exist.
- log_losses(recon_loss: Tensor, kld_loss, consistency_loss: Tensor, loss_net: Tensor, loss_dsc: Tensor, recon_dict: Dict[str, Tensor])[source]#
Log losses for monitoring and debugging during training.
- Parameters:
recon_loss – torch.Tensor Reconstruction loss.
kld_loss – torch.Tensor KLD loss.
consistency_loss – torch.Tensor Consistency loss.
recon_dict – Dict[str, torch.Tensor] Per-modality reconstruction losses.
loss_net – torch.Tensor Total VAE loss.
loss_dsc – torch.Tensor Discriminator loss.
- on_train_epoch_end()[source]#
Save a model checkpoint at the end of each training epoch with a meaningful filename.
- predict(return_in_memory: bool = True, save_dir: str | None = None, save_format: str = 'npy', joint_latent: bool = True, mod_latent: bool = False, impute: bool = False, batch_correct: bool = False, translate: bool = False, input: bool = False, verbose: bool = True)[source]#
Run model inference in a streaming manner.
This method supports three prediction modes:
Return predictions in memory (recommended for small or medium datasets).
Stream predictions to disk per mini-batch (recommended for large datasets).
Perform both simultaneously.
Notes
If return_in_memory=False, prediction tensors will not accumulate in RAM,
making this method suitable for very large datasets. - If save_dir is provided, predictions are written incrementally to disk per mini-batch. - If batch_correct=True, a second pass over the dataset is performed:
- Pass 1:
Compute joint latent representations and estimate the technical centroid using online statistics.
- Pass 2:
Reconstruct data with batch correction and stream the corrected outputs.
If translate=True, mod_latent will be automatically set to True.
- Parameters:
return_in_memory – bool, default=True Whether to keep predictions in memory and return them as a nested dictionary. Set to False for large datasets to avoid OOM.
save_dir – str or None, default=None Output directory for streaming prediction results. If None, predictions are not saved to disk.
save_format –
{“npy”, “csv”}, default=”npy” File format used when save_dir is provided.
”npy” : NumPy binary format (recommended; fast and compact).
”csv” : CSV text format (not recommended for large arrays).
joint_latent –
bool, default=True Whether to compute the joint latent representation conditioned on all observed modalities.
- Stored as:
z[“joint”] (raw latent)
or postprocessed into z_c[“joint”] and z_u[“joint”].
mod_latent –
bool, default=False Whether to compute latent representations conditioned on each individual modality.
For each modality m, a single-modality forward pass is performed and stored as:
z[m]
impute –
bool, default=False Whether to generate imputed data (x_impt) from the joint latent representation.
Stored as:
x_impt[modality]
batch_correct –
bool, default=False Whether to estimate a technical centroid and perform batch-effect correction on reconstructed data.
Stored as:
x_bc[modality]
translate –
bool, default=False Whether to perform cross-modality translation.
For each available input modality subset, missing modalities are generated and stored as:
x_trans[“<input_mods>_to_<target_mod>”]
input –
bool, default=False Whether to include the original input data and masks in the output.
Stored as:
x[modality] mask[modality] (if available)
verbose – bool, default=True Whether to display progress bars (tqdm) and logging messages.
- Returns:
dict or None
- Return type:
output
- Raises:
ValueError – If both return_in_memory=False and save_dir=None, or if an unsupported save_format is specified.
KeyError – If required prediction fields are missing.
- static print_info(mask: List[Dict[str, str]], datalist: List[Dataset], batch_names: List[str])[source]#
Print summary of mask density and dataset information.
- Parameters:
mask – List[Dict[str, str]] List of mask.
datalist – List[Dataset] List of datasets.
batch_name – List[str] List of batch names.
- save(dir_path: str, *, overwrite: bool = False) None[source]#
Save this MIDAS model to a directory.
Writes two files:
model.pt— VAE / Discriminator weights and optimizer states.setup.json— minimal data-setup metadata so the model can be reloaded on a fresh process.
This is the recommended user-facing save API (v0.3+); pair with
MIDAS.load(). The legacysave_checkpoint()/load_checkpoint()flat-file API still works.- Parameters:
dir_path – str Directory path. Created if missing.
overwrite – bool Replace
dir_pathif it already exists.
- save_checkpoint(checkpoint_path: str)[source]#
Save the current model and optimizer states to a checkpoint file.
- Parameters:
checkpoint_path – str Path to save the checkpoint file.
- Raises:
ValueError – If checkpoint_path is an invalid or empty string.
- static setup_mudata(mdata: MuData, batch_key: str = 'batch', dims_x: Dict[str, list] | None = None) None[source]#
Register a MuData object for use with MIDAS.
Stores the data-setup metadata (batch key, modality names, feature dimensions, batch list) under
mdata.uns['_scmidas']. After this call, the MuData can be passed directly toMIDAS.- Parameters:
mdata – MuData Multi-modal data. Each modality (
mdata[m]) must containbatch_keyin its.obs.batch_key – str Column in each modality’s
.obsthat identifies the source batch.dims_x – dict, optional Per-modality feature dimensions, e.g.
{'rna': [4045], 'adt': [224], 'atac': [2897, 2007, ...]}. Provide this for ATAC data that should be encoded by chromosome chunks (the published architecture). If omitted, each modality is treated as a single chunk equal to its full feature count.
Notes
Mutates
mdatain-place; returnsNone.
- train(**kwargs)[source]#
Set the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout,BatchNorm, etc.- Parameters:
mode (bool) – whether to set training mode (
True) or evaluation mode (False). Default:True.- Returns:
self
- Return type:
Module
- train_dataloader() DataLoader[source]#
Create a DataLoader for training, using the appropriate sampler.
- Returns:
Configured DataLoader instance for training.
- Return type:
DataLoader
- train_discriminator(c_all: Dict[str, Tensor], targets: Dict[str, Tensor])[source]#
Train the discriminator with modality-specific latent representations.
- Parameters:
c_all – Dict[str, torch.Tensor] Dictionary of latent representations for each modality.
targets – Dict[str, torch.Tensor] Ground truth batch labels for each modality.
- training_step(batch: Dict[str, Dict[str, Tensor]], batch_idx: int) Tensor[source]#
Executes a single training step for MIDAS.
- Parameters:
batch – Dict[str, Dict[str, torch.Tensor]] Input batch containing modality data, batch indices, and masks.
batch_idx – int Index of the current training batch.
- Returns:
Total VAE loss for the current batch.
- Return type:
torch.Tensor
- static update_model(loss: Tensor, model: Module, optimizer: Optimizer, grad_clip: int = -1)[source]#
Update model parameters using backpropagation.
- Parameters:
loss – torch.Tensor Computed loss for backpropagation.
model – torch.nn.Module Model to update.
optimizer – torch.optim.Optimizer Optimizer for parameter updates.
grad_clip – int True to allow clipping gradient.
- class scmidas.model.S_Decoder(n_batches: int, dims_dec_s: List[int], dim_u: int, norm: str, drop: float)[source]#
Bases:
ModuleDecoder for reconstructing batch ID.
- Parameters:
n_batches – int Number of distinct batches.
dims_dec_s – List[int] List of dimensions for hidden layers in the decoder.
dim_u – int Latent dimension size for the input (e.g, 2).
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
drop – float Dropout rate.
- class scmidas.model.S_Encoder(n_batches: int, dims_enc_s: List[int], dim_z: int, norm: str, drop: float)[source]#
Bases:
ModuleEncoder for batch ID latent variables.
- Parameters:
n_batches – int Number of distinct batches.
dims_enc_s – List[int] List of dimensions for hidden layers in the encoder.
dim_z – int Latent dimension size for the latent.
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
drop – float Dropout rate.
- class scmidas.model.VAE(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#
Bases:
ModuleVariational Autoencoder (VAE) for multi-modal data, supporting batch correction and sampling from distributions.
- Parameters:
dims_x – Dict[str, list] Input dimensions for each modality, e.g {‘rna’=[1000], ‘adt’=[100], ‘atac’=[10,10,10]}.
dims_s – Dict[str, int] Dimensions of the classes for each modality.
kwargs – Dict[str, Any] Additional configurations for encoders, decoders, and other modules.
- encode_batch(s: Tensor) Tuple[list, list] | None[source]#
Encode batch IDs latent variables.
- Parameters:
s – torch.Tensor Batch IDs.
- Returns:
- z_s_muList[torch.Tensor]
Mean of batch IDs latent variables.
- z_s_logvarList[torch.Tensor]
Log-variance of batch IDs latent variables.
- Return type:
Optional[Tuple[list, list]]
- forward(data: Dict[str, Tensor]) Tuple[Dict[str, Tensor], Tensor | None, Tensor, Tensor, Tensor, Tensor, Tensor, Dict[str, Tensor], Dict[str, Tensor]][source]#
Forward pass for the VAE.
- Parameters:
data – Dict[str, torch.Tensor] Input data dictionary containing: - ‘x’: Dict[str, torch.Tensor], modality-specific input data. - ‘e’: Dict[str, torch.Tensor], modality-specific masks. - ‘s’ (optional): torch.Tensor, dimensions of the output classes for each modality.
- Returns:
- x_r_preDict[str, torch.Tensor]
Reconstructed modality-specific data.
- s_r_preOptional[torch.Tensor]
If ‘s’ is provided, return reconstructed batch indices. If ‘s’ is not provided, return None.
- z_mutorch.Tensor
Mean of the combined latent variables.
- z_logvartorch.Tensor
Log-variance of the combined latent variables.
- ztorch.Tensor
Sampled latent variables.
- ctorch.Tensor
Biological information variables.
- utorch.Tensor
Technical noise variables.
- z_uniDict[str, torch.Tensor]
Unified latent variables for each modality.
- c_allDict[str, torch.Tensor]
Modality-specific Biological information variables.
- Return type:
Tuple
- gen_real_data(x_r_pre: Dict[str, Tensor], sampling: bool = True) Dict[str, Tensor][source]#
Generate real data from reconstructed data.
- Parameters:
x_r_pre – Dict[str, torch.Tensor] Dictionary of reconstructed data tensors for each modality.
sampling – bool, optional Whether to sample the output (default: True).
- Returns:
Generated real data for each modality.
- Return type:
Dict[str, torch.Tensor]
- generate_unified_latent(z_x_mu: Dict[str, Tensor], z_x_logvar: Dict[str, Tensor], z_s_mu: List[Tensor], z_s_logvar: List[Tensor], c: Tensor) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#
Generate unified latent variables and modality-specific representations.
- Parameters:
z_x_mu – Dict[str, torch.Tensor] Means of modality-specific latent variables.
z_x_logvar – Dict[str, torch.Tensor] Log-variances of modality-specific latent variables.
z_s_mu – List[torch.Tensor] Mean of the batch-ID latent variables.
z_s_logvar – List[torch.Tensor] Log-variance of the batch-ID latent variables.
c – torch.Tensor Biological information.
- Returns:
- z_uniDict[str, torch.Tensor]:
Collection of latent variables for the unimodal inputs.
- c_allDict[str, torch.Tensor]:
Collection of biological information for the unimodal and joint inputs.
- Return type:
Tuple
- get_dim_h() Dict[str, List[int]][source]#
Compute hidden dimensions for each modality.
- Returns:
A dictionary containing the hidden dimensions for each modality.
- Return type:
Dict[str, List[int]]
- static poe(mus: List[Tensor], logvars: List[Tensor]) Tuple[Tensor, Tensor][source]#
Product of Experts (PoE) for combining Gaussian distributions.
- Parameters:
mus – list of torch.Tensor List of mean tensors for each Gaussian.
logvars – list of torch.Tensor List of log-variance tensors for each Gaussian.
- Returns:
- combined_mean: torch.Tensor
Mean of the combined Gaussian distribution.
- combined_logvar: torch.Tensor
Log-variance of the combined Gaussian distribution.
- Return type:
Tuple
- static sample(name: str, data: Tensor, sampling: bool) Tensor[source]#
Map a sampling function based on the distribution name.
- Parameters:
name – str Name of the distribution.
data – torch.Tensor Input data tensor.
sampling – bool Whether to apply sampling.
- Returns:
- torch.Tensor
Sampled or original data tensor.
- static sample_gaussian(mu: Tensor, logvar: Tensor) Tensor[source]#
Sample from a Gaussian distribution using the reparameterization trick.
- Parameters:
mu – torch.Tensor Mean of the Gaussian distribution.
logvar – torch.Tensor Log-variance of the Gaussian distribution.
- Returns:
- torch.Tensor
Sampled tensor.
- sample_latent(z_mu: Tensor, z_logvar: Tensor) Tensor[source]#
Sample latent variables from a Gaussian distribution.
- Parameters:
z_mu – torch.Tensor Mean of the latent variables of shape (batch_size, latent_dim).
z_logvar – torch.Tensor Log-variance of the latent variables of shape (batch_size, latent_dim).
- Returns:
Sampled latent variables of shape (batch_size, latent_dim).
- Return type:
torch.Tensor