diff --git a/benchmarks/bench_utils/__init__.py b/benchmarks/bench_utils/__init__.py index 13990f1f..1960f934 100644 --- a/benchmarks/bench_utils/__init__.py +++ b/benchmarks/bench_utils/__init__.py @@ -20,16 +20,18 @@ """ from bench_utils.loaders import load_pickle, load_sdf, load_smarts, load_smiles -from bench_utils.molprep import clone_mols_with_conformers, prep_mols +from bench_utils.molprep import clone_mols_with_conformers, embed_and_jitter, perturb_conformer, prep_mols from bench_utils.timing import TimingResult, time_it __all__ = [ "TimingResult", "clone_mols_with_conformers", + "embed_and_jitter", "load_pickle", "load_sdf", "load_smarts", "load_smiles", + "perturb_conformer", "prep_mols", "time_it", ] diff --git a/benchmarks/bench_utils/molprep.py b/benchmarks/bench_utils/molprep.py index 45f61af7..21b88481 100644 --- a/benchmarks/bench_utils/molprep.py +++ b/benchmarks/bench_utils/molprep.py @@ -15,7 +15,17 @@ """Molecule preparation helpers shared across nvMolKit benchmarks.""" +import random +from functools import partial + from rdkit import Chem +from rdkit.Chem import rdDistGeom +from rdkit.Geometry import Point3D +from tqdm.contrib.concurrent import process_map + +# Manually tuned so the per-conformer jitter recreates an ETKDGv3-like pairwise RMSD spread +JITTER_CENTER = 1.3 +JITTER_SPREAD = 0.6 def prep_mols( @@ -63,3 +73,107 @@ def clone_mols_with_conformers(mols: list[Chem.Mol]) -> list[Chem.RWMol]: pristine input. """ return [Chem.RWMol(mol) for mol in mols] + + +def perturb_conformer( + conf: Chem.Conformer, + seed: int, + center: float = JITTER_CENTER, + spread: float = JITTER_SPREAD, +) -> None: + """Apply per-atom uniform jitter to a conformer in place. + + A single half-width is drawn for the conformer as ``center * (1 + spread * + U(-1, 1))`` and every x/y/z coordinate is then shifted by ``U(-half_width, + half_width)``. Drawing a distinct half-width per conformer (each call uses + a distinct ``seed``) gives a jittered ensemble a range of pairwise RMSDs + rather than a single structure-independent value. + """ + rng = random.Random(seed) + half_width = max(0.0, center * (1.0 + spread * rng.uniform(-1.0, 1.0))) + for atom_idx in range(conf.GetNumAtoms()): + pos = conf.GetAtomPosition(atom_idx) + conf.SetAtomPosition( + atom_idx, + Point3D( + pos.x + rng.uniform(-half_width, half_width), + pos.y + rng.uniform(-half_width, half_width), + pos.z + rng.uniform(-half_width, half_width), + ), + ) + + +def _embed_one(args_tuple: tuple[int, bytes], seed: int, add_hs: bool, min_atoms: int) -> bytes | None: + """Embed a single ETKDGv3 conformer for one mol payload (multiprocessing worker).""" + idx, mol_bytes = args_tuple + mol = Chem.Mol(mol_bytes) + if mol.GetNumAtoms() < min_atoms: + return None + if add_hs: + mol = Chem.AddHs(mol) + params = rdDistGeom.ETKDGv3() + params.useRandomCoords = True + params.randomSeed = seed + idx + try: + conf_id = rdDistGeom.EmbedMolecule(mol, params=params) + except Exception: + return None + if conf_id < 0 or mol.GetNumConformers() == 0: + return None + if add_hs: + mol = Chem.RemoveHs(mol) + return mol.ToBinary() + + +def embed_and_jitter( + mols: list[Chem.Mol], + confs_per_mol: int, + seed: int, + num_workers: int = 1, + add_hs: bool = False, + min_atoms: int = 1, + desc: str = "Embedding base conformers", +) -> list[Chem.Mol]: + """Embed one ETKDGv3 base conformer per mol in parallel, then jitter to ``confs_per_mol``. + + The embed step runs across mols via ``process_map``; the jitter step is + in-process and serial (cheap). Mols whose base embedding fails are + dropped with a printed count. When ``add_hs`` is true, hydrogens are + added before embedding and stripped from the returned mol. + """ + if not mols: + return [] + if confs_per_mol < 1: + raise ValueError(f"confs_per_mol must be >= 1, got {confs_per_mol}") + + workers = max(1, num_workers) + binaries = [(i, mol.ToBinary()) for i, mol in enumerate(mols)] + embedded_binaries = process_map( + partial(_embed_one, seed=seed, add_hs=add_hs, min_atoms=min_atoms), + binaries, + max_workers=workers, + chunksize=max(1, len(binaries) // (workers * 8) or 1), + desc=desc, + ) + + out: list[Chem.Mol] = [] + drop_count = 0 + for raw in embedded_binaries: + if raw is None: + drop_count += 1 + continue + out.append(Chem.Mol(raw)) + if drop_count > 0: + print(f" Dropped {drop_count} molecules during embedding (no conformer generated)") + + if confs_per_mol > 1: + for mol_idx, mol in enumerate(out): + base_conf_id = mol.GetConformer().GetId() + base_conf = mol.GetConformer(base_conf_id) + for conf_idx in range(1, confs_per_mol): + new_conf = Chem.Conformer(base_conf) + perturb_conformer(new_conf, seed=seed + mol_idx * confs_per_mol + conf_idx) + mol.AddConformer(new_conf, assignId=True) + perturb_conformer(mol.GetConformer(base_conf_id), seed=seed + mol_idx * confs_per_mol) + + return out diff --git a/benchmarks/conformer_rmsd_bench.py b/benchmarks/conformer_rmsd_bench.py index 3353d07a..dbd688c5 100644 --- a/benchmarks/conformer_rmsd_bench.py +++ b/benchmarks/conformer_rmsd_bench.py @@ -29,6 +29,7 @@ import numpy as np import torch +from bench_utils import perturb_conformer from benchmark_timing import time_it from rdkit import Chem from rdkit.Chem import AllChem, rdDistGeom @@ -87,12 +88,19 @@ def run_benchmark(smiles, num_confs_list, seed=42): params = rdDistGeom.ETKDGv3() params.randomSeed = seed params.useRandomCoords = True - rdDistGeom.EmbedMultipleConfs(mol, numConfs=num_confs, params=params) - actual_confs = mol.GetNumConformers() - - if actual_confs < 2: + if rdDistGeom.EmbedMolecule(mol, params=params) < 0: print(f"{num_confs:>8} {'skipped (embedding failed)':>50}") continue + if num_confs < 2: + print(f"{num_confs:>8} {'skipped (need >= 2 confs for RMSD)':>50}") + continue + base_conf_id = mol.GetConformer().GetId() + for conf_idx in range(1, num_confs): + new_conf = Chem.Conformer(mol.GetConformer(base_conf_id)) + perturb_conformer(new_conf, seed=seed + conf_idx) + mol.AddConformer(new_conf, assignId=True) + perturb_conformer(mol.GetConformer(base_conf_id), seed=seed) + actual_confs = mol.GetNumConformers() no_h = Chem.RemoveHs(mol) n_pairs = actual_confs * (actual_confs - 1) // 2 diff --git a/benchmarks/ff_optimize_bench.py b/benchmarks/ff_optimize_bench.py index 7324c17c..0292f6a3 100644 --- a/benchmarks/ff_optimize_bench.py +++ b/benchmarks/ff_optimize_bench.py @@ -39,6 +39,7 @@ import torch from bench_utils import ( clone_mols_with_conformers, + embed_and_jitter, load_pickle, load_sdf, load_smiles, @@ -53,32 +54,6 @@ OPTUNA_AVAILABLE = nv_autotune.is_available() -def _embed_conformers(mols: list[Chem.Mol], confs_per_mol: int, seed: int) -> list[Chem.Mol]: - """Generate ``confs_per_mol`` conformers per molecule using RDKit ETKDGv3. - - Molecules where embedding fails to produce at least one conformer are - dropped; a count is printed. - """ - params = rdDistGeom.ETKDGv3() - params.useRandomCoords = True - params.randomSeed = seed - - embedded: list[Chem.Mol] = [] - drop_count = 0 - for mol in mols: - try: - conf_ids = rdDistGeom.EmbedMultipleConfs(mol, numConfs=confs_per_mol, params=params) - if not conf_ids: - drop_count += 1 - continue - embedded.append(mol) - except Exception: - drop_count += 1 - if drop_count > 0: - print(f" Dropped {drop_count} molecules during embedding (no conformer generated)") - return embedded - - def _flatten_energies(per_mol: list[list[float]]) -> list[float]: """Flatten ``[[e0, e1, ...], [e0, e1, ...], ...]`` returned by nvmolkit.""" flat: list[float] = [] @@ -398,7 +373,7 @@ def main() -> None: print(f" {len(mols)} molecules ready") print(f"\nEmbedding {args.confs_per_mol} conformer(s) per molecule with RDKit ETKDGv3...") - mols = _embed_conformers(mols, args.confs_per_mol, args.seed) + mols = embed_and_jitter(mols, args.confs_per_mol, seed=args.seed, num_workers=args.rdkit_threads) if not mols: print("Error: No molecules retained after embedding") sys.exit(1) diff --git a/benchmarks/tfd_bench.py b/benchmarks/tfd_bench.py index aab79a22..036fdc79 100644 --- a/benchmarks/tfd_bench.py +++ b/benchmarks/tfd_bench.py @@ -32,6 +32,7 @@ """ import argparse +import multiprocessing import os import pickle import sys @@ -40,8 +41,9 @@ import pandas as pd import torch +from bench_utils import embed_and_jitter, load_smiles from rdkit import Chem -from rdkit.Chem import AllChem, TorsionFingerprints +from rdkit.Chem import TorsionFingerprints import nvmolkit.tfd as nvmol_tfd @@ -72,30 +74,31 @@ def time_it(func, runs: int = 3, warmups: int = 1) -> Tuple[float, float]: return avg_time / 1.0e6, std_time / 1.0e6 # Return in milliseconds -def generate_conformers(mol: Chem.Mol, num_confs: int, seed: int = 42) -> Chem.Mol: - """Generate conformers for a molecule using ETKDG. - - Args: - mol: RDKit molecule - num_confs: Number of conformers to generate - seed: Random seed +def generate_conformers_batch( + mols: List[Chem.Mol], + num_confs: int, + seed: int = 42, + num_workers: int = 0, +) -> List[Chem.Mol]: + """Generate ``num_confs`` conformers per mol via embed-once-then-perturb. - Returns: - Molecule with conformers (or None if embedding failed) + Wraps the shared :func:`bench_utils.embed_and_jitter` with TFD-specific + constraints: requires ``num_confs >= 2`` (at least one torsion pair) and + drops mols with fewer than 4 atoms; hydrogens are added during embedding + and stripped from the returned mols. """ - mol = Chem.AddHs(mol) - params = AllChem.ETKDGv3() - params.randomSeed = seed - params.numThreads = 1 # Single-threaded for reproducibility - params.useRandomCoords = True - - conf_ids = AllChem.EmbedMultipleConfs(mol, numConfs=num_confs, params=params) - - if len(conf_ids) < 2: - return None - - mol = Chem.RemoveHs(mol) - return mol + if num_confs < 2: + raise ValueError(f"num_confs must be >= 2 for TFD, got {num_confs}") + workers = num_workers if num_workers > 0 else max(1, multiprocessing.cpu_count() // 2) + return embed_and_jitter( + mols, + confs_per_mol=num_confs, + seed=seed, + num_workers=workers, + add_hs=True, + min_atoms=4, + desc=f"Embedding base conformer (1/{num_confs})", + ) def _try_load_pickle(num_confs: int, max_mols: int, smiles_file: str = None) -> List[Chem.Mol]: @@ -117,15 +120,20 @@ def _try_load_pickle(num_confs: int, max_mols: int, smiles_file: str = None) -> def prepare_molecules( - smiles_list: List[str], num_confs: int, max_mols: int = 100, smiles_file: str = None + input_mols: List[Chem.Mol], + num_confs: int, + max_mols: int = 100, + smiles_file: str = None, + num_workers: int = 0, ) -> List[Chem.Mol]: """Prepare molecules with conformers, using precomputed pickle if available. Args: - smiles_list: List of SMILES strings (fallback if no pickle) + input_mols: Parsed RDKit molecules (used when no precomputed pickle is found). num_confs: Number of conformers per molecule max_mols: Maximum number of molecules to prepare - smiles_file: Path to SMILES CSV (used to locate pickle files) + smiles_file: Path to SMILES file (used to locate sibling pickle files) + num_workers: Parallel workers for ETKDG embedding (0 = auto, half of CPUs) Returns: List of molecules with conformers @@ -135,23 +143,15 @@ def prepare_molecules( return cached print(f" No precomputed pickle found, generating from scratch...") - mols = [] - for i, smi in enumerate(smiles_list): - if len(mols) >= max_mols: - break - - mol = Chem.MolFromSmiles(smi) - if mol is None: - continue - - if mol.GetNumAtoms() < 4: + candidates: List[Chem.Mol] = [] + for mol in input_mols: + if mol is None or mol.GetNumAtoms() < 4: continue + candidates.append(mol) + if len(candidates) >= max_mols: + break - mol_with_confs = generate_conformers(mol, num_confs, seed=42 + i) - if mol_with_confs is not None and mol_with_confs.GetNumConformers() >= 2: - mols.append(mol_with_confs) - - return mols + return generate_conformers_batch(candidates, num_confs, seed=42, num_workers=num_workers) def bench_rdkit_single(mol: Chem.Mol) -> None: @@ -224,26 +224,28 @@ def load_pkl_files(pkl_paths: List[str]) -> List[Chem.Mol]: def run_benchmarks( - smiles_list: List[str] | None = None, + input_mols: List[Chem.Mol] | None = None, skip_rdkit: bool = False, output_file: str = "tfd_results.csv", smiles_file: str = None, mol_counts: List[int] = None, conformer_counts: List[int] = None, preloaded_mols: List[Chem.Mol] | None = None, + num_workers: int = 0, ) -> pd.DataFrame: """Run TFD benchmarks with various configurations. Args: - smiles_list: List of SMILES strings (unused when preloaded_mols given) + input_mols: Parsed RDKit molecules without conformers (unused when preloaded_mols given). skip_rdkit: If True, skip RDKit benchmarks (faster for large runs) output_file: Output CSV file path - smiles_file: Path to SMILES CSV (used to locate pickle files) + smiles_file: Path to SMILES file (used to locate sibling pickle files) mol_counts: List of molecule counts to benchmark conformer_counts: List of conformer counts to benchmark preloaded_mols: Pre-loaded molecules with conformers (e.g. from --pkl-file). - When provided, smiles_list/smiles_file/conformer_counts are ignored and + When provided, input_mols/smiles_file/conformer_counts are ignored and the actual conformer count is read from the molecules. + num_workers: Parallel workers for ETKDG embedding (0 = auto, half of CPUs). Returns: DataFrame with benchmark results @@ -273,7 +275,11 @@ def run_benchmarks( else: print(f"\n--- Preparing molecules with {num_confs} conformers ---") all_mols = prepare_molecules( - smiles_list, num_confs, max_mols=max(mol_counts) + 20, smiles_file=smiles_file + input_mols, + num_confs, + max_mols=max(mol_counts) + 20, + smiles_file=smiles_file, + num_workers=num_workers, ) if len(all_mols) < max(mol_counts): @@ -316,7 +322,6 @@ def run_benchmarks( result["rdkit_time_ms"] = None result["rdkit_std_ms"] = None - # nvMolKit GPU list benchmark (return_type="list") try: t, s = time_it(lambda: bench_nvmol_gpu_list(mols)) result["nvmol_gpu_list_time_ms"] = t @@ -326,7 +331,6 @@ def run_benchmarks( print(f" nvMolKit GPU list failed: {e}") result["nvmol_gpu_list_time_ms"] = None - # nvMolKit GPU numpy benchmark (return_type="numpy") try: t, s = time_it(lambda: bench_nvmol_gpu_numpy(mols)) result["nvmol_gpu_numpy_time_ms"] = t @@ -336,7 +340,6 @@ def run_benchmarks( print(f" nvMolKit GPU numpy failed: {e}") result["nvmol_gpu_numpy_time_ms"] = None - # nvMolKit GPU tensor benchmark (return_type="tensor", no D2H) try: t, s = time_it(lambda: bench_nvmol_gpu_tensor(mols)) result["nvmol_gpu_tensor_time_ms"] = t @@ -346,7 +349,6 @@ def run_benchmarks( print(f" nvMolKit GPU tensor failed: {e}") result["nvmol_gpu_tensor_time_ms"] = None - # Calculate speedups vs RDKit speedups = {} for key, label in [ ("nvmol_gpu_list_time_ms", "GPU list"), @@ -424,10 +426,16 @@ def main(): action="store_true", help="Verify correctness and exit (skip benchmarking)", ) + parser.add_argument( + "--prep-workers", + type=int, + default=0, + help="Parallel workers for ETKDG embedding during prep (0 = auto, half of CPUs)", + ) args = parser.parse_args() preloaded_mols = None - smiles_list = None + input_mols = None if args.pkl_file: print("Loading precomputed molecules from pickle file(s)...") @@ -439,19 +447,24 @@ def main(): else: print(f"Loading SMILES from: {args.smiles_file}") try: - df = pd.read_csv(args.smiles_file) - smiles_list = df.iloc[:, 0].tolist() + input_mols = load_smiles(args.smiles_file, max_count=max(args.num_mols) + 100) except Exception as e: print(f"Error loading SMILES file: {e}") sys.exit(1) - print(f"Loaded {len(smiles_list)} SMILES") + print(f"Loaded {len(input_mols)} molecules") if args.verify or args.verify_only: print("\nVerifying correctness...") if preloaded_mols is not None: test_mols = preloaded_mols[:50] else: - test_mols = prepare_molecules(smiles_list, num_confs=5, max_mols=50, smiles_file=args.smiles_file) + test_mols = prepare_molecules( + input_mols, + num_confs=5, + max_mols=50, + smiles_file=args.smiles_file, + num_workers=args.prep_workers, + ) all_correct = True mismatches = 0 for i, mol in enumerate(test_mols): @@ -470,13 +483,14 @@ def main(): sys.exit(0 if all_correct else 1) run_benchmarks( - smiles_list=smiles_list, + input_mols=input_mols, skip_rdkit=args.skip_rdkit, output_file=args.output, smiles_file=args.smiles_file, mol_counts=args.num_mols, conformer_counts=args.num_confs if not args.pkl_file else None, preloaded_mols=preloaded_mols, + num_workers=args.prep_workers, )