Skip to content

Add DrugBank metadata extraction and W1-W3 runners#9

Open
yananlong wants to merge 3 commits into
mainfrom
w1-w3-controller
Open

Add DrugBank metadata extraction and W1-W3 runners#9
yananlong wants to merge 3 commits into
mainfrom
w1-w3-controller

Conversation

@yananlong
Copy link
Copy Markdown
Owner

Summary

  • add DrugBank metadata extraction script with expanded columns
  • add W1-W3 fingerprint benchmark runner (GBDT + FP-MLP, splits, drift metrics)
  • add controller to orchestrate W1-W3 runs across seeds/FP settings

Testing

  • not run (compute-heavy)

Copilot AI review requested due to automatic review settings February 11, 2026 02:15
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds new data-prep and experiment-running scripts to support W1–W3 drift benchmarking for fingerprint-based DDI models, including DrugBank-derived metadata/date fields and an orchestrator for multi-seed / FP settings sweeps.

Changes:

  • Add extract_drugbank_metadata.py to parse DrugBank XML and emit enriched TSV metadata (marketing dates, regulatory flags, external IDs, optional scaffolds).
  • Add run_w1_w3_fp_bench.py to run W1–W3 split benchmarks (GBDT + FP-MLP) and drift metrics (entropy + MMD) over temporal windows.
  • Add run_w1_w3_controller.py to orchestrate repeated benchmark runs across seeds / Morgan FP settings.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 9 comments.

File Description
scripts/run_w1_w3_fp_bench.py Implements W1–W3 benchmark runner, feature construction, split strategies, FP-MLP training, and drift metrics/plots.
scripts/run_w1_w3_controller.py Adds a sweep controller to run the benchmark script across parameter grids and manage output directories.
scripts/extract_drugbank_metadata.py Adds DrugBank XML extraction into metadata TSVs used by temporal/scaffold-based sampling.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +594 to +603
def _mmd_rbf(x: np.ndarray, y: np.ndarray, gamma: float) -> float:
x_sq = np.sum(x * x, axis=1, keepdims=True)
y_sq = np.sum(y * y, axis=1, keepdims=True)
dist = x_sq + y_sq.T - 2 * x @ y.T
kxx = np.exp(-gamma * (x_sq + x_sq.T - 2 * x @ x.T))
kyy = np.exp(-gamma * (y_sq + y_sq.T - 2 * y @ y.T))
kxy = np.exp(-gamma * dist)
return float(kxx.mean() + kyy.mean() - 2 * kxy.mean())


Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_mmd_rbf materializes full NxN kernel matrices (kxx, kyy) and an NxM distance matrix (dist). With the default mmd_sample=3000 this can allocate multiple ~3000x3000 float64 arrays (hundreds of MB) and may OOM/slow significantly. Consider computing MMD in chunks and/or using float32 (or reducing the default sample) to keep memory bounded.

Suggested change
def _mmd_rbf(x: np.ndarray, y: np.ndarray, gamma: float) -> float:
x_sq = np.sum(x * x, axis=1, keepdims=True)
y_sq = np.sum(y * y, axis=1, keepdims=True)
dist = x_sq + y_sq.T - 2 * x @ y.T
kxx = np.exp(-gamma * (x_sq + x_sq.T - 2 * x @ x.T))
kyy = np.exp(-gamma * (y_sq + y_sq.T - 2 * y @ y.T))
kxy = np.exp(-gamma * dist)
return float(kxx.mean() + kyy.mean() - 2 * kxy.mean())
def _mmd_rbf(x: np.ndarray, y: np.ndarray, gamma: float, chunk_size: int = 512) -> float:
"""Compute RBF-kernel MMD^2 between x and y using chunked kernel means.
This avoids materializing full NxN / NxM kernel matrices to keep memory bounded.
The computation includes diagonal terms, matching the original implementation.
"""
def _kernel_mean(a: np.ndarray, b: np.ndarray) -> float:
n_a = a.shape[0]
n_b = b.shape[0]
# Accumulate sum of kernel values and total number of pairs
total = 0.0
count = 0
for i in range(0, n_a, chunk_size):
a_block = a[i : i + chunk_size]
a_sq = np.sum(a_block * a_block, axis=1, keepdims=True)
for j in range(0, n_b, chunk_size):
b_block = b[j : j + chunk_size]
b_sq = np.sum(b_block * b_block, axis=1, keepdims=True)
# Squared Euclidean distances for the block
dist_block = a_sq + b_sq.T - 2.0 * (a_block @ b_block.T)
k_block = np.exp(-gamma * dist_block)
total += float(k_block.sum())
count += k_block.size
if count == 0:
return 0.0
return total / float(count)
kxx_mean = _kernel_mean(x, x)
kyy_mean = _kernel_mean(y, y)
kxy_mean = _kernel_mean(x, y)
return float(kxx_mean + kyy_mean - 2.0 * kxy_mean)

