Add DrugBank metadata extraction and W1-W3 runners#9
Conversation
There was a problem hiding this comment.
Pull request overview
Adds new data-prep and experiment-running scripts to support W1–W3 drift benchmarking for fingerprint-based DDI models, including DrugBank-derived metadata/date fields and an orchestrator for multi-seed / FP settings sweeps.
Changes:
- Add
extract_drugbank_metadata.pyto parse DrugBank XML and emit enriched TSV metadata (marketing dates, regulatory flags, external IDs, optional scaffolds). - Add
run_w1_w3_fp_bench.pyto run W1–W3 split benchmarks (GBDT + FP-MLP) and drift metrics (entropy + MMD) over temporal windows. - Add
run_w1_w3_controller.pyto orchestrate repeated benchmark runs across seeds / Morgan FP settings.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| scripts/run_w1_w3_fp_bench.py | Implements W1–W3 benchmark runner, feature construction, split strategies, FP-MLP training, and drift metrics/plots. |
| scripts/run_w1_w3_controller.py | Adds a sweep controller to run the benchmark script across parameter grids and manage output directories. |
| scripts/extract_drugbank_metadata.py | Adds DrugBank XML extraction into metadata TSVs used by temporal/scaffold-based sampling. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _mmd_rbf(x: np.ndarray, y: np.ndarray, gamma: float) -> float: | ||
| x_sq = np.sum(x * x, axis=1, keepdims=True) | ||
| y_sq = np.sum(y * y, axis=1, keepdims=True) | ||
| dist = x_sq + y_sq.T - 2 * x @ y.T | ||
| kxx = np.exp(-gamma * (x_sq + x_sq.T - 2 * x @ x.T)) | ||
| kyy = np.exp(-gamma * (y_sq + y_sq.T - 2 * y @ y.T)) | ||
| kxy = np.exp(-gamma * dist) | ||
| return float(kxx.mean() + kyy.mean() - 2 * kxy.mean()) | ||
|
|
||
|
|
There was a problem hiding this comment.
_mmd_rbf materializes full NxN kernel matrices (kxx, kyy) and an NxM distance matrix (dist). With the default mmd_sample=3000 this can allocate multiple ~3000x3000 float64 arrays (hundreds of MB) and may OOM/slow significantly. Consider computing MMD in chunks and/or using float32 (or reducing the default sample) to keep memory bounded.
| def _mmd_rbf(x: np.ndarray, y: np.ndarray, gamma: float) -> float: | |
| x_sq = np.sum(x * x, axis=1, keepdims=True) | |
| y_sq = np.sum(y * y, axis=1, keepdims=True) | |
| dist = x_sq + y_sq.T - 2 * x @ y.T | |
| kxx = np.exp(-gamma * (x_sq + x_sq.T - 2 * x @ x.T)) | |
| kyy = np.exp(-gamma * (y_sq + y_sq.T - 2 * y @ y.T)) | |
| kxy = np.exp(-gamma * dist) | |
| return float(kxx.mean() + kyy.mean() - 2 * kxy.mean()) | |
| def _mmd_rbf(x: np.ndarray, y: np.ndarray, gamma: float, chunk_size: int = 512) -> float: | |
| """Compute RBF-kernel MMD^2 between x and y using chunked kernel means. | |
| This avoids materializing full NxN / NxM kernel matrices to keep memory bounded. | |
| The computation includes diagonal terms, matching the original implementation. | |
| """ | |
| def _kernel_mean(a: np.ndarray, b: np.ndarray) -> float: | |
| n_a = a.shape[0] | |
| n_b = b.shape[0] | |
| # Accumulate sum of kernel values and total number of pairs | |
| total = 0.0 | |
| count = 0 | |
| for i in range(0, n_a, chunk_size): | |
| a_block = a[i : i + chunk_size] | |
| a_sq = np.sum(a_block * a_block, axis=1, keepdims=True) | |
| for j in range(0, n_b, chunk_size): | |
| b_block = b[j : j + chunk_size] | |
| b_sq = np.sum(b_block * b_block, axis=1, keepdims=True) | |
| # Squared Euclidean distances for the block | |
| dist_block = a_sq + b_sq.T - 2.0 * (a_block @ b_block.T) | |
| k_block = np.exp(-gamma * dist_block) | |
| total += float(k_block.sum()) | |
| count += k_block.size | |
| if count == 0: | |
| return 0.0 | |
| return total / float(count) | |
| kxx_mean = _kernel_mean(x, x) | |
| kyy_mean = _kernel_mean(y, y) | |
| kxy_mean = _kernel_mean(x, y) | |
| return float(kxx_mean + kyy_mean - 2.0 * kxy_mean) |
| import math | ||
| import random | ||
| import sys |
There was a problem hiding this comment.
math and sys are imported but not used in this script. Removing unused imports avoids lint noise and makes optional-dependency handling a bit clearer.
| import math | |
| import random | |
| import sys | |
| import random |
| match = re.search( | ||
| r"(Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|" | ||
| r"Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|" | ||
| r"Dec(?:ember)?)\\s+(\\d{1,2}),?\\s+(\\d{4})", | ||
| value, | ||
| re.IGNORECASE, | ||
| ) | ||
| if match: | ||
| month_map = { | ||
| "jan": 1, | ||
| "feb": 2, | ||
| "mar": 3, | ||
| "apr": 4, | ||
| "may": 5, | ||
| "jun": 6, | ||
| "jul": 7, | ||
| "aug": 8, | ||
| "sep": 9, | ||
| "oct": 10, | ||
| "nov": 11, | ||
| "dec": 12, | ||
| } | ||
| month = match.group(1).lower()[:3] | ||
| day = int(match.group(2)) | ||
| year = int(match.group(3)) | ||
| month_num = month_map.get(month) | ||
| if month_num: | ||
| try: | ||
| return date(year, month_num, day) | ||
| except ValueError: | ||
| return None | ||
| return None | ||
|
|
||
|
|
||
| def _extract_year(value: str) -> str: | ||
| match = re.search(r"(19|20)\\d{2}", value) | ||
| return match.group(0) if match else "" |
There was a problem hiding this comment.
The regex patterns in _parse_flexible_date are raw strings but use double-escaped sequences like \\s and \\d. In a raw string this makes the regex match a literal \s / \d instead of whitespace/digits, so month-name dates and year extraction will fail. Use single backslashes (\s, \d) in the raw string patterns.
| class _PairDataset(Dataset): | ||
| def __init__(self, fp_matrix: np.ndarray, idx1: np.ndarray, idx2: np.ndarray, labels: np.ndarray): | ||
| self.fp_matrix = fp_matrix | ||
| self.idx1 = idx1 | ||
| self.idx2 = idx2 | ||
| self.labels = labels | ||
|
|
||
| def __len__(self) -> int: | ||
| return len(self.labels) | ||
|
|
||
| def __getitem__(self, idx: int): | ||
| fp1 = self.fp_matrix[self.idx1[idx]] | ||
| fp2 = self.fp_matrix[self.idx2[idx]] | ||
| return torch.from_numpy(fp1), torch.from_numpy(fp2), torch.tensor(self.labels[idx]) |
There was a problem hiding this comment.
_PairDataset is defined unconditionally as class _PairDataset(Dataset), but when torch isn't installed the optional import sets Dataset = None, causing an import-time TypeError (and preventing even GBDT-only runs). Guard this class definition behind a if Dataset is not None: block (and/or provide a stub) so the script can still be imported/executed when torch is absent.
| class _PairDataset(Dataset): | |
| def __init__(self, fp_matrix: np.ndarray, idx1: np.ndarray, idx2: np.ndarray, labels: np.ndarray): | |
| self.fp_matrix = fp_matrix | |
| self.idx1 = idx1 | |
| self.idx2 = idx2 | |
| self.labels = labels | |
| def __len__(self) -> int: | |
| return len(self.labels) | |
| def __getitem__(self, idx: int): | |
| fp1 = self.fp_matrix[self.idx1[idx]] | |
| fp2 = self.fp_matrix[self.idx2[idx]] | |
| return torch.from_numpy(fp1), torch.from_numpy(fp2), torch.tensor(self.labels[idx]) | |
| if Dataset is not None: | |
| class _PairDataset(Dataset): | |
| def __init__(self, fp_matrix: np.ndarray, idx1: np.ndarray, idx2: np.ndarray, labels: np.ndarray): | |
| self.fp_matrix = fp_matrix | |
| self.idx1 = idx1 | |
| self.idx2 = idx2 | |
| self.labels = labels | |
| def __len__(self) -> int: | |
| return len(self.labels) | |
| def __getitem__(self, idx: int): | |
| fp1 = self.fp_matrix[self.idx1[idx]] | |
| fp2 = self.fp_matrix[self.idx2[idx]] | |
| return torch.from_numpy(fp1), torch.from_numpy(fp2), torch.tensor(self.labels[idx]) | |
| else: # pragma: no cover - optional dependency | |
| class _PairDataset: | |
| def __init__(self, *args, **kwargs) -> None: | |
| raise RuntimeError("Missing dependency 'torch'. Install it to use the FP-MLP models.") |
| mlp_model = _train_mlp( | ||
| fp_matrix, train_df, val_df, fp_index, num_classes, args, fpmlp_device | ||
| ) | ||
| mlp_probs = _predict_mlp( |
There was a problem hiding this comment.
The benchmark always trains the FP-MLP (_train_mlp(...)) for every split, even when torch is not installed (the imports are treated as optional earlier). As written, missing torch will raise at runtime and prevents running GBDT-only benchmarks. Consider skipping FP-MLP training/prediction when torch (or the selected MLP implementation) is unavailable, or add an explicit CLI flag to disable the MLP path.
| _require_module(torch, "torch") | ||
| model_class = _resolve_mlp_class(args) | ||
| if model_class is None: | ||
| raise RuntimeError("Missing FP-MLP dependencies; install torch and torch-geometric.") |
There was a problem hiding this comment.
The error raised when the FP-MLP class can't be resolved says to install both torch and torch-geometric, but this script also supports the simple fallback MLP which only needs torch. Updating the message (or branching it based on the selected/available implementation) would make failures easier to diagnose.
| raise RuntimeError("Missing FP-MLP dependencies; install torch and torch-geometric.") | |
| # Provide a more precise error message based on the requested implementation | |
| if args.mlp_impl == "simple": | |
| raise RuntimeError( | |
| "Missing FP-MLP dependency for simple implementation: 'torch'. " | |
| "Install torch or select a different MLP implementation." | |
| ) | |
| elif args.mlp_impl == "lightning": | |
| raise RuntimeError( | |
| "Missing FP-MLP dependencies for Lightning implementation. " | |
| "Ensure that 'torch', 'torch-geometric', and the 'GPU.models' package " | |
| "are installed, or select '--mlp-impl simple' to use the fallback MLP." | |
| ) | |
| else: # auto | |
| missing_parts = [] | |
| if LightningFPMLP is None: | |
| missing_parts.append( | |
| "Lightning FP-MLP (requires torch, torch-geometric, and GPU.models)" | |
| ) | |
| if _FallbackFPMLP is None: | |
| missing_parts.append("simple FP-MLP (requires torch)") | |
| details = "; ".join(missing_parts) if missing_parts else "FP-MLP implementations" | |
| raise RuntimeError( | |
| f"Could not resolve any FP-MLP implementation: {details}. " | |
| "Install the necessary dependencies or specify " | |
| "'--mlp-impl simple' or '--mlp-impl lightning'." | |
| ) |
| test_idx2 = test_df["Drug2"].map(fp_index).to_numpy() | ||
|
|
||
| x_train = _build_features(fp_matrix, train_idx1, train_idx2) | ||
| x_val = _build_features(fp_matrix, val_idx1, val_idx2) |
There was a problem hiding this comment.
Variable x_val is not used.
| x_val = _build_features(fp_matrix, val_idx1, val_idx2) |
| x_test = _build_features(fp_matrix, test_idx1, test_idx2) | ||
|
|
||
| y_train = train_df["Label"].to_numpy() | ||
| y_val = val_df["Label"].to_numpy() |
There was a problem hiding this comment.
Variable y_val is not used.
| y_val = val_df["Label"].to_numpy() |
| import random | ||
| import sys | ||
| from pathlib import Path | ||
| from typing import Dict, Iterable, List, Optional, Tuple |
There was a problem hiding this comment.
Import of 'Iterable' is not used.
| from typing import Dict, Iterable, List, Optional, Tuple | |
| from typing import Dict, List, Optional, Tuple |
|
@copilot another issue is that the new functionalities were not implemented for the TPU codebase? |
|
@yananlong I've opened a new pull request, #10, to work on those changes. Once the pull request is ready, I'll request review from you. |
Summary
Testing