Skip to content

Trainer rewrite: AMP, convergence-aware stopping, Runner.save(), clean adata separation#30

Open
jackytamkc wants to merge 4 commits intoZJUFanLab:mainfrom
jackytamkc:pr2-trainer-rewrite
Open

Trainer rewrite: AMP, convergence-aware stopping, Runner.save(), clean adata separation#30
jackytamkc wants to merge 4 commits intoZJUFanLab:mainfrom
jackytamkc:pr2-trainer-rewrite

Conversation

@jackytamkc
Copy link
Copy Markdown

@jackytamkc jackytamkc commented Apr 21, 2026

Depends on #29. Until #29 merges, the diff shown here includes its two-line num_nodes() patch (commit 4c811f9). Once #29 merges, that commit becomes a no-op and only the three trainer commits remain.

Motivation

Two problems with the existing trainer:

  1. No convergence signal, no way to stop early. The loop runs a fixed number of epochs with no record of what happened inside. Because the total loss is dominated by a BCE floor on sparse adjacency reconstruction, loss curves plateau long before the embedding stops changing — and small late-epoch loss improvements don't necessarily correspond to any change in the downstream clustering. Users end up either training far too long or stopping too early based on a signal that doesn't reflect how the output is actually used.
  2. Training scaffolding bleeds into the data object. prepare_data{,_batch} writes DGL graphs, the GraphDataLoader, and batch_idx into adata.uns, and fit() then writes loss / lr_history / early_stop there too. The result: adata.copy() pulls in multiprocess dataloaders, write_h5ad() carries scaffolding it can't serialise cleanly, and the "trained AnnData" is not actually downstream-ready.

This PR addresses both: stop on a signal that matches how the embedding is used, persist enough training metadata to reconstruct what happened, and cleanly separate the model (state dicts + diagnostics, saved as model.pt) from the data (embedding only, saved as adata.h5ad).

What changes

Commit 2 — Trainer rewrite (9555191)

  • Mixed precision: torch.autocast(cuda, bf16) around the GCN forward and adjacency-reconstruction losses. The MIM / discriminator block stays fp32 — log(x + 1e-6) near 1.0 is unsafe in bf16.
  • Loss-based early stopping with patience, min_delta, and optional restore_best (rolls weights back to the best-loss epoch before producing the final embedding).
  • ReduceLROnPlateau scheduler (lr_patience, lr_factor, min_lr) with verbose logging when the LR drops.
  • ARI-based convergence probe (Runner_batch only): every ari_every epochs, snapshot the embedding, KMeans-cluster it (ari_k), and compute ARI vs. the previous snapshot. Stop when ari_stop_patience consecutive probes stay above ari_stop_threshold. This is the signal that matches downstream use, not the raw loss.
  • Batch caching: move every batch to device once up front and reuse the same graph/feat/adj tensors across epochs — eliminates the per-epoch H2D transfer.
  • _infer_embedding helper shared by the ARI probe and the final embedding pass.
  • _model.py / _build.py: MGAE.forward builds its fused graph GPU-natively (no scipy.csr roundtrip); the batch GraphDataLoader gets num_workers=4 for the one-time upfront caching pass.

Commit 3 — Keep model scaffolding off adata (7006515)

  • Runner.__init__ pops g_<view> into self.graph.
  • Runner_batch.__init__ pops dataloader into self.dataloader and batch_idx into self.batch_idx.
  • After Runner construction, adata.uns is clean of model plumbing — adata.copy() / write_h5ad() work without needing workarounds like adata.uns.pop('dataloader', None).
  • New Runner.save(output_dir) / Runner_batch.save(output_dir) writes:
    • {output_dir}/adata.h5ad — results-only AnnData
    • {output_dir}/model.pt{model_state, discriminator_state, config} with enough config to rebuild MGAE + Discriminator for later inference.

Commit 4 — Move training diagnostics off adata too (ddfcf09)

  • loss_history, lr_history, ari_history, early_stop now live on the Runner (self.*) instead of adata.uns.
  • save() bundles them into model.pt under a training key, alongside the state dicts.
  • Result: the AnnData returned from fit() contains only obsm['X_scniche'] as a training-related addition — nothing in adata.uns is there because of training bookkeeping. It is genuinely downstream-ready (niche clustering, enrichment plots, etc.) without any scaffolding or diagnostics noise.

