From 4c811f98c8adfdf5aba45a797486a9ecb501d753 Mon Sep 17 00:00:00 2001 From: jackytamkc Date: Tue, 21 Apr 2026 13:46:42 +0100 Subject: [PATCH 1/4] Support Blackwell GPUs (torch 2.7 + cu128 + DGL 2.4 cu124) - _train.py: replace `graph.adjacency_matrix().to_dense().shape[0]` with `graph.num_nodes()` at the two pos_weight sites. Same value, no N*N allocation, and avoids the torch-pinned libdgl_sparse_pytorch load that fails on torch 2.7 with the DGL 2.4 cu124 wheel. - README: add Blackwell / sm_120 install instructions alongside the existing cu113 path. Backward-compatible: num_nodes() works on all supported DGL versions; README changes are additive. Co-Authored-By: Claude Opus 4.7 --- README.md | 16 ++++++++++++++++ scniche/trainer/_train.py | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 76f7399..109137d 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,22 @@ pip install dgl==1.1.0+cu113 -f https://data.dgl.ai/wheels/cu113/repo.html ``` The version of PyTorch and DGL should be suitable to the CUDA version of your machine. You can find the appropriate version on the [PyTorch](https://pytorch.org/get-started/locally/) and [DGL](https://www.dgl.ai/) website. +#### Blackwell GPUs (RTX PRO 6000, RTX 50-series, sm_120) +Blackwell requires PyTorch built against CUDA 12.8. No DGL wheel is currently published for cu128, but the cu124 wheel's CUDA kernels are ABI-compatible with torch 2.7 and run on sm_120 via JIT forwarding. Install in this order: + +``` +# 1. Blackwell-capable PyTorch +pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128 + +# 2. DGL (this step downgrades torch to 2.4.0 — expected) +pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html + +# 3. Restore torch 2.7 on top of DGL +pip install --force-reinstall --no-deps torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128 +``` + +`pip check` will warn that `dgl requires torch<=2.4.0`; this warning is safe to ignore for scNiche's code paths. The `DGLGraph.adjacency_matrix()` / `dgl.sparse` routines are not ABI-compatible across torch versions and will fail to load `libdgl_sparse_pytorch_.so` on torch 2.7 — scNiche no longer calls those routines, but downstream user code that does will need to be patched or the DGL wheel rebuilt from source. + ### Install scNiche ``` diff --git a/scniche/trainer/_train.py b/scniche/trainer/_train.py index 1bef903..1d6b75a 100644 --- a/scniche/trainer/_train.py +++ b/scniche/trainer/_train.py @@ -105,7 +105,7 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): optim = torch.optim.Adam(self.model.parameters(), lr=lr) # loss pos_weight = torch.Tensor( - [float(self.graph[0].adjacency_matrix().to_dense().shape[0] ** 2 - self.edges / 2) / self.edges * 2] + [float(self.graph[0].num_nodes() ** 2 - self.edges / 2) / self.edges * 2] ) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) criterion_m = torch.nn.MSELoss().to(self.device) @@ -231,7 +231,7 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): # loss pos_weight = torch.Tensor( - [float(graphs[0].adjacency_matrix().to_dense().shape[0] ** 2 - edges / 2) / edges * 2] + [float(graphs[0].num_nodes() ** 2 - edges / 2) / edges * 2] ) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) criterion_m = torch.nn.MSELoss().to(self.device) From 9555191226dec8a629ce8e812afb271d322172a1 Mon Sep 17 00:00:00 2001 From: jackytamkc Date: Tue, 21 Apr 2026 13:53:38 +0100 Subject: [PATCH 2/4] Trainer: AMP, early stopping, LR schedule, batch caching, ARI probe Motivation ---------- Transparency and no wasted epochs. The original training loop runs a fixed number of epochs with no record of what actually happened inside and no way to stop early when the model has already converged. With hundreds of thousands of cells per slice this is expensive, and because the total loss is dominated by a BCE floor on sparse adjacency reconstruction, loss curves plateau long before the embedding stabilises (and, conversely, small loss improvements don't necessarily mean the clustering is still changing). Users need a convergence signal that matches how the output is used. Changes ------- - _train.py (Runner + Runner_batch): * Mixed precision: bf16 autocast around the GCN forward and adjacency-reconstruction losses. The MIM / discriminator block is kept in fp32 because `log(x + 1e-6)` near 1.0 is unsafe in bf16. * Loss-based early stopping with `patience`, `min_delta`, and optional `restore_best` to roll the weights back to the best epoch before producing the final embedding. * `ReduceLROnPlateau` scheduler (`lr_patience`, `lr_factor`, `min_lr`) with verbose logging when the LR drops. * ARI-based convergence probe (Runner_batch only): every `ari_every` epochs, snapshot the embedding, KMeans-cluster it (`ari_k`), and compute ARI vs. the previous snapshot. Stop when `ari_stop_patience` consecutive probes stay above `ari_stop_threshold`. This is the signal that matches downstream use of the embedding, not the raw loss. * Batch caching: move every batch to device once up front and reuse the same graph/feat/adj tensors across epochs, eliminating the per-epoch H2D transfer. The (multiprocess) dataloader is dropped from `adata.uns` after caching so `adata.copy()` works downstream. * Persist `loss`, `lr_history`, `ari_history`, and `early_stop` metadata (best epoch, best loss, stopped_epoch, stopped_by, restored_best) to `adata.uns` for post-hoc inspection. * `_infer_embedding` helper shared by the ARI probe and the final embedding pass. - _model.py (MGAE.forward): replace the `fused_adj -> cpu -> numpy -> scipy.csr -> dgl.from_scipy` roundtrip with a GPU-native `dgl.graph((src, dst), num_nodes=...)`. Same resulting graph, but stays on device and shaves a visible fraction of the forward pass. - _build.py: `num_workers=4` on the batch `GraphDataLoader`. Batches are cached once, so this only costs extra workers for that one upfront pass; for large datasets the caching step is noticeably faster. Backward compatibility ---------------------- All new `fit()` arguments default to values that reproduce the original behaviour: - `patience=None`, `ari_every=None`, `ari_stop_threshold=None` -> no early stopping, no ARI probe. - `lr_patience=5` with `lr_factor=0.5` does introduce an LR schedule by default; pass `lr_patience=None` to keep the LR fixed at the original value. - `use_amp=True` by default; pass `use_amp=False` for the original fp32 path. Co-Authored-By: Claude Opus 4.7 --- scniche/preprocess/_build.py | 5 +- scniche/trainer/_model.py | 12 +- scniche/trainer/_train.py | 335 ++++++++++++++++++++++++++++------- 3 files changed, 283 insertions(+), 69 deletions(-) diff --git a/scniche/preprocess/_build.py b/scniche/preprocess/_build.py index b688be4..820db2e 100644 --- a/scniche/preprocess/_build.py +++ b/scniche/preprocess/_build.py @@ -124,7 +124,10 @@ def prepare_data_batch( print("Constructing done.") mydataset = myDataset(g_list) - dataloader = GraphDataLoader(mydataset, batch_size=1, shuffle=False, pin_memory=True) + dataloader = GraphDataLoader( + mydataset, batch_size=1, shuffle=False, pin_memory=True, + num_workers=4, + ) adata.uns['dataloader'] = dataloader return adata diff --git a/scniche/trainer/_model.py b/scniche/trainer/_model.py index 24564e4..89f86f3 100644 --- a/scniche/trainer/_model.py +++ b/scniche/trainer/_model.py @@ -3,7 +3,6 @@ # This project is licensed under the GPL-3.0 License. -import scipy.sparse as sp import torch import torch.nn as nn import torch.nn.functional as F @@ -59,11 +58,12 @@ def consensus_graph(self, graphs, device): fused_adj = torch.clamp(fused_adj, 0, 1) fused_adj = torch.round(fused_adj + 0.1) - # build symmetric adjacency matrix - adj_np = fused_adj.detach().cpu().numpy() - adj_np += adj_np.T - adj_sparse = sp.csr_matrix(adj_np) - g = dgl.from_scipy(adj_sparse, device=device) + # build symmetric adjacency matrix ON GPU (avoids GPU<->CPU<->scipy roundtrip) + with torch.no_grad(): + sym = fused_adj.detach() + sym = sym + sym.T + src, dst = (sym > 0).nonzero(as_tuple=True) + g = dgl.graph((src.int(), dst.int()), num_nodes=fused_adj.shape[0], device=device) return fused_adj, g diff --git a/scniche/trainer/_train.py b/scniche/trainer/_train.py index 1d6b75a..2361094 100644 --- a/scniche/trainer/_train.py +++ b/scniche/trainer/_train.py @@ -1,12 +1,29 @@ +import copy from anndata import AnnData from tqdm import tqdm import pandas as pd import numpy as np +from sklearn.cluster import KMeans +from sklearn.metrics import adjusted_rand_score from ._model import * from ._utils import shuffling from typing import Optional +def _infer_embedding(model, cached_batches, device, use_amp): + """Forward-pass cached batches in eval mode and return a (N_cells, latent) ndarray.""" + was_training = model.training + model.eval() + emb = [] + with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp): + for b in cached_batches: + _, _, z = model.forward(b['graphs'], b['feats'], device) + emb.append(z.float().data.cpu().numpy()) + if was_training: + model.train() + return np.stack(emb, axis=0) # (batches, cells_per_batch, latent) + + class Runner: def __init__( self, @@ -77,14 +94,39 @@ def __init__( )) print("Mutual Information Matrix Size for training: {}".format(self.mik.shape)) - def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): + def fit( + self, + lr: Optional[float] = 0.01, + epochs: Optional[int] = 100, + patience: Optional[int] = None, + min_delta: float = 1e-4, + restore_best: bool = True, + lr_patience: Optional[int] = 5, + lr_factor: float = 0.5, + min_lr: float = 1e-6, + use_amp: bool = True, + ): """ Fit the scNiche model. Args: lr: Optional[float] - Learning rate for the optimizer. + Initial learning rate for the optimizer. epochs: Optional[int] - Number of training epochs. + Maximum number of training epochs. + patience: Optional[int] + Early-stopping patience. Stop if loss has not improved by `min_delta` for `patience` consecutive epochs. If None, early stopping is disabled. + min_delta: float + Minimum loss improvement to reset the patience / LR-scheduler counters. + restore_best: bool + If True, restore the model weights from the epoch with the lowest training loss before producing the final embedding. + lr_patience: Optional[int] + ReduceLROnPlateau patience — halve the LR after this many epochs with no improvement. If None, LR is held constant. + lr_factor: float + Multiplicative factor applied to the LR when the plateau is hit (e.g. 0.5 halves the LR). + min_lr: float + Floor for the LR; the scheduler will not reduce below this value. + use_amp: bool + If True (default), run the GCN forward and adjacency-reconstruction losses under `torch.autocast(cuda, bf16)`. The discriminator/MIM block is kept in fp32 for numerical stability (log near 1.0). Returns: adata: AnnData Anndata object with learned embeddings (stored in `adata.obsm['X_scniche']`) and training loss (stored in `adata.uns['loss']`). @@ -103,6 +145,12 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): # optimizer optim = torch.optim.Adam(self.model.parameters(), lr=lr) + scheduler = None + if lr_patience is not None: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optim, mode='min', factor=lr_factor, patience=lr_patience, + threshold=min_delta, threshold_mode='abs', min_lr=min_lr, + ) # loss pos_weight = torch.Tensor( [float(self.graph[0].num_nodes() ** 2 - self.edges / 2) / self.edges * 2] @@ -115,30 +163,66 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): self.model.train() loss_all = [] + lr_all = [] + best_loss = float('inf') + best_state = None + best_epoch = -1 + no_improve = 0 + stopped_epoch = epochs pbar = tqdm(range(epochs)) for epoch in pbar: - adj_r, adj_logits, z = self.model.forward(self.graph, self.feat, self.device) - loss_gre = sum(criterion_m(adj_r, adj) for adj in self.adj) / self.views - loss_rec = sum(criterion(adj_logits[i], adj) for i, adj in enumerate(self.adj)) / self.views - - loss_mim = 0 - for i in range(self.mik.shape[1]): - z_shuf = shuffling(z, latent=self.hidden_size[-1], device=self.device) - z_comb = torch.cat((z, z_shuf), 1) - z_shuf_scores = self.model_d(z_comb) - z_idx = torch.cat((z, z[self.mik[:, i]]), 1) - z_scores = self.model_d(z_idx) - loss_mim += - torch.mean( - torch.log(z_scores + 1e-6) + torch.log(1 - z_shuf_scores + 1e-6) - ) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp): + adj_r, adj_logits, z = self.model.forward(self.graph, self.feat, self.device) + loss_gre = sum(criterion_m(adj_r, adj) for adj in self.adj) / self.views + loss_rec = sum(criterion(adj_logits[i], adj) for i, adj in enumerate(self.adj)) / self.views + + # keep MIM loss in fp32: log(x + 1e-6) near 1.0 is unsafe in bf16 + with torch.autocast(device_type='cuda', enabled=False): + z_fp32 = z.float() + loss_mim = 0 + for i in range(self.mik.shape[1]): + z_shuf = shuffling(z_fp32, latent=self.hidden_size[-1], device=self.device) + z_comb = torch.cat((z_fp32, z_shuf), 1) + z_shuf_scores = self.model_d(z_comb) + z_idx = torch.cat((z_fp32, z_fp32[self.mik[:, i]]), 1) + z_scores = self.model_d(z_idx) + loss_mim += - torch.mean( + torch.log(z_scores + 1e-6) + torch.log(1 - z_shuf_scores + 1e-6) + ) loss = loss_gre + loss_rec + loss_mim optim.zero_grad() loss.backward() optim.step() + current_lr = optim.param_groups[0]['lr'] + lr_all.append(current_lr) pbar.set_description('Train Epoch: {}'.format(epoch + 1)) - pbar.set_postfix(loss=f"{loss:.4f}") + pbar.set_postfix(loss=f"{loss:.4f}", lr=f"{current_lr:.2e}") loss_all.append(loss.data.cpu().numpy()) + loss_val = float(loss.item()) + if scheduler is not None: + prev_lr = current_lr + scheduler.step(loss_val) + new_lr = optim.param_groups[0]['lr'] + if self.verbose and new_lr < prev_lr: + print(f" epoch {epoch + 1}: LR reduced {prev_lr:.2e} -> {new_lr:.2e}") + if loss_val < best_loss - min_delta: + best_loss = loss_val + best_epoch = epoch + no_improve = 0 + if restore_best: + best_state = copy.deepcopy(self.model.state_dict()) + else: + no_improve += 1 + if patience is not None and no_improve >= patience: + stopped_epoch = epoch + 1 + if self.verbose: + print(f"Early stopping at epoch {stopped_epoch}; best loss {best_loss:.4f} at epoch {best_epoch + 1}.") + break + + if restore_best and best_state is not None: + self.model.load_state_dict(best_state) + if self.verbose: print("Training done.") @@ -146,6 +230,13 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): _, _, z = self.model.forward(self.graph, self.feat, self.device) self.adata.uns['loss'] = loss_all + self.adata.uns['lr_history'] = lr_all + self.adata.uns['early_stop'] = { + 'best_epoch': best_epoch + 1, + 'best_loss': best_loss, + 'stopped_epoch': stopped_epoch, + 'restored_best': bool(restore_best and best_state is not None), + } self.adata.obsm['X_scniche'] = z.data.cpu().numpy() return self.adata @@ -201,7 +292,39 @@ def __init__( )) print("Batch size: {}".format(len(self.dataloader))) - def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): + def fit( + self, + lr: Optional[float] = 0.01, + epochs: Optional[int] = 100, + patience: Optional[int] = None, + min_delta: float = 1e-4, + restore_best: bool = True, + lr_patience: Optional[int] = 5, + lr_factor: float = 0.5, + min_lr: float = 1e-6, + use_amp: bool = True, + ari_every: Optional[int] = None, + ari_k: int = 10, + ari_stop_threshold: Optional[float] = None, + ari_stop_patience: int = 2, + ): + """ + Extra args: + ari_every: if set, every N epochs snapshot the embedding, KMeans it + (with k=ari_k), and compute ARI vs the previous snapshot's labels. + Results are appended to adata.uns['ari_history'] as list of + (epoch, ari) pairs. ARI >= ~0.98 between consecutive snapshots + indicates the embedding has converged. + ari_k: k for the KMeans used in the ARI probe. Pick close to your + expected niche count. + ari_stop_threshold: if set (e.g. 0.98), training stops when the + probe ARI stays >= this value for `ari_stop_patience` consecutive + probes. This is a meaningful convergence signal for this loss + structure (the total loss plateaus at the BCE floor even while + the embedding is still being shaped). + ari_stop_patience: how many consecutive high-ARI probes to require + before stopping. + """ # model batch_size = len(self.adata.uns['batch_idx'][0]) @@ -212,49 +335,83 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): # optimizer optim = torch.optim.Adam(self.model.parameters(), lr=lr) + scheduler = None + if lr_patience is not None: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optim, mode='min', factor=lr_factor, patience=lr_patience, + threshold=min_delta, threshold_mode='abs', min_lr=min_lr, + ) + + # pre-cache all batches on device ONCE. Graph ndata (feat, adj, mik) moves + # with the graph, so subsequent iterations pay zero H2D cost. + if self.verbose: + print("-------Caching batches on device...") + cached_batches = [] + for batch in self.dataloader: + graphs = [batch[i].to(self.device) for i in range(len(batch))] + feats = [g.ndata['feat'] for g in graphs] + adj_sparse = [g.ndata['adj'] for g in graphs] + mik = torch.cat([g.ndata['mik'].to(self.device) for g in graphs], dim=1).long() + edges = sum(g.number_of_edges() for g in graphs) + n_nodes = graphs[0].num_nodes() + pos_weight = torch.tensor( + [float(n_nodes ** 2 - edges / 2) / edges * 2], device=self.device + ) + bce = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) + cached_batches.append({ + 'graphs': graphs, 'feats': feats, 'adj_sparse': adj_sparse, + 'mik': mik, 'bce': bce, + }) + # free the dataloader: its multiprocess iterators aren't picklable, which + # breaks anndata.copy() / adata[mask].copy() downstream. + self.dataloader = None + self.adata.uns.pop('dataloader', None) + # stateless loss — build once + mse = torch.nn.MSELoss() + latent_dim = self.hidden_size[-1] if self.verbose: print("-------Start training...") self.model.train() loss_all = [] + lr_all = [] + ari_history = [] # list of (epoch_1based, ari_vs_prev_snapshot) + prev_snap_labels = None + ari_good_streak = 0 + ari_stopped = False + best_loss = float('inf') + best_state = None + best_epoch = -1 + no_improve = 0 + stopped_epoch = epochs pbar = tqdm(range(epochs)) for epoch in pbar: batch_loss = 0 - for batch in self.dataloader: - - graphs = [batch[i] for i in range(len(batch))] - feats = [g.ndata['feat'] for g in graphs] - adjs = [g.ndata['adj'].to_dense() for g in graphs] - mik = np.hstack((g.ndata['mik'] for g in graphs)) - edges = sum(g.number_of_edges() for g in graphs) - - # loss - pos_weight = torch.Tensor( - [float(graphs[0].num_nodes() ** 2 - edges / 2) / edges * 2] - ) - criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) - criterion_m = torch.nn.MSELoss().to(self.device) - - # to device - feats = [feat.to(self.device) for feat in feats] - graphs = [g.to(self.device) for g in graphs] - adjs = [adj.to(self.device) for adj in adjs] - - adj_r, adj_logits, z = self.model.forward(graphs, feats, self.device) - loss_gre = sum(criterion_m(adj_r, adj) for adj in adjs) / self.views - loss_rec = sum(criterion(adj_logits[i], adj) for i, adj in enumerate(adjs)) / self.views - - loss_mim = 0 - for i in range(mik.shape[1]): - z_shuf = shuffling(z, latent=self.hidden_size[-1], device=self.device) - z_comb = torch.cat((z, z_shuf), 1) - z_shuf_scores = self.model_d(z_comb) - z_idx = torch.cat((z, z[mik[:, i]]), 1) - z_scores = self.model_d(z_idx) - loss_mim += - torch.mean( - torch.log(z_scores + 1e-6) + torch.log(1 - z_shuf_scores + 1e-6) - ) + for b in cached_batches: + graphs = b['graphs'] + feats = b['feats'] + mik = b['mik'] + bce = b['bce'] + adjs = [a.to_dense() for a in b['adj_sparse']] + + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp): + adj_r, adj_logits, z = self.model.forward(graphs, feats, self.device) + loss_gre = sum(mse(adj_r, adj) for adj in adjs) / self.views + loss_rec = sum(bce(adj_logits[i], adj) for i, adj in enumerate(adjs)) / self.views + + with torch.autocast(device_type='cuda', enabled=False): + z_fp32 = z.float() + loss_mim = 0 + for i in range(mik.shape[1]): + z_shuf = shuffling(z_fp32, latent=latent_dim, device=self.device) + z_comb = torch.cat((z_fp32, z_shuf), 1) + z_shuf_scores = self.model_d(z_comb) + z_idx = torch.cat((z_fp32, z_fp32[mik[:, i]]), 1) + z_scores = self.model_d(z_idx) + loss_mim += - torch.mean( + torch.log(z_scores + 1e-6) + torch.log(1 - z_shuf_scores + 1e-6) + ) loss = loss_gre + loss_rec + loss_mim optim.zero_grad() @@ -264,22 +421,67 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): batch_loss += loss.item() loss_all.append(batch_loss) + current_lr = optim.param_groups[0]['lr'] + lr_all.append(current_lr) pbar.set_description('Train Epoch: {}'.format(epoch)) - pbar.set_postfix(loss=f"{batch_loss:.4f}") + pbar.set_postfix(loss=f"{batch_loss:.4f}", lr=f"{current_lr:.2e}") + + if scheduler is not None: + prev_lr = current_lr + scheduler.step(batch_loss) + new_lr = optim.param_groups[0]['lr'] + if self.verbose and new_lr < prev_lr: + print(f" epoch {epoch + 1}: LR reduced {prev_lr:.2e} -> {new_lr:.2e}") + if batch_loss < best_loss - min_delta: + best_loss = batch_loss + best_epoch = epoch + no_improve = 0 + if restore_best: + best_state = copy.deepcopy(self.model.state_dict()) + else: + no_improve += 1 + if patience is not None and no_improve >= patience: + stopped_epoch = epoch + 1 + if self.verbose: + print(f"Early stopping at epoch {stopped_epoch}; best loss {best_loss:.4f} at epoch {best_epoch + 1}.") + break + + # ARI convergence probe + if ari_every is not None and ari_every > 0 and ((epoch + 1) % ari_every == 0): + emb_snap = _infer_embedding(self.model, cached_batches, self.device, use_amp) + emb_snap_flat = emb_snap.reshape(-1, emb_snap.shape[-1]) + snap_labels = KMeans(n_clusters=ari_k, random_state=123, n_init=10).fit_predict(emb_snap_flat) + if prev_snap_labels is not None: + ari = float(adjusted_rand_score(prev_snap_labels, snap_labels)) + ari_history.append((epoch + 1, ari)) + if self.verbose: + print(f" epoch {epoch + 1}: ARI vs prev snapshot = {ari:.4f}") + if ari_stop_threshold is not None and ari >= ari_stop_threshold: + ari_good_streak += 1 + if ari_good_streak >= ari_stop_patience: + stopped_epoch = epoch + 1 + ari_stopped = True + if self.verbose: + print(f"ARI-based stop at epoch {stopped_epoch}: " + f"{ari_good_streak} consecutive probes >= {ari_stop_threshold}.") + prev_snap_labels = snap_labels + break + else: + ari_good_streak = 0 + prev_snap_labels = snap_labels + + if restore_best and best_state is not None: + self.model.load_state_dict(best_state) if self.verbose: print("Training done.") self.model.eval() emb = [] - for batch in tqdm(self.dataloader): - graphs = [batch[i] for i in range(len(batch))] - feats = [g.ndata['feat'] for g in graphs] - graphs = [g.to(self.device) for g in graphs] - feats = [feat.to(self.device) for feat in feats] - - _, _, z = self.model.forward(graphs, feats, self.device) - emb.append(list(z.data.cpu().numpy())) + with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp): + for b in tqdm(cached_batches): + _, _, z = self.model.forward(b['graphs'], b['feats'], self.device) + emb.append(list(z.float().data.cpu().numpy())) emb = np.array(emb) emb = pd.DataFrame(np.reshape(emb, (-1, emb.shape[2]))) @@ -291,5 +493,14 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): emb = emb.reindex(self.adata.obs_names) self.adata.uns['loss'] = loss_all + self.adata.uns['lr_history'] = lr_all + self.adata.uns['ari_history'] = ari_history + self.adata.uns['early_stop'] = { + 'best_epoch': best_epoch + 1, + 'best_loss': best_loss, + 'stopped_epoch': stopped_epoch, + 'stopped_by': 'ari' if ari_stopped else ('loss' if stopped_epoch < epochs else 'epochs'), + 'restored_best': bool(restore_best and best_state is not None), + } self.adata.obsm['X_scniche'] = np.array(emb) return self.adata From 700651522c1a7609daf8e823778bb6c55f741807 Mon Sep 17 00:00:00 2001 From: jackytamkc Date: Tue, 21 Apr 2026 14:02:52 +0100 Subject: [PATCH 3/4] Keep model scaffolding off adata; add Runner.save(output_dir) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Motivation ---------- `adata` should only carry trained results, not the model plumbing. Previously `_build.py` wrote DGL graphs (`g_*`), the `GraphDataLoader`, and `batch_idx` into `adata.uns` and the trainer read them back from there — coupling the data container to the training lifecycle and polluting any `adata.copy()` / `adata.write_h5ad()` afterwards with objects that don't belong. Changes ------- - Runner.__init__ and Runner_batch.__init__ now pop the scaffolding keys from `adata.uns` into `self.*`: Runner: `g_` -> self.graph Runner_batch: `dataloader` -> self.dataloader, `batch_idx` -> self.batch_idx After the Runner is constructed, `adata.uns` is clean of model plumbing; only subsequent training-result writes (`loss`, `lr_history`, `ari_history`, `early_stop`, and `obsm['X_scniche']`) remain. - Runner_batch.fit now reads `batch_size` and the final reindex from `self.batch_idx` instead of `adata.uns['batch_idx']`, and no longer needs the explicit `adata.uns.pop('dataloader', None)` (it was never there after __init__). - New Runner.save(output_dir) / Runner_batch.save(output_dir): {output_dir}/adata.h5ad — results-only AnnData {output_dir}/model.pt — {model_state, discriminator_state, config} with enough config to rebuild MGAE + Discriminator for later inference. Co-Authored-By: Claude Opus 4.7 --- scniche/trainer/_train.py | 73 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/scniche/trainer/_train.py b/scniche/trainer/_train.py index 2361094..4142b2d 100644 --- a/scniche/trainer/_train.py +++ b/scniche/trainer/_train.py @@ -1,4 +1,5 @@ import copy +import os from anndata import AnnData from tqdm import tqdm import pandas as pd @@ -61,8 +62,9 @@ def __init__( self.views = len(self.choose_views) + # move model scaffolding off adata.uns so adata only carries training results self.graph_name = ['g_' + view for view in self.choose_views] - self.graph = [self.adata.uns[graph_name] for graph_name in self.graph_name] + self.graph = [self.adata.uns.pop(graph_name) for graph_name in self.graph_name] self.adj = [g.ndata['adj'].to_dense() for g in self.graph] self.mik = np.hstack((g.ndata['mik'] for g in self.graph)) self.edges = sum(g.number_of_edges() for g in self.graph) @@ -240,6 +242,35 @@ def fit( self.adata.obsm['X_scniche'] = z.data.cpu().numpy() return self.adata + def save(self, output_dir: str): + """Save trained model + adata to `output_dir`. + + Writes: + {output_dir}/adata.h5ad — AnnData carrying only training results + (embedding in `obsm['X_scniche']`; loss / lr_history / early_stop + metadata in `uns`). Model scaffolding (graphs) was already popped + off `adata.uns` at Runner construction. + {output_dir}/model.pt — model + discriminator state dicts and the + config needed to rebuild the MGAE/Discriminator pair. + """ + if not hasattr(self, 'model'): + raise RuntimeError("Call .fit() before .save().") + os.makedirs(output_dir, exist_ok=True) + self.adata.write_h5ad(os.path.join(output_dir, 'adata.h5ad')) + torch.save({ + 'model_state': self.model.state_dict(), + 'discriminator_state': self.model_d.state_dict(), + 'config': { + 'in_feats': self.in_feats, + 'hidden_size_v': self.hidden_size_v, + 'hidden_size': self.hidden_size, + 'views': self.views, + 'choose_views': self.choose_views, + 'n_nodes': self.adj[0].shape[0], + 'mode': 'full', + }, + }, os.path.join(output_dir, 'model.pt')) + class Runner_batch: def __init__( @@ -255,7 +286,9 @@ def __init__( Initialize the Runner_batch class for training scNihce model with batch training strategy. """ self.adata = adata - self.dataloader = self.adata.uns['dataloader'] + # move model scaffolding off adata.uns so adata only carries training results + self.dataloader = self.adata.uns.pop('dataloader') + self.batch_idx = self.adata.uns.pop('batch_idx') self.choose_views = choose_views if self.choose_views is None: self.choose_views = ['X_cn_norm', 'X_data', 'X_data_nbr'] @@ -327,7 +360,7 @@ def fit( """ # model - batch_size = len(self.adata.uns['batch_idx'][0]) + batch_size = len(self.batch_idx[0]) self.model = MGAE(self.in_feats, self.hidden_size_v, self.hidden_size, self.views, batch_size) self.model_d = Discriminator(latent_dim=self.hidden_size[-1]) self.model = self.model.to(self.device) @@ -365,7 +398,6 @@ def fit( # free the dataloader: its multiprocess iterators aren't picklable, which # breaks anndata.copy() / adata[mask].copy() downstream. self.dataloader = None - self.adata.uns.pop('dataloader', None) # stateless loss — build once mse = torch.nn.MSELoss() latent_dim = self.hidden_size[-1] @@ -486,7 +518,7 @@ def fit( emb = np.array(emb) emb = pd.DataFrame(np.reshape(emb, (-1, emb.shape[2]))) - idx = np.array(self.adata.uns['batch_idx']).flatten().tolist() + idx = np.array(self.batch_idx).flatten().tolist() emb.index = idx emb = emb[~emb.index.duplicated()] emb.index = self.adata.obs_names[emb.index] @@ -504,3 +536,34 @@ def fit( } self.adata.obsm['X_scniche'] = np.array(emb) return self.adata + + def save(self, output_dir: str): + """Save trained model + adata to `output_dir`. + + Writes: + {output_dir}/adata.h5ad — AnnData carrying only training results + (embedding in `obsm['X_scniche']`; loss / lr_history / ari_history + / early_stop metadata in `uns`). Model scaffolding (dataloader, + batch_idx) was already popped off `adata.uns` at Runner_batch + construction. + {output_dir}/model.pt — model + discriminator state dicts and the + config needed to rebuild the MGAE/Discriminator pair. + """ + if not hasattr(self, 'model'): + raise RuntimeError("Call .fit() before .save().") + os.makedirs(output_dir, exist_ok=True) + self.adata.write_h5ad(os.path.join(output_dir, 'adata.h5ad')) + torch.save({ + 'model_state': self.model.state_dict(), + 'discriminator_state': self.model_d.state_dict(), + 'config': { + 'in_feats': self.in_feats, + 'hidden_size_v': self.hidden_size_v, + 'hidden_size': self.hidden_size, + 'views': self.views, + 'choose_views': self.choose_views, + 'n_nodes': len(self.batch_idx[0]), + 'batch_idx': self.batch_idx, + 'mode': 'batch', + }, + }, os.path.join(output_dir, 'model.pt')) From ddfcf098db813e69962d73c4dd8f66445042d195 Mon Sep 17 00:00:00 2001 From: jackytamkc Date: Tue, 21 Apr 2026 14:39:05 +0100 Subject: [PATCH 4/4] Move training diagnostics off adata onto Runner / model.pt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Motivation ---------- `adata` should carry only downstream-relevant trained parameters — enough to run niche clustering, enrichment plots, etc. — and nothing else. Previously `fit()` wrote `loss`, `lr_history`, `ari_history`, and `early_stop` into `adata.uns`, which mixes training diagnostics (which belong with the model) into the data object (which belongs to the analysis pipeline). Changes ------- - Runner.fit / Runner_batch.fit: write training diagnostics to `self.loss_history`, `self.lr_history`, `self.ari_history`, `self.early_stop` instead of `adata.uns[...]`. The only thing still written to `adata` is the embedding at `obsm['X_scniche']`. - Runner.save / Runner_batch.save: bundle those diagnostics into `model.pt` under a `training` key alongside the state dicts, so they are persisted with the model they describe. - Updated docstrings on `fit()` return value and the ARI probe. Result: the AnnData returned from `fit()` is downstream-ready — nothing in `adata.uns` is there because of training bookkeeping. Co-Authored-By: Claude Opus 4.7 --- scniche/trainer/_train.py | 60 +++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/scniche/trainer/_train.py b/scniche/trainer/_train.py index 4142b2d..388cbf2 100644 --- a/scniche/trainer/_train.py +++ b/scniche/trainer/_train.py @@ -131,7 +131,7 @@ def fit( If True (default), run the GCN forward and adjacency-reconstruction losses under `torch.autocast(cuda, bf16)`. The discriminator/MIM block is kept in fp32 for numerical stability (log near 1.0). Returns: adata: AnnData - Anndata object with learned embeddings (stored in `adata.obsm['X_scniche']`) and training loss (stored in `adata.uns['loss']`). + Anndata object with learned embeddings (stored in `adata.obsm['X_scniche']`). Training diagnostics (loss_history, lr_history, early_stop) live on the Runner and are persisted via `Runner.save()` — they are intentionally kept off `adata` so the returned object is downstream-ready without carrying training scaffolding. """ # to device @@ -231,9 +231,9 @@ def fit( self.model.eval() _, _, z = self.model.forward(self.graph, self.feat, self.device) - self.adata.uns['loss'] = loss_all - self.adata.uns['lr_history'] = lr_all - self.adata.uns['early_stop'] = { + self.loss_history = loss_all + self.lr_history = lr_all + self.early_stop = { 'best_epoch': best_epoch + 1, 'best_loss': best_loss, 'stopped_epoch': stopped_epoch, @@ -246,12 +246,12 @@ def save(self, output_dir: str): """Save trained model + adata to `output_dir`. Writes: - {output_dir}/adata.h5ad — AnnData carrying only training results - (embedding in `obsm['X_scniche']`; loss / lr_history / early_stop - metadata in `uns`). Model scaffolding (graphs) was already popped - off `adata.uns` at Runner construction. - {output_dir}/model.pt — model + discriminator state dicts and the - config needed to rebuild the MGAE/Discriminator pair. + {output_dir}/adata.h5ad — AnnData carrying only the downstream- + relevant trained result (`obsm['X_scniche']`). Model scaffolding + and training diagnostics do not appear in this file. + {output_dir}/model.pt — model + discriminator state dicts, rebuild + config, and the training diagnostics (loss_history, lr_history, + early_stop) that belong with the model rather than the data. """ if not hasattr(self, 'model'): raise RuntimeError("Call .fit() before .save().") @@ -269,6 +269,11 @@ def save(self, output_dir: str): 'n_nodes': self.adj[0].shape[0], 'mode': 'full', }, + 'training': { + 'loss_history': self.loss_history, + 'lr_history': self.lr_history, + 'early_stop': self.early_stop, + }, }, os.path.join(output_dir, 'model.pt')) @@ -345,9 +350,10 @@ def fit( Extra args: ari_every: if set, every N epochs snapshot the embedding, KMeans it (with k=ari_k), and compute ARI vs the previous snapshot's labels. - Results are appended to adata.uns['ari_history'] as list of - (epoch, ari) pairs. ARI >= ~0.98 between consecutive snapshots - indicates the embedding has converged. + Results are appended to `self.ari_history` as a list of + (epoch, ari) pairs and written into model.pt via Runner.save(). + ARI >= ~0.98 between consecutive snapshots indicates the + embedding has converged. ari_k: k for the KMeans used in the ARI probe. Pick close to your expected niche count. ari_stop_threshold: if set (e.g. 0.98), training stops when the @@ -524,10 +530,10 @@ def fit( emb.index = self.adata.obs_names[emb.index] emb = emb.reindex(self.adata.obs_names) - self.adata.uns['loss'] = loss_all - self.adata.uns['lr_history'] = lr_all - self.adata.uns['ari_history'] = ari_history - self.adata.uns['early_stop'] = { + self.loss_history = loss_all + self.lr_history = lr_all + self.ari_history = ari_history + self.early_stop = { 'best_epoch': best_epoch + 1, 'best_loss': best_loss, 'stopped_epoch': stopped_epoch, @@ -541,13 +547,13 @@ def save(self, output_dir: str): """Save trained model + adata to `output_dir`. Writes: - {output_dir}/adata.h5ad — AnnData carrying only training results - (embedding in `obsm['X_scniche']`; loss / lr_history / ari_history - / early_stop metadata in `uns`). Model scaffolding (dataloader, - batch_idx) was already popped off `adata.uns` at Runner_batch - construction. - {output_dir}/model.pt — model + discriminator state dicts and the - config needed to rebuild the MGAE/Discriminator pair. + {output_dir}/adata.h5ad — AnnData carrying only the downstream- + relevant trained result (`obsm['X_scniche']`). Model scaffolding + and training diagnostics do not appear in this file. + {output_dir}/model.pt — model + discriminator state dicts, rebuild + config, and the training diagnostics (loss_history, lr_history, + ari_history, early_stop) that belong with the model rather than + the data. """ if not hasattr(self, 'model'): raise RuntimeError("Call .fit() before .save().") @@ -566,4 +572,10 @@ def save(self, output_dir: str): 'batch_idx': self.batch_idx, 'mode': 'batch', }, + 'training': { + 'loss_history': self.loss_history, + 'lr_history': self.lr_history, + 'ari_history': self.ari_history, + 'early_stop': self.early_stop, + }, }, os.path.join(output_dir, 'model.pt'))