Source code for scmidas.plot

"""Plotting helpers for MIDAS / MuData users.

The high-level helpers (:func:`umap`, :func:`modality_grid`) take a
:class:`MuData` directly and route through a temporary :class:`AnnData`
wrapper, side-stepping the current limitations of scanpy + MuData
plotting. They are exposed as both ``scmidas.plot.X`` and the shorter
``scmidas.pl.X``.

The AnnData-only helpers :func:`plot_umap`, :func:`plot_umap_grid`,
:func:`plot_z_umap_grid` are retained for backwards compatibility with
older tutorial code.
"""
from __future__ import annotations

import copy
import logging
from typing import Iterable, Optional, Sequence, Union

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc

logger = logging.getLogger(__name__)


[docs] def umap( mdata, *, basis: str = 'X_midas', color: Union[str, Sequence[str]] = 'batch', obs_keys: Optional[Sequence[str]] = None, n_neighbors: int = 30, min_dist: float = 0.3, random_state: int = 42, shuffle: bool = True, recompute: bool = False, **kwargs, ): """Compute and plot a UMAP from a MuData embedding (one-liner). Wraps the MuData's selected embedding as a thin AnnData and routes through ``sc.pl.umap``, avoiding the current limitations of scanpy + MuData plotting. Parameters: mdata : MuData Multi-modal data with an integration embedding in ``mdata.obsm[basis]`` (e.g. written by :meth:`~scmidas.model.MIDAS.get_latent_representation`). basis : str Key in ``mdata.obsm`` to use as the representation. Default ``'X_midas'`` matches :meth:`MIDAS.get_latent_representation`. color : str or sequence of str One or more ``mdata.obs`` columns to color by. Mirrors scanpy's ``color`` argument. obs_keys : sequence of str, optional Which ``mdata.obs`` columns to copy onto the temporary AnnData. Defaults to the union of ``color`` plus any keys referenced by scanpy kwargs that read from ``.obs``. n_neighbors, min_dist, random_state Forwarded to ``sc.pp.neighbors`` / ``sc.tl.umap``. shuffle : bool If True, shuffle cells before plotting (unbiased visual density when batches/cell types overlap). recompute : bool If False (default) and ``mdata.obsm['X_umap_<basis>']`` already exists, reuse it. If True, recompute UMAP. **kwargs Forwarded to ``sc.pl.umap``. Returns: AnnData: The temporary AnnData used for plotting (so callers can access ``.obsm['X_umap']`` if they want to keep it). """ if basis not in mdata.obsm: raise KeyError( f"mdata.obsm[{basis!r}] not found. Run " f"model.get_latent_representation() and assign it first." ) color_list = [color] if isinstance(color, str) else list(color) keep_cols = list(set((obs_keys or []) + color_list)) keep_cols = [c for c in keep_cols if c in mdata.obs.columns] obs = mdata.obs[keep_cols].copy() if keep_cols else None cache_key = f'X_umap_{basis}' have_cache = cache_key in mdata.obsm and not recompute adata = ad.AnnData(X=mdata.obsm[basis], obs=obs) adata.obs_names = mdata.obs_names if have_cache: adata.obsm['X_umap'] = mdata.obsm[cache_key] else: sc.pp.neighbors(adata, n_neighbors=n_neighbors, use_rep='X') sc.tl.umap(adata, random_state=random_state, min_dist=min_dist) mdata.obsm[cache_key] = adata.obsm['X_umap'] if shuffle: sc.pp.subsample(adata, fraction=1, random_state=random_state) sc.pl.umap(adata, color=color_list, **kwargs) return adata
[docs] def modality_grid( model, mdata, *, batch_key: str = 'batch', label_key: str = 'label', figsize: float = 2.0, point_size: float = 2.0, fontsize: int = 10, transpose: bool = False, random_state: int = 42, ): """Per-modality vs per-batch UMAP grid (joint c plus single-modality c). Internally runs ``model.predict(joint_latent=True, mod_latent=True)``, concatenates the per-modality biological latents, and tiles them as a (modality × batch) grid coloured by ``label_key``. The grid view answers: "does each modality on its own carry enough signal to separate cell types in each batch?" Parameters: model : MIDAS A constructed (and trained or checkpoint-loaded) MIDAS model. mdata : MuData The MuData passed to ``MIDAS(mdata)``. batch_key : str Column in ``mdata.obs`` (or ``mdata[m].obs``) identifying the batch. label_key : str Column in ``mdata.obs`` (or ``mdata[m].obs``) used for colouring. figsize, point_size, fontsize : float Per-subplot styling. transpose : bool If True, swap rows/columns (modality-by-batch instead of batch-by-modality). random_state : int Seed used by ``sc.tl.umap``. Returns: AnnData: The aggregated AnnData used for the plot. """ out = model.predict(joint_latent=True, mod_latent=True, verbose=False) pieces, types_, batches_, labels_ = [], [], [], [] for batch_id, b in enumerate(model.batch_names): block = out[b] first_mod = model.combs[batch_id][0] # cell IDs aligned to the FIRST modality in this batch sub_obs = mdata[first_mod].obs[mdata[first_mod].obs[batch_key].astype(str) == str(b)] sub_labels = sub_obs[label_key].astype(str).values if label_key in sub_obs.columns else np.full(len(sub_obs), '?') for k, z in block.get('z_c', {}).items(): if z.shape[0] != len(sub_obs): continue pieces.append(z) types_.append(np.full(z.shape[0], k.upper())) batches_.append(np.full(z.shape[0], b)) labels_.append(sub_labels) if not pieces: raise RuntimeError( "No per-modality latents to plot. Was the model trained?" ) X = np.concatenate(pieces) adata = ad.AnnData(X=X) adata.obs['type'] = pd.Categorical(np.concatenate(types_)) adata.obs[batch_key] = pd.Categorical(np.concatenate(batches_)) adata.obs[label_key] = pd.Categorical(np.concatenate(labels_)) sc.pp.neighbors(adata, use_rep='X') sc.tl.umap(adata, random_state=random_state) axis1 = batch_key if not transpose else 'type' axis2 = 'type' if not transpose else batch_key rows = adata.obs[axis1].cat.categories.tolist() cols = adata.obs[axis2].cat.categories.tolist() # Preferred display order for the modality axis: ATAC, RNA, ADT, JOINT # (any modalities outside this list keep their original ordering and # come last). _preferred = ['ATAC', 'RNA', 'ADT', 'JOINT'] def _reorder(types): front = [t for t in _preferred if t in types] rest = [t for t in types if t not in _preferred] return front + rest if transpose: rows = _reorder(rows) else: cols = _reorder(cols) fig, axes = plt.subplots(len(rows), len(cols), figsize=(figsize * len(cols), figsize * len(rows))) axes = np.atleast_2d(axes) fig_dummy, ax_dummy = plt.subplots() sc.pl.umap(adata, color=label_key, show=False, ax=ax_dummy) handles, leg_labels = ax_dummy.get_legend_handles_labels() plt.close(fig_dummy) for i, r in enumerate(rows): for j, c in enumerate(cols): ax = axes[i, j] sc.pl.umap(adata, show=False, ax=ax, s=point_size) # background sub = adata[(adata.obs[axis1] == r) & (adata.obs[axis2] == c)] if sub.n_obs > 0: sc.pl.umap(sub, color=label_key, show=False, ax=ax, s=point_size) if ax.get_legend(): ax.get_legend().set_visible(False) ax.set_xticks([]); ax.set_yticks([]) ax.set_xlabel(''); ax.set_title(c if i == 0 else '') ax.set_ylabel(r if j == 0 else '') fig.legend(handles, leg_labels, loc='center', bbox_to_anchor=(0.5, -0.02), ncol=len(leg_labels), fontsize=fontsize) plt.tight_layout(rect=[0.05, 0.05, 1, 1]) plt.show() return adata
# --------------------------------------------------------------------------- # Legacy AnnData-only helpers (used by the original tutorials) # ---------------------------------------------------------------------------
[docs] def plot_umap( adata, key='z_c_joint', do_pca=False, n_comps=32, color='batch', shuffle=True, **kwargs ): """ Computes and plots a UMAP for a single AnnData object based on a specific latent representation. This function allows for optional PCA preprocessing on the selected latent representation (stored in .obsm) before computing the neighborhood graph and UMAP. Args: adata (AnnData): The input annotated data matrix. key (str, optional): The key in `adata.obsm` to use as the representation. Defaults to 'z_c_joint'. do_pca (bool, optional): Whether to perform scaling and PCA on the representation before neighbor calculation. Defaults to False. n_comps (int, optional): The number of principal components to use if `do_pca` is True. Defaults to 32. color (str, optional): The key in `adata.obs` used to color the plot. Defaults to 'batch'. shuffle (bool, optional): Shuffle the samples. **kwargs: Additional keyword arguments passed to `sc.pl.umap`. Returns: None: Displays the plot. """ if do_pca: adata2 = sc.AnnData(adata.obsm[key]) adata2.obs = adata.obs sc.pp.scale(adata2) sc.pp.pca(adata2, n_comps=n_comps) key = 'X_pca' else: adata2 = copy.deepcopy(adata) if shuffle: sc.pp.subsample(adata2, fraction=1) sc.pp.neighbors(adata2, use_rep=key) sc.tl.umap(adata2) sc.pl.umap(adata2, color=color, **kwargs)
[docs] def plot_umap_grid(adata, axis1, axis2, color, figsize=2, point_size=2, fontsize=10, background=True): """ Plots a grid (facet plot) of UMAPs split by two categorical variables. This visualizes how specific groups (defined by axis1 and axis2) are distributed within the global UMAP space. Args: adata (AnnData): Annotated data matrix with pre-computed UMAP coordinates (`X_umap`). axis1 (str): Key in `adata.obs` defining the rows of the grid. axis2 (str): Key in `adata.obs` defining the columns of the grid. color (str): Key in `adata.obs` used for coloring the points. figsize (float, optional): The size (in inches) of each subplot. Defaults to 2. point_size (float, optional): The size of the scatter points. Defaults to 2. fontsize (int, optional): Font size for the legend. Defaults to 10. background (bool, optional): If True, plots all cells in grey in the background of each subplot to show the global structure. Defaults to True. Returns: None: Displays the plot. """ axis1_names = adata.obs[axis1].unique() axis2_names = adata.obs[axis2].unique() nrows = len(axis1_names) ncols = len(axis2_names) fig, ax = plt.subplots(nrows, ncols, figsize=[figsize * ncols, figsize * nrows]) fig_dummy, ax_dummy = plt.subplots() sc.pl.umap(adata, color=color, show=False, ax=ax_dummy) handles, labels_ = ax_dummy.get_legend_handles_labels() plt.close(fig_dummy) for i, k1 in enumerate(axis1_names): for j, k2 in enumerate(axis2_names): if background: sc.pl.umap(adata, show=False, ax=ax[i, j], s=point_size) # background sc.pl.umap(adata[(adata.obs[axis1]==k1) & (adata.obs[axis2]==k2)], color=color, show=False, ax=ax[i, j], s=point_size) ax[i, j].get_legend().set_visible(False) ax[i, j].set_xticks([]) ax[i, j].set_yticks([]) ax[i, j].set_xlabel('') if j==0: ax[i, j].set_ylabel(k1) else: ax[i, j].set_ylabel('') if i==0: ax[i, j].set_title(k2) else: ax[i, j].set_title('') # create global legend fig.legend(handles, labels_, loc='center', bbox_to_anchor=(0.5, -0.02), ncol=len(labels_), fontsize=fontsize) # adjust the figure plt.tight_layout(rect=[0.1, 0.05, 1, 1]) plt.show()
[docs] def plot_z_umap_grid(adata_list, batch_col='batch', color='label', figsize=2, point_size=2, fontsize=10, transpose=False): """ Aggregates latent representations from a dictionary of AnnData objects, computes a joint UMAP, and plots a grid view. It specifically looks for keys in `.obsm` starting with 'z_c', concatenates them, and re-computes the UMAP to visualize the alignment or distribution across different batches/types. Args: adata_list (dict): A dictionary where keys are batch identifiers and values are AnnData objects. batch_col (str, optional): Key in `adata.obs` identifying the batch/sample. Defaults to 'batch'. color (str, optional): Key in `adata.obs` used for coloring. Defaults to 'label'. figsize (float, optional): The size (in inches) of each subplot. Defaults to 2. point_size (float, optional): The size of the scatter points. Defaults to 2. fontsize (int, optional): Font size for the legend. Defaults to 10. transpose (bool, optional): If True, swaps the row and column axes of the grid (Batch vs. Type). Defaults to False. Returns: None: Displays the plot. """ data = [] axis1_ = [] axis2_ = [] label_ = [] for b, adata in adata_list.items(): for k in adata.obsm: if k.startswith('z_c'): data.append(adata.obsm[k]) axis1_.append(adata.obs[batch_col]) axis2_.append([k.split('_')[-1].upper() for i in range(len(adata))]) label_.append(adata.obs[color]) data = np.concatenate(data) axis1_ = np.concatenate(axis1_) axis2_ = np.concatenate(axis2_) label_ = np.concatenate(label_) adata = sc.AnnData(data) adata.obs['batch'] = axis1_ adata.obs['type'] = axis2_ adata.obs[color] = label_ sc.pp.neighbors(adata) sc.tl.umap(adata) axis1 = 'batch' if not transpose else 'type' axis2 = 'type' if not transpose else 'batch' axis1_names = adata.obs[axis1].unique() axis2_names = adata.obs[axis2].unique() nrows = len(axis1_names) ncols = len(axis2_names) fig, ax = plt.subplots(nrows, ncols, figsize=[figsize * ncols, figsize * nrows]) fig_dummy, ax_dummy = plt.subplots() sc.pl.umap(adata, color=color, show=False, ax=ax_dummy) handles, labels_ = ax_dummy.get_legend_handles_labels() plt.close(fig_dummy) for i, k1 in enumerate(axis1_names): for j, k2 in enumerate(axis2_names): sc.pl.umap(adata, show=False, ax=ax[i, j], s=point_size) # background sc.pl.umap(adata[(adata.obs[axis1]==k1) & (adata.obs[axis2]==k2)], color=color, show=False, ax=ax[i, j], s=point_size) ax[i, j].get_legend().set_visible(False) ax[i, j].set_xticks([]) ax[i, j].set_yticks([]) ax[i, j].set_xlabel('') if j==0: ax[i, j].set_ylabel(k1) else: ax[i, j].set_ylabel('') if i==0: ax[i, j].set_title(k2) else: ax[i, j].set_title('') # create global legend fig.legend(handles, labels_, loc='center', bbox_to_anchor=(0.5, -0.02), ncol=len(labels_), fontsize=fontsize) # adjust the figure plt.tight_layout(rect=[0.1, 0.05, 1, 1]) plt.show()