Trainer rewrite: AMP, convergence-aware stopping, Runner.save(), clean adata separation#30
Open
jackytamkc wants to merge 4 commits intoZJUFanLab:mainfrom
Open
Trainer rewrite: AMP, convergence-aware stopping, Runner.save(), clean adata separation#30jackytamkc wants to merge 4 commits intoZJUFanLab:mainfrom
jackytamkc wants to merge 4 commits intoZJUFanLab:mainfrom
Conversation
- _train.py: replace `graph.adjacency_matrix().to_dense().shape[0]` with `graph.num_nodes()` at the two pos_weight sites. Same value, no N*N allocation, and avoids the torch-pinned libdgl_sparse_pytorch load that fails on torch 2.7 with the DGL 2.4 cu124 wheel. - README: add Blackwell / sm_120 install instructions alongside the existing cu113 path. Backward-compatible: num_nodes() works on all supported DGL versions; README changes are additive. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Motivation
----------
Transparency and no wasted epochs. The original training loop runs a
fixed number of epochs with no record of what actually happened inside
and no way to stop early when the model has already converged. With
hundreds of thousands of cells per slice this is expensive, and because
the total loss is dominated by a BCE floor on sparse adjacency
reconstruction, loss curves plateau long before the embedding
stabilises (and, conversely, small loss improvements don't necessarily
mean the clustering is still changing). Users need a convergence signal
that matches how the output is used.
Changes
-------
- _train.py (Runner + Runner_batch):
* Mixed precision: bf16 autocast around the GCN forward and
adjacency-reconstruction losses. The MIM / discriminator block is
kept in fp32 because `log(x + 1e-6)` near 1.0 is unsafe in bf16.
* Loss-based early stopping with `patience`, `min_delta`, and
optional `restore_best` to roll the weights back to the best epoch
before producing the final embedding.
* `ReduceLROnPlateau` scheduler (`lr_patience`, `lr_factor`,
`min_lr`) with verbose logging when the LR drops.
* ARI-based convergence probe (Runner_batch only): every
`ari_every` epochs, snapshot the embedding, KMeans-cluster it
(`ari_k`), and compute ARI vs. the previous snapshot. Stop when
`ari_stop_patience` consecutive probes stay above
`ari_stop_threshold`. This is the signal that matches downstream
use of the embedding, not the raw loss.
* Batch caching: move every batch to device once up front and reuse
the same graph/feat/adj tensors across epochs, eliminating the
per-epoch H2D transfer. The (multiprocess) dataloader is dropped
from `adata.uns` after caching so `adata.copy()` works downstream.
* Persist `loss`, `lr_history`, `ari_history`, and `early_stop`
metadata (best epoch, best loss, stopped_epoch, stopped_by,
restored_best) to `adata.uns` for post-hoc inspection.
* `_infer_embedding` helper shared by the ARI probe and the final
embedding pass.
- _model.py (MGAE.forward): replace the
`fused_adj -> cpu -> numpy -> scipy.csr -> dgl.from_scipy` roundtrip
with a GPU-native `dgl.graph((src, dst), num_nodes=...)`. Same
resulting graph, but stays on device and shaves a visible fraction
of the forward pass.
- _build.py: `num_workers=4` on the batch `GraphDataLoader`. Batches
are cached once, so this only costs extra workers for that one
upfront pass; for large datasets the caching step is noticeably
faster.
Backward compatibility
----------------------
All new `fit()` arguments default to values that reproduce the
original behaviour:
- `patience=None`, `ari_every=None`, `ari_stop_threshold=None` -> no
early stopping, no ARI probe.
- `lr_patience=5` with `lr_factor=0.5` does introduce an LR schedule
by default; pass `lr_patience=None` to keep the LR fixed at the
original value.
- `use_amp=True` by default; pass `use_amp=False` for the original
fp32 path.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Motivation
----------
`adata` should only carry trained results, not the model plumbing.
Previously `_build.py` wrote DGL graphs (`g_*`), the `GraphDataLoader`,
and `batch_idx` into `adata.uns` and the trainer read them back from
there — coupling the data container to the training lifecycle and
polluting any `adata.copy()` / `adata.write_h5ad()` afterwards with
objects that don't belong.
Changes
-------
- Runner.__init__ and Runner_batch.__init__ now pop the scaffolding
keys from `adata.uns` into `self.*`:
Runner: `g_<view>` -> self.graph
Runner_batch: `dataloader` -> self.dataloader,
`batch_idx` -> self.batch_idx
After the Runner is constructed, `adata.uns` is clean of model
plumbing; only subsequent training-result writes (`loss`,
`lr_history`, `ari_history`, `early_stop`, and `obsm['X_scniche']`)
remain.
- Runner_batch.fit now reads `batch_size` and the final reindex from
`self.batch_idx` instead of `adata.uns['batch_idx']`, and no longer
needs the explicit `adata.uns.pop('dataloader', None)` (it was never
there after __init__).
- New Runner.save(output_dir) / Runner_batch.save(output_dir):
{output_dir}/adata.h5ad — results-only AnnData
{output_dir}/model.pt — {model_state, discriminator_state,
config} with enough config to rebuild
MGAE + Discriminator for later
inference.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Motivation ---------- `adata` should carry only downstream-relevant trained parameters — enough to run niche clustering, enrichment plots, etc. — and nothing else. Previously `fit()` wrote `loss`, `lr_history`, `ari_history`, and `early_stop` into `adata.uns`, which mixes training diagnostics (which belong with the model) into the data object (which belongs to the analysis pipeline). Changes ------- - Runner.fit / Runner_batch.fit: write training diagnostics to `self.loss_history`, `self.lr_history`, `self.ari_history`, `self.early_stop` instead of `adata.uns[...]`. The only thing still written to `adata` is the embedding at `obsm['X_scniche']`. - Runner.save / Runner_batch.save: bundle those diagnostics into `model.pt` under a `training` key alongside the state dicts, so they are persisted with the model they describe. - Updated docstrings on `fit()` return value and the ARI probe. Result: the AnnData returned from `fit()` is downstream-ready — nothing in `adata.uns` is there because of training bookkeeping. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Two problems with the existing trainer:
prepare_data{,_batch}writes DGL graphs, theGraphDataLoader, andbatch_idxintoadata.uns, andfit()then writesloss/lr_history/early_stopthere too. The result:adata.copy()pulls in multiprocess dataloaders,write_h5ad()carries scaffolding it can't serialise cleanly, and the "trained AnnData" is not actually downstream-ready.This PR addresses both: stop on a signal that matches how the embedding is used, persist enough training metadata to reconstruct what happened, and cleanly separate the model (state dicts + diagnostics, saved as
model.pt) from the data (embedding only, saved asadata.h5ad).What changes
Commit 2 — Trainer rewrite (
9555191)torch.autocast(cuda, bf16)around the GCN forward and adjacency-reconstruction losses. The MIM / discriminator block stays fp32 —log(x + 1e-6)near 1.0 is unsafe in bf16.patience,min_delta, and optionalrestore_best(rolls weights back to the best-loss epoch before producing the final embedding).ReduceLROnPlateauscheduler (lr_patience,lr_factor,min_lr) with verbose logging when the LR drops.Runner_batchonly): everyari_everyepochs, snapshot the embedding, KMeans-cluster it (ari_k), and compute ARI vs. the previous snapshot. Stop whenari_stop_patienceconsecutive probes stay aboveari_stop_threshold. This is the signal that matches downstream use, not the raw loss._infer_embeddinghelper shared by the ARI probe and the final embedding pass._model.py/_build.py: MGAE.forward builds its fused graph GPU-natively (no scipy.csr roundtrip); the batchGraphDataLoadergetsnum_workers=4for the one-time upfront caching pass.Commit 3 — Keep model scaffolding off
adata(7006515)Runner.__init__popsg_<view>intoself.graph.Runner_batch.__init__popsdataloaderintoself.dataloaderandbatch_idxintoself.batch_idx.adata.unsis clean of model plumbing —adata.copy()/write_h5ad()work without needing workarounds likeadata.uns.pop('dataloader', None).Runner.save(output_dir)/Runner_batch.save(output_dir)writes:{output_dir}/adata.h5ad— results-only AnnData{output_dir}/model.pt—{model_state, discriminator_state, config}with enough config to rebuild MGAE + Discriminator for later inference.Commit 4 — Move training diagnostics off
adatatoo (ddfcf09)loss_history,lr_history,ari_history,early_stopnow live on the Runner (self.*) instead ofadata.uns.save()bundles them intomodel.ptunder atrainingkey, alongside the state dicts.fit()contains onlyobsm['X_scniche']as a training-related addition — nothing inadata.unsis there because of training bookkeeping. It is genuinely downstream-ready (niche clustering, enrichment plots, etc.) without any scaffolding or diagnostics noise.Backward compatibility
Defaults reproduce the old training behaviour as closely as possible:
patience=None,ari_every=None,ari_stop_threshold=None→ early stopping and ARI probe are off by default.lr_patience=5,lr_factor=0.5→ an LR schedule is on by default. Passlr_patience=Noneto hold the LR fixed.use_amp=Trueby default. Passuse_amp=Falsefor the original fp32 path.restore_best=True— weights revert to best-loss epoch before the final embedding pass.Breaking change (intentional): training diagnostics are no longer written to
adata.uns. Code that readsadata.uns['loss']/adata.uns['lr_history']/adata.uns['ari_history']/adata.uns['early_stop']needs to read frommodel.loss_history/model.lr_history/model.ari_history/model.early_stopinstead, or from the corresponding keys insidemodel.ptafterRunner.save(). This is the point of the separation — mixing them in previously was the bug.Test plan
Runner.fit()with all defaults on a small dataset: finishes inepochs,adata.obsm['X_scniche']sane,adata.unsuntouched by training.Runner.fit(use_amp=False, lr_patience=None, restore_best=False)produces the same embedding as before this PR on a fixed seed.Runner_batch.fit(ari_every=10, ari_stop_threshold=0.98, ari_stop_patience=2)on a large dataset: stops at the first pair of probes hitting ARI ≥ 0.98;model.early_stop['stopped_by'] == 'ari'.Runner_batch.fit(patience=10, min_delta=1e-4): stops on loss plateau;model.early_stop['stopped_by'] == 'loss'.Runner.save('out/')/Runner_batch.save('out/')writesadata.h5ad+model.pt;model.ptround-trips back into a rebuildable MGAE + Discriminator, and thetrainingblock contains loss/lr/ari/early_stop.adata.copy()works immediately afterRunner_batch(...)construction (was previously broken by the multiprocess dataloader sitting inadata.uns).🤖 Generated with Claude Code