Backward compatibility

Defaults reproduce the old training behaviour as closely as possible:

  • patience=None, ari_every=None, ari_stop_threshold=None → early stopping and ARI probe are off by default.
  • lr_patience=5, lr_factor=0.5 → an LR schedule is on by default. Pass lr_patience=None to hold the LR fixed.
  • use_amp=True by default. Pass use_amp=False for the original fp32 path.
  • restore_best=True — weights revert to best-loss epoch before the final embedding pass.

Breaking change (intentional): training diagnostics are no longer written to adata.uns. Code that reads adata.uns['loss'] / adata.uns['lr_history'] / adata.uns['ari_history'] / adata.uns['early_stop'] needs to read from model.loss_history / model.lr_history / model.ari_history / model.early_stop instead, or from the corresponding keys inside model.pt after Runner.save(). This is the point of the separation — mixing them in previously was the bug.

Test plan

  • Runner.fit() with all defaults on a small dataset: finishes in epochs, adata.obsm['X_scniche'] sane, adata.uns untouched by training.
  • Runner.fit(use_amp=False, lr_patience=None, restore_best=False) produces the same embedding as before this PR on a fixed seed.
  • Runner_batch.fit(ari_every=10, ari_stop_threshold=0.98, ari_stop_patience=2) on a large dataset: stops at the first pair of probes hitting ARI ≥ 0.98; model.early_stop['stopped_by'] == 'ari'.
  • Runner_batch.fit(patience=10, min_delta=1e-4): stops on loss plateau; model.early_stop['stopped_by'] == 'loss'.
  • Runner.save('out/') / Runner_batch.save('out/') writes adata.h5ad + model.pt; model.pt round-trips back into a rebuildable MGAE + Discriminator, and the training block contains loss/lr/ari/early_stop.
  • adata.copy() works immediately after Runner_batch(...) construction (was previously broken by the multiprocess dataloader sitting in adata.uns).

🤖 Generated with Claude Code

jackytamkc and others added 4 commits April 21, 2026 13:46
- _train.py: replace `graph.adjacency_matrix().to_dense().shape[0]`
  with `graph.num_nodes()` at the two pos_weight sites. Same value,
  no N*N allocation, and avoids the torch-pinned libdgl_sparse_pytorch
  load that fails on torch 2.7 with the DGL 2.4 cu124 wheel.
- README: add Blackwell / sm_120 install instructions alongside the
  existing cu113 path.

Backward-compatible: num_nodes() works on all supported DGL versions;
README changes are additive.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Motivation
----------
Transparency and no wasted epochs. The original training loop runs a
fixed number of epochs with no record of what actually happened inside
and no way to stop early when the model has already converged. With
hundreds of thousands of cells per slice this is expensive, and because
the total loss is dominated by a BCE floor on sparse adjacency
reconstruction, loss curves plateau long before the embedding
stabilises (and, conversely, small loss improvements don't necessarily
mean the clustering is still changing). Users need a convergence signal
that matches how the output is used.

Changes
-------
- _train.py (Runner + Runner_batch):
  * Mixed precision: bf16 autocast around the GCN forward and
    adjacency-reconstruction losses. The MIM / discriminator block is
    kept in fp32 because `log(x + 1e-6)` near 1.0 is unsafe in bf16.
  * Loss-based early stopping with `patience`, `min_delta`, and
    optional `restore_best` to roll the weights back to the best epoch
    before producing the final embedding.
  * `ReduceLROnPlateau` scheduler (`lr_patience`, `lr_factor`,
    `min_lr`) with verbose logging when the LR drops.
  * ARI-based convergence probe (Runner_batch only): every
    `ari_every` epochs, snapshot the embedding, KMeans-cluster it
    (`ari_k`), and compute ARI vs. the previous snapshot. Stop when
    `ari_stop_patience` consecutive probes stay above
    `ari_stop_threshold`. This is the signal that matches downstream
    use of the embedding, not the raw loss.
  * Batch caching: move every batch to device once up front and reuse
    the same graph/feat/adj tensors across epochs, eliminating the
    per-epoch H2D transfer. The (multiprocess) dataloader is dropped
    from `adata.uns` after caching so `adata.copy()` works downstream.
  * Persist `loss`, `lr_history`, `ari_history`, and `early_stop`
    metadata (best epoch, best loss, stopped_epoch, stopped_by,
    restored_best) to `adata.uns` for post-hoc inspection.
  * `_infer_embedding` helper shared by the ARI probe and the final
    embedding pass.
- _model.py (MGAE.forward): replace the
  `fused_adj -> cpu -> numpy -> scipy.csr -> dgl.from_scipy` roundtrip
  with a GPU-native `dgl.graph((src, dst), num_nodes=...)`. Same
  resulting graph, but stays on device and shaves a visible fraction
  of the forward pass.
- _build.py: `num_workers=4` on the batch `GraphDataLoader`. Batches
  are cached once, so this only costs extra workers for that one
  upfront pass; for large datasets the caching step is noticeably
  faster.

Backward compatibility
----------------------
All new `fit()` arguments default to values that reproduce the
original behaviour:
- `patience=None`, `ari_every=None`, `ari_stop_threshold=None` -> no
  early stopping, no ARI probe.
- `lr_patience=5` with `lr_factor=0.5` does introduce an LR schedule
  by default; pass `lr_patience=None` to keep the LR fixed at the
  original value.
- `use_amp=True` by default; pass `use_amp=False` for the original
  fp32 path.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Motivation
----------
`adata` should only carry trained results, not the model plumbing.
Previously `_build.py` wrote DGL graphs (`g_*`), the `GraphDataLoader`,
and `batch_idx` into `adata.uns` and the trainer read them back from
there — coupling the data container to the training lifecycle and
polluting any `adata.copy()` / `adata.write_h5ad()` afterwards with
objects that don't belong.

Changes
-------
- Runner.__init__ and Runner_batch.__init__ now pop the scaffolding
  keys from `adata.uns` into `self.*`:
    Runner:       `g_<view>` -> self.graph
    Runner_batch: `dataloader` -> self.dataloader,
                  `batch_idx`  -> self.batch_idx
  After the Runner is constructed, `adata.uns` is clean of model
  plumbing; only subsequent training-result writes (`loss`,
  `lr_history`, `ari_history`, `early_stop`, and `obsm['X_scniche']`)
  remain.
- Runner_batch.fit now reads `batch_size` and the final reindex from
  `self.batch_idx` instead of `adata.uns['batch_idx']`, and no longer
  needs the explicit `adata.uns.pop('dataloader', None)` (it was never
  there after __init__).
- New Runner.save(output_dir) / Runner_batch.save(output_dir):
    {output_dir}/adata.h5ad — results-only AnnData
    {output_dir}/model.pt   — {model_state, discriminator_state,
                               config} with enough config to rebuild
                               MGAE + Discriminator for later
                               inference.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Motivation
----------
`adata` should carry only downstream-relevant trained parameters —
enough to run niche clustering, enrichment plots, etc. — and nothing
else. Previously `fit()` wrote `loss`, `lr_history`, `ari_history`,
and `early_stop` into `adata.uns`, which mixes training diagnostics
(which belong with the model) into the data object (which belongs to
the analysis pipeline).

Changes
-------
- Runner.fit / Runner_batch.fit: write training diagnostics to
  `self.loss_history`, `self.lr_history`, `self.ari_history`,
  `self.early_stop` instead of `adata.uns[...]`. The only thing still
  written to `adata` is the embedding at `obsm['X_scniche']`.
- Runner.save / Runner_batch.save: bundle those diagnostics into
  `model.pt` under a `training` key alongside the state dicts, so they
  are persisted with the model they describe.
- Updated docstrings on `fit()` return value and the ARI probe.

Result: the AnnData returned from `fit()` is downstream-ready —
nothing in `adata.uns` is there because of training bookkeeping.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@jackytamkc jackytamkc changed the title Trainer: AMP, early stopping, LR schedule, batch caching, ARI convergence probe Trainer rewrite: AMP, convergence-aware stopping, Runner.save(), clean adata separation Apr 21, 2026
@jackytamkc jackytamkc marked this pull request as ready for review April 21, 2026 13:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant