Memory efficient and fast graph partitioner implementations for GriNNder (MLSys '26)
Supports the following partitioners: GriNNder's switching-aware partitioning (MLSys '26), Spinner (ICDE' 17), and Random.
from grdpart import GrinnderPartitioner
import torch
edge_index = torch.randint(0, 100, (2, 500))
partitioner = GrinnderPartitioner(num_parts=4)
result = partitioner.partition(edge_index)
print(f"Nodes: {result.partition.shape[0]}")
print(f"Partitions: {result.num_parts}")
print(f"Balance: {result.balance:.3f}")conda install -c conda-forge grdpartFor development:
git clone https://github.com/AIS-SNU/GriNNder.git
cd GriNNder/grdpart
pip install -e .| Partitioner | Algorithm | Best For |
|---|---|---|
| GrinnderPartitioner | GriNNder's switching-aware partitioning with optimized parallel implementation | GNN training with feature reuse |
| SpinnerPartitioner | Spinner's algorithm with Gather-Apply-Scatter (GAS) implementation | General-purpose balanced partitioning |
| RandomPartitioner | Uniform random | Quality baseline |
from grdpart import GrinnderPartitioner
import torch
# edge_index shape: [2, num_edges]
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long)
partitioner = GrinnderPartitioner(num_parts=4)
result = partitioner.partition(edge_index)from grdpart import GrinnderPartitioner
import numpy as np
# adjacency matrix: [num_nodes, num_nodes]
adj = np.random.randint(0, 2, (20, 20))
partitioner = GrinnderPartitioner(num_parts=4)
result = partitioner.partition(adj)from grdpart import GrinnderPartitioner
from torch_geometric.data import Data
data = Data(edge_index=edge_index, num_nodes=100)
partitioner = GrinnderPartitioner(num_parts=4)
result = partitioner.partition(data)Result object returned by all partitioners.
Fields:
partition(torch.Tensor): Partition assignment per node, shape [num_nodes]num_parts(int): Number of partitionsperm(torch.Tensor): Permutation sorting nodes by partitionptr(torch.Tensor): Partition boundary pointers, shape [num_parts + 1]balance(float): Load balance ratio (max_size / avg_size), >= 1.0
Methods:
partition_sizes(): List[int] — number of nodes in each partition
Streaming partitioning with GAS algorithm.
Parameters:
num_parts(int): Number of partitions, >= 1capacity(float): Over-capacity factor, >= 1.0 (default: 1.1)beta(float): Load penalty weight, >= 0 (default: 1.0)max_iter(int): Maximum iterations, >= 1 (default: 50)halting_eps(float): Convergence threshold, >= 0 (default: 1e-4)halting_window(int): Window size for convergence, >= 1 (default: 5)num_threads(int): OpenMP threads, >= 1 (default: 4)
Reuse-aware partitioning with 2nd-best target scoring.
Note: num_parts must be a power of 2.
Parameters:
num_parts(int): Number of partitions (must be power of 2)capacity(float): Over-capacity factor, >= 1.0 (default: 1.1)beta(float): Load penalty weight, >= 0 (default: 1.0)max_iter(int): Maximum iterations, >= 1 (default: 50)halting_eps(float): Convergence threshold, >= 0 (default: 1e-4)halting_window(int): Window size for convergence, >= 1 (default: 5)reuse_aware(bool): Enable reuse-aware scoring (default: True)refine(bool): Enable partition refinement (default: True)num_threads(int): OpenMP threads, >= 1 (default: 4)
Uniform random baseline partitioner.
Parameters:
num_parts(int): Number of partitions, >= 1
Partitioning the Reddit graph into 4 parts on CPU:
from grdpart import GrinnderPartitioner, SpinnerPartitioner, RandomPartitioner
from torch_geometric.datasets import Reddit
data = Reddit(root="/tmp/Reddit")[0]
for name, p in [
("GriNNder", GrinnderPartitioner(num_parts=4, num_threads=32)),
("Spinner", SpinnerPartitioner(num_parts=4, num_threads=32)),
("Random", RandomPartitioner(num_parts=4)),
]:
result = p.partition(data)
print(f"{name:<12} balance={result.balance:.4f}")Partitioning the ogbn-products graph into 4 parts on CPU:
from grdpart import GrinnderPartitioner, SpinnerPartitioner, RandomPartitioner
from ogb.nodeproppred import PygNodePropPredDataset
dataset = PygNodePropPredDataset(name="ogbn-products", root="/tmp/ogbn-products")
data = dataset[0]
for name, p in [
("GriNNder", GrinnderPartitioner(num_parts=4, num_threads=32)),
("Spinner", SpinnerPartitioner(num_parts=4, num_threads=32)),
("Random", RandomPartitioner(num_parts=4)),
]:
result = p.partition(data)
print(f"{name:<12} balance={result.balance:.4f}")Reusability measures the average number of unique neighboring partitions per node (printed by C++ log mode, enabled by default). Lower reusability means less partition expansion — i.e., fewer feature recomputations during GNN training. GriNNder achieves low reusability while optimizing balance, which is critical for GriNNder's storage offloading pipeline.
MIT License — AIS-SNU
For the full GriNNder project, see https://github.com/AIS-SNU/GriNNder
@inproceedings{song2026grinnder,
title={Gri{NN}der: Breaking the Memory Capacity Wall in Full-Graph {GNN} Training with Storage Offloading},
author={Song, Jaeyong and Park, Seongyeon and Jang, Hongsun and Jung, Jaewon and Lim, Hunseong and Hong, Junguk and Lee, Jinho},
booktitle={Ninth Conference on Machine Learning and Systems (MLSys 2026)},
year={2026},
url={https://openreview.net/forum?id=8SNPzGRldN}
}