Skip to content

schwallergroup/drift-react

Repository files navigation

Drift-React: One-step Generation of Reaction Pathways via SE(3) Drifting Fields

Python 3.12+ PyTorch Linting: Ruff uv License: MIT Weights & Biases arXiv


Drift-React: one-step SE(3)-equivariant reaction pathway generation


📖 Overview

Predicting how a chemical reaction unfolds in 3-D, from reactant geometry $\mathbf{X}_R$ to product geometry $\mathbf{X}_P$ through a transition state, is a central problem in computational chemistry. Existing approaches face a hard trade-off:

  • 🐢 Point-wise transition-state generators (flow matching, diffusion) require expensive ODE/SDE integration, ignore the topology of the reaction valley, and can produce atom clashes.
  • 🪢 Iterative path relaxers (Nudged Elastic Band, NeuralNEB) need $10^4-10^5$ force evaluations per reaction and quickly become computationally prohibitive.

Drift-React is a single-step, SE(3)-equivariant generative model that maps a simple analytic prior $\mathbf{X}_{\text{lin}}$ to the full minimum-energy pathway $\mathbf{Y}$ in ~12 ms per reaction. It is trained with an affinity-kernel drifting objective that yields continuous reaction pathways without test-time integration or iterative relaxation.

🚀 Installation

Prerequisites

  • Python 3.12+
  • uv (recommended)

From source

git clone https://github.com/schwallergroup/drift-react.git
cd drift-react

uv sync --all-extras --dev

Optional dependencies

Baselines that depend on non-PyPI packages are installed outside the uv sync resolution graph. The corresponding baseline modules import these libraries lazily and the test suite auto-skips when they are absent.

# Geodesic baseline (Zhu et al., 2019): wheels exist for Python ≤ 3.12 only.
uv pip install geodesic-interpolate

# NeuralNEB (Schreiner et al., 2022): GitLab source distribution.
uv pip install git+https://gitlab.com/matschreiner/neuralneb.git

🏃 Quick start

Training

# Default config: LEFTNet + SE(3) drift + linear-interp prior
uv run python scripts/train.py

# Override any Hydra field
uv run python scripts/train.py \
    prior=idpp \
    train.epochs=200 \
    optim.lr=5e-4

Configs live under configs/ and compose four groups:

Group Choices
model leftnet, mace
drift se3, feature_guided
prior gaussian, linear_interp, linear_interp_envelope, idpp, brownian_bridge
wandb Weights & Biases logging

📈 Evaluation

uv run python scripts/evaluate.py \
    --checkpoint runs/<id>/checkpoints/best.ckpt \
    --test-data data/halo8_test.lmdb \
    --eval-mode full_pathway \
    --n-frames 8

Reports IRC-RMSD (mean / max), discrete Fréchet distance, and TS-RMSD.

🖥️ HPC (SLURM)

sbatch slurm/train.sbatch                 # single run
sbatch slurm/benchmark_train.sbatch       # sweep
sbatch slurm/benchmark_eval.sbatch        # eval after sweep

🗂️ Datasets

Drift-React expects reaction pathways stored as LMDB files of ASE atoms with reactant_pos, product_pos, and pathway_pos ($F$ frames). Conversion utilities live under scripts/data/:

  • pkl_to_lmdb.py: convert a pickle of pathways to LMDB.
  • lmdb_to_asedb.py: round-trip to an ASE database.
  • split_lmdb.py: train / val / test split.

Out-of-the-box configurations target two benchmarks:

  • 🧬 Transition1x: full reaction pathways, 10 images per pathway.
  • 🧪 Halo8: 8-images reaction pathways for halogenation reactions.

📁 Repository layout

drift-react/
├── configs/                  # Hydra configs (model / drift / prior / wandb)
├── scripts/                  # train, evaluate, inference, benchmark, data prep
├── slurm/                    # SLURM job scripts
├── src/drift_react/
│   ├── data/                 # ReactionPathwayDataset (LMDB + padding)
│   ├── drifting/             # Drift fields + analytic priors
│   ├── models/               # DriftingGenerator + LEFTNet / MACE backbones
│   ├── baselines/            # Linear, IDPP, geodesic, NeuralNEB
│   ├── metrics/              # RMSD, Fréchet, energy
│   ├── callbacks/            # Pathway plotting for W&B
│   └── utils/
└── tests/

🛠️ Development

uv run ruff check --fix . && uv run ruff format .   # lint + format
uv run mypy src                                     # type-check
uv run pytest                                       # full suite
uv run pytest tests/test_drift.py::test_name        # single test
uv run tox                                          # mypy + pytest, isolated venv

Tests

Tox runs mypy src tests and the full pytest suite in an isolated venv:

uv run --with tox tox -e py313

Optional baseline dependencies (geodesic-interpolate, neuralneb) are installed manually by users who need them. The corresponding test classes auto-skip when they are absent, so CI stays green on a minimal install.

📚 How to cite

If you use Drift-React in your research, please cite the preprint:

@misc{schlamaDriftReactOnestepGeneration2026,
title={Drift-React: One-step Generation of Reaction Pathways via SE(3) Drifting Fields},
      author={Rémi Schlama and Philippe Schwaller},
      year={2026},
      eprint={2605.22990},
      archivePrefix={arXiv},
      primaryClass={physics.chem-ph},
      url={https://arxiv.org/abs/2605.22990},
}

📄 License

Released under the MIT License, see LICENSE for the full text.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors