Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ pip install dgl==1.1.0+cu113 -f https://data.dgl.ai/wheels/cu113/repo.html
```
The version of PyTorch and DGL should be suitable to the CUDA version of your machine. You can find the appropriate version on the [PyTorch](https://pytorch.org/get-started/locally/) and [DGL](https://www.dgl.ai/) website.

#### Blackwell GPUs (RTX PRO 6000, RTX 50-series, sm_120)
Blackwell requires PyTorch built against CUDA 12.8. No DGL wheel is currently published for cu128, but the cu124 wheel's CUDA kernels are ABI-compatible with torch 2.7 and run on sm_120 via JIT forwarding. Install in this order:

```
# 1. Blackwell-capable PyTorch
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128

# 2. DGL (this step downgrades torch to 2.4.0 — expected)
pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html

# 3. Restore torch 2.7 on top of DGL
pip install --force-reinstall --no-deps torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128
```

`pip check` will warn that `dgl requires torch<=2.4.0`; this warning is safe to ignore for scNiche's code paths. The `DGLGraph.adjacency_matrix()` / `dgl.sparse` routines are not ABI-compatible across torch versions and will fail to load `libdgl_sparse_pytorch_<X>.so` on torch 2.7 — scNiche no longer calls those routines, but downstream user code that does will need to be patched or the DGL wheel rebuilt from source.


### Install scNiche
```
Expand Down
4 changes: 2 additions & 2 deletions scniche/trainer/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ):
optim = torch.optim.Adam(self.model.parameters(), lr=lr)
# loss
pos_weight = torch.Tensor(
[float(self.graph[0].adjacency_matrix().to_dense().shape[0] ** 2 - self.edges / 2) / self.edges * 2]
[float(self.graph[0].num_nodes() ** 2 - self.edges / 2) / self.edges * 2]
)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device)
criterion_m = torch.nn.MSELoss().to(self.device)
Expand Down Expand Up @@ -231,7 +231,7 @@ def fit(self, lr: Optional[float] = 0.01, epochs: Optional[int] = 100, ):

# loss
pos_weight = torch.Tensor(
[float(graphs[0].adjacency_matrix().to_dense().shape[0] ** 2 - edges / 2) / edges * 2]
[float(graphs[0].num_nodes() ** 2 - edges / 2) / edges * 2]
)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(self.device)
criterion_m = torch.nn.MSELoss().to(self.device)
Expand Down