From db18a223c97409058667f3353a0d15ee3edfce0e Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Wed, 28 Jan 2026 11:25:00 +0800 Subject: [PATCH 1/8] add LRGB pcqm-contact example --- .gitignore | 1 + example/LRGB/criterion.py | 78 ++++++++++++++++ example/LRGB/data.py | 8 ++ example/LRGB/mrr.py | 81 +++++++++++++++++ example/LRGB/run.py | 181 ++++++++++++++++++++++++++++++++++++++ example/TSP/run.py | 1 + 6 files changed, 350 insertions(+) create mode 100644 example/LRGB/criterion.py create mode 100644 example/LRGB/data.py create mode 100644 example/LRGB/mrr.py create mode 100644 example/LRGB/run.py diff --git a/.gitignore b/.gitignore index 3b56bf9..b37650d 100644 --- a/.gitignore +++ b/.gitignore @@ -216,3 +216,4 @@ example/outputs count.out run_scripts example/data/TSP +example/data/BREC/ diff --git a/example/LRGB/criterion.py b/example/LRGB/criterion.py new file mode 100644 index 0000000..9427483 --- /dev/null +++ b/example/LRGB/criterion.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FocalLoss(nn.Module): + def __init__( + self, + gamma=2, + alpha=None, + reduction='mean', + task_type='binary', + ): + """ + Unified Focal Loss class for binary, multi-class, and multi-label classification tasks. + :param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma + :param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used. + :param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum' + :param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label' + :param num_classes: Number of classes (only required for multi-class classification) + """ + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.task_type = task_type + + def forward(self, inputs, graph): + """ + Forward pass to compute the Focal Loss based on the specified task type. + :param inputs: Predictions (logits) from the model. + Shape: + - binary/multi-label: (batch_size, num_classes) + - multi-class: (batch_size, num_classes) + :param targets: Ground truth labels. + Shape: + - binary: (batch_size,) + - multi-label: (batch_size, num_classes) + - multi-class: (batch_size,) + """ + if self.task_type == 'binary': + return self.binary_focal_loss(inputs, graph) + else: + raise ValueError( + f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'.") + + def binary_focal_loss(self, inputs, graph): + """ Focal loss for binary classification. """ + inputs = inputs[-1] # get the edge prediction + inputs = inputs.squeeze(-1) + + targets = graph.adj_label + mask = graph.pair_mask + + probs = torch.sigmoid(inputs).clamp(min=1e-10) + targets = targets.float() + + # Compute binary cross entropy + bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') + + # Compute focal weight + p_t = probs * targets + (1 - probs) * (1 - targets) + focal_weight = (1 - p_t) ** self.gamma + + # Apply alpha if provided + if self.alpha is not None: + alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) + bce_loss = alpha_t * bce_loss + + # Apply focal loss weighting + loss = focal_weight * bce_loss + loss = loss * mask + + if self.reduction == 'mean': + return loss.sum() / mask.sum().clamp(min=1), {} + elif self.reduction == 'sum': + return loss.sum(), {} + return loss, {} diff --git a/example/LRGB/data.py b/example/LRGB/data.py new file mode 100644 index 0000000..b64f1b2 --- /dev/null +++ b/example/LRGB/data.py @@ -0,0 +1,8 @@ +import torch + +def transform_pcqm(graph): + # Data(x=[36, 9], edge_index=[2, 72], edge_attr=[72, 3], edge_label_index=[2, 84], edge_label=[84]) + graph.x = torch.nn.functional.pad(graph.x, (0, 0, 1, 0)) + graph.edge_index = graph.edge_index + 1 + graph.edge_label_index = graph.edge_label_index + 1 + return graph \ No newline at end of file diff --git a/example/LRGB/mrr.py b/example/LRGB/mrr.py new file mode 100644 index 0000000..2cafb7d --- /dev/null +++ b/example/LRGB/mrr.py @@ -0,0 +1,81 @@ +from collections import defaultdict + +import torch + +def _eval_mrr(y_pred_pos, y_pred_neg): + """ Compute Hits@k and Mean Reciprocal Rank (MRR). + + Implementation from OGB: + https://github.com/snap-stanford/ogb/blob/master/ogb/linkproppred/evaluate.py + + Args: + y_pred_neg: array with shape (batch size, num_entities_neg). + y_pred_pos: array with shape (batch size, ) + """ + + y_pred = torch.cat([y_pred_pos.view(-1, 1), y_pred_neg], dim=1) + argsort = torch.argsort(y_pred, dim=1, descending=True) + ranking_list = torch.nonzero(argsort == 0, as_tuple=False) + ranking_list = ranking_list[:, 1] + 1 + # average within graph + hits1 = (ranking_list <= 1).to(torch.float).mean().item() + hits3 = (ranking_list <= 3).to(torch.float).mean().item() + hits10 = (ranking_list <= 10).to(torch.float).mean().item() + mrr = (1. / ranking_list.to(torch.float)).mean().item() + + # print(f"hits@1 {hits1:.5f}") + # print(f"hits@3 {hits3:.5f}") + # print(f"hits@10 {hits10:.5f}") + # print(f"mrr {mrr:.5f}") + return hits1, hits3, hits10, mrr + + +class EdgeMRR: + def __init__(self): + self.states = defaultdict(lambda: []) + + def clean(self): + self.states = defaultdict(lambda: []) + + def add_batch(self, pred, graph_batch): + pred = pred[-1] + for b in range(graph_batch.single_mask.shape[0]): + indices = torch.where(graph_batch.single_mask[b])[0] + self.states["preds"].append(pred[b][indices][:, indices]) + self.states["trues"].append(graph_batch.adj_label[b][indices][:, indices]) + + def compute(self): + # pred: list of [n, n] + # true: list of [n, n] + pred_list = self.states["preds"] + true_list = self.states["trues"] + batch_stats = [[], [], [], []] + for pred, true in zip(pred_list, true_list): + n = pred.shape[0] + pos_edge_index = torch.where(true == 1) + pred_pos = pred[pos_edge_index] + num_pos_edges = pos_edge_index[0].shape[0] + if num_pos_edges == 0: + continue + + neg_mask = torch.ones([num_pos_edges, n], dtype=torch.bool) + neg_mask[torch.arange(num_pos_edges), pos_edge_index[1]] = False + pred_neg = pred[pos_edge_index[0]][neg_mask].view(num_pos_edges, -1) + + mrr_list = _eval_mrr(pred_pos, pred_neg) + for i, v in enumerate(mrr_list): + batch_stats[i].append(v) + # average among all graphs + res = [] + for i in range(4): + v = torch.tensor(batch_stats[i]) + v = torch.nan_to_num(v, nan=0).sum().item() + res.append(v) + + return { + 'hits@1': res[0], + 'hits@3': res[1], + 'hits@10': res[2], + 'mrr': res[3], + "sample_count": len(pred_list), + } \ No newline at end of file diff --git a/example/LRGB/run.py b/example/LRGB/run.py new file mode 100644 index 0000000..c462369 --- /dev/null +++ b/example/LRGB/run.py @@ -0,0 +1,181 @@ +# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import argparse +from pathlib import Path +from functools import partial + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +import wandb +import torch +from torch_geometric.datasets import LRGBDataset + +from common.utils import compute_init, print0, compute_cleanup, is_ddp_initialized, DummyWandb, is_master_process, load_checkpoint +from common.dataloader import create_dataloader +from common.model import ModelConfig, FloydNet +from common.trainer import Trainer +from LRGB.data import transform_pcqm +from LRGB.criterion import FocalLoss +from LRGB.mrr import EdgeMRR + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run FloydNet for LRGB task") + # data + parser.add_argument("--data_dir", type=str, default="data/LRGB", help="Directory for the dataset") + parser.add_argument("--name", type=str, default="pcqm-contact") + parser.add_argument("--batch_size", type=int, default=48, help="Batch size for training") + # model + parser.add_argument("--seed", type=int, default=158293, help="Random seed for initialization") + parser.add_argument("--n_embd", type=int, default=256, help="Embedding dimension") + parser.add_argument("--n_head", type=int, default=4, help="Number of attention heads") + parser.add_argument("--depth", type=int, default=32, help="Number of floyd transformer layers") + parser.add_argument("--dropout", type=float, default=0.15) + # training + parser.add_argument("--max_train_epoch_len", type=int, default=400, help="Maximum length of training epoch in number of samples") + parser.add_argument("--max_val_epoch_len", type=int, default=1000, help="Maximum length of validation epoch in number of samples") + parser.add_argument("--max_test_epoch_len", type=int, default=1000, help="Maximum length of test epoch in number of samples") + parser.add_argument("--max_epochs", type=int, default=5000, help="Maximum number of training epochs") + parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay for optimizer") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.98), help="Betas for AdamW optimizer") + parser.add_argument("--grad_clip", type=float, default=100.0, help="Gradient clipping value") + parser.add_argument("--grad_accumulation", type=int, default=1, help="Gradient accumulation steps") + parser.add_argument("--eval_interval", type=int, default=20, help="Evaluation interval (in epochs)") + parser.add_argument("--test_interval", type=int, default=20, help="Test interval (in epochs)") + parser.add_argument("--save_interval", type=int, default=100, help="Model saving interval (in epochs)") + # output + parser.add_argument("--wandb_name", type=str, default="dummy", help="WandB experiment name, dummy to disable WandB logging") + parser.add_argument("--output_dir", type=str, default="outputs/LRGB", help="Directory to save outputs") + parser.add_argument("--load_checkpoint", type=str, default=None, help="Path to load checkpoint from") + args = parser.parse_args() + return args + + +def build_dataloader(args): + dataset_builder = partial( + LRGBDataset, + root=args.data_dir, + name=args.name, + transform=partial(transform_pcqm), + ) + dataloaders = create_dataloader( + dataset=dataset_builder, + batch_size=args.batch_size, + train=True, + ) + return dataloaders + + +def build_model(args): + if args.load_checkpoint is not None: + print0(f"Loading model from checkpoint: {args.load_checkpoint}") + ckpt_model_config_path = Path(args.load_checkpoint).parent / "model_config.json" + print0(f"Using model config from checkpoint, ignoring command line model config args") + with open(ckpt_model_config_path, "r") as f: + model_config = json.load(f) + else: + model_config = dict( + n_embd=args.n_embd, + n_head=args.n_head, + n_out=1, + depth=args.depth, + enable_adj_emb=True, + enable_diffusion=False, + n_edge_feat=3, + edge_feat_vocab_size=120, + task_level="e", + n_decode_layers=4, + node_feat_vocab_size=120, + n_node_feat=9, + supernode=True, + dropout=args.dropout, + norm_fn="ln", + ) + + start_epoch = 0 + model = FloydNet(ModelConfig(**model_config)).to("cuda") + model.init_weights() + print0(model) + if is_ddp_initialized(): + torch.distributed.barrier() + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=tuple(args.adam_betas), eps=1e-8) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=10) + + if args.load_checkpoint is not None: + model, optimizer, scheduler, start_epoch = load_checkpoint( + checkpoint_path=args.load_checkpoint, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + print0(f"Resumed from epoch {start_epoch}") + + if is_master_process(): + # save model config for loading model later + output_dir = Path(args.output_dir) / "checkpoints" + os.makedirs(output_dir, exist_ok=True) + with open(output_dir / "model_config.json", "w") as f: + json.dump(model_config, f, indent=4) + + return model, optimizer, scheduler, model_config, start_epoch + + +def main(args): + print0("Initializing distributed environment...") + compute_init(seed=args.seed) + + print0("Building dataloaders...") + dataloaders = build_dataloader(args) + print0("Building model, optimizer, and scheduler...") + model, optimizer, scheduler, model_config, start_epoch = build_model(args) + + criterion = FocalLoss() + logger = DummyWandb() if (args.wandb_name == "dummy" or not is_master_process()) else wandb.init( + project="floydnetwork", + name=args.wandb_name, + config=vars(args).copy(), + ) + trainer = Trainer( + model=model, + optimizer=optimizer, + scheduler=scheduler, + critn=criterion, + out_dir=args.output_dir, + max_epochs=args.max_epochs, + amp=True, + grad_clip_val=args.grad_clip, + grad_accumulation=args.grad_accumulation, + eval_interval=args.eval_interval, + test_interval=args.test_interval, + save_interval=args.save_interval, + max_train_epoch_len=args.max_train_epoch_len, + max_val_epoch_len=args.max_val_epoch_len, + max_test_epoch_len=args.max_test_epoch_len, + ) + print0("Starting training...") + trainer.fit(logger, dataloaders, start_epoch=start_epoch, metric=EdgeMRR()) + + compute_cleanup() + print0("Finished all processes.") + + +if __name__ == '__main__': + args = parse_args() + main(args) \ No newline at end of file diff --git a/example/TSP/run.py b/example/TSP/run.py index a28bb0b..94fdc06 100644 --- a/example/TSP/run.py +++ b/example/TSP/run.py @@ -115,6 +115,7 @@ def build_model(args): depth=args.depth, enable_adj_emb=False, enable_diffusion=True, + n_edge_feat=1, edge_feat_vocab_size=220, task_level="e", n_decode_layers=1, From 6532baaa6bb9004a28ba16be6bb585034b8dad0b Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Wed, 28 Jan 2026 11:25:47 +0800 Subject: [PATCH 2/8] update common for LRGB example --- example/common/graph.py | 14 +++++++- example/common/model.py | 67 +++++++++++++++++++++++++++++++++++---- example/common/trainer.py | 27 +++++++++++++--- example/common/utils.py | 11 +++++-- 4 files changed, 104 insertions(+), 15 deletions(-) diff --git a/example/common/graph.py b/example/common/graph.py index 7591699..e3eb82f 100644 --- a/example/common/graph.py +++ b/example/common/graph.py @@ -72,7 +72,7 @@ def batch_pad_e_y(y, ptr, counts, max_m): return padded_y -def graph_preprocess(graph): +def graph_preprocess(graph, supernode): """ Convert a PyG-style batched graph object into dense, per-graph padded tensors. @@ -116,6 +116,18 @@ def graph_preprocess(graph): if "y_edge" in graph: graph.y_edge = batch_pad_e_y(graph.y_edge, graph.y_edge_ptr, counts, max_m) + if supernode: + b, m = x.shape[:2] + is_supernode = torch.zeros((b, m), dtype=torch.long, device=adj.device) + adj_superedge = torch.zeros((b, m, m), dtype=torch.long, device=adj.device) + is_supernode[:, 0] = 1 + adj_superedge[:, 0, :] = 1 + adj_superedge[:, :, 0] = 2 + adj_superedge[:, 0, 0] = 3 + + graph.is_supernode = is_supernode + graph.adj_superedge = adj_superedge + graph.x = x graph.single_mask = mask graph.pair_mask = mask[:, :, None] * mask[:, None, :] diff --git a/example/common/model.py b/example/common/model.py index a883990..8a1d4c2 100644 --- a/example/common/model.py +++ b/example/common/model.py @@ -33,10 +33,16 @@ class ModelConfig: enable_adj_emb: bool = True enable_diffusion: bool = False task_level: str = "g" + n_edge_feat: int = 0 edge_feat_vocab_size: int = -1 n_decode_layers: int = 1 decoder_mask_by_adj: bool = False enable_ffn: bool = True + node_feat_vocab_size: int = 0 + n_node_feat: int = 0 + supernode: bool = False + dropout: float = 0.0 + norm_fn: str = "affine" class DiffusionEmbedder(nn.Module): def __init__(self, n_embd): @@ -110,12 +116,20 @@ def __init__( if config.enable_adj_emb: self.emb_adj = nn.Embedding(2, config.n_embd) - if config.edge_feat_vocab_size > 0: - self.emb_edge = nn.Embedding(config.edge_feat_vocab_size, config.n_embd) + if config.n_edge_feat > 0 and config.edge_feat_vocab_size > 0: + if config.n_edge_feat != 1: + self.emb_edge = nn.ModuleList([nn.Embedding(config.edge_feat_vocab_size, config.n_embd) for _ in range(config.n_edge_feat)]) + else: + self.emb_edge = nn.Embedding(config.edge_feat_vocab_size, config.n_embd) if config.enable_diffusion: self.diffusion_embedder = DiffusionEmbedder(config.n_embd) + if config.n_node_feat > 0 and config.node_feat_vocab_size > 0: + self.emb_node_i = nn.ModuleList([nn.Embedding(config.node_feat_vocab_size, config.n_embd) for _ in range(config.n_node_feat)]) + self.emb_node_j = nn.ModuleList([nn.Embedding(config.node_feat_vocab_size, config.n_embd) for _ in range(config.n_node_feat)]) + if config.supernode: + self.emb_superedge = nn.Embedding(4, config.n_embd) - self.blocks = nn.ModuleList([PivotalAttentionBlock(embed_dim=config.n_embd, num_heads=config.n_head, activation_fn="silu", norm_fn="affine", enable_ffn=config.enable_ffn) for _ in range(config.depth)]) + self.blocks = nn.ModuleList([PivotalAttentionBlock(embed_dim=config.n_embd, num_heads=config.n_head, activation_fn="silu", norm_fn=self.config.norm_fn, enable_ffn=config.enable_ffn, dropout=self.config.dropout) for _ in range(config.depth)]) if "g" in config.task_level: self.head_g = FFN(config.n_embd, config.n_out, config.n_decode_layers) if "v" in config.task_level: @@ -127,7 +141,18 @@ def init_weights(self): if hasattr(self, "emb_adj"): nn.init.normal_(self.emb_adj.weight, mean=0.0, std=1.0) if hasattr(self, "emb_edge"): - nn.init.normal_(self.emb_edge.weight, mean=0.0, std=1.0) + if self.config.n_edge_feat > 1: + for emb in self.emb_edge: + nn.init.normal_(emb.weight, mean=0.0, std=1.0) + else: + nn.init.normal_(self.emb_edge.weight, mean=0.0, std=1.0) + if hasattr(self, "emb_node_i"): + for emb in self.emb_node_i: + nn.init.normal_(emb.weight, mean=0.0, std=1.0) + for emb in self.emb_node_j: + nn.init.normal_(emb.weight, mean=0.0, std=1.0) + if hasattr(self, "emb_superedge"): + nn.init.normal_(self.emb_superedge.weight, mean=0.0, std=1.0) if hasattr(self, "diffusion_embedder"): self.diffusion_embedder.init_weights() if hasattr(self, "head_g"): @@ -141,16 +166,46 @@ def init_weights(self): b._reset_parameters() def preprocess(self, graph: pyg.data.Data): - return graph_preprocess(graph) + return graph_preprocess(graph, supernode=self.config.supernode) def embed(self, graph: pyg.data.Data): x = 0.0 if self.config.enable_adj_emb: x = x + self.emb_adj(graph.adj) if self.config.edge_feat_vocab_size > 0: - x = x + self.emb_edge(graph.adj_attr[:, :, :, 0]) + if self.config.n_edge_feat > 1: + for idx in range(self.config.n_edge_feat): + x = x + self.emb_edge[idx](graph.adj_attr[:, :, :, idx]) + else: + x = x + self.emb_edge(graph.adj_attr[:, :, :, 0]) + if self.config.n_node_feat > 0 and self.config.node_feat_vocab_size > 0: + if self.config.supernode: + graph.x = graph.x.to(torch.long) + emb_node_i = 0 + emb_node_j = 0 + for idx in range(self.config.n_node_feat): + emb_node_i = emb_node_i + self.emb_node_i[idx](graph.x[:, 1:, idx]) + emb_node_j = emb_node_j + self.emb_node_j[idx](graph.x[:, 1:, idx]) + # emb_node_i & j: [B, N - 1, c] + # add to superedge, which is first row and first column + # take care supernode it self is removed + n = graph.x.shape[1] + emb_node_i = emb_node_i[:, :, None] + emb_node_j = emb_node_j[:, None, :] + # i: [B, N - 1, 1, c] -> [B, N, N, c] + emb_node_i = torch.nn.functional.pad(emb_node_i, (0, 0, 0, n - 1, 1, 0)) + # j: [B, 1, N - 1, c] -> [B, N, N, c] + emb_node_j = torch.nn.functional.pad(emb_node_j, (0, 0, 1, 0, 0, n - 1)) + + x = x + emb_node_i + x = x + emb_node_j + else: + raise ValueError("Supernode must be enabled when using node features.") if self.config.enable_diffusion: x = x + self.diffusion_embedder(graph) + if self.config.supernode: + x = x + self.emb_superedge(graph.adj_superedge) + x = x * graph.pair_mask.unsqueeze(-1) return x diff --git a/example/common/trainer.py b/example/common/trainer.py index 24926d5..206be5e 100644 --- a/example/common/trainer.py +++ b/example/common/trainer.py @@ -95,6 +95,7 @@ def fit( logger, dataloader, start_epoch, + metric=None, ): start_fit_time = time.time() self.critn = self.critn.cuda() @@ -114,16 +115,18 @@ def fit( should_evaluate_this_epoch = (epoch == 0 or (epoch + 1) % self.eval_interval == 0 or epoch == self.max_epochs - 1) if should_evaluate_this_epoch: - metrics = self.eval_fn(dataloader["val"], "Val", epoch,) + metrics = self.eval_fn(dataloader["val"], "Val", epoch, metric=metric) min_val_loss = min(min_val_loss, metrics["loss"]) logger.log({f"val/{k}": v for k, v in metrics.items()}) + print("val metrics:", metrics) self.scheduler.step(loss) if (epoch + 1) % self.save_interval == 0 or epoch == self.max_epochs - 1: save_checkpoint(self.model, self.optimizer, self.scheduler, epoch, loss, self.out_dir / "checkpoints") should_test_this_epoch = ((epoch + 1) % self.test_interval == 0 or epoch == self.max_epochs - 1) if should_test_this_epoch: - metrics = self.eval_fn(dataloader["test"], "Test", epoch,) + metrics = self.eval_fn(dataloader["test"], "Test", epoch, metric=metric) + print("test metrics:", metrics) min_test_loss = min(min_test_loss, metrics["loss"]) logger.log({f"test/{k}": v for k, v in metrics.items()}) @@ -136,7 +139,7 @@ def test( epoch, sample_count_per_case=1, ): - loss = self.eval_fn(dataloader["test"], "Test", epoch, sample_count_per_case=sample_count_per_case) + loss = self.eval_fn(dataloader["test"], "Test", epoch, sample_count_per_case=sample_count_per_case, metric=metric) return loss def _train_epoch(self, loader, epoch): @@ -178,11 +181,13 @@ def _train_epoch(self, loader, epoch): return loss_sum / loss_count @torch.no_grad() - def _evaluate(self, loader, stage_name, epoch, **kwargs): + def _evaluate(self, loader, stage_name, epoch, metric=None, **kwargs): self.model.eval() loss_sum = 0.0 loss_count = 0 loss_breakdown_sum = {} + if metric is not None: + metric.clean() max_len = self.max_val_epoch_len if stage_name.lower() == "val" else self.max_test_epoch_len progress_bar = tqdm(loader, desc=f"{stage_name} Epoch {epoch+1}/{self.max_epochs}", leave=False, file=sys.stdout, total=min(len(loader), max_len), disable=not is_master_process()) @@ -203,6 +208,9 @@ def _evaluate(self, loader, stage_name, epoch, **kwargs): loss_sum += loss.item() loss_count += 1 + if metric is not None: + metric.add_batch(pred, graph_batch) + progress_bar.set_postfix(loss=f'{loss.item():.5f}') progress_bar.close() @@ -210,11 +218,20 @@ def _evaluate(self, loader, stage_name, epoch, **kwargs): metrics = {k: v / loss_count for k, v in loss_breakdown_sum.items()} metrics["loss"] = loss_sum / loss_count metrics = reduce_metrics(metrics) + + if metric is not None: + new_metric = metric.compute() + new_metric = reduce_metrics(new_metric, reduction="sum") + for k, v in new_metric.items(): + if k != "sample_count": + v = v / new_metric["sample_count"] + metrics[k] = v + return metrics @torch.no_grad() - def _evaluate_TSP(self, loader, stage_name, epoch, sample_count_per_case=1): + def _evaluate_TSP(self, loader, stage_name, epoch, sample_count_per_case=1, **kwargs): self.model.eval() out_dir = self.out_dir / f"{stage_name}_infer_results_epoch{epoch+1}" out_dir.mkdir(parents=True, exist_ok=True) diff --git a/example/common/utils.py b/example/common/utils.py index a317faf..3f74f5f 100644 --- a/example/common/utils.py +++ b/example/common/utils.py @@ -101,7 +101,7 @@ def empty_cache(): gc.collect() torch.cuda.synchronize() -def reduce_metrics(metric_dict): +def reduce_metrics(metric_dict, reduction="mean"): if not is_ddp_initialized(): return metric_dict @@ -110,8 +110,13 @@ def reduce_metrics(metric_dict): dist.all_reduce(metric_values, op=dist.ReduceOp.SUM) - world_size = dist.get_world_size() - metric_values /= world_size + if reduction == "mean": + world_size = dist.get_world_size() + metric_values /= world_size + elif reduction == "sum": + pass # already summed + else: + raise ValueError(f"Unsupported reduction type: {reduction}") avg_metric_dict = {k: v.item() for k, v in zip(metric_names, metric_values)} return avg_metric_dict From 719925091c33f901c43032364bc4e53f7f14e893 Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Thu, 29 Jan 2026 15:19:41 +0800 Subject: [PATCH 3/8] add gcn to compare --- example/LRGB/gcn_model.py | 147 ++++++++++++++++++++++++++++++++++++++ example/LRGB/run.py | 67 ++++++++++------- example/common/trainer.py | 2 +- 3 files changed, 188 insertions(+), 28 deletions(-) create mode 100644 example/LRGB/gcn_model.py diff --git a/example/LRGB/gcn_model.py b/example/LRGB/gcn_model.py new file mode 100644 index 0000000..8b47b23 --- /dev/null +++ b/example/LRGB/gcn_model.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.nn import GCNConv +from common.graph import graph_preprocess + +try: + from ogb.graphproppred.mol_encoder import AtomEncoder +except ImportError as e: + AtomEncoder = None + + + +@dataclass +class GCNContactConfig: + layers_mp: int = 5 + layers_post_mp: int = 1 + dim_inner: int = 275 + dropout: float = 0.0 + batchnorm: bool = True + act: str = "relu" + agg: str = "mean" + edge_decoding: str = "dot" + gcn_add_self_loops: bool = True + gcn_normalize: bool = True + + +class MLPNoAct(nn.Module): + def __init__(self, dim: int, num_layers: int): + super().__init__() + assert num_layers >= 1 + layers = [] + for _ in range(num_layers): + layers.append(nn.Linear(dim, dim, bias=True)) + self.net = nn.Sequential(*layers) + self.reset_parameters() + + def reset_parameters(self): + for m in self.net.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class GCNContactModel(nn.Module): + def __init__(self, cfg: Optional[GCNContactConfig] = None): + super().__init__() + self.cfg = cfg or GCNContactConfig() + dim = self.cfg.dim_inner + + if AtomEncoder is None: + raise ImportError( + "ogb is required to match LRGB Atom encoder. " + "Please `pip install ogb` in your environment." + ) + self.node_encoder = AtomEncoder(emb_dim=dim) + + # Message passing stack + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + for _ in range(self.cfg.layers_mp): + self.convs.append( + GCNConv( + in_channels=dim, + out_channels=dim, + improved=False, + cached=False, + add_self_loops=self.cfg.gcn_add_self_loops, + normalize=self.cfg.gcn_normalize, + bias=True, + ) + ) + if self.cfg.batchnorm: + self.bns.append(nn.BatchNorm1d(dim)) + + self.post_mp = MLPNoAct(dim=dim, num_layers=self.cfg.layers_post_mp) + + self.reset_parameters() + + def reset_parameters(self): + if hasattr(self.node_encoder, "reset_parameters"): + self.node_encoder.reset_parameters() + + for i, conv in enumerate(self.convs): + conv.reset_parameters() + if self.cfg.batchnorm: + self.bns[i].reset_parameters() + + self.post_mp.reset_parameters() + + def _encode_nodes(self, data: Data) -> torch.Tensor: + if data.x is None: + raise ValueError("data.x is required for PCQM4Mv2Contact AtomEncoder.") + if data.x.dtype != torch.long: + raise TypeError(f"Expected data.x dtype torch.long, got {data.x.dtype}") + return self.node_encoder(data.x) + + def _mp(self, h: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: + for i, conv in enumerate(self.convs): + h = conv(h, edge_index) + if self.cfg.batchnorm: + h = self.bns[i](h) + h = F.relu(h) + if self.cfg.dropout > 0: + h = F.dropout(h, p=self.cfg.dropout, training=self.training) + return h + + def _decode_edges_dot(self, h: torch.Tensor, edge_label_index: torch.Tensor) -> torch.Tensor: + # edge_label_index: [2, E_pred] + src, dst = edge_label_index[0], edge_label_index[1] + # single logit per edge (matches dot decoding head expectation) + return (h[src] * h[dst]).sum(dim=-1, keepdim=True) # [E_pred, 1] + + def preprocess(self, data): + return data + + def forward(self, data: Data) -> Tuple[None, None, torch.Tensor]: + if data.edge_index is None: + raise ValueError("data.edge_index is required.") + if not hasattr(data, "edge_label_index") or data.edge_label_index is None: + raise ValueError("data.edge_label_index [2, E_pred] is required for edge prediction.") + + h = self._encode_nodes(data) + h = self._mp(h, data.edge_index) + + # Head post-mp on nodes, then dot decode + h = self.post_mp(h) + + data.x = h + data = graph_preprocess(data, supernode=False) + logits = data.x @ data.x.transpose(1, 2) + logits = logits.unsqueeze(-1) # [B, N, N, 1] + return None, None, logits + # original impl with edge decoding: + # logits = self._decode_edges_dot(h, data.edge_label_index) + # data.x = h + # return None, None, logits diff --git a/example/LRGB/run.py b/example/LRGB/run.py index c462369..95b0551 100644 --- a/example/LRGB/run.py +++ b/example/LRGB/run.py @@ -31,6 +31,7 @@ from LRGB.data import transform_pcqm from LRGB.criterion import FocalLoss from LRGB.mrr import EdgeMRR +from LRGB.gcn_model import GCNContactModel, GCNContactConfig def parse_args(): @@ -45,6 +46,7 @@ def parse_args(): parser.add_argument("--n_head", type=int, default=4, help="Number of attention heads") parser.add_argument("--depth", type=int, default=32, help="Number of floyd transformer layers") parser.add_argument("--dropout", type=float, default=0.15) + parser.add_argument("--gcn", action="store_true", help="Use GCN model instead of FloydNet") # training parser.add_argument("--max_train_epoch_len", type=int, default=400, help="Maximum length of training epoch in number of samples") parser.add_argument("--max_val_epoch_len", type=int, default=1000, help="Maximum length of validation epoch in number of samples") @@ -82,35 +84,46 @@ def build_dataloader(args): def build_model(args): - if args.load_checkpoint is not None: - print0(f"Loading model from checkpoint: {args.load_checkpoint}") - ckpt_model_config_path = Path(args.load_checkpoint).parent / "model_config.json" - print0(f"Using model config from checkpoint, ignoring command line model config args") - with open(ckpt_model_config_path, "r") as f: - model_config = json.load(f) + if args.gcn: + print0("Using GCN model for LRGB task") + model_config = GCNContactConfig(dense_repr=args.dense_repr) + model = GCNContactModel(model_config).to("cuda") + print0(model) else: - model_config = dict( - n_embd=args.n_embd, - n_head=args.n_head, - n_out=1, - depth=args.depth, - enable_adj_emb=True, - enable_diffusion=False, - n_edge_feat=3, - edge_feat_vocab_size=120, - task_level="e", - n_decode_layers=4, - node_feat_vocab_size=120, - n_node_feat=9, - supernode=True, - dropout=args.dropout, - norm_fn="ln", - ) + if args.load_checkpoint is not None: + print0(f"Loading model from checkpoint: {args.load_checkpoint}") + ckpt_model_config_path = Path(args.load_checkpoint).parent / "model_config.json" + print0(f"Using model config from checkpoint, ignoring command line model config args") + with open(ckpt_model_config_path, "r") as f: + model_config = json.load(f) + else: + model_config = dict( + n_embd=args.n_embd, + n_head=args.n_head, + n_out=1, + depth=args.depth, + enable_adj_emb=True, + enable_diffusion=False, + n_edge_feat=3, + edge_feat_vocab_size=120, + task_level="e", + n_decode_layers=4, + node_feat_vocab_size=120, + n_node_feat=9, + supernode=True, + dropout=args.dropout, + norm_fn="ln", + ) + model = FloydNet(ModelConfig(**model_config)).to("cuda") + model.init_weights() + print0(model) + + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print0(f"Total trainable parameters: {trainable_params / 1e6:.2f}M") + start_epoch = 0 - model = FloydNet(ModelConfig(**model_config)).to("cuda") - model.init_weights() - print0(model) + if is_ddp_initialized(): torch.distributed.barrier() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) @@ -127,7 +140,7 @@ def build_model(args): ) print0(f"Resumed from epoch {start_epoch}") - if is_master_process(): + if is_master_process() and not args.gcn: # save model config for loading model later output_dir = Path(args.output_dir) / "checkpoints" os.makedirs(output_dir, exist_ok=True) diff --git a/example/common/trainer.py b/example/common/trainer.py index 206be5e..9c176f8 100644 --- a/example/common/trainer.py +++ b/example/common/trainer.py @@ -223,7 +223,7 @@ def _evaluate(self, loader, stage_name, epoch, metric=None, **kwargs): new_metric = metric.compute() new_metric = reduce_metrics(new_metric, reduction="sum") for k, v in new_metric.items(): - if k != "sample_count": + if k != "sample_count" and "sample_count" in new_metric: v = v / new_metric["sample_count"] metrics[k] = v From 5d14ace07ddd5e89c2df5ed14c9918a1d201626f Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Thu, 29 Jan 2026 15:22:01 +0800 Subject: [PATCH 4/8] add license --- example/LRGB/criterion.py | 14 ++++++++++++++ example/LRGB/data.py | 17 ++++++++++++++++- example/LRGB/mrr.py | 14 ++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/example/LRGB/criterion.py b/example/LRGB/criterion.py index 9427483..0b47cd1 100644 --- a/example/LRGB/criterion.py +++ b/example/LRGB/criterion.py @@ -1,3 +1,17 @@ +# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/example/LRGB/data.py b/example/LRGB/data.py index b64f1b2..3753209 100644 --- a/example/LRGB/data.py +++ b/example/LRGB/data.py @@ -1,7 +1,22 @@ +# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch def transform_pcqm(graph): - # Data(x=[36, 9], edge_index=[2, 72], edge_attr=[72, 3], edge_label_index=[2, 84], edge_label=[84]) + # format Data(x=[36, 9], edge_index=[2, 72], edge_attr=[72, 3], edge_label_index=[2, 84], edge_label=[84]) + # add supernode at the beginning graph.x = torch.nn.functional.pad(graph.x, (0, 0, 1, 0)) graph.edge_index = graph.edge_index + 1 graph.edge_label_index = graph.edge_label_index + 1 diff --git a/example/LRGB/mrr.py b/example/LRGB/mrr.py index 2cafb7d..b5dc502 100644 --- a/example/LRGB/mrr.py +++ b/example/LRGB/mrr.py @@ -1,3 +1,17 @@ +# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import defaultdict import torch From 3c968309ca46b093fa8f336eeb34fe6ea0ba9134 Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Thu, 29 Jan 2026 20:40:41 +0800 Subject: [PATCH 5/8] update --- example/LRGB/run.py | 6 +++--- example/common/model.py | 2 +- example/common/trainer.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/example/LRGB/run.py b/example/LRGB/run.py index 95b0551..8d06653 100644 --- a/example/LRGB/run.py +++ b/example/LRGB/run.py @@ -46,7 +46,7 @@ def parse_args(): parser.add_argument("--n_head", type=int, default=4, help="Number of attention heads") parser.add_argument("--depth", type=int, default=32, help="Number of floyd transformer layers") parser.add_argument("--dropout", type=float, default=0.15) - parser.add_argument("--gcn", action="store_true", help="Use GCN model instead of FloydNet") + parser.add_argument("--arch", type=str, default="floydnet", choices=["floydnet", "gcn"], help="Model architecture to use") # training parser.add_argument("--max_train_epoch_len", type=int, default=400, help="Maximum length of training epoch in number of samples") parser.add_argument("--max_val_epoch_len", type=int, default=1000, help="Maximum length of validation epoch in number of samples") @@ -84,7 +84,7 @@ def build_dataloader(args): def build_model(args): - if args.gcn: + if args.arch == "gcn": print0("Using GCN model for LRGB task") model_config = GCNContactConfig(dense_repr=args.dense_repr) model = GCNContactModel(model_config).to("cuda") @@ -140,7 +140,7 @@ def build_model(args): ) print0(f"Resumed from epoch {start_epoch}") - if is_master_process() and not args.gcn: + if is_master_process() and args.arch == "floydnet": # save model config for loading model later output_dir = Path(args.output_dir) / "checkpoints" os.makedirs(output_dir, exist_ok=True) diff --git a/example/common/model.py b/example/common/model.py index 8a1d4c2..2aaa1f0 100644 --- a/example/common/model.py +++ b/example/common/model.py @@ -117,7 +117,7 @@ def __init__( if config.enable_adj_emb: self.emb_adj = nn.Embedding(2, config.n_embd) if config.n_edge_feat > 0 and config.edge_feat_vocab_size > 0: - if config.n_edge_feat != 1: + if config.n_edge_feat > 1: self.emb_edge = nn.ModuleList([nn.Embedding(config.edge_feat_vocab_size, config.n_embd) for _ in range(config.n_edge_feat)]) else: self.emb_edge = nn.Embedding(config.edge_feat_vocab_size, config.n_embd) diff --git a/example/common/trainer.py b/example/common/trainer.py index 9c176f8..6c65cdd 100644 --- a/example/common/trainer.py +++ b/example/common/trainer.py @@ -118,7 +118,7 @@ def fit( metrics = self.eval_fn(dataloader["val"], "Val", epoch, metric=metric) min_val_loss = min(min_val_loss, metrics["loss"]) logger.log({f"val/{k}": v for k, v in metrics.items()}) - print("val metrics:", metrics) + print0(f"val metrics: {metrics}") self.scheduler.step(loss) if (epoch + 1) % self.save_interval == 0 or epoch == self.max_epochs - 1: save_checkpoint(self.model, self.optimizer, self.scheduler, epoch, loss, self.out_dir / "checkpoints") @@ -126,7 +126,7 @@ def fit( should_test_this_epoch = ((epoch + 1) % self.test_interval == 0 or epoch == self.max_epochs - 1) if should_test_this_epoch: metrics = self.eval_fn(dataloader["test"], "Test", epoch, metric=metric) - print("test metrics:", metrics) + print0(f"test metrics: {metrics}") min_test_loss = min(min_test_loss, metrics["loss"]) logger.log({f"test/{k}": v for k, v in metrics.items()}) @@ -138,6 +138,7 @@ def test( dataloader, epoch, sample_count_per_case=1, + metric=None, ): loss = self.eval_fn(dataloader["test"], "Test", epoch, sample_count_per_case=sample_count_per_case, metric=metric) return loss From c98fffc3f5bd46a313c8c601d84e248eab074744 Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Thu, 29 Jan 2026 20:56:54 +0800 Subject: [PATCH 6/8] update comment --- example/LRGB/mrr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/LRGB/mrr.py b/example/LRGB/mrr.py index b5dc502..f8e17ac 100644 --- a/example/LRGB/mrr.py +++ b/example/LRGB/mrr.py @@ -79,7 +79,7 @@ def compute(self): mrr_list = _eval_mrr(pred_pos, pred_neg) for i, v in enumerate(mrr_list): batch_stats[i].append(v) - # average among all graphs + # sum among all graphs, will do average outside the metric res = [] for i in range(4): v = torch.tensor(batch_stats[i]) From df78fd751784d2f5cb52ab4c10f36c51996cbad0 Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Thu, 29 Jan 2026 21:03:38 +0800 Subject: [PATCH 7/8] update readme --- example/README.md | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/example/README.md b/example/README.md index f6b8623..b285085 100644 --- a/example/README.md +++ b/example/README.md @@ -1,10 +1,11 @@ ### Benchmarks -The paper reports results on **three benchmarks**: +The paper reports results on **four benchmarks**: - Graph Count - BREC - TSP +- LRGB ## 🚀 Key Results @@ -134,4 +135,21 @@ torchrun \ --wandb_name TSP_exp ``` ---- \ No newline at end of file +--- + +### LRGB + +The LRGB benchmark and dataset construction follow: +https://github.com/vijaydwivedi75/lrgb + +#### PCQM-Contact + +```bash +source .venv/bin/activate +cd example +torchrun \ + --nproc_per_node=8 \ + -m LRGB.run \ + --name pcqm-contact \ + --wandb_name LRGB_pcqm-contact +``` From 9df38ac03c96de06559c9e872de285b39c194669 Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Thu, 29 Jan 2026 21:07:03 +0800 Subject: [PATCH 8/8] update --- example/LRGB/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/LRGB/run.py b/example/LRGB/run.py index 8d06653..e079236 100644 --- a/example/LRGB/run.py +++ b/example/LRGB/run.py @@ -86,7 +86,7 @@ def build_dataloader(args): def build_model(args): if args.arch == "gcn": print0("Using GCN model for LRGB task") - model_config = GCNContactConfig(dense_repr=args.dense_repr) + model_config = GCNContactConfig() model = GCNContactModel(model_config).to("cuda") print0(model) else: