Visualizing the training process with TensorBoard

Visualizing the training process with TensorBoard#

PyTorch Lightning emits scalar logs for every training metric (recon, KLD, discriminator, adversarial, alignment, …). MIDAS hooks into this stream through L.Trainer(logger=...), so you can pipe the metrics into TensorBoard without modifying MIDAS itself.

Logging scalars#

Pass a TensorBoardLogger to L.Trainer, then forward the trainer to model.train via logger= (which itself is forwarded to Trainer). With the v0.3 API:

import scmidas
from lightning.pytorch import loggers as pl_loggers

scmidas.MIDAS.setup_mudata(mdata, batch_key='batch')
model = scmidas.MIDAS(mdata)

tb_logger = pl_loggers.TensorBoardLogger(
    save_dir='./logs/',
    version='my_run',  # any descriptive name
)
model.train(max_epochs=2000, logger=tb_logger)

Then in a terminal:

tensorboard --logdir ./logs/lightning_logs

Open the URL (e.g. http://localhost:6006) in a browser to watch the loss curves update live.

TensorBoard scalar dashboard

Logging UMAPs of the joint latent during training#

To also watch the joint biological latent z_c evolve as training proceeds, pass viz_umap_tb=True (and a checkpointing interval) to scmidas.MIDAS:

model = scmidas.MIDAS(
    mdata,
    viz_umap_tb=True,
    n_save=200,           # write a UMAP every 200 epochs
)
model.train(max_epochs=2000, logger=tb_logger)

The UMAP image is added to the TensorBoard Images tab on each save boundary.