diff --git a/README.md b/README.md index 76f7399..109137d 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,22 @@ pip install dgl==1.1.0+cu113 -f https://data.dgl.ai/wheels/cu113/repo.html ``` The version of PyTorch and DGL should be suitable to the CUDA version of your machine. You can find the appropriate version on the [PyTorch](https://pytorch.org/get-started/locally/) and [DGL](https://www.dgl.ai/) website. +#### Blackwell GPUs (RTX PRO 6000, RTX 50-series, sm_120) +Blackwell requires PyTorch built against CUDA 12.8. No DGL wheel is currently published for cu128, but the cu124 wheel's CUDA kernels are ABI-compatible with torch 2.7 and run on sm_120 via JIT forwarding. Install in this order: + +``` +# 1. Blackwell-capable PyTorch +pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128 + +# 2. DGL (this step downgrades torch to 2.4.0 — expected) +pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html + +# 3. Restore torch 2.7 on top of DGL +pip install --force-reinstall --no-deps torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128 +``` + +`pip check` will warn that `dgl requires torch<=2.4.0`; this warning is safe to ignore for scNiche's code paths. The `DGLGraph.adjacency_matrix()` / `dgl.sparse` routines are not ABI-compatible across torch versions and will fail to load `libdgl_sparse_pytorch_.so` on torch 2.7 — scNiche no longer calls those routines, but downstream user code that does will need to be patched or the DGL wheel rebuilt from source. + ### Install scNiche ``` diff --git a/scniche/preprocess/_build.py b/scniche/preprocess/_build.py index b688be4..820db2e 100644 --- a/scniche/preprocess/_build.py +++ b/scniche/preprocess/_build.py @@ -124,7 +124,10 @@ def prepare_data_batch( print("Constructing done.") mydataset = myDataset(g_list) - dataloader = GraphDataLoader(mydataset, batch_size=1, shuffle=False, pin_memory=True) + dataloader = GraphDataLoader( + mydataset, batch_size=1, shuffle=False, pin_memory=True, + num_workers=4, + ) adata.uns['dataloader'] = dataloader return adata diff --git a/scniche/trainer/_model.py b/scniche/trainer/_model.py index 24564e4..89f86f3 100644 --- a/scniche/trainer/_model.py +++ b/scniche/trainer/_model.py @@ -3,7 +3,6 @@ # This project is licensed under the GPL-3.0 License. -import scipy.sparse as sp import torch import torch.nn as nn import torch.nn.functional as F @@ -59,11 +58,12 @@ def consensus_graph(self, graphs, device): fused_adj = torch.clamp(fused_adj, 0, 1) fused_adj = torch.round(fused_adj + 0.1) - # build symmetric adjacency matrix - adj_np = fused_adj.detach().cpu().numpy() - adj_np += adj_np.T - adj_sparse = sp.csr_matrix(adj_np) - g = dgl.from_scipy(adj_sparse, device=device) + # build symmetric adjacency matrix ON GPU (avoids GPU<->CPU<->scipy roundtrip) + with torch.no_grad(): + sym = fused_adj.detach() + sym = sym + sym.T + src, dst = (sym > 0).nonzero(as_tuple=True) + g = dgl.graph((src.int(), dst.int()), num_nodes=fused_adj.shape[0], device=device) return fused_adj, g diff --git a/scniche/trainer/_train.py b/scniche/trainer/_train.py index 1bef903..388cbf2 100644 --- a/scniche/trainer/_train.py +++ b/scniche/trainer/_train.py @@ -1,12 +1,30 @@ +import copy +import os from anndata import AnnData from tqdm import tqdm import pandas as pd import numpy as np +from sklearn.cluster import KMeans +from sklearn.metrics import adjusted_rand_score from ._model import * from ._utils import shuffling from typing import Optional +def _infer_embedding(model, cached_batches, device, use_amp): + """Forward-pass cached batches in eval mode and return a (N_cells, latent) ndarray.""" + was_training = model.training + model.eval() + emb = [] + with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp): + for b in cached_batches: + _, _, z = model.forward(b['graphs'], b['feats'], device) + emb.append(z.float().data.cpu().numpy()) + if was_training: + model.train() + return np.stack(emb, axis=0) # (batches, cells_per_batch, latent) + + class Runner: def __init__( self, @@ -44,8 +62,9 @@ def __init__( self.views = len(self.choose_views) + # move model scaffolding off adata.uns so adata only carries training results self.graph_name = ['g_' + view for view in self.choose_views] - self.graph = [self.adata.uns[graph_name] for graph_name in self.graph_name] + self.graph = [self.adata.uns.pop(graph_name) for graph_name in self.graph_name] self.adj = [g.ndata['adj'].to_dense() for g in self.graph] self.mik = np.hstack((g.ndata['mik'] for g in self.graph)) self.edges = sum(g.number_of_edges() for g in self.graph) @@ -77,17 +96,42 @@ def __init__( )) print("Mutual Information Matrix Size for training: {}".format(self.mik.shape)) - def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): + def fit( + self, + lr: Optional[float] = 0.01, + epochs: Optional[int] = 100, + patience: Optional[int] = None, + min_delta: float = 1e-4, + restore_best: bool = True, + lr_patience: Optional[int] = 5, + lr_factor: float = 0.5, + min_lr: float = 1e-6, + use_amp: bool = True, + ): """ Fit the scNiche model. Args: lr: Optional[float] - Learning rate for the optimizer. + Initial learning rate for the optimizer. epochs: Optional[int] - Number of training epochs. + Maximum number of training epochs. + patience: Optional[int] + Early-stopping patience. Stop if loss has not improved by `min_delta` for `patience` consecutive epochs. If None, early stopping is disabled. + min_delta: float + Minimum loss improvement to reset the patience / LR-scheduler counters. + restore_best: bool + If True, restore the model weights from the epoch with the lowest training loss before producing the final embedding. + lr_patience: Optional[int] + ReduceLROnPlateau patience — halve the LR after this many epochs with no improvement. If None, LR is held constant. + lr_factor: float + Multiplicative factor applied to the LR when the plateau is hit (e.g. 0.5 halves the LR). + min_lr: float + Floor for the LR; the scheduler will not reduce below this value. + use_amp: bool + If True (default), run the GCN forward and adjacency-reconstruction losses under `torch.autocast(cuda, bf16)`. The discriminator/MIM block is kept in fp32 for numerical stability (log near 1.0). Returns: adata: AnnData - Anndata object with learned embeddings (stored in `adata.obsm['X_scniche']`) and training loss (stored in `adata.uns['loss']`). + Anndata object with learned embeddings (stored in `adata.obsm['X_scniche']`). Training diagnostics (loss_history, lr_history, early_stop) live on the Runner and are persisted via `Runner.save()` — they are intentionally kept off `adata` so the returned object is downstream-ready without carrying training scaffolding. """ # to device @@ -103,9 +147,15 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): # optimizer optim = torch.optim.Adam(self.model.parameters(), lr=lr) + scheduler = None + if lr_patience is not None: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optim, mode='min', factor=lr_factor, patience=lr_patience, + threshold=min_delta, threshold_mode='abs', min_lr=min_lr, + ) # loss pos_weight = torch.Tensor( - [float(self.graph[0].adjacency_matrix().to_dense().shape[0] ** 2 - self.edges / 2) / self.edges * 2] + [float(self.graph[0].num_nodes() ** 2 - self.edges / 2) / self.edges * 2] ) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) criterion_m = torch.nn.MSELoss().to(self.device) @@ -115,40 +165,117 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): self.model.train() loss_all = [] + lr_all = [] + best_loss = float('inf') + best_state = None + best_epoch = -1 + no_improve = 0 + stopped_epoch = epochs pbar = tqdm(range(epochs)) for epoch in pbar: - adj_r, adj_logits, z = self.model.forward(self.graph, self.feat, self.device) - loss_gre = sum(criterion_m(adj_r, adj) for adj in self.adj) / self.views - loss_rec = sum(criterion(adj_logits[i], adj) for i, adj in enumerate(self.adj)) / self.views - - loss_mim = 0 - for i in range(self.mik.shape[1]): - z_shuf = shuffling(z, latent=self.hidden_size[-1], device=self.device) - z_comb = torch.cat((z, z_shuf), 1) - z_shuf_scores = self.model_d(z_comb) - z_idx = torch.cat((z, z[self.mik[:, i]]), 1) - z_scores = self.model_d(z_idx) - loss_mim += - torch.mean( - torch.log(z_scores + 1e-6) + torch.log(1 - z_shuf_scores + 1e-6) - ) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp): + adj_r, adj_logits, z = self.model.forward(self.graph, self.feat, self.device) + loss_gre = sum(criterion_m(adj_r, adj) for adj in self.adj) / self.views + loss_rec = sum(criterion(adj_logits[i], adj) for i, adj in enumerate(self.adj)) / self.views + + # keep MIM loss in fp32: log(x + 1e-6) near 1.0 is unsafe in bf16 + with torch.autocast(device_type='cuda', enabled=False): + z_fp32 = z.float() + loss_mim = 0 + for i in range(self.mik.shape[1]): + z_shuf = shuffling(z_fp32, latent=self.hidden_size[-1], device=self.device) + z_comb = torch.cat((z_fp32, z_shuf), 1) + z_shuf_scores = self.model_d(z_comb) + z_idx = torch.cat((z_fp32, z_fp32[self.mik[:, i]]), 1) + z_scores = self.model_d(z_idx) + loss_mim += - torch.mean( + torch.log(z_scores + 1e-6) + torch.log(1 - z_shuf_scores + 1e-6) + ) loss = loss_gre + loss_rec + loss_mim optim.zero_grad() loss.backward() optim.step() + current_lr = optim.param_groups[0]['lr'] + lr_all.append(current_lr) pbar.set_description('Train Epoch: {}'.format(epoch + 1)) - pbar.set_postfix(loss=f"{loss:.4f}") + pbar.set_postfix(loss=f"{loss:.4f}", lr=f"{current_lr:.2e}") loss_all.append(loss.data.cpu().numpy()) + loss_val = float(loss.item()) + if scheduler is not None: + prev_lr = current_lr + scheduler.step(loss_val) + new_lr = optim.param_groups[0]['lr'] + if self.verbose and new_lr < prev_lr: + print(f" epoch {epoch + 1}: LR reduced {prev_lr:.2e} -> {new_lr:.2e}") + if loss_val < best_loss - min_delta: + best_loss = loss_val + best_epoch = epoch + no_improve = 0 + if restore_best: + best_state = copy.deepcopy(self.model.state_dict()) + else: + no_improve += 1 + if patience is not None and no_improve >= patience: + stopped_epoch = epoch + 1 + if self.verbose: + print(f"Early stopping at epoch {stopped_epoch}; best loss {best_loss:.4f} at epoch {best_epoch + 1}.") + break + + if restore_best and best_state is not None: + self.model.load_state_dict(best_state) + if self.verbose: print("Training done.") self.model.eval() _, _, z = self.model.forward(self.graph, self.feat, self.device) - self.adata.uns['loss'] = loss_all + self.loss_history = loss_all + self.lr_history = lr_all + self.early_stop = { + 'best_epoch': best_epoch + 1, + 'best_loss': best_loss, + 'stopped_epoch': stopped_epoch, + 'restored_best': bool(restore_best and best_state is not None), + } self.adata.obsm['X_scniche'] = z.data.cpu().numpy() return self.adata + def save(self, output_dir: str): + """Save trained model + adata to `output_dir`. + + Writes: + {output_dir}/adata.h5ad — AnnData carrying only the downstream- + relevant trained result (`obsm['X_scniche']`). Model scaffolding + and training diagnostics do not appear in this file. + {output_dir}/model.pt — model + discriminator state dicts, rebuild + config, and the training diagnostics (loss_history, lr_history, + early_stop) that belong with the model rather than the data. + """ + if not hasattr(self, 'model'): + raise RuntimeError("Call .fit() before .save().") + os.makedirs(output_dir, exist_ok=True) + self.adata.write_h5ad(os.path.join(output_dir, 'adata.h5ad')) + torch.save({ + 'model_state': self.model.state_dict(), + 'discriminator_state': self.model_d.state_dict(), + 'config': { + 'in_feats': self.in_feats, + 'hidden_size_v': self.hidden_size_v, + 'hidden_size': self.hidden_size, + 'views': self.views, + 'choose_views': self.choose_views, + 'n_nodes': self.adj[0].shape[0], + 'mode': 'full', + }, + 'training': { + 'loss_history': self.loss_history, + 'lr_history': self.lr_history, + 'early_stop': self.early_stop, + }, + }, os.path.join(output_dir, 'model.pt')) + class Runner_batch: def __init__( @@ -164,7 +291,9 @@ def __init__( Initialize the Runner_batch class for training scNihce model with batch training strategy. """ self.adata = adata - self.dataloader = self.adata.uns['dataloader'] + # move model scaffolding off adata.uns so adata only carries training results + self.dataloader = self.adata.uns.pop('dataloader') + self.batch_idx = self.adata.uns.pop('batch_idx') self.choose_views = choose_views if self.choose_views is None: self.choose_views = ['X_cn_norm', 'X_data', 'X_data_nbr'] @@ -201,10 +330,43 @@ def __init__( )) print("Batch size: {}".format(len(self.dataloader))) - def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): + def fit( + self, + lr: Optional[float] = 0.01, + epochs: Optional[int] = 100, + patience: Optional[int] = None, + min_delta: float = 1e-4, + restore_best: bool = True, + lr_patience: Optional[int] = 5, + lr_factor: float = 0.5, + min_lr: float = 1e-6, + use_amp: bool = True, + ari_every: Optional[int] = None, + ari_k: int = 10, + ari_stop_threshold: Optional[float] = None, + ari_stop_patience: int = 2, + ): + """ + Extra args: + ari_every: if set, every N epochs snapshot the embedding, KMeans it + (with k=ari_k), and compute ARI vs the previous snapshot's labels. + Results are appended to `self.ari_history` as a list of + (epoch, ari) pairs and written into model.pt via Runner.save(). + ARI >= ~0.98 between consecutive snapshots indicates the + embedding has converged. + ari_k: k for the KMeans used in the ARI probe. Pick close to your + expected niche count. + ari_stop_threshold: if set (e.g. 0.98), training stops when the + probe ARI stays >= this value for `ari_stop_patience` consecutive + probes. This is a meaningful convergence signal for this loss + structure (the total loss plateaus at the BCE floor even while + the embedding is still being shaped). + ari_stop_patience: how many consecutive high-ARI probes to require + before stopping. + """ # model - batch_size = len(self.adata.uns['batch_idx'][0]) + batch_size = len(self.batch_idx[0]) self.model = MGAE(self.in_feats, self.hidden_size_v, self.hidden_size, self.views, batch_size) self.model_d = Discriminator(latent_dim=self.hidden_size[-1]) self.model = self.model.to(self.device) @@ -212,49 +374,82 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): # optimizer optim = torch.optim.Adam(self.model.parameters(), lr=lr) + scheduler = None + if lr_patience is not None: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optim, mode='min', factor=lr_factor, patience=lr_patience, + threshold=min_delta, threshold_mode='abs', min_lr=min_lr, + ) + + # pre-cache all batches on device ONCE. Graph ndata (feat, adj, mik) moves + # with the graph, so subsequent iterations pay zero H2D cost. + if self.verbose: + print("-------Caching batches on device...") + cached_batches = [] + for batch in self.dataloader: + graphs = [batch[i].to(self.device) for i in range(len(batch))] + feats = [g.ndata['feat'] for g in graphs] + adj_sparse = [g.ndata['adj'] for g in graphs] + mik = torch.cat([g.ndata['mik'].to(self.device) for g in graphs], dim=1).long() + edges = sum(g.number_of_edges() for g in graphs) + n_nodes = graphs[0].num_nodes() + pos_weight = torch.tensor( + [float(n_nodes ** 2 - edges / 2) / edges * 2], device=self.device + ) + bce = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) + cached_batches.append({ + 'graphs': graphs, 'feats': feats, 'adj_sparse': adj_sparse, + 'mik': mik, 'bce': bce, + }) + # free the dataloader: its multiprocess iterators aren't picklable, which + # breaks anndata.copy() / adata[mask].copy() downstream. + self.dataloader = None + # stateless loss — build once + mse = torch.nn.MSELoss() + latent_dim = self.hidden_size[-1] if self.verbose: print("-------Start training...") self.model.train() loss_all = [] + lr_all = [] + ari_history = [] # list of (epoch_1based, ari_vs_prev_snapshot) + prev_snap_labels = None + ari_good_streak = 0 + ari_stopped = False + best_loss = float('inf') + best_state = None + best_epoch = -1 + no_improve = 0 + stopped_epoch = epochs pbar = tqdm(range(epochs)) for epoch in pbar: batch_loss = 0 - for batch in self.dataloader: - - graphs = [batch[i] for i in range(len(batch))] - feats = [g.ndata['feat'] for g in graphs] - adjs = [g.ndata['adj'].to_dense() for g in graphs] - mik = np.hstack((g.ndata['mik'] for g in graphs)) - edges = sum(g.number_of_edges() for g in graphs) - - # loss - pos_weight = torch.Tensor( - [float(graphs[0].adjacency_matrix().to_dense().shape[0] ** 2 - edges / 2) / edges * 2] - ) - criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) - criterion_m = torch.nn.MSELoss().to(self.device) - - # to device - feats = [feat.to(self.device) for feat in feats] - graphs = [g.to(self.device) for g in graphs] - adjs = [adj.to(self.device) for adj in adjs] - - adj_r, adj_logits, z = self.model.forward(graphs, feats, self.device) - loss_gre = sum(criterion_m(adj_r, adj) for adj in adjs) / self.views - loss_rec = sum(criterion(adj_logits[i], adj) for i, adj in enumerate(adjs)) / self.views - - loss_mim = 0 - for i in range(mik.shape[1]): - z_shuf = shuffling(z, latent=self.hidden_size[-1], device=self.device) - z_comb = torch.cat((z, z_shuf), 1) - z_shuf_scores = self.model_d(z_comb) - z_idx = torch.cat((z, z[mik[:, i]]), 1) - z_scores = self.model_d(z_idx) - loss_mim += - torch.mean( - torch.log(z_scores + 1e-6) + torch.log(1 - z_shuf_scores + 1e-6) - ) + for b in cached_batches: + graphs = b['graphs'] + feats = b['feats'] + mik = b['mik'] + bce = b['bce'] + adjs = [a.to_dense() for a in b['adj_sparse']] + + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp): + adj_r, adj_logits, z = self.model.forward(graphs, feats, self.device) + loss_gre = sum(mse(adj_r, adj) for adj in adjs) / self.views + loss_rec = sum(bce(adj_logits[i], adj) for i, adj in enumerate(adjs)) / self.views + + with torch.autocast(device_type='cuda', enabled=False): + z_fp32 = z.float() + loss_mim = 0 + for i in range(mik.shape[1]): + z_shuf = shuffling(z_fp32, latent=latent_dim, device=self.device) + z_comb = torch.cat((z_fp32, z_shuf), 1) + z_shuf_scores = self.model_d(z_comb) + z_idx = torch.cat((z_fp32, z_fp32[mik[:, i]]), 1) + z_scores = self.model_d(z_idx) + loss_mim += - torch.mean( + torch.log(z_scores + 1e-6) + torch.log(1 - z_shuf_scores + 1e-6) + ) loss = loss_gre + loss_rec + loss_mim optim.zero_grad() @@ -264,32 +459,123 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): batch_loss += loss.item() loss_all.append(batch_loss) + current_lr = optim.param_groups[0]['lr'] + lr_all.append(current_lr) pbar.set_description('Train Epoch: {}'.format(epoch)) - pbar.set_postfix(loss=f"{batch_loss:.4f}") + pbar.set_postfix(loss=f"{batch_loss:.4f}", lr=f"{current_lr:.2e}") + + if scheduler is not None: + prev_lr = current_lr + scheduler.step(batch_loss) + new_lr = optim.param_groups[0]['lr'] + if self.verbose and new_lr < prev_lr: + print(f" epoch {epoch + 1}: LR reduced {prev_lr:.2e} -> {new_lr:.2e}") + if batch_loss < best_loss - min_delta: + best_loss = batch_loss + best_epoch = epoch + no_improve = 0 + if restore_best: + best_state = copy.deepcopy(self.model.state_dict()) + else: + no_improve += 1 + if patience is not None and no_improve >= patience: + stopped_epoch = epoch + 1 + if self.verbose: + print(f"Early stopping at epoch {stopped_epoch}; best loss {best_loss:.4f} at epoch {best_epoch + 1}.") + break + + # ARI convergence probe + if ari_every is not None and ari_every > 0 and ((epoch + 1) % ari_every == 0): + emb_snap = _infer_embedding(self.model, cached_batches, self.device, use_amp) + emb_snap_flat = emb_snap.reshape(-1, emb_snap.shape[-1]) + snap_labels = KMeans(n_clusters=ari_k, random_state=123, n_init=10).fit_predict(emb_snap_flat) + if prev_snap_labels is not None: + ari = float(adjusted_rand_score(prev_snap_labels, snap_labels)) + ari_history.append((epoch + 1, ari)) + if self.verbose: + print(f" epoch {epoch + 1}: ARI vs prev snapshot = {ari:.4f}") + if ari_stop_threshold is not None and ari >= ari_stop_threshold: + ari_good_streak += 1 + if ari_good_streak >= ari_stop_patience: + stopped_epoch = epoch + 1 + ari_stopped = True + if self.verbose: + print(f"ARI-based stop at epoch {stopped_epoch}: " + f"{ari_good_streak} consecutive probes >= {ari_stop_threshold}.") + prev_snap_labels = snap_labels + break + else: + ari_good_streak = 0 + prev_snap_labels = snap_labels + + if restore_best and best_state is not None: + self.model.load_state_dict(best_state) if self.verbose: print("Training done.") self.model.eval() emb = [] - for batch in tqdm(self.dataloader): - graphs = [batch[i] for i in range(len(batch))] - feats = [g.ndata['feat'] for g in graphs] - graphs = [g.to(self.device) for g in graphs] - feats = [feat.to(self.device) for feat in feats] - - _, _, z = self.model.forward(graphs, feats, self.device) - emb.append(list(z.data.cpu().numpy())) + with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=use_amp): + for b in tqdm(cached_batches): + _, _, z = self.model.forward(b['graphs'], b['feats'], self.device) + emb.append(list(z.float().data.cpu().numpy())) emb = np.array(emb) emb = pd.DataFrame(np.reshape(emb, (-1, emb.shape[2]))) - idx = np.array(self.adata.uns['batch_idx']).flatten().tolist() + idx = np.array(self.batch_idx).flatten().tolist() emb.index = idx emb = emb[~emb.index.duplicated()] emb.index = self.adata.obs_names[emb.index] emb = emb.reindex(self.adata.obs_names) - self.adata.uns['loss'] = loss_all + self.loss_history = loss_all + self.lr_history = lr_all + self.ari_history = ari_history + self.early_stop = { + 'best_epoch': best_epoch + 1, + 'best_loss': best_loss, + 'stopped_epoch': stopped_epoch, + 'stopped_by': 'ari' if ari_stopped else ('loss' if stopped_epoch < epochs else 'epochs'), + 'restored_best': bool(restore_best and best_state is not None), + } self.adata.obsm['X_scniche'] = np.array(emb) return self.adata + + def save(self, output_dir: str): + """Save trained model + adata to `output_dir`. + + Writes: + {output_dir}/adata.h5ad — AnnData carrying only the downstream- + relevant trained result (`obsm['X_scniche']`). Model scaffolding + and training diagnostics do not appear in this file. + {output_dir}/model.pt — model + discriminator state dicts, rebuild + config, and the training diagnostics (loss_history, lr_history, + ari_history, early_stop) that belong with the model rather than + the data. + """ + if not hasattr(self, 'model'): + raise RuntimeError("Call .fit() before .save().") + os.makedirs(output_dir, exist_ok=True) + self.adata.write_h5ad(os.path.join(output_dir, 'adata.h5ad')) + torch.save({ + 'model_state': self.model.state_dict(), + 'discriminator_state': self.model_d.state_dict(), + 'config': { + 'in_feats': self.in_feats, + 'hidden_size_v': self.hidden_size_v, + 'hidden_size': self.hidden_size, + 'views': self.views, + 'choose_views': self.choose_views, + 'n_nodes': len(self.batch_idx[0]), + 'batch_idx': self.batch_idx, + 'mode': 'batch', + }, + 'training': { + 'loss_history': self.loss_history, + 'lr_history': self.lr_history, + 'ari_history': self.ari_history, + 'early_stop': self.early_stop, + }, + }, os.path.join(output_dir, 'model.pt'))