diff --git a/README.md b/README.md index 76f7399..109137d 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,22 @@ pip install dgl==1.1.0+cu113 -f https://data.dgl.ai/wheels/cu113/repo.html ``` The version of PyTorch and DGL should be suitable to the CUDA version of your machine. You can find the appropriate version on the [PyTorch](https://pytorch.org/get-started/locally/) and [DGL](https://www.dgl.ai/) website. +#### Blackwell GPUs (RTX PRO 6000, RTX 50-series, sm_120) +Blackwell requires PyTorch built against CUDA 12.8. No DGL wheel is currently published for cu128, but the cu124 wheel's CUDA kernels are ABI-compatible with torch 2.7 and run on sm_120 via JIT forwarding. Install in this order: + +``` +# 1. Blackwell-capable PyTorch +pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128 + +# 2. DGL (this step downgrades torch to 2.4.0 — expected) +pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html + +# 3. Restore torch 2.7 on top of DGL +pip install --force-reinstall --no-deps torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128 +``` + +`pip check` will warn that `dgl requires torch<=2.4.0`; this warning is safe to ignore for scNiche's code paths. The `DGLGraph.adjacency_matrix()` / `dgl.sparse` routines are not ABI-compatible across torch versions and will fail to load `libdgl_sparse_pytorch_.so` on torch 2.7 — scNiche no longer calls those routines, but downstream user code that does will need to be patched or the DGL wheel rebuilt from source. + ### Install scNiche ``` diff --git a/scniche/trainer/_train.py b/scniche/trainer/_train.py index 1bef903..1d6b75a 100644 --- a/scniche/trainer/_train.py +++ b/scniche/trainer/_train.py @@ -105,7 +105,7 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): optim = torch.optim.Adam(self.model.parameters(), lr=lr) # loss pos_weight = torch.Tensor( - [float(self.graph[0].adjacency_matrix().to_dense().shape[0] ** 2 - self.edges / 2) / self.edges * 2] + [float(self.graph[0].num_nodes() ** 2 - self.edges / 2) / self.edges * 2] ) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) criterion_m = torch.nn.MSELoss().to(self.device) @@ -231,7 +231,7 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ): # loss pos_weight = torch.Tensor( - [float(graphs[0].adjacency_matrix().to_dense().shape[0] ** 2 - edges / 2) / edges * 2] + [float(graphs[0].num_nodes() ** 2 - edges / 2) / edges * 2] ) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device) criterion_m = torch.nn.MSELoss().to(self.device)