import os
import datetime
from typing import Any, Dict, List, Optional, Tuple
from anndata import AnnData
from mudata import MuData
import natsort
import toml
import pandas as pd
import scanpy as sc
from tqdm import tqdm
from matplotlib import pyplot as plt
from scipy.sparse import csr_matrix
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset, Dataset
import lightning as L
from lightning.pytorch.utilities import rank_zero_only
import logging
# Project-Specific Imports
from .data import MyDistributedSampler, MultiBatchSampler, MultiModalDataset
from .utils import *
logger = logging.getLogger(__name__)
from .nn import MLP, Layer1D, distribution_registry, transform_registry
[docs]
class Encoder(nn.Module):
"""
Encoder class for multi-modal data with modality-specific pre-processing, encoding,
and shared encoding layers.
Parameters:
dims_x : Dict[str, list]
Input dimensions for each modality (e.g, {'rna':[1000], 'adt':[100]}).
dims_h : Dict[str, list]
Hidden dimensions for each modality after pre-encoding (e.g, {'rna':256, 'adt':256}).
dim_z : int
Latent dimension size (e.g, 32).
norm : str
Normalization type (e.g., 'ln' for LayerNorm).
out_trans : str
Output activation function (e.g., 'mish').
drop : float
Dropout rate.
kwargs : Dict[str, Any]
Additional modality-specific configurations.
Notes:
By default, RNA and ADT data are log1p-transformed in the encoder and will be exponentiated after decoding.
To skip this step, modify the configuration file. See parameter 'trsf_before_enc_'.
"""
def __init__(
self,
dims_x: Dict[str, list],
dims_h: Dict[str, list],
dim_z: int,
norm: str,
out_trans: str,
drop: float,
**kwargs,
):
super(Encoder, self).__init__()
self.dims_x = dims_x
self.dims_h = dims_h
self.dim_z = dim_z
self.norm = norm
self.out_trans = out_trans
self.drop = drop
# Dynamically set additional arguments as attributes
for key, value in kwargs.items():
setattr(self, key, value)
# Extract transformations to apply before encoding
self.trsf_before_enc = filter_keys(kwargs, 'trsf_before_enc')
# Shared encoder across all modalities
shared_encoder = MLP(
self.dims_shared_enc + [self.dim_z * 2],
hid_norm=self.norm,
hid_drop=self.drop,
)
# Initialize modality-specific encoders
# mod1 -> (opt) transform[mod1] -> (opt) pre_encoder[mod1] ->
# (opt) transform_concat[mod1] -> indiv_enc[mod1] -> share_encoder -> z_mod1
self.pre_encoders = nn.ModuleDict() # Modality-specific pre-encoding layers
self.transform_concat = nn.ModuleDict() # Post-concatenation layers
encoders = {} # Final encoders for each modality
for modality, input_dims in dims_x.items():
# For truncated input, such as ATAC
if len(input_dims) > 1:
self.pre_encoders[modality] = nn.ModuleList([
MLP([dim] + kwargs[f'dims_before_enc_{modality}'], hid_norm=self.norm, hid_drop=self.drop)
for dim in input_dims
])
self.transform_concat[modality] = Layer1D(self.dims_h[modality],
self.norm,
self.out_trans,
self.drop)
# Create individual encoder for the modality
indiv_enc = MLP(
[self.dims_h[modality][0], self.dims_shared_enc[0]],
out_trans=self.out_trans,
norm=self.norm,
drop=self.drop,
)
encoders[modality] = nn.Sequential(indiv_enc, shared_encoder)
self.encoders = nn.ModuleDict(encoders)
[docs]
def forward(
self,
data: Dict[str, torch.Tensor],
mask: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Forward pass for the encoder.
Parameters:
data : Dict[str, torch.Tensor]
Input data for each modality.
mask : Dict[str, torch.Tensor]
Masks for each modality.
Returns:
Tuple:
- z_x_mu : Dict[str, torch.Tensor]
Mean values for latent space for each modality.
- z_x_logvar : Dict[str, torch.Tensor]
Log-variance values for latent space for each modality.
"""
data = data.copy()
mask = mask.copy()
# Apply transformations before encoding
for modality in data.keys():
if f'trsf_before_enc_{modality}' in self.trsf_before_enc:
transformation = self.trsf_before_enc[f'trsf_before_enc_{modality}']
data[modality] = transform_registry.get(transformation)(data[modality])
# Apply masks to data.
# Use out-of-place multiplication: an in-place ``*=`` here would
# mutate the upstream batch tensors for any modality that did not
# take the ``trsf_before_enc_*`` branch above (whose transform
# would otherwise have produced a fresh tensor). The mathematical
# result is identical because mask is a 0/1 modality-presence
# indicator, but mutating the caller's batch dict is fragile and
# makes the encoder unsafe to call multiple times on the same
# input (see ``predict``'s mod_latent / translate paths).
for modality, mask_value in mask.items():
data[modality] = data[modality] * mask_value
# Pre-encode and concatenate if necessary, for truncated inputs
for modality in data.keys():
if modality in self.pre_encoders:
# Split and process individual dimensions
batches = data[modality].split(self.dims_x[modality], dim=1)
processed_batches = [
self.pre_encoders[modality][i](batch) for i, batch in enumerate(batches)
]
# Concatenate processed batches and transform
data[modality] = self.transform_concat[modality](torch.cat(processed_batches, dim=1))
# Encode data and split into mean and log-variance
z_x_mu, z_x_logvar = {}, {}
for modality, modality_data in data.items():
encoded = self.encoders[modality](modality_data)
z_x_mu[modality], z_x_logvar[modality] = encoded.split(self.dim_z, dim=1)
return z_x_mu, z_x_logvar
[docs]
class Decoder(nn.Module):
"""
Decoder class for multi-modal data with shared and modality-specific decoding layers.
Parameters:
dims_x : Dict[str, list]
Output dimensions for each modality.
dims_h : Dict[str, list]
Hidden dimensions for each modality.
dim_z : int
Latent dimension size.
norm : str
Normalization type (e.g., 'ln' for LayerNorm).
out_trans : str
Output activation function (e.g., 'relu').
drop : float
Dropout rate.
kwargs : Dict[str, Any]
Additional modality-specific configurations.
"""
def __init__(
self,
dims_x: Dict[str, list],
dims_h: Dict[str, list],
dim_z: int,
norm: str,
out_trans: str,
drop: float,
**kwargs,
):
super(Decoder, self).__init__()
self.dims_x = dims_x
self.dims_h = dims_h
self.dim_z = dim_z
self.norm = norm
self.out_trans = out_trans
self.drop = drop
# Dynamically set additional arguments as attributes
for key, value in kwargs.items():
setattr(self, key, value)
# z -> shared_decoder -> (opt) post_decoders[mod1] -> (opt) transform_concat[mod1] -> mod1
# Shared decoder layer
total_hidden_dims = sum(dim[0] for dim in dims_h.values())
self.shared_decoder = MLP(
[self.dim_z] + self.dims_shared_dec + [total_hidden_dims],
hid_norm=self.norm,
hid_drop=self.drop,
)
# Modality-specific decoders
self.post_decoders = nn.ModuleDict()
self.transform_concat = nn.ModuleDict()
for modality, output_dims in dims_x.items():
# Modality-specific post-decoding layers
if len(output_dims) > 1:
self.post_decoders[modality] = nn.ModuleList([
MLP(kwargs[f'dims_after_dec_{modality}'] + [dim],
hid_norm=self.norm, hid_drop=self.drop)
for dim in output_dims
])
# Layer to process concatenated outputs
self.transform_concat[modality] = Layer1D(self.dims_h[modality],
self.norm, self.out_trans,
self.drop)
[docs]
def forward(self, latent_data: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass for the decoder.
Parameters:
latent_data : torch.Tensor
Latent variable input tensor of shape (batch_size, dim_z).
Returns:
Dict[str, torch.Tensor] :
Decoded outputs for each modality.
"""
# Pass through the shared decoder
shared_output = self.shared_decoder(latent_data)
# Split shared decoder output into modality-specific chunks
modality_outputs = shared_output.split(
[dim[0] for dim in self.dims_h.values()],
dim=1,
)
# Create a dictionary to hold the modality-specific outputs
data_dict = {modality: output
for modality, output in zip(self.dims_x.keys(), modality_outputs)}
# Process each modality-specific output
for modality, post_decoders in self.post_decoders.items():
# Apply transformation layer
processed_output = self.transform_concat[modality](data_dict[modality])
batches = processed_output.split(self.__dict__[f'dims_after_dec_{modality}'][0], dim=1)
# Apply modality-specific post-decoders
data_dict[modality] = torch.cat(
[post_decoders[i](batch) for i, batch in enumerate(batches)],
dim=1,
)
# Apply activation functions based on distribution
for modality, output in data_dict.items():
distribution = self.__dict__[f'distribution_dec_{modality}']
activation_fn = distribution_registry.get_activate(distribution)
data_dict[modality] = activation_fn(output)
return data_dict
[docs]
class S_Encoder(nn.Module):
"""
Encoder for batch ID latent variables.
Parameters:
n_batches : int
Number of distinct batches.
dims_enc_s : List[int]
List of dimensions for hidden layers in the encoder.
dim_z : int
Latent dimension size for the latent.
norm : str
Normalization type (e.g., 'ln' for LayerNorm).
drop : float
Dropout rate.
"""
def __init__(
self,
n_batches: int,
dims_enc_s: List[int],
dim_z: int,
norm: str,
drop: float
):
super(S_Encoder, self).__init__()
self.n_batches = n_batches
self.dims_enc_s = dims_enc_s
self.dim_z = dim_z
self.norm = norm
self.drop = drop
# Define the encoder MLP
self.s_encoder = MLP(
[self.n_batches] + self.dims_enc_s + [self.dim_z * 2],
hid_norm=self.norm,
hid_drop=self.drop,
)
[docs]
def forward(self, data: torch.Tensor) -> torch.Tensor:
"""
Forward pass for S_Encoder.
Parameters:
data : torch.Tensor
Input tensor of shape (batch_size, 1), containing batch indices.
Returns:
torch.Tensor :
Encoded tensor of shape (batch_size, dim_z * 2).
"""
# One-hot encode the batch indices
one_hot_data = nn.functional.one_hot(data.squeeze(1), num_classes=self.n_batches).float()
# Pass through the encoder
return self.s_encoder(one_hot_data)
[docs]
class S_Decoder(nn.Module):
"""
Decoder for reconstructing batch ID.
Parameters:
n_batches : int
Number of distinct batches.
dims_dec_s : List[int]
List of dimensions for hidden layers in the decoder.
dim_u : int
Latent dimension size for the input (e.g, 2).
norm : str
Normalization type (e.g., 'ln' for LayerNorm).
drop : float
Dropout rate.
"""
def __init__(
self,
n_batches: int,
dims_dec_s: List[int],
dim_u: int,
norm: str,
drop: float):
super(S_Decoder, self).__init__()
self.n_batches = n_batches
self.dims_dec_s = dims_dec_s
self.dim_u = dim_u
self.norm = norm
self.drop = drop
# Define the decoder MLP
self.s_decoder = MLP(
[self.dim_u] + self.dims_dec_s + [self.n_batches],
hid_norm=self.norm,
hid_drop=self.drop,
)
[docs]
def forward(self, data: torch.Tensor) -> torch.Tensor:
"""
Forward pass for S_Decoder.
Parameters:
data : torch.Tensor
Latent input tensor of shape (batch_size, dim_u).
Returns:
torch.Tensor :
Reconstructed tensor of shape (batch_size, n_batches).
"""
return self.s_decoder(data)
[docs]
class VAE(nn.Module):
"""
Variational Autoencoder (VAE) for multi-modal data, supporting batch correction and
sampling from distributions.
Parameters:
dims_x : Dict[str, list]
Input dimensions for each modality, e.g {'rna'=[1000], 'adt'=[100], 'atac'=[10,10,10]}.
dims_s : Dict[str, int]
Dimensions of the classes for each modality.
kwargs : Dict[str, Any]
Additional configurations for encoders, decoders, and other modules.
"""
def __init__(self, dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs):
super(VAE, self).__init__()
self.dims_x = dims_x
self.dims_s = dims_s
self.mods = set(dims_x.keys())
logger.debug(f'Initializing VAE with modalities: {self.mods}')
logger.debug(f'Initializing VAE with dims_s: {self.dims_s}')
logger.debug(f'Initializing VAE with dims_x: {self.dims_x}')
# Dynamically set additional arguments
for key, value in kwargs.items():
setattr(self, key, value)
self.available_mods = set(self.dims_x.keys())
self.dim_z = self.dim_c + self.dim_u
self.dims_h = self.get_dim_h()
self.n_batches = dims_s['joint']
# Initialize modules
self.encoder = Encoder(self.dims_x, self.dims_h, self.dim_z, self.norm, self.out_trans, self.drop,
**filter_keys(self.__dict__, '_enc'))
self.decoder = Decoder(self.dims_x, self.dims_h, self.dim_z, self.norm, self.out_trans, self.drop,
**filter_keys(self.__dict__, '_dec'))
self.s_encoder = S_Encoder(self.n_batches, self.dims_enc_s, self.dim_z, self.norm, self.drop)
self.s_decoder = S_Decoder(self.n_batches, self.dims_dec_s, self.dim_u, self.norm, self.drop)
# Batch correction and sampling configurations
self.batch_correction = False
self.u_centroid = None
self.drop_s = False
self.sampling = False
self.sample_num = 0
[docs]
def forward(self, data: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Dict[str, torch.Tensor],
Dict[str, torch.Tensor]]:
"""
Forward pass for the VAE.
Parameters:
data : Dict[str, torch.Tensor]
Input data dictionary containing:
- 'x': Dict[str, torch.Tensor], modality-specific input data.
- 'e': Dict[str, torch.Tensor], modality-specific masks.
- 's' (optional): torch.Tensor, dimensions of the output classes for each modality.
Returns:
Tuple:
- x_r_pre : Dict[str, torch.Tensor]
Reconstructed modality-specific data.
- s_r_pre : Optional[torch.Tensor]
If 's' is provided, return reconstructed batch indices.
If 's' is not provided, return None.
- z_mu : torch.Tensor
Mean of the combined latent variables.
- z_logvar : torch.Tensor
Log-variance of the combined latent variables.
- z : torch.Tensor
Sampled latent variables.
- c : torch.Tensor
Biological information variables.
- u : torch.Tensor
Technical noise variables.
- z_uni : Dict[str, torch.Tensor]
Unified latent variables for each modality.
- c_all : Dict[str, torch.Tensor]
Modality-specific Biological information variables.
"""
x = data['x']
e = data['e']
s = None
# Handle batch-specific information. See https://github.com/labomics/midas/issues/12.
if not self.drop_s and 's' in data:
s_drop_rate = self.s_drop_rate if self.training else 0
if torch.rand([]).item() < 1 - s_drop_rate:
s = data['s']
# Encode data
# check device:
logger.debug(f"x device: {next(iter(x.values())).device}")
logger.debug(f"model device: {next(self.parameters()).device}")
z_x_mu, z_x_logvar = self.encoder(x, e)
z_s_mu, z_s_logvar = self.encode_batch(s)
# Combine latent variables using Product of Experts
z_mu, z_logvar = self.poe(
list(z_x_mu.values()) + z_s_mu,
list(z_x_logvar.values()) + z_s_logvar,
)
# Sample latent variables
z = self.sample_latent(z_mu, z_logvar)
# Split latent variables into c and u
c, u = z.split([self.dim_c, self.dim_u], dim=1)
# Perform batch correction if enabled
if self.batch_correction:
z[:, self.dim_c:] = self.u_centroid.type_as(z).unsqueeze(0)
# Decode data
x_r_pre = self.decoder(z)
# Decode batch-specific information
s_r_pre = self.s_decoder(u) if s is not None else None
# Generate unified latent variables and modality-specific c
z_uni, c_all = self.generate_unified_latent(z_x_mu, z_x_logvar, z_s_mu, z_s_logvar, c)
return x_r_pre, s_r_pre, z_mu, z_logvar, z, c, u, z_uni, c_all
[docs]
def encode_batch(self, s: torch.Tensor) -> Optional[Tuple[list, list]]:
"""
Encode batch IDs latent variables.
Parameters:
s : torch.Tensor
Batch IDs.
Returns:
Optional[Tuple[list, list]]:
- z_s_mu : List[torch.Tensor]
Mean of batch IDs latent variables.
- z_s_logvar : List[torch.Tensor]
Log-variance of batch IDs latent variables.
"""
if s is not None:
z_s_mu, z_s_logvar = self.s_encoder(s['joint']).split(self.dim_z, dim=1)
return [z_s_mu], [z_s_logvar]
return [], []
[docs]
def sample_latent(self, z_mu: torch.Tensor, z_logvar: torch.Tensor) -> torch.Tensor:
"""
Sample latent variables from a Gaussian distribution.
Parameters:
z_mu : torch.Tensor
Mean of the latent variables of shape (batch_size, latent_dim).
z_logvar : torch.Tensor
Log-variance of the latent variables of shape (batch_size, latent_dim).
Returns:
torch.Tensor:
Sampled latent variables of shape (batch_size, latent_dim).
"""
if self.training:
return self.sample_gaussian(z_mu, z_logvar)
elif self.sampling and self.sample_num > 0:
z_mu_expand = z_mu.unsqueeze(1)
z_logvar_expand = z_logvar.unsqueeze(1).expand(-1, self.sample_num, self.dim_z)
return self.sample_gaussian(z_mu_expand, z_logvar_expand).reshape(-1, self.dim_z)
return z_mu
[docs]
def generate_unified_latent(
self,
z_x_mu: Dict[str, torch.Tensor],
z_x_logvar: Dict[str, torch.Tensor],
z_s_mu: List[torch.Tensor],
z_s_logvar: List[torch.Tensor],
c: torch.Tensor,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Generate unified latent variables and modality-specific representations.
Parameters:
z_x_mu : Dict[str, torch.Tensor]
Means of modality-specific latent variables.
z_x_logvar : Dict[str, torch.Tensor]
Log-variances of modality-specific latent variables.
z_s_mu : List[torch.Tensor]
Mean of the batch-ID latent variables.
z_s_logvar : List[torch.Tensor]
Log-variance of the batch-ID latent variables.
c : torch.Tensor
Biological information.
Returns:
Tuple:
- z_uni : Dict[str, torch.Tensor]:
Collection of latent variables for the unimodal inputs.
- c_all : Dict[str, torch.Tensor]:
Collection of biological information for the unimodal and joint inputs.
"""
z_uni = {}
c_all = {}
for modality, z_x_mu_mod in z_x_mu.items():
# Combine modality-specific and batch-specific latent variables
z_uni_mu, z_uni_logvar = self.poe([z_x_mu_mod] + z_s_mu, [z_x_logvar[modality]] + z_s_logvar)
# fix here
z_uni[modality] = self.sample_latent(z_uni_mu, z_uni_logvar)
# Extract shared latent representation (biological information)
c_all[modality] = z_uni[modality][:, :self.dim_c]
# Add joint representation
c_all['joint'] = c
return z_uni, c_all
[docs]
def get_dim_h(self) -> Dict[str, List[int]]:
"""
Compute hidden dimensions for each modality.
Returns:
Dict[str, List[int]]:
A dictionary containing the hidden dimensions for each modality.
"""
dims_h = self.dims_x.copy()
# Adjust dimensions based on pre-encoding layers
for key in filter_keys(self.__dict__, 'dims_before_enc_'):
modality = key.split('_')[-1]
if (modality in self.dims_x) and (len(self.dims_x[modality]) > 1):
dims_h[modality] = [sum([self.__dict__[key][-1]] * len(self.dims_x[modality]))]
return dims_h
[docs]
def gen_real_data(self,
x_r_pre: Dict[str, torch.Tensor],
sampling: bool = True) -> Dict[str, torch.Tensor]:
"""
Generate real data from reconstructed data.
Parameters:
x_r_pre : Dict[str, torch.Tensor]
Dictionary of reconstructed data tensors for each modality.
sampling : bool, optional
Whether to sample the output (default: True).
Returns:
Dict[str, torch.Tensor]:
Generated real data for each modality.
"""
x_r = {}
for modality, tensor in x_r_pre.items():
# Apply inverse transformations if needed
if f'trsf_before_enc_{modality}' in self.__dict__:
tensor = reverse_trsf(self.__dict__[f'trsf_before_enc_{modality}'].split('_')[-1], tensor)
# Apply sampling or directly return the data
x_r[modality] = self.sample(
self.__dict__[f'distribution_dec_{modality}'].split('_')[-1], tensor, sampling)
return x_r
[docs]
@staticmethod
def sample(name: str, data: torch.Tensor, sampling: bool) -> torch.Tensor:
"""
Map a sampling function based on the distribution name.
Parameters:
name : str
Name of the distribution.
data : torch.Tensor
Input data tensor.
sampling : bool
Whether to apply sampling.
Returns:
torch.Tensor
Sampled or original data tensor.
"""
if sampling:
return distribution_registry.get_sampling(name)(data)
return data
[docs]
@staticmethod
def sample_gaussian(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
Sample from a Gaussian distribution using the reparameterization trick.
Parameters:
mu : torch.Tensor
Mean of the Gaussian distribution.
logvar : torch.Tensor
Log-variance of the Gaussian distribution.
Returns:
torch.Tensor
Sampled tensor.
"""
std = (0.5 * logvar).exp()
eps = torch.randn_like(std)
return mu + std * eps
[docs]
@staticmethod
def poe(mus: List[torch.Tensor], logvars: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Product of Experts (PoE) for combining Gaussian distributions.
Parameters:
mus : list of torch.Tensor
List of mean tensors for each Gaussian.
logvars : list of torch.Tensor
List of log-variance tensors for each Gaussian.
Returns:
Tuple :
- combined_mean: torch.Tensor
Mean of the combined Gaussian distribution.
- combined_logvar: torch.Tensor
Log-variance of the combined Gaussian distribution.
"""
# Add prior distributions with zero mean and unit variance
try:
mus = [torch.zeros_like(mus[0])] + mus
except:
logger.debug(mus)
logvars = [torch.zeros_like(logvars[0])] + logvars
# Calculate precision and combined precision
precisions = torch.exp(-torch.stack(logvars, dim=1)) # Shape: (batch_size, num_experts, latent_dim)
precision_sum = precisions.sum(dim=1)
# Calculate combined mean and variance
weighted_means = (torch.stack(mus, dim=1) * precisions).sum(dim=1)
combined_mean = weighted_means / precision_sum
combined_logvar = torch.log(1 / precision_sum)
return combined_mean, combined_logvar
[docs]
class Discriminator(nn.Module):
"""
Discriminator class for multi-modal latent variables.
Parameters:
dims_x : Dict[str, list]
Input dimensions for each modality.
dims_s : Dict[str, int]
Dimensions of the classes for each modality.
kwargs : Dict[str, Any]
Additional configurations, such as hidden layer sizes, dropout rate, and normalization type.
"""
def __init__(self, dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs):
super(Discriminator, self).__init__()
self.dims_x = dims_x
self.dims_s = dims_s
# Dynamically set additional arguments as attributes
for key, value in kwargs.items():
setattr(self, key, value)
# Combine modality keys with 'joint' modality
self.modalities = list(self.dims_x.keys()) + ['joint']
# Create predictors for each modality
self.predictors = nn.ModuleDict({
modality: MLP(
[self.dim_c] + self.dims_dsc + [self.dims_s[modality]],
hid_norm=self.norm,
hid_drop=self.drop
)
for modality in self.modalities
})
# Cross-entropy loss function
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum') # log_softmax + nll
[docs]
def forward(self, latent_inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass for the discriminator.
Parameters:
latent_inputs : Dict[str, torch.Tensor]
Dictionary of latent inputs for each modality, where keys are modality names
and values are tensors of shape (batch_size, dim_c).
Returns:
Dict[str, torch.Tensor] :
Dictionary of logits for each modality, where keys are modality names
and values are tensors of shape (batch_size, dims_s[modality]).
"""
return {modality: self.predictors[modality](latent_input)
for modality, latent_input in latent_inputs.items()}
[docs]
def calculate_loss(self,
predictions: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculate cross-entropy loss for all modalities.
Parameters:
predictions : Dict[str, torch.Tensor]
Dictionary of predicted logits for each modality.
targets : Dict[str, torch.Tensor]
Dictionary of ground truth labels for each modality.
Returns:
torch.Tensor :
Total normalized loss.
"""
total_loss = sum(
self.cross_entropy_loss(pred, targets[modality].squeeze(1))
for modality, pred in predictions.items()
)
# Normalize the total loss by the batch size of the joint modality
batch_size = predictions['joint'].size(0)
return total_loss / batch_size
[docs]
class MIDAS(L.LightningModule):
"""
MIDAS processes mosaic single-cell data into imputed and batch-corrected data for multimodal analysis.
Attributes:
net : VAE
Variational Autoencoder for multi-modal data encoding and decoding.
dsc : Discriminator
Discriminator for distinguishing latent variables across batches.
configs : Dict[str, Any]
Model and training configurations dynamically set as attributes.
automatic_optimization : bool
Controls whether optimization is automatic or manually defined. Always True.
"""
def __init__(
self,
mdata: Optional["MuData"] = None,
*,
save_model_path: str = './saved_models/scmidas',
configs: Optional[Dict[str, Any]] = None,
batch_size: int = 256,
n_save: int = 500,
sampler_type: str = 'auto',
viz_umap_tb: bool = False,
transform: Optional[Dict[str, str]] = None,
):
"""Construct a MIDAS model.
Two construction paths are supported:
- **Recommended (v0.3+):** ``MIDAS(mdata, ...)`` — pass a MuData that
has already been registered via :func:`MIDAS.setup_mudata`. The
model reads ``mdata.uns['_scmidas']`` and builds its internal
datalist directly. Instance state is fully encapsulated.
- **Legacy:** ``MIDAS()`` (no args) — relies on class-level
attributes previously set by :meth:`configure_data` /
:meth:`configure_data_from_mdata` /
:meth:`configure_data_from_dir`. Kept for backwards compatibility;
will be removed in 0.4.0.
"""
super(MIDAS, self).__init__()
self.start_epoch = 0
self.load_optimizer_state = False
if mdata is not None:
self._init_from_mdata(
mdata,
save_model_path=save_model_path,
configs=configs,
batch_size=batch_size,
n_save=n_save,
sampler_type=sampler_type,
viz_umap_tb=viz_umap_tb,
transform=transform,
)
# Initialize VAE and Discriminator
self.net = VAE(self.dims_x, self.dims_s, **self.configs)
self.dsc = Discriminator(self.dims_x, self.dims_s, **self.configs)
# Dynamically set configurations as attributes
for key, value in self.configs.items():
setattr(self, key, value)
# Disable automatic optimization to manually control training steps. Always True.
self.automatic_optimization = False
def _init_from_mdata(
self,
mdata: "MuData",
*,
save_model_path: str,
configs: Optional[Dict[str, Any]],
batch_size: int,
n_save: int,
sampler_type: str,
viz_umap_tb: bool,
transform: Optional[Dict[str, str]],
) -> None:
"""Populate instance state from a registered MuData object."""
if '_scmidas' not in mdata.uns:
raise RuntimeError(
"mdata has not been registered with MIDAS. Call "
"scmidas.MIDAS.setup_mudata(mdata, batch_key=...) first."
)
setup = dict(mdata.uns['_scmidas'])
if configs is None:
from .config import load_config
configs = dict(load_config())
else:
configs = dict(configs)
batch_key = setup['batch_key']
dims_x = {m: list(map(int, v)) for m, v in setup['dims_x'].items()}
atac_dims = dims_x.get('atac')
if atac_dims is not None and len(atac_dims) == 1:
logger.warning(
f"Detected ATAC with only one dimension [{atac_dims[0]}]. "
"This will cause the data to be encoded directly instead of by chromosome, as described in our paper. "
"We recommend splitting the ATAC data by chromosome."
)
if 'dims_before_enc_atac' in configs and 'dims_after_dec_atac' in configs:
raise ValueError(
'Invalid ATAC configuration: both "dims_before_enc_atac" and '
'"dims_after_dec_atac" are present in configs, but '
'len(dims_x["atac"]) == 1.'
)
# ATAC is binarized by default — that is the recommended setting and
# the one used in our published results. Users who need raw counts
# can override by passing ``transform={'atac': None}`` or by
# supplying their own dict with a different value for ``'atac'``.
effective_transform: Dict[str, Any] = {}
if 'atac' in mdata.mod:
effective_transform['atac'] = 'binarize'
if transform is not None:
for k, v in transform.items():
if v is None:
effective_transform.pop(k, None)
else:
effective_transform[k] = v
transform_arg = effective_transform if effective_transform else None
data, mask, batch_names = self.get_info_from_mdata(mdata, batch_key)
datalist, dims_s, s_joint, combs = self.get_datasets_from_adata(
data, mask, batch_names, transform=transform_arg,
)
self.configs = configs
self.dims_x = dims_x
self.dims_s = dims_s
self.batch_names = list(batch_names)
self.datalist = datalist
self.s_joint = s_joint
self.combs = combs
self.mods = list(dims_x.keys())
self.save_model_path = save_model_path
self.batch_size = batch_size
self.n_save = n_save
self.sampler_type = sampler_type
self.viz_umap_tb = viz_umap_tb
self._mdata = mdata
[docs]
@staticmethod
def setup_mudata(
mdata: "MuData",
batch_key: str = 'batch',
dims_x: Optional[Dict[str, list]] = None,
) -> None:
"""Register a MuData object for use with MIDAS.
Stores the data-setup metadata (batch key, modality names, feature
dimensions, batch list) under ``mdata.uns['_scmidas']``. After this
call, the MuData can be passed directly to :class:`MIDAS`.
Parameters:
mdata : MuData
Multi-modal data. Each modality (``mdata[m]``) must contain
``batch_key`` in its ``.obs``.
batch_key : str
Column in each modality's ``.obs`` that identifies the
source batch.
dims_x : dict, optional
Per-modality feature dimensions, e.g.
``{'rna': [4045], 'adt': [224], 'atac': [2897, 2007, ...]}``.
Provide this for ATAC data that should be encoded by
chromosome chunks (the published architecture). If omitted,
each modality is treated as a single chunk equal to its
full feature count.
Notes:
Mutates ``mdata`` in-place; returns ``None``.
"""
if not hasattr(mdata, 'mod') or len(mdata.mod) == 0:
raise ValueError("mdata must be a MuData with at least one modality.")
missing = [m for m in mdata.mod if batch_key not in mdata[m].obs.columns]
if missing:
avail = {m: list(mdata[m].obs.columns) for m in missing}
raise ValueError(
f"batch_key='{batch_key}' not found in .obs of modalities {missing}. "
f"Available columns per modality: {avail}"
)
if dims_x is None:
dims_x = {m: [int(mdata[m].n_vars)] for m in mdata.mod}
else:
extra = set(dims_x) - set(mdata.mod)
if extra:
raise ValueError(
f"dims_x has modalities {sorted(extra)} not present in mdata.mod={list(mdata.mod)}"
)
dims_x = dict(dims_x)
for m in mdata.mod:
if m not in dims_x:
dims_x[m] = [int(mdata[m].n_vars)]
for m, chunks in dims_x.items():
total = int(sum(chunks))
actual = int(mdata[m].n_vars)
if total != actual:
raise ValueError(
f"dims_x[{m!r}] sums to {total} but mdata[{m!r}].n_vars={actual}."
)
dims_x[m] = list(map(int, chunks))
batch_names = []
for m in mdata.mod:
batch_names.extend(mdata[m].obs[batch_key].astype(str).unique().tolist())
batch_names = sorted(set(batch_names))
mdata.uns['_scmidas'] = {
'batch_key': batch_key,
'dims_x': dims_x,
'batch_names': batch_names,
'mods': list(mdata.mod.keys()),
'version': 1,
}
logger.info(
"setup_mudata: batch_key='%s', batches=%s, modalities=%s, dims_x=%s",
batch_key, batch_names, list(mdata.mod.keys()), dims_x,
)
[docs]
def train_dataloader(self) -> DataLoader:
"""
Create a DataLoader for training, using the appropriate sampler.
Returns:
DataLoader :
Configured DataLoader instance for training.
"""
# Concatenate all datasets
try:
dataset = ConcatDataset(self.datalist)
logger.info(f'Total number of samples: {len(dataset)} from {len(self.datalist)} datasets.')
except Exception as e:
raise ValueError('Failed to concatenate datasets. Please check the input datalist.') from e
# Select the appropriate sampler.
# 'auto' picks the DDP sampler when a process group is initialized;
# this matches the user-visible 'auto' name and prevents silent
# rank-agnostic sampling under DDP.
use_ddp_sampler = self.sampler_type == 'ddp' or (
self.sampler_type == 'auto'
and dist.is_available()
and dist.is_initialized()
)
if use_ddp_sampler:
logger.info('Using Distributed Data Parallel (DDP) sampler.')
sampler = MyDistributedSampler(dataset, batch_size=self.batch_size, n_max=self.n_max)
else:
logger.info('Using MultiBatchSampler for data loading.')
sampler = MultiBatchSampler(dataset, batch_size=self.batch_size, n_max=self.n_max)
# Create the DataLoader
try:
train_loader = DataLoader(
dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers
)
logger.info(f'DataLoader created with batch size {self.batch_size} and {self.num_workers} workers.')
except Exception as e:
raise RuntimeError('Failed to create DataLoader. Check DataLoader configuration.') from e
logger.debug(f'DataLoader: {len(train_loader)}')
return train_loader
[docs]
def training_step(self,
batch: Dict[str, Dict[str, torch.Tensor]],
batch_idx: int) -> torch.Tensor:
"""
Executes a single training step for MIDAS.
Parameters:
batch : Dict[str, Dict[str, torch.Tensor]]
Input batch containing modality data, batch indices, and masks.
batch_idx : int
Index of the current training batch.
Returns:
torch.Tensor :
Total VAE loss for the current batch.
"""
# Forward pass through the VAE
logger.debug(f"Training step - batch index: {batch_idx}")
logger.debug(f"Input: {batch}")
x_r_pre, s_r_pre, z_mu, z_logvar, z, c, u, z_uni, c_all = self.net(batch)
logger.debug(f"Current batch: {batch['s']['joint'][0]}")
c_all['joint'] = c
# Compute reconstruction loss
recon_loss, recon_dict = self.calc_recon_loss(
batch['x'], batch['s']['joint'], batch['e'],
x_r_pre, s_r_pre,
filter_keys(self.__dict__, 'distribution_dec_'),
filter_keys(self.__dict__, 'lam_recon_')
)
recon_loss *= self.lam_recon
# Compute KLD loss
kld_loss = self.calc_kld_z_loss(
self.dim_c, self.dim_u, self.lam_kld_c, self.lam_kld_u, z_mu, z_logvar
) * self.lam_kld
# Compute consistency loss
consistency_loss = self.calc_consistency_loss(z_uni) * self.lam_alignment
# Compute total VAE loss
loss_net = recon_loss + kld_loss + consistency_loss
# Train discriminator for n_iter_disc iterations
for _ in range(self.n_iter_disc):
self.train_discriminator(c_all, batch['s'])
# Compute adversarial loss for the VAE
s_pred = self.dsc(c_all)
loss_dsc = self.calc_dsc_loss(s_pred, batch['s']) * self.lam_dsc
loss_net = loss_net - loss_dsc * self.lam_adv
# Update VAE model
self.update_model(loss_net, self.net, self.net_optim, self.grad_clip)
# Log training losses
self.log_losses(recon_loss, kld_loss, consistency_loss, loss_net, loss_dsc, recon_dict)
return loss_net
[docs]
def train_discriminator(self,
c_all: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor]):
"""
Train the discriminator with modality-specific latent representations.
Parameters:
c_all : Dict[str, torch.Tensor]
Dictionary of latent representations for each modality.
targets : Dict[str, torch.Tensor]
Ground truth batch labels for each modality.
"""
s_pred = self.dsc(detach_tensors(c_all))
loss_dsc = self.calc_dsc_loss(s_pred, targets) * self.lam_dsc
self.update_model(loss_dsc, self.dsc, self.dsc_optim, self.grad_clip)
[docs]
def train(self, **kwargs):
trainer = L.Trainer(**kwargs)
trainer.fit(model=self)
[docs]
@rank_zero_only
def predict(
self,
return_in_memory: bool = True,
save_dir: Optional[str] = None,
save_format: str = "npy", # "npy" or "csv"
joint_latent: bool = True,
mod_latent: bool = False,
impute: bool = False,
batch_correct: bool = False,
translate: bool = False,
input: bool = False,
verbose: bool = True
):
"""
Run model inference in a streaming manner.
This method supports three prediction modes:
1. Return predictions in memory (recommended for small or medium datasets).
2. Stream predictions to disk per mini-batch (recommended for large datasets).
3. Perform both simultaneously.
Notes:
- If `return_in_memory=False`, prediction tensors will not accumulate in RAM,
making this method suitable for very large datasets.
- If `save_dir` is provided, predictions are written incrementally to disk
per mini-batch.
- If `batch_correct=True`, a second pass over the dataset is performed:
Pass 1:
Compute joint latent representations and estimate the technical
centroid using online statistics.
Pass 2:
Reconstruct data with batch correction and stream the corrected
outputs.
- If `translate=True`, `mod_latent` will be automatically set to True.
Parameters:
return_in_memory : bool, default=True
Whether to keep predictions in memory and return them as a nested
dictionary. Set to False for large datasets to avoid OOM.
save_dir : str or None, default=None
Output directory for streaming prediction results. If None,
predictions are not saved to disk.
save_format : {"npy", "csv"}, default="npy"
File format used when `save_dir` is provided.
- `"npy"` : NumPy binary format (recommended; fast and compact).
- `"csv"` : CSV text format (not recommended for large arrays).
joint_latent : bool, default=True
Whether to compute the joint latent representation conditioned on all
observed modalities.
Stored as:
- `z["joint"]` (raw latent)
- or postprocessed into `z_c["joint"]` and `z_u["joint"]`.
mod_latent : bool, default=False
Whether to compute latent representations conditioned on each
individual modality.
For each modality `m`, a single-modality forward pass is performed
and stored as:
`z[m]`
impute : bool, default=False
Whether to generate imputed data (`x_impt`) from the joint latent
representation.
Stored as:
`x_impt[modality]`
batch_correct : bool, default=False
Whether to estimate a technical centroid and perform batch-effect
correction on reconstructed data.
Stored as:
`x_bc[modality]`
translate : bool, default=False
Whether to perform cross-modality translation.
For each available input modality subset, missing modalities are
generated and stored as:
`x_trans["<input_mods>_to_<target_mod>"]`
input : bool, default=False
Whether to include the original input data and masks in the output.
Stored as:
`x[modality]`
`mask[modality]` (if available)
verbose : bool, default=True
Whether to display progress bars (`tqdm`) and logging messages.
Returns:
output : dict or None
Raises:
ValueError
If both `return_in_memory=False` and `save_dir=None`, or if an
unsupported `save_format` is specified.
KeyError
If required prediction fields are missing.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if verbose:
logger.info(f"Predicting using device: {device}")
model = self.net.to(device).eval()
_old_bc = getattr(model, "batch_correction", None)
_old_uc = getattr(model, "u_centroid", None)
if hasattr(model, "batch_correction"):
model.batch_correction = False
if translate:
mod_latent = True
# Choose sink(s)
sinks: List[BaseSink] = []
mem_sink = MemorySink() if return_in_memory else None
if mem_sink is not None:
sinks.append(mem_sink)
disk_sink = None
if save_dir is not None:
disk_sink = DiskSink(DiskSinkConfig(save_dir=save_dir, save_format=save_format))
sinks.append(disk_sink)
if not sinks:
raise ValueError("You must enable at least one of return_in_memory=True or save_dir!=None.")
# For batch_correct centroid computation (online; no full z storage)
online_stats: Optional[OnlineMeanByGroup] = None
if batch_correct:
online_stats = None
all_combinations = generate_all_combinations(self.mods) if translate else None
with torch.no_grad():
# -----------------------
# Pass 1: standard outputs (+ collect stats for batch_correct)
# -----------------------
for batch_id, dataset in enumerate(self.datalist):
batch_name = self.batch_names[batch_id]
loader = DataLoader(dataset, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers)
if verbose:
logger.info("Processing batch %s: %s", batch_name, str(self.combs[batch_id]))
for i, batch in enumerate(tqdm(loader, desc=f"predict:{batch_name}", disable=not verbose)):
batch = convert_tensors_to_cuda(batch, device)
# Save s (labels / subset ids) if present
if "s" in batch and isinstance(batch["s"], dict):
for k, v in batch["s"].items():
for s in sinks:
s.write(batch_name, ["s", k], v)
# Always compute forward once (cheap vs branching), then selectively write
# Expected forward signature: x_r_pre, ..., z, c, u, ...
out = model(batch)
x_r_pre = out[0]
z = out[4]
# init online_stats after z known
if batch_correct and online_stats is None:
z_dim = z.shape[1]
u_dim = z_dim - self.dim_c
if u_dim <= 0:
raise ValueError(f"dim_c={self.dim_c} is invalid for z_dim={z_dim}")
online_stats = OnlineMeanByGroup(dim=u_dim)
# joint latent (z)
if joint_latent:
for s in sinks:
s.write(batch_name, ["z", "joint"], z)
# online stats for batch correction: use u=z[:, dim_c:], group id from s['joint'] if exists
if batch_correct:
# try common field names
if "s" in batch and isinstance(batch["s"], dict):
if "joint" in batch["s"]:
g = batch["s"]["joint"]
else:
# fallback: first key
g = next(iter(batch["s"].values()))
else:
raise ValueError("batch_correct=True requires batch['s'][...] for grouping.")
u = z[:, self.dim_c:]
online_stats.update(u, g)
# impute
if impute:
x_r = model.gen_real_data(x_r_pre, sampling=False)
for m, xm in x_r.items():
for s in sinks:
s.write(batch_name, ["x_impt", m], xm)
# save input + masks
if input:
for m in self.combs[batch_id]:
if "x" in batch and m in batch["x"]:
for s in sinks:
s.write(batch_name, ["x", m], batch["x"][m])
if "e" in batch and isinstance(batch["e"], dict) and m in batch["e"]:
# mask typically small; store as meta or normal tensor
# if you prefer single-file meta, use write_meta
mask_np = to_numpy(batch["e"][m])[0] if batch["e"][m].ndim >= 2 else to_numpy(batch["e"][m])
for s in sinks:
s.write_meta(batch_name, ["mask", m], mask_np)
# per-modality latent
if mod_latent:
for m in batch.get("x", {}).keys():
input_data = {"x": {m: batch["x"][m]}, "s": batch.get("s", {}), "e": {}}
if "e" in batch and m in batch["e"]:
input_data["e"][m] = batch["e"][m]
out_m = model(input_data)
z_m = out_m[4]
for s in sinks:
s.write(batch_name, ["z", m], z_m)
# translate (general: any input subset -> remaining outputs)
if translate and all_combinations is not None:
for input_mods, output_mods in all_combinations:
input_mods_sorted = sorted(input_mods)
# check availability in this minibatch
if not all(m in batch.get("x", {}) for m in input_mods_sorted):
continue
input_data = {
"x": {m: batch["x"][m] for m in input_mods_sorted},
"s": batch.get("s", {}),
"e": {}
}
if "e" in batch:
for m in input_mods_sorted:
if m in batch["e"]:
input_data["e"][m] = batch["e"][m]
out_t = model(input_data)
x_r_pre_t = out_t[0]
x_r_t = model.gen_real_data(x_r_pre_t, sampling=False)
for mod in output_mods:
key = "_".join(input_mods_sorted) + "_to_" + mod
for s in sinks:
s.write(batch_name, ["x_trans", key], x_r_t[mod])
# -----------------------
# Pass 2: batch correction reconstruction (streaming)
# -----------------------
if batch_correct:
if online_stats is None:
raise RuntimeError("Internal error: online_stats not initialized.")
u_centroid = online_stats.finalize_centroid().to(device)
# expected: model has fields for correction; adapt to your implementation
# (match your new code: model.u_centroid / model.batch_correction)
_bc_prev = getattr(model, "batch_correction", None)
_uc_prev = getattr(model, "u_centroid", None)
try:
if hasattr(model, "u_centroid"):
model.u_centroid = u_centroid
if hasattr(model, "batch_correction"):
model.batch_correction = True
if verbose:
logger.info("Batch correction (second pass) ...")
for batch_id, dataset in enumerate(self.datalist):
batch_name = self.batch_names[batch_id]
loader = DataLoader(dataset, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers)
if verbose:
logger.info("Processing batch %s: %s", batch_name, str(self.combs[batch_id]))
for i, batch in enumerate(tqdm(loader, desc=f"batch_correct:{batch_name}", disable=not verbose)):
batch = convert_tensors_to_cuda(batch, device)
out = model(batch)
x_r_pre = out[0]
x_r = model.gen_real_data(x_r_pre, sampling=True)
for m in self.mods:
if m in x_r:
for s in sinks:
s.write(batch_name, ["x_bc", m], x_r[m])
finally:
if hasattr(model, "batch_correction"):
model.batch_correction = (_bc_prev if _bc_prev is not None else False)
if hasattr(model, "u_centroid"):
model.u_centroid = _uc_prev
# finalize sinks
disk_out = disk_sink.finalize() if disk_sink is not None else None
if mem_sink is None:
# pure disk mode
return disk_out
# Post-process memory output into pred_b like you had (z -> z_c / z_u)
raw = mem_sink.finalize()
pred_b: Dict[str, Any] = {}
for batch_name, d in raw.items():
pred_b[batch_name] = {}
# z split
if "z" in d:
pred_b[batch_name]["z_c"] = {}
pred_b[batch_name]["z_u"] = {}
for k, zt in d["z"].items():
znp = to_numpy(zt)
pred_b[batch_name]["z_c"][k] = znp[:, :self.dim_c]
pred_b[batch_name]["z_u"][k] = znp[:, self.dim_c:]
# others
for var in ["x_impt", "x_trans", "x_bc", "x", "s"]:
if var in d:
pred_b[batch_name][var] = {}
for k, vt in d[var].items():
pred_b[batch_name][var][k] = to_numpy(vt)
# masks (meta)
if "mask" in d:
pred_b[batch_name]["mask"] = d["mask"]
if not joint_latent and "z_c" in pred_b[batch_name]:
pred_b[batch_name]["z_c"].pop("joint", None)
pred_b[batch_name]["z_u"].pop("joint", None)
if disk_out is not None:
# if both: return both memory + disk manifest
return {"memory": pred_b, "disk": disk_out}
if hasattr(model, "batch_correction"):
model.batch_correction = (_old_bc if _old_bc is not None else False)
if hasattr(model, "u_centroid"):
model.u_centroid = _old_uc
return pred_b
[docs]
@rank_zero_only
def get_latent_representation(
self,
mdata: Optional["MuData"] = None,
*,
kind: str = 'c',
verbose: bool = False,
) -> 'np.ndarray':
"""Return the joint latent representation aligned to ``mdata.obs_names``.
This is the convenience wrapper around :meth:`predict` that most
users want: it runs the joint latent forward pass, stitches the
per-batch outputs, and reorders the result so each row matches
``mdata.obs_names`` — ready to drop straight into
``mdata.obsm['X_midas']``.
Parameters:
mdata : MuData, optional
MuData to align against. Defaults to the MuData used at
construction time. Must include the same cells as the
registered MuData (same ``obs_names``).
kind : {'c', 'u', 'joint'}
Which latent to return:
- ``'c'`` (default): biological latent ``z_c`` of shape
``(n_obs, dim_c)``. Recommended for clustering /
visualization.
- ``'u'``: technical latent ``z_u`` of shape
``(n_obs, dim_u)``.
- ``'joint'``: concatenation ``[z_c, z_u]``.
verbose : bool
Whether to show prediction progress bars.
Returns:
np.ndarray:
Array of shape ``(mdata.n_obs, dim)`` aligned to
``mdata.obs_names``. Cells absent from training data
yield NaN rows (a warning is logged).
"""
if kind not in ('c', 'u', 'joint'):
raise ValueError(f"kind must be 'c', 'u', or 'joint'; got {kind!r}")
target = mdata if mdata is not None else getattr(self, '_mdata', None)
if target is None:
raise RuntimeError(
"No MuData available. Either construct the model via "
"MIDAS(mdata, ...) or pass mdata= to get_latent_representation."
)
if '_scmidas' not in target.uns:
raise RuntimeError(
"mdata is not registered. Call MIDAS.setup_mudata(mdata, ...) first."
)
out = self.predict(joint_latent=True, verbose=verbose)
pieces: List['np.ndarray'] = []
cell_ids: List[str] = []
batch_key = target.uns['_scmidas']['batch_key']
for batch_id, batch_name in enumerate(self.batch_names):
block = out[batch_name]
if kind == 'c':
z = block['z_c']['joint']
elif kind == 'u':
z = block['z_u']['joint']
else:
z = np.concatenate([block['z_c']['joint'], block['z_u']['joint']], axis=1)
pieces.append(z)
# cell IDs: first modality in this batch's combs (matches MultiModalDataset
# iteration order, which sets per-batch sample order via `len(first mod)`)
first_mod = self.combs[batch_id][0]
mask = target[first_mod].obs[batch_key].astype(str) == str(batch_name)
ids = target[first_mod].obs_names[mask].tolist()
if len(ids) != z.shape[0]:
raise RuntimeError(
f"Latent stitching mismatch for batch {batch_name!r}: "
f"expected {len(ids)} cells from modality {first_mod!r}, "
f"got latent of shape {z.shape}."
)
cell_ids.extend(ids)
z_full = np.vstack(pieces)
id_to_row = {cid: i for i, cid in enumerate(cell_ids)}
n_obs = target.n_obs
out_arr = np.full((n_obs, z_full.shape[1]), np.nan, dtype=z_full.dtype)
missing: List[str] = []
for i, cid in enumerate(target.obs_names):
row = id_to_row.get(cid)
if row is None:
missing.append(cid)
else:
out_arr[i] = z_full[row]
if missing:
logger.warning(
"%d cells in mdata.obs_names have no latent representation "
"(first few: %s). Their rows are NaN.",
len(missing), missing[:3],
)
return out_arr
[docs]
@rank_zero_only
def get_imputed_values(
self,
mdata: Optional["MuData"] = None,
*,
modality: str = 'rna',
verbose: bool = False,
) -> 'np.ndarray':
"""Return imputed expression values for a single modality.
For mosaic data, this fills in cells that originally lacked the
modality with model-inferred values. Cells that already had real
observations also get the model's reconstruction (useful for
denoising).
Parameters:
mdata : MuData, optional
MuData to align against. Defaults to the MuData used at
construction time.
modality : str
Modality name (e.g. ``'rna'``, ``'adt'``, ``'atac'``).
Must be one of ``self.mods``.
verbose : bool
Whether to show prediction progress bars.
Returns:
np.ndarray:
Array of shape ``(mdata.n_obs, n_features[modality])``,
aligned to ``mdata.obs_names``. Cells absent from training
data yield NaN rows (a warning is logged).
"""
if modality not in self.mods:
raise ValueError(
f"modality={modality!r} not in registered modalities {self.mods}."
)
target = mdata if mdata is not None else getattr(self, '_mdata', None)
if target is None:
raise RuntimeError(
"No MuData available. Either construct via MIDAS(mdata, ...) "
"or pass mdata= to get_imputed_values."
)
if '_scmidas' not in target.uns:
raise RuntimeError(
"mdata is not registered. Call MIDAS.setup_mudata(mdata, ...) first."
)
out = self.predict(joint_latent=False, impute=True, verbose=verbose)
pieces: List['np.ndarray'] = []
cell_ids: List[str] = []
batch_key = target.uns['_scmidas']['batch_key']
for batch_id, batch_name in enumerate(self.batch_names):
block = out[batch_name].get('x_impt', {})
if modality not in block:
raise RuntimeError(
f"predict(impute=True) did not write x_impt[{modality!r}] "
f"for batch {batch_name!r}; cannot stitch."
)
x = block[modality]
pieces.append(x)
first_mod = self.combs[batch_id][0]
mask = target[first_mod].obs[batch_key].astype(str) == str(batch_name)
ids = target[first_mod].obs_names[mask].tolist()
if len(ids) != x.shape[0]:
raise RuntimeError(
f"Imputation stitching mismatch for batch {batch_name!r}: "
f"expected {len(ids)} cells from modality {first_mod!r}, "
f"got x_impt of shape {x.shape}."
)
cell_ids.extend(ids)
x_full = np.vstack(pieces)
id_to_row = {cid: i for i, cid in enumerate(cell_ids)}
out_arr = np.full((target.n_obs, x_full.shape[1]), np.nan, dtype=x_full.dtype)
missing: List[str] = []
for i, cid in enumerate(target.obs_names):
row = id_to_row.get(cid)
if row is None:
missing.append(cid)
else:
out_arr[i] = x_full[row]
if missing:
logger.warning(
"%d cells in mdata.obs_names have no imputed value "
"(first few: %s). Their rows are NaN.",
len(missing), missing[:3],
)
return out_arr
[docs]
def on_train_epoch_end(self):
"""
Save a model checkpoint at the end of each training epoch with a meaningful filename.
"""
# Save the checkpoint periodically based on n_save
total_epoch = self.current_epoch + self.start_epoch
if (total_epoch + 1) % self.n_save == 0:
os.makedirs(self.save_model_path, exist_ok=True)
# Get the current timestamp
timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
# Generate a descriptive checkpoint filename
checkpoint_filename = f'model_epoch{total_epoch+1}_{timestamp}.pt'
checkpoint_path = os.path.join(self.save_model_path, checkpoint_filename)
# Save the checkpoint
self.save_checkpoint(checkpoint_path)
if self.viz_umap_tb:
logger.info('Plotting UMAP...')
self.get_emb_umap(save_dir=self.save_model_path, n_obs=20000, verbose=False)
self.net.train()
[docs]
def on_train_end(self):
"""
Save the final model checkpoint at the end of training.
"""
os.makedirs(self.save_model_path, exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
checkpoint_filename = f'model_epoch{self.current_epoch+self.start_epoch}_{timestamp}.pt'
checkpoint_path = os.path.join(self.save_model_path, checkpoint_filename)
self.save_checkpoint(checkpoint_path)
[docs]
@rank_zero_only
def save(self, dir_path: str, *, overwrite: bool = False) -> None:
"""Save this MIDAS model to a directory.
Writes two files:
- ``model.pt`` — VAE / Discriminator weights and optimizer states.
- ``setup.json`` — minimal data-setup metadata so the model can be
reloaded on a fresh process.
This is the recommended user-facing save API (v0.3+); pair with
:meth:`MIDAS.load`. The legacy :meth:`save_checkpoint` /
:meth:`load_checkpoint` flat-file API still works.
Parameters:
dir_path : str
Directory path. Created if missing.
overwrite : bool
Replace ``dir_path`` if it already exists.
"""
import json
from pathlib import Path
p = Path(dir_path)
if p.exists() and any(p.iterdir()) and not overwrite:
raise FileExistsError(
f"{p} already exists and is non-empty. "
"Pass overwrite=True to replace it."
)
p.mkdir(parents=True, exist_ok=True)
state = {
'net': self.net.state_dict(),
'dsc': self.dsc.state_dict(),
'optim_net': self.net_optim.state_dict() if hasattr(self, 'net_optim') else None,
'optim_dsc': self.dsc_optim.state_dict() if hasattr(self, 'dsc_optim') else None,
'configs': self.configs,
'epoch': int(getattr(self, 'current_epoch', 0)) + int(getattr(self, 'start_epoch', 0)),
}
torch.save(state, str(p / 'model.pt'))
# batch_key may have been recorded on the registered mdata
batch_key = None
if getattr(self, '_mdata', None) is not None:
batch_key = self._mdata.uns.get('_scmidas', {}).get('batch_key')
setup = {
'dims_x': {m: list(map(int, v)) for m, v in self.dims_x.items()},
'dims_s': {m: int(v) for m, v in self.dims_s.items()},
'batch_names': list(map(str, self.batch_names)),
'mods': list(self.mods),
'batch_key': batch_key,
'scmidas_version': '0.3',
}
with open(p / 'setup.json', 'w') as f:
json.dump(setup, f, indent=2)
logger.info('MIDAS saved to %s', str(p))
[docs]
@classmethod
def load(
cls,
dir_path: str,
mdata: "MuData",
**kwargs: Any,
) -> 'MIDAS':
"""Load a MIDAS model previously saved via :meth:`save`.
If ``mdata`` has not been registered (no ``_scmidas`` in
``mdata.uns``), this method auto-registers it using the
``batch_key`` saved alongside the model.
Parameters:
dir_path : str
Directory written by :meth:`save`.
mdata : MuData
Multi-modal data the model was trained on (or a query
dataset with the same modality structure).
**kwargs : Any
Forwarded to :class:`MIDAS` (e.g. ``batch_size``,
``save_model_path``).
Returns:
MIDAS:
A model with the saved weights loaded.
"""
import json
from pathlib import Path
p = Path(dir_path)
model_pt = p / 'model.pt'
setup_json = p / 'setup.json'
if not model_pt.exists():
raise FileNotFoundError(f"Missing {model_pt}; not a MIDAS save directory.")
if '_scmidas' not in mdata.uns:
if not setup_json.exists():
raise RuntimeError(
f"mdata is not registered and {setup_json} is missing. "
"Call scmidas.MIDAS.setup_mudata(mdata, batch_key=...) "
"before MIDAS.load(...)."
)
with open(setup_json) as f:
saved = json.load(f)
cls.setup_mudata(mdata, batch_key=saved.get('batch_key', 'batch'))
state = torch.load(str(model_pt), weights_only=False)
model = cls(mdata, configs=state.get('configs'), **kwargs)
model.net.load_state_dict(state['net'])
model.dsc.load_state_dict(state['dsc'])
if state.get('optim_net') is not None:
model.load_optimizer_state = True
model.loaded_net_optim_state = state['optim_net']
model.loaded_dsc_optim_state = state['optim_dsc']
model.start_epoch = int(state.get('epoch', 0))
logger.info('MIDAS loaded from %s (epoch=%d)', str(p), model.start_epoch)
return model
[docs]
@rank_zero_only
def save_checkpoint(self, checkpoint_path: str):
"""
Save the current model and optimizer states to a checkpoint file.
Parameters:
checkpoint_path : str
Path to save the checkpoint file.
Raises:
ValueError:
If `checkpoint_path` is an invalid or empty string.
"""
# Validate the output path
if not checkpoint_path or not isinstance(checkpoint_path, str):
raise ValueError('Invalid checkpoint path. Please provide a valid string.')
# Create a state dictionary with model and optimizer states
checkpoint_data = {
'net': self.net.state_dict(), # State dictionary of the main model
'dsc': self.dsc.state_dict(), # State dictionary of the discriminator
'optim_net': self.net_optim.state_dict(), # State dictionary of the main optimizer
'optim_dsc': self.dsc_optim.state_dict() # State dictionary of the discriminator optimizer
}
# Save the state dictionary to the specified path
torch.save(checkpoint_data, checkpoint_path)
# Inform the user of successful save
logger.info(f'Checkpoint successfully saved to "{checkpoint_path}".')
[docs]
def load_checkpoint(self, checkpoint_path: str, start_epoch: int = 0, **kwargs):
"""
Load model and optimizer states from a checkpoint file.
Parameters:
checkpoint_path : str
Path to the checkpoint file containing saved model and optimizer states.
start_epoch: int
Indicate how many epoch the model has been trained.
kwargs : Dict[str, Any]
Additional configurations for torch.load().
Raises:
AssertionError:
If the provided checkpoint path does not exist.
"""
# Verify the checkpoint path exists
assert os.path.exists(checkpoint_path), f'Checkpoint path "{checkpoint_path}" does not exist.'
# Load the checkpoint file
checkpoint_data = torch.load(checkpoint_path, weights_only=True, **kwargs)
# Load the model state dictionaries
self.net.load_state_dict(checkpoint_data['net'])
self.dsc.load_state_dict(checkpoint_data['dsc'])
# Load the optimizer state dictionaries
self.load_optimizer_state = True
self.loaded_net_optim_state = checkpoint_data['optim_net']
self.loaded_dsc_optim_state = checkpoint_data['optim_dsc']
self.start_epoch = start_epoch # influence saving name of checkpoints
[docs]
@rank_zero_only
def get_emb_umap(
self,
pred_dir: str = None,
pred_format: str = None, #'npy' or 'csv'
save_dir: str = None,
drop_c_umap: bool = False,
drop_u_umap: bool = False,
color_by: str = "batch", # NEW: "batch" (default) or "s_joint" or any obs column you add
n_obs: int = None,
verbose=True,
**kwargs
) -> Tuple[List[sc.AnnData], List[plt.Figure]]:
"""
Generate UMAP visualizations for biological and technical latent embeddings.
This function loads predicted latent representations and computes UMAP
embeddings for visualization. Two types of embeddings are supported:
1. Biological embedding (`z_c`)
2. Technical embedding (`z_u`)
For large datasets, the function can optionally subsample observations to
accelerate UMAP computation.
Parameters:
pred_dir : str, optional
Directory containing predicted results generated by `predict()` or
`predict_streaming()`. If None, predictions will be generated on-the-fly
using `self.predict()`.
pred_format : {"npy", "csv"}, optional
File format of saved prediction files when loading from disk.
Only used when `pred_dir` is provided.
save_dir : str, optional
Directory to save the generated UMAP figures. If None, figures will not
be written to disk.
drop_c_umap : bool, default=False
Whether to skip UMAP visualization for the biological embedding (`z_c`).
drop_u_umap : bool, default=False
Whether to skip UMAP visualization for the technical embedding (`z_u`).
color_by : str, default="batch"
Column name in `adata.obs` used to color cells in UMAP plots.
Common options include:
- `"batch"` : batch label
- `"s_joint"` : subset or dataset identifier
- any other metadata column stored in `adata.obs`
n_obs : int, optional
Number of observations to randomly subsample before computing UMAP.
Useful for large datasets to speed up visualization. If None, all
observations will be used.
verbose : bool, default=True
Whether to display progress bars and logging information.
**kwargs : Dict[str, Any]
Additional keyword arguments passed to `scanpy.pl.umap()`.
Returns:
all_adata : List[AnnData]
List of AnnData objects containing the computed UMAP embeddings.
all_figures : List[matplotlib.figure.Figure]
List of generated UMAP figure objects.
Notes:
- UMAP is computed using `scanpy.pp.neighbors()` followed by
`scanpy.tl.umap()`.
- The biological embedding (`z_c`) captures biological variation,
while the technical embedding (`z_u`) reflects batch or technical effects.
- For very large datasets (e.g., >1M cells), it is recommended to set
`n_obs` (e.g., 20,000) to reduce computation time.
Examples
--------
Generate UMAP from saved predictions:
>>> model.get_emb_umap(pred_dir="./predictions")
Generate UMAP with subsampling and custom coloring:
>>> model.get_emb_umap(
... pred_dir="./predictions",
... n_obs=20000,
... color_by="batch"
... )
Generate UMAP and save figures:
>>> model.get_emb_umap(
... pred_dir="./predictions",
... save_dir="./figs"
... )
"""
def _unwrap_pred(p: Any) -> Dict[str, Any]:
# If predict returned {"memory": ..., "disk": ...}
if isinstance(p, dict) and "memory" in p and isinstance(p["memory"], dict):
return p["memory"]
return p
if verbose:
logger.info(f"Loading predicted data from: {pred_dir}")
if pred_dir is not None:
# IMPORTANT: adapt this call to your actual loader signature.
# If you're using the loader we discussed earlier, it would be something like:
pred = load_predicted(pred_dir, save_format=pred_format, dim_c=self.dim_c, split_z=True, var_names=['z'])
#
# If your project already has load_predicted(pred_dir, self.combs, mtx=use_mtx), keep it.
# pred = load_predicted(pred_dir, dim_c=self.dim_c, split_z=True) # <- adjust if needed
else:
# Use the new streaming predict (in-memory) by default
pred = self.predict(
return_in_memory=True,
save_dir=None,
joint_latent=True,
mod_latent=False,
impute=False,
batch_correct=False,
translate=False,
input=False,
verbose=verbose
)
pred = _unwrap_pred(pred)
# pred is expected to be: {batch_name: {...}, batch_name2: {...}}
if not isinstance(pred, dict) or len(pred) == 0:
raise ValueError("Empty prediction results.")
# ----------------------------
# Concatenate z_c/z_u across batches
# ----------------------------
zc_list, zu_list = [], []
batch_labels = []
s_joint_labels = [] # optional
for batch_name, data in pred.items():
if "z_c" not in data or "joint" not in data["z_c"]:
raise KeyError(f"Missing z_c/joint in batch '{batch_name}'")
if "z_u" not in data or "joint" not in data["z_u"]:
raise KeyError(f"Missing z_u/joint in batch '{batch_name}'")
zc = data["z_c"]["joint"]
zu = data["z_u"]["joint"]
# allow torch or numpy
if hasattr(zc, "detach"):
zc = zc.detach().cpu().numpy()
if hasattr(zu, "detach"):
zu = zu.detach().cpu().numpy()
n = zc.shape[0]
zc_list.append(zc)
zu_list.append(zu)
batch_labels.append(np.array([batch_name] * n, dtype=object))
# optional: keep s['joint'] if present (useful for coloring)
sj = None
if "s" in data and isinstance(data["s"], dict):
# common key names: 'joint' or first key
if "joint" in data["s"]:
sj = data["s"]["joint"]
else:
sj = next(iter(data["s"].values()))
if sj is not None:
if hasattr(sj, "detach"):
sj = sj.detach().cpu().numpy()
sj = np.asarray(sj).reshape(-1)
if sj.shape[0] == n:
s_joint_labels.append(sj.astype(int).astype(str))
bio_embedding = np.concatenate(zc_list, axis=0)
tech_embedding = np.concatenate(zu_list, axis=0)
batch_labels = np.concatenate(batch_labels, axis=0)
s_joint_labels = np.concatenate(s_joint_labels, axis=0) if len(s_joint_labels) else None
# ----------------------------
# Build UMAPs
# ----------------------------
all_adata: List[sc.AnnData] = []
all_figures: List[plt.Figure] = []
file_names = ["biological_information.png", "technical_noise.png"]
embeddings = [bio_embedding, tech_embedding]
for index, (embedding, file_name) in enumerate(zip(embeddings, file_names)):
if file_name == "biological_information.png" and drop_c_umap:
logger.info("Skipping biological embedding UMAP generation (drop_c_umap=True).")
continue
if file_name == "technical_noise.png" and drop_u_umap:
logger.info("Skipping technical embedding UMAP generation (drop_u_umap=True).")
continue
if verbose:
logger.info(f"Processing {'biological' if index == 0 else 'technical'} embedding...")
adata = sc.AnnData(embedding)
adata.obs["batch"] = batch_labels
if s_joint_labels is not None:
adata.obs["s_joint"] = s_joint_labels
# neighbors + umap (use the embedding directly as X)
if verbose:
logger.info(" - Computing neighbors...")
if n_obs:
sc.pp.subsample(adata, n_obs=min(len(adata), n_obs))
sc.pp.neighbors(adata, n_neighbors=30, use_rep="X") # X is already embedding
if verbose:
logger.info(" - Computing UMAP...")
sc.tl.umap(adata)
# pick color
plot_color = color_by
if plot_color is not None and plot_color not in adata.obs.columns:
logger.warning(
f"color_by='{plot_color}' not found in adata.obs. "
f"Available: {list(adata.obs.columns)}. Falling back to 'batch'."
)
plot_color = "batch"
if verbose:
logger.info(f" - Generating UMAP plot for {file_name}...")
fig = sc.pl.umap(
adata,
title=file_name[:-4],
color=plot_color,
show=False,
return_fig=True,
**kwargs,
)
all_figures.append(fig)
if save_dir:
fig_save_path = os.path.join(save_dir, "figs", f"epoch_{self.current_epoch + self.start_epoch + 1}_"+file_name)
os.makedirs(os.path.dirname(fig_save_path), exist_ok=True)
fig.savefig(fig_save_path, dpi=200, bbox_inches="tight")
if verbose:
logger.info(f" - UMAP plot saved to: {fig_save_path}")
if getattr(self, "logger", None) is not None and getattr(self, "viz_umap_tb", False):
self.logger.experiment.add_figure(file_name, fig, self.current_epoch + self.start_epoch)
all_adata.append(adata)
if verbose:
logger.info("UMAP generation completed.")
return all_adata, all_figures
[docs]
def log_losses(self,
recon_loss: torch.Tensor,
kld_loss, consistency_loss: torch.Tensor,
loss_net: torch.Tensor,
loss_dsc: torch.Tensor,
recon_dict: Dict[str, torch.Tensor]):
"""
Log losses for monitoring and debugging during training.
Parameters:
recon_loss : torch.Tensor
Reconstruction loss.
kld_loss : torch.Tensor
KLD loss.
consistency_loss : torch.Tensor
Consistency loss.
recon_dict : Dict[str, torch.Tensor]
Per-modality reconstruction losses.
loss_net : torch.Tensor
Total VAE loss.
loss_dsc : torch.Tensor
Discriminator loss.
"""
self.log_dict(
{
'loss_/recon_loss': recon_loss,
'loss_/kld_loss': kld_loss,
'loss_/consistency_loss': consistency_loss,
'loss/net': loss_net,
'loss/dsc':loss_dsc
},
prog_bar=True,
on_epoch=True,
sync_dist=True,
)
self.log_dict(recon_dict, on_epoch=True, sync_dist=True)
[docs]
@staticmethod
def update_model(
loss: torch.Tensor,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
grad_clip: int=-1):
"""
Update model parameters using backpropagation.
Parameters:
loss : torch.Tensor
Computed loss for backpropagation.
model : torch.nn.Module
Model to update.
optimizer : torch.optim.Optimizer
Optimizer for parameter updates.
grad_clip : int
True to allow clipping gradient.
"""
optimizer.zero_grad()
loss.backward()
if grad_clip > 0:
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
[docs]
@staticmethod
def calc_dsc_loss(pred: Dict[str, torch.Tensor], true: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculate the discriminator loss using cross-entropy.
Parameters:
pred : Dict[str, torch.Tensor]
Predicted logits for each modality.
true : Dict[str, torch.Tensor]
Ground truth labels for each modality.
Returns:
torch.Tensor :
Computed discriminator loss.
"""
cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum') # Cross-entropy loss
loss = {}
# Compute loss for each modality
for modality in pred:
loss[modality] = cross_entropy_loss(pred[modality], true[modality].squeeze(1))
# Normalize total loss by batch size
total_loss = sum(loss.values()) / pred['joint'].size(0)
return total_loss
[docs]
@staticmethod
def calc_kld_z_loss(dim_c: int,
dim_u: int,
lam_kld_c: float,
lam_kld_u: float,
mu: torch.Tensor,
logvar: torch.Tensor) -> torch.Tensor:
"""
Calculate the Kullback-Leibler Divergence (KLD) loss for latent variables z.
Parameters:
dim_c : int
Dimension of the biological latent space.
dim_u : int
Dimension of the technical latent space.
lam_kld_c : float
Weight for KLD loss of the biological latent space.
lam_kld_u : float
Weight for KLD loss of the technical latent space.
mu : torch.Tensor
Mean of the latent variable distribution (batch_size x (dim_c + dim_u)).
logvar : torch.Tensor
Log-variance of the latent variable distribution (batch_size x (dim_c + dim_u)).
Returns:
torch.Tensor:
Weighted sum of KLD losses for the biological and technical latent spaces.
"""
# Split the mean and log-variance into biological (c) and technical (u) components
mu_c, mu_u = mu.split([dim_c, dim_u], dim=1)
logvar_c, logvar_u = logvar.split([dim_c, dim_u], dim=1)
# Calculate KLD losses for biological and technical latent spaces
kld_c_loss = MIDAS.calc_kld_loss(mu_c, logvar_c)
kld_u_loss = MIDAS.calc_kld_loss(mu_u, logvar_u)
# Combine the losses with their respective weights
kld_z_loss = kld_c_loss * lam_kld_c + kld_u_loss * lam_kld_u
return kld_z_loss
[docs]
@staticmethod
def calc_kld_loss(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
Calculate the KLD loss for a single latent space.
Parameters:
mu : torch.Tensor
Mean of the latent variable distribution (batch_size x latent_dim).
logvar : torch.Tensor
Log-variance of the latent variable distribution (batch_size x latent_dim).
Returns:
torch.Tensor :
KLD loss for the latent space, normalized by batch size.
"""
# KLD loss formula: -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.size(0)
return kld_loss
[docs]
@staticmethod
def calc_consistency_loss(z_uni: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculate the consistency loss for unified latent variables across modalities.
Parameters:
z_uni : Dict[str, torch.Tensor]
Dictionary of unified latent variables for each modality, where each value is a
tensor of shape (batch_size x latent_dim).
Returns:
torch.Tensor :
Consistency loss computed as the variance of the unified latent variables.
"""
# Stack the unified latent variables along a new dimension (modalities)
z_uni_stack = torch.stack(list(z_uni.values()), dim=0) # Shape: M x N x K (M=modalities, N=batch_size, K=latent_dim)
# Calculate the mean across modalities
z_uni_mean = z_uni_stack.mean(0, keepdim=True) # Shape: 1 x N x K
# Consistency loss is the variance across modalities
consistency_loss = ((z_uni_stack - z_uni_mean) ** 2).sum() / z_uni_stack.size(1) # Normalize by batch size
return consistency_loss
[docs]
@staticmethod
def calc_recon_loss(
x: Dict[str, torch.Tensor],
s: torch.Tensor,
e: Dict[str, torch.Tensor],
x_r_pre: Dict[str, torch.Tensor],
s_r_pre: Dict[str, torch.Tensor],
dist: Dict[str, str],
lam: Dict[str, float]
) -> Tuple[float, Dict[torch.Tensor, torch.Tensor]]:
"""
Calculate the reconstruction loss for input data and predicted outputs.
Parameters:
x : Dict[str, torch.Tensor]
Original input data for each modality (x^m).
s : torch.Tensor
Ground truth batch labels.
e : Dict[str, torch.Tensor]
Mask.
x_r_pre : Dict[str, torch.Tensor]
Reconstructed predictions for each modality (x_r^m).
s_r_pre : Dict[str, torch.Tensor]
Reconstructed predictions for batch labels.
dist : Dict[str, str]
Dictionary specifying the distribution type for each modality's decoder.
lam : Dict[str, float]
Dictionary containing reconstruction loss weights for each modality and for s.
Returns:
Tuple:
- total_loss : torch.Tensor
Total reconstruction loss, normalized by batch size.
- losses : Dict[str, torch.Tensor]
Dictionary containing reconstruction losses for each modality and for batch labels.
"""
losses = {}
# Compute reconstruction loss for each modality
for modality, x_original in x.items():
# Get the appropriate loss function based on the modality's decoder distribution
loss_fn = distribution_registry.get_loss(dist[f'distribution_dec_{modality}'])
# Check if there is an event-specific mask for the modality
if modality in e:
# Apply event-specific mask to the reconstruction loss
losses[f'recon_loss/{modality}'] = (
loss_fn(x_r_pre[modality], x_original) * e[modality]
).sum() * lam[f'lam_recon_{modality}']
else:
# Compute the reconstruction loss without a mask
losses[f'recon_loss/{modality}'] = (
loss_fn(x_r_pre[modality], x_original)
).sum() * lam[f'lam_recon_{modality}']
# Compute reconstruction loss for batch labels, if provided
if s_r_pre is not None:
# Use cross-entropy loss for batch label reconstruction
losses['recon_loss/s'] = (
distribution_registry.get_loss('CE')(s_r_pre, s.squeeze(1))
).sum() * lam['lam_recon_s']
# Normalize total loss by the batch size
total_loss = sum(losses.values()) / s.size(0)
return total_loss, losses
[docs]
@staticmethod
def get_info_from_mdata(mdata, batch_key='batch'):
batch_names = []
for k in mdata.mod.keys():
batch_names.extend(np.unique(mdata[k].obs[batch_key]).tolist())
batch_names = np.unique(batch_names)
data = []
mask = []
for b in batch_names:
t = {}
mt = {}
for m in mdata.mod.keys():
if b in mdata[m].obs[batch_key].values:
t[m] = mdata[m][mdata[m].obs[batch_key]==b]
if f'mask_{b}' in mdata[m].uns:
mt[m] = mdata[m].uns[f'mask_{b}']
data.append(t)
mask.append(mt)
return data, mask, batch_names
[docs]
@staticmethod
def get_info_from_dir(dir_path: str, format: str):
"""
Extract data, mask, and feature dimensions from a directory of vectors.
Parameters:
dir_path : str
Path to the directory containing data and mask files.
format : str
Support 'mtx', 'csv', and 'vec'.
Returns:
Tuple:
- data : List[Dict[str, str]]
List of dictionaries where keys are modalities and values are file paths.
- mask : List[Dict[str, str]]
List of dictionaries where keys are modalities and values are mask file paths.
- dims_x : Dict[str, list]
Dictionary containing feature dimensions for each modality.
Notes:
The directory should be organized as::
dataset/
feat/
# Dimensions of each modality: {mod1=[...], mod2=[...]}.
# Split the data into chunks if the length of the list is greater than 1.
# For instance, you can split the ATAC data by chromosomes.
feat_dims.toml
batch_0/
mask/mod1.csv
mask/mod2.csv
vec/mod1/0000.csv # the first sample
vec/mod1/0001.csv # the second sample
....
vec/mod2/0000.csv
vec/mod2/0001.csv
....
batch_1/
mask/mod1.csv
mask/mod2.csv
vec/mod1/0000.csv
vec/mod1/0001.csv
....
vec/mod2/0000.csv
vec/mod2/0001.csv
....
or like::
dataset/
feat/
# Dimensions of each modality: {mod1=[...], mod2=[...]}.
# Split the data into chunks if the length of the list is greater than 1.
# For instance, you can split the ATAC data by chromosomes.
feat_dims.toml
batch_0/
mask/mod1.csv
mask/mod2.csv
mat/mod1.mtx (.csv)
mat/mod2.mtx (.csv)
....
batch_1/
mask/mod1.csv
mask/mod2.csv
mat/mod1.mtx (.csv)
mat/mod2.mtx (.csv)
....
"""
data = [] # List to store data file paths
mask = [] # List to store mask file paths
batch_names = []
for batch_dir in natsort.natsorted(os.listdir(dir_path)):
if batch_dir != 'feat': # Ignore the 'feat' directory
data_batch = {}
mask_batch = {}
batch_path = os.path.join(dir_path, batch_dir)
batch_names.append(batch_dir)
# Collect file paths for data and masks
if format == 'vec':
if os.path.exists(batch_path):
vec_dir = os.path.join(batch_path, 'vec')
mask_dir = os.path.join(batch_path, 'mask')
for file in os.listdir(vec_dir):
data_batch[file] = os.path.join(vec_dir, file)
for file in os.listdir(mask_dir):
mask_batch[file[:-4]] = os.path.join(mask_dir, file)
elif format in ['csv', 'mtx']:
if os.path.exists(batch_path):
mat_dir = os.path.join(batch_path, 'mat')
mask_dir = os.path.join(batch_path, 'mask')
for file in os.listdir(mat_dir):
data_batch[file[:-4]] = os.path.join(mat_dir, file)
for file in os.listdir(mask_dir):
mask_batch[file[:-4]] = os.path.join(mask_dir, file)
data.append(data_batch)
mask.append(mask_batch)
# Load feature dimensions from 'feat_dims.toml'
dims_x = toml.load(os.path.join(dir_path, 'feat', 'feat_dims.toml'))
return data, mask, dims_x, batch_names
[docs]
@staticmethod
def get_datasets_from_dir(
data: List[Dict[str, str]],
mask: List[Dict[str, str]],
batch_names: List[str],
transform: Dict[str, str]=None,
format: str='mtx'):
"""
Configure data from directory.
Parameters:
data : List[Dict[str, str]]
List of data dictionaries, where keys are modalities and values are file paths.
mask : List[Dict[str, str]]
List of mask dictionaries, where keys are modalities and values are mask file paths.
batch_name : List[str]
List of batch names.
transform : Optional[Dict[str, str]]
Transformations to apply to specific modalities.
format : str
File type of the input data, default is 'vec'. ['vec', 'mtx', 'csv']
Returns:
Tuple:
- datasets : List[MultiModalDataset]
List of initialized `MultiModalDataset` objects.
- dims_s : Dict[str, int]
Dimensions for batch correction for each modality.
- s_joint : List[Dict[str, int]]
Modality indices for each batch.
- combs : List[List[str]]
List of modality combinations for each batch.
"""
s_joint = [] # Modality indices for each batch
n_s = {} # Counter for each modality
combs = [] # Modality combinations for each batch
datasets = [] # List of datasets
dims_s = {} # Dimensions for batch correction
for i, batch_data in enumerate(data):
batch_s = {} # Store batch-specific indices
batch_combs = [] # Modality combination for the current batch
# Assign batch index for each modality
for modality in batch_data.keys():
if modality in n_s:
batch_s[modality] = n_s[modality] + 1
n_s[modality] += 1
else:
batch_s[modality] = 0
n_s[modality] = 0
batch_combs.append(modality)
# Add joint batch information
batch_s['joint'] = i
n_s['joint'] = i
s_joint.append(batch_s)
combs.append(batch_combs)
# Determine file types for each modality
file_types = {
modality: format
for modality in batch_data.keys()
}
# Initialize MultiModalDataset
dataset = MultiModalDataset(batch_data, batch_s, file_types, mask[i], transform)
datasets.append(dataset)
# Define dimensions for batch correction
dims_s = {modality: count + 1 for modality, count in n_s.items()}
MIDAS.print_info(mask, datasets, batch_names)
return datasets, dims_s, s_joint, combs
[docs]
@staticmethod
def get_datasets_from_adata(
data: List[Dict[str, AnnData]],
mask: List[Dict[str, str]],
batch_names: List[str],
transform: Dict[str, str]=None):
"""
Configure data from a CSV input.
Parameters:
data : List[Dict[str, str]]
List of data dictionaries, where keys are modalities and values are adata object.
mask : List[Dict[str, str]]
List of mask dictionaries, where keys are modalities and values are mask values.
batch_name : List[str]
List of batch names.
transform : Optional[Dict[str, str]]
Transformations to apply to specific modalities.
format : str
File type of the input data, default is 'vec'. ['vec', 'mtx', 'csv']
Returns:
Tuple:
- datasets : List[MultiModalDataset]
List of initialized `MultiModalDataset` objects.
- dims_s : Dict[str, int]
Dimensions for batch correction for each modality.
- s_joint : List[Dict[str, int]]
Modality indices for each batch.
- combs : List[List[str]]
List of modality combinations for each batch.
"""
s_joint = [] # Modality indices for each batch
n_s = {} # Counter for each modality
combs = [] # Modality combinations for each batch
datasets = [] # List of datasets
dims_s = {} # Dimensions for batch correction
for i, batch_data in enumerate(data):
batch_s = {} # Store batch-specific indices
batch_combs = [] # Modality combination for the current batch
# Assign batch index for each modality
for modality in batch_data.keys():
if modality in n_s:
batch_s[modality] = n_s[modality] + 1
n_s[modality] += 1
else:
batch_s[modality] = 0
n_s[modality] = 0
batch_combs.append(modality)
# Add joint batch information
batch_s['joint'] = i
n_s['joint'] = i
s_joint.append(batch_s)
combs.append(batch_combs)
# Determine file types for each modality
file_types = {
modality: 'anndata'
for modality in batch_data.keys()
}
# Initialize MultiModalDataset
dataset = MultiModalDataset(batch_data, batch_s, file_types, mask[i], transform)
datasets.append(dataset)
# Define dimensions for batch correction
dims_s = {modality: count + 1 for modality, count in n_s.items()}
MIDAS.print_info(mask, datasets, batch_names)
return datasets, dims_s, s_joint, combs
[docs]
@staticmethod
@rank_zero_only
def print_info(mask: List[Dict[str, str]], datalist: List[Dataset], batch_names: List[str]):
"""
Print summary of mask density and dataset information.
Parameters:
mask : List[Dict[str, str]]
List of mask.
datalist : List[Dataset]
List of datasets.
batch_name : List[str]
List of batch names.
"""
# Calculate mask density for each batch
feature = []
valid_feature = []
for i, dataset in enumerate(datalist):
s1 = {}
s2 = {}
dataset = dataset[0]
mask_ = mask[i]
for m in dataset['x']:
s1['#%s'%m.upper()] = len(dataset['x'][m])
if m in mask_:
if isinstance(mask_[m], str):
t = pd.read_csv(mask_[m], index_col=0).values
else:
t = mask_[m]
s2['#VALID_'+m.upper()] = t.sum()
feature.append(s1)
valid_feature.append(s2)
valid_feature = pd.DataFrame(valid_feature)
valid_feature.index = batch_names
cell_number = pd.DataFrame({'#CELL':[len(dataset) for dataset in datalist]})
cell_number.index = batch_names
feature = pd.DataFrame(feature)
feature.index = batch_names
data = pd.concat([cell_number, feature, valid_feature], axis=1)
# Print summary
logger.info('Input data: \n' + data.to_string())