Copilot uses AI. Check for mistakes.
Comment on lines +7 to +9
import math
import random
import sys
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

math and sys are imported but not used in this script. Removing unused imports avoids lint noise and makes optional-dependency handling a bit clearer.

Suggested change
import math
import random
import sys
import random

Copilot uses AI. Check for mistakes.
Comment on lines +175 to +211
match = re.search(
r"(Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|"
r"Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|"
r"Dec(?:ember)?)\\s+(\\d{1,2}),?\\s+(\\d{4})",
value,
re.IGNORECASE,
)
if match:
month_map = {
"jan": 1,
"feb": 2,
"mar": 3,
"apr": 4,
"may": 5,
"jun": 6,
"jul": 7,
"aug": 8,
"sep": 9,
"oct": 10,
"nov": 11,
"dec": 12,
}
month = match.group(1).lower()[:3]
day = int(match.group(2))
year = int(match.group(3))
month_num = month_map.get(month)
if month_num:
try:
return date(year, month_num, day)
except ValueError:
return None
return None


def _extract_year(value: str) -> str:
match = re.search(r"(19|20)\\d{2}", value)
return match.group(0) if match else ""
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex patterns in _parse_flexible_date are raw strings but use double-escaped sequences like \\s and \\d. In a raw string this makes the regex match a literal \s / \d instead of whitespace/digits, so month-name dates and year extraction will fail. Use single backslashes (\s, \d) in the raw string patterns.

Copilot uses AI. Check for mistakes.
Comment on lines +481 to +494
class _PairDataset(Dataset):
def __init__(self, fp_matrix: np.ndarray, idx1: np.ndarray, idx2: np.ndarray, labels: np.ndarray):
self.fp_matrix = fp_matrix
self.idx1 = idx1
self.idx2 = idx2
self.labels = labels

def __len__(self) -> int:
return len(self.labels)

def __getitem__(self, idx: int):
fp1 = self.fp_matrix[self.idx1[idx]]
fp2 = self.fp_matrix[self.idx2[idx]]
return torch.from_numpy(fp1), torch.from_numpy(fp2), torch.tensor(self.labels[idx])
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_PairDataset is defined unconditionally as class _PairDataset(Dataset), but when torch isn't installed the optional import sets Dataset = None, causing an import-time TypeError (and preventing even GBDT-only runs). Guard this class definition behind a if Dataset is not None: block (and/or provide a stub) so the script can still be imported/executed when torch is absent.

Suggested change
class _PairDataset(Dataset):
def __init__(self, fp_matrix: np.ndarray, idx1: np.ndarray, idx2: np.ndarray, labels: np.ndarray):
self.fp_matrix = fp_matrix
self.idx1 = idx1
self.idx2 = idx2
self.labels = labels
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, idx: int):
fp1 = self.fp_matrix[self.idx1[idx]]
fp2 = self.fp_matrix[self.idx2[idx]]
return torch.from_numpy(fp1), torch.from_numpy(fp2), torch.tensor(self.labels[idx])
if Dataset is not None:
class _PairDataset(Dataset):
def __init__(self, fp_matrix: np.ndarray, idx1: np.ndarray, idx2: np.ndarray, labels: np.ndarray):
self.fp_matrix = fp_matrix
self.idx1 = idx1
self.idx2 = idx2
self.labels = labels
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, idx: int):
fp1 = self.fp_matrix[self.idx1[idx]]
fp2 = self.fp_matrix[self.idx2[idx]]
return torch.from_numpy(fp1), torch.from_numpy(fp2), torch.tensor(self.labels[idx])
else: # pragma: no cover - optional dependency
class _PairDataset:
def __init__(self, *args, **kwargs) -> None:
raise RuntimeError("Missing dependency 'torch'. Install it to use the FP-MLP models.")

Copilot uses AI. Check for mistakes.
Comment on lines +755 to +758
mlp_model = _train_mlp(
fp_matrix, train_df, val_df, fp_index, num_classes, args, fpmlp_device
)
mlp_probs = _predict_mlp(
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmark always trains the FP-MLP (_train_mlp(...)) for every split, even when torch is not installed (the imports are treated as optional earlier). As written, missing torch will raise at runtime and prevents running GBDT-only benchmarks. Consider skipping FP-MLP training/prediction when torch (or the selected MLP implementation) is unavailable, or add an explicit CLI flag to disable the MLP path.

Copilot uses AI. Check for mistakes.
_require_module(torch, "torch")
model_class = _resolve_mlp_class(args)
if model_class is None:
raise RuntimeError("Missing FP-MLP dependencies; install torch and torch-geometric.")
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error raised when the FP-MLP class can't be resolved says to install both torch and torch-geometric, but this script also supports the simple fallback MLP which only needs torch. Updating the message (or branching it based on the selected/available implementation) would make failures easier to diagnose.

Suggested change
raise RuntimeError("Missing FP-MLP dependencies; install torch and torch-geometric.")
# Provide a more precise error message based on the requested implementation
if args.mlp_impl == "simple":
raise RuntimeError(
"Missing FP-MLP dependency for simple implementation: 'torch'. "
"Install torch or select a different MLP implementation."
)
elif args.mlp_impl == "lightning":
raise RuntimeError(
"Missing FP-MLP dependencies for Lightning implementation. "
"Ensure that 'torch', 'torch-geometric', and the 'GPU.models' package "
"are installed, or select '--mlp-impl simple' to use the fallback MLP."
)
else: # auto
missing_parts = []
if LightningFPMLP is None:
missing_parts.append(
"Lightning FP-MLP (requires torch, torch-geometric, and GPU.models)"
)
if _FallbackFPMLP is None:
missing_parts.append("simple FP-MLP (requires torch)")
details = "; ".join(missing_parts) if missing_parts else "FP-MLP implementations"
raise RuntimeError(
f"Could not resolve any FP-MLP implementation: {details}. "
"Install the necessary dependencies or specify "
"'--mlp-impl simple' or '--mlp-impl lightning'."
)

Copilot uses AI. Check for mistakes.
test_idx2 = test_df["Drug2"].map(fp_index).to_numpy()

x_train = _build_features(fp_matrix, train_idx1, train_idx2)
x_val = _build_features(fp_matrix, val_idx1, val_idx2)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable x_val is not used.

Suggested change
x_val = _build_features(fp_matrix, val_idx1, val_idx2)

Copilot uses AI. Check for mistakes.
x_test = _build_features(fp_matrix, test_idx1, test_idx2)

y_train = train_df["Label"].to_numpy()
y_val = val_df["Label"].to_numpy()
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable y_val is not used.

Suggested change
y_val = val_df["Label"].to_numpy()

Copilot uses AI. Check for mistakes.
import random
import sys
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Iterable' is not used.

Suggested change
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

Copilot uses AI. Check for mistakes.
@yananlong
Copy link
Copy Markdown
Owner Author

@copilot another issue is that the new functionalities were not implemented for the TPU codebase?

Copy link
Copy Markdown

Copilot AI commented Feb 11, 2026

@yananlong I've opened a new pull request, #10, to work on those changes. Once the pull request is ready, I'll request review from you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants