From 4c811f98c8adfdf5aba45a797486a9ecb501d753 Mon Sep 17 00:00:00 2001 From: jackytamkc Date: Tue, 21 Apr 2026 13:46:42 +0100 Subject: [PATCH] Support Blackwell GPUs (torch 2.7 + cu128 + DGL 2.4 cu124) - _train.py: replace `graph.adjacency_matrix().to_dense().shape[0]` with `graph.num_nodes()` at the two pos_weight sites. Same value, no N*N allocation, and avoids the torch-pinned libdgl_sparse_pytorch load that fails on torch 2.7 with the DGL 2.4 cu124 wheel. - README: add Blackwell / sm_120 install instructions alongside the existing cu113 path. Backward-compatible: num_nodes() works on all supported DGL versions; README changes are additive. Co-Authored-By: Claude Opus 4.7 --- README.md | 16 ++++++++++++++++ scniche/trainer/_train.py | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 76f7399..109137d 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,22 @@ pip install dgl==1.1.0+cu113 -f https://data.dgl.ai/wheels/cu113/repo.html ``` The version of PyTorch and DGL should be suitable to the CUDA version of your machine. You can find the appropriate version on the [PyTorch](https://pytorch.org/get-started/locally/) and [DGL](https://www.dgl.ai/) website. +#### Blackwell GPUs (RTX PRO 6000, RTX 50-series, sm_120) +Blackwell requires PyTorch built against CUDA 12.8. No DGL wheel is currently published for cu128, but the cu124 wheel's CUDA kernels are ABI-compatible with torch 2.7 and run on sm_120 via JIT forwarding. Install in this order: + +``` +# 1. Blackwell-capable PyTorch +pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128 + +# 2. DGL (this step downgrades torch to 2.4.0 — expected) +pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html + +# 3. Restore torch 2.7 on top of DGL +pip install --force-reinstall --no-deps torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128 +``` + +`pip check` will warn that `dgl requires torch<=2.4.0`; this warning is safe to ignore for scNiche's code paths. The `DGLGraph.adjacency_matrix()` / `dgl.sparse` routines are not ABI-compatible across torch versions and will fail to load `libdgl_sparse_pytorch_.so` on torch 2.7 — scNiche no longer calls those routines, but downstream user code that does will need to be patched or the DGL wheel rebuilt from source. + ### Install scNiche ``` diff --git a/scniche/trainer/_train.py b/scniche/trainer/_train.py index 1bef903..1d6b75a 100644 --- a/scniche/trainer/_train.py +++ b/scniche/trainer/_train.py @@ -105,7 +105,7 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): optim = torch.optim.Adam(self.model.parameters(), lr=lr) # loss pos_weight = torch.Tensor( - [float(self.graph[0].adjacency_matrix().to_dense().shape[0] ** 2 - self.edges / 2) / self.edges * 2] + [float(self.graph[0].num_nodes() ** 2 - self.edges / 2) / self.edges * 2] ) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) criterion_m = torch.nn.MSELoss().to(self.device) @@ -231,7 +231,7 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): # loss pos_weight = torch.Tensor( - [float(graphs[0].adjacency_matrix().to_dense().shape[0] ** 2 - edges / 2) / edges * 2] + [float(graphs[0].num_nodes() ** 2 - edges / 2) / edges * 2] ) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) criterion_m = torch.nn.MSELoss().to(self.device)