Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,4 @@ example/outputs
count.out
run_scripts
example/data/TSP
example/data/BREC/
92 changes: 92 additions & 0 deletions example/LRGB/criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class FocalLoss(nn.Module):
def __init__(
self,
gamma=2,
alpha=None,
reduction='mean',
task_type='binary',
):
"""
Unified Focal Loss class for binary, multi-class, and multi-label classification tasks.
:param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma
:param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used.
:param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum'
:param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label'
:param num_classes: Number of classes (only required for multi-class classification)
"""
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.task_type = task_type

def forward(self, inputs, graph):
"""
Forward pass to compute the Focal Loss based on the specified task type.
:param inputs: Predictions (logits) from the model.
Shape:
- binary/multi-label: (batch_size, num_classes)
- multi-class: (batch_size, num_classes)
:param targets: Ground truth labels.
Shape:
- binary: (batch_size,)
- multi-label: (batch_size, num_classes)
- multi-class: (batch_size,)
"""
if self.task_type == 'binary':
return self.binary_focal_loss(inputs, graph)
else:
raise ValueError(
f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'.")

def binary_focal_loss(self, inputs, graph):
""" Focal loss for binary classification. """
inputs = inputs[-1] # get the edge prediction
inputs = inputs.squeeze(-1)

targets = graph.adj_label
mask = graph.pair_mask

probs = torch.sigmoid(inputs).clamp(min=1e-10)
targets = targets.float()

# Compute binary cross entropy
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')

# Compute focal weight
p_t = probs * targets + (1 - probs) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma

# Apply alpha if provided
if self.alpha is not None:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
bce_loss = alpha_t * bce_loss

# Apply focal loss weighting
loss = focal_weight * bce_loss
loss = loss * mask

if self.reduction == 'mean':
return loss.sum() / mask.sum().clamp(min=1), {}
elif self.reduction == 'sum':
return loss.sum(), {}
return loss, {}
23 changes: 23 additions & 0 deletions example/LRGB/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

def transform_pcqm(graph):
# format Data(x=[36, 9], edge_index=[2, 72], edge_attr=[72, 3], edge_label_index=[2, 84], edge_label=[84])
# add supernode at the beginning
graph.x = torch.nn.functional.pad(graph.x, (0, 0, 1, 0))
graph.edge_index = graph.edge_index + 1
graph.edge_label_index = graph.edge_label_index + 1
return graph
147 changes: 147 additions & 0 deletions example/LRGB/gcn_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from common.graph import graph_preprocess

try:
from ogb.graphproppred.mol_encoder import AtomEncoder
except ImportError as e:
AtomEncoder = None



@dataclass
class GCNContactConfig:
layers_mp: int = 5
layers_post_mp: int = 1
dim_inner: int = 275
dropout: float = 0.0
batchnorm: bool = True
act: str = "relu"
agg: str = "mean"
edge_decoding: str = "dot"
gcn_add_self_loops: bool = True
gcn_normalize: bool = True


class MLPNoAct(nn.Module):
def __init__(self, dim: int, num_layers: int):
super().__init__()
assert num_layers >= 1
layers = []
for _ in range(num_layers):
layers.append(nn.Linear(dim, dim, bias=True))
self.net = nn.Sequential(*layers)
self.reset_parameters()

def reset_parameters(self):
for m in self.net.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


class GCNContactModel(nn.Module):
def __init__(self, cfg: Optional[GCNContactConfig] = None):
super().__init__()
self.cfg = cfg or GCNContactConfig()
dim = self.cfg.dim_inner

if AtomEncoder is None:
raise ImportError(
"ogb is required to match LRGB Atom encoder. "
"Please `pip install ogb` in your environment."
)
self.node_encoder = AtomEncoder(emb_dim=dim)

# Message passing stack
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for _ in range(self.cfg.layers_mp):
self.convs.append(
GCNConv(
in_channels=dim,
out_channels=dim,
improved=False,
cached=False,
add_self_loops=self.cfg.gcn_add_self_loops,
normalize=self.cfg.gcn_normalize,
bias=True,
)
)
if self.cfg.batchnorm:
self.bns.append(nn.BatchNorm1d(dim))

self.post_mp = MLPNoAct(dim=dim, num_layers=self.cfg.layers_post_mp)

self.reset_parameters()

def reset_parameters(self):
if hasattr(self.node_encoder, "reset_parameters"):
self.node_encoder.reset_parameters()

for i, conv in enumerate(self.convs):
conv.reset_parameters()
if self.cfg.batchnorm:
self.bns[i].reset_parameters()

self.post_mp.reset_parameters()

def _encode_nodes(self, data: Data) -> torch.Tensor:
if data.x is None:
raise ValueError("data.x is required for PCQM4Mv2Contact AtomEncoder.")
if data.x.dtype != torch.long:
raise TypeError(f"Expected data.x dtype torch.long, got {data.x.dtype}")
return self.node_encoder(data.x)

def _mp(self, h: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
for i, conv in enumerate(self.convs):
h = conv(h, edge_index)
if self.cfg.batchnorm:
h = self.bns[i](h)
h = F.relu(h)
if self.cfg.dropout > 0:
h = F.dropout(h, p=self.cfg.dropout, training=self.training)
return h

def _decode_edges_dot(self, h: torch.Tensor, edge_label_index: torch.Tensor) -> torch.Tensor:
# edge_label_index: [2, E_pred]
src, dst = edge_label_index[0], edge_label_index[1]
# single logit per edge (matches dot decoding head expectation)
return (h[src] * h[dst]).sum(dim=-1, keepdim=True) # [E_pred, 1]

def preprocess(self, data):
return data

def forward(self, data: Data) -> Tuple[None, None, torch.Tensor]:
if data.edge_index is None:
raise ValueError("data.edge_index is required.")
if not hasattr(data, "edge_label_index") or data.edge_label_index is None:
raise ValueError("data.edge_label_index [2, E_pred] is required for edge prediction.")

h = self._encode_nodes(data)
h = self._mp(h, data.edge_index)

# Head post-mp on nodes, then dot decode
h = self.post_mp(h)

data.x = h
data = graph_preprocess(data, supernode=False)
logits = data.x @ data.x.transpose(1, 2)
logits = logits.unsqueeze(-1) # [B, N, N, 1]
return None, None, logits
# original impl with edge decoding:
# logits = self._decode_edges_dot(h, data.edge_label_index)
# data.x = h
# return None, None, logits
95 changes: 95 additions & 0 deletions example/LRGB/mrr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 Beijing Academy of Artificial Intelligence (BAAI)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

import torch

def _eval_mrr(y_pred_pos, y_pred_neg):
""" Compute Hits@k and Mean Reciprocal Rank (MRR).

Implementation from OGB:
https://github.com/snap-stanford/ogb/blob/master/ogb/linkproppred/evaluate.py

Args:
y_pred_neg: array with shape (batch size, num_entities_neg).
y_pred_pos: array with shape (batch size, )
"""

y_pred = torch.cat([y_pred_pos.view(-1, 1), y_pred_neg], dim=1)
argsort = torch.argsort(y_pred, dim=1, descending=True)
ranking_list = torch.nonzero(argsort == 0, as_tuple=False)
ranking_list = ranking_list[:, 1] + 1
# average within graph
hits1 = (ranking_list <= 1).to(torch.float).mean().item()
hits3 = (ranking_list <= 3).to(torch.float).mean().item()
hits10 = (ranking_list <= 10).to(torch.float).mean().item()
mrr = (1. / ranking_list.to(torch.float)).mean().item()

# print(f"hits@1 {hits1:.5f}")
# print(f"hits@3 {hits3:.5f}")
# print(f"hits@10 {hits10:.5f}")
# print(f"mrr {mrr:.5f}")
return hits1, hits3, hits10, mrr


class EdgeMRR:
def __init__(self):
self.states = defaultdict(lambda: [])

def clean(self):
self.states = defaultdict(lambda: [])

def add_batch(self, pred, graph_batch):
pred = pred[-1]
for b in range(graph_batch.single_mask.shape[0]):
indices = torch.where(graph_batch.single_mask[b])[0]
self.states["preds"].append(pred[b][indices][:, indices])
self.states["trues"].append(graph_batch.adj_label[b][indices][:, indices])

def compute(self):
# pred: list of [n, n]
# true: list of [n, n]
pred_list = self.states["preds"]
true_list = self.states["trues"]
batch_stats = [[], [], [], []]
for pred, true in zip(pred_list, true_list):
n = pred.shape[0]
pos_edge_index = torch.where(true == 1)
pred_pos = pred[pos_edge_index]
num_pos_edges = pos_edge_index[0].shape[0]
if num_pos_edges == 0:
continue

neg_mask = torch.ones([num_pos_edges, n], dtype=torch.bool)
neg_mask[torch.arange(num_pos_edges), pos_edge_index[1]] = False
pred_neg = pred[pos_edge_index[0]][neg_mask].view(num_pos_edges, -1)

mrr_list = _eval_mrr(pred_pos, pred_neg)
for i, v in enumerate(mrr_list):
batch_stats[i].append(v)
# sum among all graphs, will do average outside the metric
res = []
for i in range(4):
v = torch.tensor(batch_stats[i])
v = torch.nan_to_num(v, nan=0).sum().item()
res.append(v)

return {
'hits@1': res[0],
'hits@3': res[1],
'hits@10': res[2],
'mrr': res[3],
"sample_count": len(pred_list),
}
Loading
Loading