Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ pip install dgl==1.1.0+cu113 -f https://data.dgl.ai/wheels/cu113/repo.html
```
The version of PyTorch and DGL should be suitable to the CUDA version of your machine. You can find the appropriate version on the [PyTorch](https://pytorch.org/get-started/locally/) and [DGL](https://www.dgl.ai/) website.

#### Blackwell GPUs (RTX PRO 6000, RTX 50-series, sm_120)
Blackwell requires PyTorch built against CUDA 12.8. No DGL wheel is currently published for cu128, but the cu124 wheel's CUDA kernels are ABI-compatible with torch 2.7 and run on sm_120 via JIT forwarding. Install in this order:

```
# 1. Blackwell-capable PyTorch
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128

# 2. DGL (this step downgrades torch to 2.4.0 — expected)
pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html

# 3. Restore torch 2.7 on top of DGL
pip install --force-reinstall --no-deps torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128
```

`pip check` will warn that `dgl requires torch<=2.4.0`; this warning is safe to ignore for scNiche's code paths. The `DGLGraph.adjacency_matrix()` / `dgl.sparse` routines are not ABI-compatible across torch versions and will fail to load `libdgl_sparse_pytorch_<X>.so` on torch 2.7 — scNiche no longer calls those routines, but downstream user code that does will need to be patched or the DGL wheel rebuilt from source.


### Install scNiche
```
Expand Down
5 changes: 4 additions & 1 deletion scniche/preprocess/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def prepare_data_batch(
print("Constructing done.")

mydataset = myDataset(g_list)
dataloader = GraphDataLoader(mydataset, batch_size=1, shuffle=False, pin_memory=True)
dataloader = GraphDataLoader(
mydataset, batch_size=1, shuffle=False, pin_memory=True,
num_workers=4,
)
adata.uns['dataloader'] = dataloader

return adata
12 changes: 6 additions & 6 deletions scniche/trainer/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This project is licensed under the GPL-3.0 License.


import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -59,11 +58,12 @@ def consensus_graph(self, graphs, device):
fused_adj = torch.clamp(fused_adj, 0, 1)
fused_adj = torch.round(fused_adj + 0.1)

# build symmetric adjacency matrix
adj_np = fused_adj.detach().cpu().numpy()
adj_np += adj_np.T
adj_sparse = sp.csr_matrix(adj_np)
g = dgl.from_scipy(adj_sparse, device=device)
# build symmetric adjacency matrix ON GPU (avoids GPU<->CPU<->scipy roundtrip)
with torch.no_grad():
sym = fused_adj.detach()
sym = sym + sym.T
src, dst = (sym > 0).nonzero(as_tuple=True)
g = dgl.graph((src.int(), dst.int()), num_nodes=fused_adj.shape[0], device=device)

return fused_adj, g

Expand Down
Loading