From 7d15649679976be5bb093827777c3a98f9588441 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 9 Mar 2026 15:36:22 -0400 Subject: [PATCH 1/4] fea: first-party mace --- pyproject.toml | 2 +- tests/models/test_mace.py | 20 +- torch_sim/models/mace.py | 423 +++++++++----------------------------- 3 files changed, 105 insertions(+), 340 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4322706b..3e916f19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ test = [ ] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] symmetry = ["moyopy>=0.7.8"] -mace = ["mace-torch>=0.3.15"] +mace = ["mace-torch @ git+https://github.com/ACEsuit/mace.git@develop"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"] orb = ["orb-models>=0.6.0"] diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 322f3d12..0d18b196 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -17,7 +17,7 @@ from mace.calculators import MACECalculator from mace.calculators.foundations_models import mace_mp, mace_off - from torch_sim.models.mace import MaceModel, MaceUrls + from torch_sim.models.mace import MaceModel except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] @@ -32,8 +32,11 @@ raw_mace_omol = None HAS_MACE_OMOL = False -raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) -raw_mace_off = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True) +MACE_MP_SMALL_URL = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model" +MACE_OFF_SMALL_URL = "https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true" + +raw_mace_mp = mace_mp(model=MACE_MP_SMALL_URL, return_raw_model=True) +raw_mace_off = mace_off(model=MACE_OFF_SMALL_URL, return_raw_model=True) DTYPE = torch.float64 @@ -41,7 +44,7 @@ def ase_mace_calculator() -> MACECalculator: dtype = str(DTYPE).removeprefix("torch.") return mace_mp( - model=MaceUrls.mace_mp_small, device="cpu", default_dtype=dtype, dispersion=False + model=MACE_MP_SMALL_URL, device="cpu", default_dtype=dtype, dispersion=False ) @@ -81,7 +84,7 @@ def test_mace_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: @pytest.fixture def ase_mace_off_calculator() -> MACECalculator: return mace_off( - model=MaceUrls.mace_off_small, + model=MACE_OFF_SMALL_URL, device=str(DEVICE), default_dtype=str(DTYPE).removeprefix("torch."), dispersion=False, @@ -114,13 +117,6 @@ def test_mace_off_dtype_working( model.forward(benzene_sim_state.to(DEVICE, dtype)) -def test_mace_urls_enum() -> None: - assert len(MaceUrls) > 2 - for key in MaceUrls: - assert key.value.startswith("https://github.com/ACEsuit/mace-") - assert key.value.endswith((".model", ".model?raw=true")) - - @pytest.mark.skipif(not HAS_MACE_OMOL, reason="mace_omol not available") @pytest.mark.parametrize( ("charge", "spin"), diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 7dfbb765..9b40fe07 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -1,358 +1,127 @@ """Wrapper for MACE model in TorchSim. -This module provides a TorchSim wrapper of the MACE model for computing -energies, forces, and stresses for atomistic systems. It integrates the MACE model -with TorchSim's simulation framework, handling batched computations for multiple -systems simultaneously. +This module re-exports the MACE package's torch-sim integration for convenient +importing. The actual implementation is maintained in the `mace` package. -The implementation supports various features including: - -* Computing energies, forces, and stresses -* Handling periodic boundary conditions (PBC) -* Optional CuEq acceleration for improved performance -* Batched calculations for multiple systems - -Notes: - This module depends on the MACE package and implements the ModelInterface - for compatibility with the broader TorchSim framework. +References: + - MACE Package: https://github.com/ACEsuit/mace """ import traceback import warnings from collections.abc import Callable -from enum import StrEnum from pathlib import Path from typing import Any import torch -import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import torchsim_nl try: - from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq - from mace.tools import atomic_numbers_to_indices, utils -except (ImportError, ModuleNotFoundError) as exc: + from mace.calculators.mace_torchsim import MaceTorchSimModel +except ImportError as exc: warnings.warn(f"MACE import failed: {traceback.format_exc()}", stacklevel=2) class MaceModel(ModelInterface): - """MACE model wrapper for torch-sim. + """Dummy MACE model wrapper for torch-sim to enable safe imports. - This class is a placeholder for the MaceModel class. - It raises an ImportError if MACE is not installed. + NOTE: This class is a placeholder when `mace` is not installed. + It raises an ImportError if accessed. """ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: """Dummy init for type checking.""" raise err - - -def to_one_hot( - indices: torch.Tensor, num_classes: int, dtype: torch.dtype -) -> torch.Tensor: - """Generates one-hot encoding from indices. - - NOTE: this is a modified version of the to_one_hot function in mace.tools, - consider using upstream version if possible after https://github.com/ACEsuit/mace/pull/903/ - is merged. - - Args: - indices: A tensor of shape (N x 1) containing class indices. - num_classes: An integer specifying the total number of classes. - dtype: The desired data type of the output tensor. - - Returns: - torch.Tensor: A tensor of shape (N x num_classes) containing the - one-hot encodings. - """ - shape = (*indices.shape[:-1], num_classes) - oh = torch.zeros(shape, device=indices.device, dtype=dtype).view(shape) - - # scatter_ is the in-place version of scatter - oh.scatter_(dim=-1, index=indices, value=1) - - return oh.view(*shape) - - -class MaceModel(ModelInterface): - """Computes energies for multiple systems using a MACE model. - - This class wraps a MACE model to compute energies, forces, and stresses for - atomic systems within the TorchSim framework. It supports batched calculations - for multiple systems and handles the necessary transformations between - TorchSim's data structures and MACE's expected inputs. - - Attributes: - r_max (float): Cutoff radius for neighbor interactions. - z_table (utils.AtomicNumberTable): Table mapping atomic numbers to indices. - model (torch.nn.Module): The underlying MACE neural network model. - neighbor_list_fn (Callable): Function used to compute neighbor lists. - atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms]. - system_idx (torch.Tensor): System indices with shape [n_atoms]. - n_systems (int): Number of systems in the batch. - n_atoms_per_system (list[int]): Number of atoms in each system. - ptr (torch.Tensor): Pointers to the start of each system in the batch with - shape [n_systems + 1]. - total_atoms (int): Total number of atoms across all systems. - node_attrs (torch.Tensor): One-hot encoded atomic types with shape - [n_atoms, n_elements]. - """ - - def __init__( - self, - model: str | Path | torch.nn.Module | None = None, - *, - device: torch.device | None = None, - dtype: torch.dtype = torch.float64, - neighbor_list_fn: Callable = torchsim_nl, - compute_forces: bool = True, - compute_stress: bool = True, - enable_cueq: bool = False, - atomic_numbers: torch.Tensor | None = None, - system_idx: torch.Tensor | None = None, - ) -> None: - """Initialize the MACE model for energy and force calculations. - - Sets up the MACE model for energy, force, and stress calculations within - the TorchSim framework. The model can be initialized with atomic numbers - and system indices, or these can be provided during the forward pass. - - Args: - model (str | Path | torch.nn.Module | None): The MACE neural network model, - either as a path to a saved model or as a loaded torch.nn.Module instance. - device (torch.device | None): The device to run computations on. - Defaults to CUDA if available, otherwise CPU. - dtype (torch.dtype): The data type for tensor operations. - Defaults to torch.float64. - atomic_numbers (torch.Tensor | None): Atomic numbers with shape [n_atoms]. - If provided at initialization, cannot be provided again during forward. - system_idx (torch.Tensor | None): System indices with shape [n_atoms] - indicating which system each atom belongs to. If not provided with - atomic_numbers, all atoms are assumed to be in the same system. - neighbor_list_fn (Callable): Function to compute neighbor lists. - Defaults to torch_nl_linked_cell. - compute_forces (bool): Whether to compute forces. Defaults to True. - compute_stress (bool): Whether to compute stress. Defaults to True. - enable_cueq (bool): Whether to enable CuEq acceleration. Defaults to False. - - Raises: - NotImplementedError: If model is provided as a file path (not - implemented yet). - TypeError: If model is neither a path nor a torch.nn.Module. +else: + # Create a backwards-compatible wrapper around MaceTorchSimModel + class MaceModel(MaceTorchSimModel): + """Computes energies for multiple systems using a MACE model. + + This class wraps the MACE first-party TorchSim interface, providing + backwards compatibility with the previous torch-sim implementation. + + This class wraps a MACE model to compute energies, forces, and stresses for + atomic systems within the TorchSim framework. It supports batched calculations + for multiple systems and handles the necessary transformations between + TorchSim's data structures and MACE's expected inputs. + + Attributes: + r_max (float): Cutoff radius for neighbor interactions. + model (torch.nn.Module): The underlying MACE neural network model. + neighbor_list_fn (Callable): Function used to compute neighbor lists. + atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms]. + system_idx (torch.Tensor): System indices with shape [n_atoms]. + n_systems (int): Number of systems in the batch. + n_atoms_per_system (list[int]): Number of atoms in each system. + ptr (torch.Tensor): Pointers to the start of each system in the batch with + shape [n_systems + 1]. + total_atoms (int): Total number of atoms across all systems. + node_attrs (torch.Tensor): One-hot encoded atomic types with shape + [n_atoms, n_elements]. """ - super().__init__() - self._device = device or torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - self._dtype = dtype - self._compute_forces = compute_forces - self._compute_stress = compute_stress - self.neighbor_list_fn = neighbor_list_fn - self._memory_scales_with = "n_atoms_x_density" - - # Load model if provided as path - if isinstance(model, str | Path): - self.model = torch.load(model, map_location=self.device, weights_only=False) - elif isinstance(model, torch.nn.Module): - self.model = model.to(self.device) - else: - raise TypeError("Model must be a path or torch.nn.Module") - - self.model = self.model.eval() - - # Move all model components to device - self.model = self.model.to(device=self._device) - if self.dtype is not None: - self.model = self.model.to(dtype=self.dtype) - - if enable_cueq: - print("Converting models to CuEq for acceleration") # noqa: T201 - self.model = run_e3nn_to_cueq(self.model, device=self.device.type) - - # Set model properties - self.r_max = self.model.r_max - atomic_nums = self.model.atomic_numbers - if not isinstance(atomic_nums, torch.Tensor): - raise TypeError("MACE model atomic_numbers must be a tensor") - self.z_table = utils.AtomicNumberTable([int(z) for z in atomic_nums]) - self.model.atomic_numbers = atomic_nums.detach().clone().to(device=self.device) - - self.atomic_numbers_in_init = atomic_numbers is not None - self.system_idx_in_init = system_idx is not None - - if atomic_numbers is not None: - self.atomic_numbers = atomic_numbers - self._setup_node_attrs(atomic_numbers) - - if system_idx is not None: - self.system_idx = system_idx - self._setup_ptr(system_idx) - if ( - atomic_numbers is not None - and system_idx is not None - and system_idx.shape[0] != atomic_numbers.shape[0] - ): - raise ValueError( - f"system_idx length {system_idx.shape[0]} must match " - f"atomic_numbers length {atomic_numbers.shape[0]}." + def __init__( + self, + model: str | Path | torch.nn.Module, + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + neighbor_list_fn: Callable | None = None, + compute_forces: bool = True, + compute_stress: bool = True, + enable_cueq: bool = False, + atomic_numbers: torch.Tensor | None = None, + system_idx: torch.Tensor | None = None, + enable_oeq: bool = False, + compile_mode: str | None = None, + ) -> None: + """Initialize the MACE model for energy and force calculations. + + Sets up the MACE model for energy, force, and stress calculations within + the TorchSim framework. The model can be initialized with atomic numbers + and system indices, or these can be provided during the forward pass. + + Args: + model: The MACE neural network model, either as a path to a saved + model or as a loaded torch.nn.Module instance. + device: The device to run computations on. Defaults to CUDA if + available, otherwise CPU. + dtype: The data type for tensor operations. Defaults to + torch.float64. + atomic_numbers: Atomic numbers with shape [n_atoms]. If provided + at initialization, cannot be provided again during forward. + system_idx: System indices with shape [n_atoms] indicating which + system each atom belongs to. If not provided with + atomic_numbers, all atoms are assumed to be in the same + system. + neighbor_list_fn: Function to compute neighbor lists. Defaults to + torchsim_nl from torch-sim. + compute_forces: Whether to compute forces. Defaults to True. + compute_stress: Whether to compute stress. Defaults to True. + enable_cueq: Whether to enable CuEq acceleration. Defaults to + False. + enable_oeq: Whether to enable OEq acceleration. Defaults to + False. + compile_mode: PyTorch compilation mode (e.g., "reduce-overhead"). + Defaults to None (no compilation). + + Raises: + TypeError: If model is neither a path nor a torch.nn.Module. + """ + super().__init__( + model=model, + device=device, + dtype=dtype, + neighbor_list_fn=neighbor_list_fn, + compute_forces=compute_forces, + compute_stress=compute_stress, + enable_cueq=enable_cueq, + enable_oeq=enable_oeq, + compile_mode=compile_mode, + atomic_numbers=atomic_numbers, + system_idx=system_idx, ) - def _setup_ptr(self, system_idx: torch.Tensor) -> None: - """Compute system boundary pointers from system indices. - - Args: - system_idx (torch.Tensor): System indices tensor with shape [n_atoms]. - """ - counts = torch.bincount(system_idx) - self.n_systems = len(counts) - self.n_atoms_per_system = counts.tolist() - self.ptr = torch.cat([counts.new_zeros(1), counts.cumsum(0)]) - - def _setup_node_attrs(self, atomic_numbers: torch.Tensor) -> None: - """Compute one-hot encoded node attributes from atomic numbers. - - Args: - atomic_numbers (torch.Tensor): Atomic numbers tensor with shape [n_atoms]. - """ - self.node_attrs = to_one_hot( - torch.tensor( - atomic_numbers_to_indices( - atomic_numbers.detach().cpu().numpy(), z_table=self.z_table - ), - dtype=torch.long, - device=self.device, - ).unsqueeze(-1), - num_classes=len(self.z_table), - dtype=self.dtype, - ) - - def forward( # noqa: C901 - self, state: ts.SimState, **_kwargs: object - ) -> dict[str, torch.Tensor]: - """Compute energies, forces, and stresses for the given atomic systems. - - Processes the provided state information and computes energies, forces, and - stresses using the underlying MACE model. Handles batched calculations for - multiple systems and constructs the necessary neighbor lists. - - Args: - state (SimState): State object containing positions, cell, and other - system information. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - 'energy': System energies with shape [n_systems] - - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True - - 'stress': System stresses with shape [n_systems, 3, 3] if - compute_stress=True - - Raises: - ValueError: If atomic numbers are not provided either in the constructor - or in the forward pass, or if provided in both places. - ValueError: If system indices are not provided when needed. - """ - if self.atomic_numbers_in_init: - if state.positions.shape[0] != self.atomic_numbers.shape[0]: - raise ValueError( - f"Expected {self.atomic_numbers.shape[0]} atoms, " - f"got {state.positions.shape[0]}." - ) - elif not hasattr(self, "atomic_numbers") or not torch.equal( - state.atomic_numbers, self.atomic_numbers - ): - self._setup_node_attrs(state.atomic_numbers) - self.atomic_numbers = state.atomic_numbers - - if self.system_idx_in_init: - if state.system_idx.shape[0] != self.system_idx.shape[0]: - raise ValueError( - f"Expected system_idx of length {self.system_idx.shape[0]}, " - f"got {state.system_idx.shape[0]}." - ) - elif not hasattr(self, "system_idx") or not torch.equal( - state.system_idx, self.system_idx - ): - self._setup_ptr(state.system_idx) - self.system_idx = state.system_idx - - # Wrap positions into the unit cell - wrapped_positions = ( - ts.transforms.pbc_wrap_batched( - state.positions, - state.cell, - state.system_idx, - state.pbc, - ) - if state.pbc.any() - else state.positions - ) - - # Batched neighbor list using linked-cell algorithm - edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( - wrapped_positions, - state.row_vector_cell, - state.pbc, - self.r_max, - state.system_idx, - ) - # Convert unit cell shift indices to Cartesian shifts - shifts = ts.transforms.compute_cell_shifts( - state.row_vector_cell, unit_shifts, mapping_system - ) - - # Build data dict for MACE model - data_dict = dict( - ptr=self.ptr, - node_attrs=self.node_attrs, - batch=state.system_idx, - pbc=state.pbc, - cell=state.row_vector_cell, - positions=wrapped_positions, - edge_index=edge_index, - unit_shifts=unit_shifts, - shifts=shifts, - total_charge=state.charge, - total_spin=state.spin, - ) - - # Get model output - out = self.model( - data_dict, - compute_force=self.compute_forces, - compute_stress=self.compute_stress, - ) - - results: dict[str, torch.Tensor] = {} - - # Process energy - energy = out["energy"] - if energy is not None: - results["energy"] = energy.detach() - else: - results["energy"] = torch.zeros(self.n_systems, device=self.device) - - # Process forces - if self.compute_forces: - forces = out["forces"] - if forces is not None: - results["forces"] = forces.detach() - - # Process stress - if self.compute_stress: - stress = out["stress"] - if stress is not None: - results["stress"] = stress.detach() - - return results - - -class MaceUrls(StrEnum): - """Checkpoint download URLs for MACE models.""" - mace_mp_small = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model" - mace_mpa_medium = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" - mace_off_small = "https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true" +__all__ = ["MaceModel"] From 6d4aa980fd3e3902be9522ec87d6aeac3e9138a7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 9 Mar 2026 15:39:57 -0400 Subject: [PATCH 2/4] fea: allow kwargs --- torch_sim/models/mace.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 9b40fe07..81321c80 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -75,6 +75,7 @@ def __init__( system_idx: torch.Tensor | None = None, enable_oeq: bool = False, compile_mode: str | None = None, + **kwargs: Any, ) -> None: """Initialize the MACE model for energy and force calculations. @@ -105,6 +106,7 @@ def __init__( False. compile_mode: PyTorch compilation mode (e.g., "reduce-overhead"). Defaults to None (no compilation). + **kwargs: Additional keyword arguments to pass to the MACE model. Raises: TypeError: If model is neither a path nor a torch.nn.Module. @@ -121,6 +123,7 @@ def __init__( compile_mode=compile_mode, atomic_numbers=atomic_numbers, system_idx=system_idx, + **kwargs, ) From a15a2a3f6c36f20348f72e2b657061f1784ec4a5 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 9 Mar 2026 16:07:26 -0400 Subject: [PATCH 3/4] make ci work. --- examples/scripts/1_introduction.py | 18 +++---- examples/scripts/2_structural_optimization.py | 43 +++++++-------- examples/scripts/3_dynamics.py | 32 +++++------ examples/scripts/4_high_level_api.py | 37 +++++++------ examples/scripts/5_workflow.py | 49 +++++++++-------- examples/scripts/6_phonons.py | 48 +++++++++-------- examples/scripts/7_others.py | 53 ++++++++++--------- examples/scripts/8_bechmarking.py | 7 ++- examples/tutorials/autobatching_tutorial.py | 2 +- examples/tutorials/hybrid_swap_tutorial.py | 2 +- examples/tutorials/low_level_tutorial.py | 8 +-- 11 files changed, 160 insertions(+), 139 deletions(-) diff --git a/examples/scripts/1_introduction.py b/examples/scripts/1_introduction.py index 2c2f9e4e..14a2d55e 100644 --- a/examples/scripts/1_introduction.py +++ b/examples/scripts/1_introduction.py @@ -6,7 +6,7 @@ """ # /// script -# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"] +# dependencies = ["scipy>=1.15", "torch-sim-atomistic[mace] @ ."] # /// import itertools @@ -18,7 +18,7 @@ import torch_sim as ts from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.models.mace import MaceModel from torch_sim.telemetry import configure_logging, get_logger @@ -33,9 +33,9 @@ # ============================================================================ # SECTION 1: Lennard-Jones Model - Simple Classical Potential # ============================================================================ -log.info("=" * 70) + log.info("SECTION 1: Lennard-Jones Model") -log.info("=" * 70) + # Create face-centered cubic (FCC) Argon # 5.26 Å is a typical lattice constant for Ar @@ -118,13 +118,14 @@ # ============================================================================ # SECTION 2: MACE Model - Machine Learning Potential (Batched) # ============================================================================ -log.info("=" * 70) + log.info("SECTION 2: MACE Model with Batched Input") -log.info("=" * 70) + # Load the raw model from the downloaded model +MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, + model=MACE_MPA_MEDIUM_URL, return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), @@ -213,6 +214,5 @@ log.info(f"Max forces difference: {forces_diff}") log.info(f"Max stress difference: {stress_diff}") -log.info("=" * 70) + log.info("Introduction examples completed!") -log.info("=" * 70) diff --git a/examples/scripts/2_structural_optimization.py b/examples/scripts/2_structural_optimization.py index 35363621..278c7cb2 100644 --- a/examples/scripts/2_structural_optimization.py +++ b/examples/scripts/2_structural_optimization.py @@ -8,7 +8,7 @@ """ # /// script -# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"] +# dependencies = ["scipy>=1.15", "torch-sim-atomistic[mace] @ ."] # /// import itertools @@ -21,11 +21,14 @@ import torch_sim as ts from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.models.mace import MaceModel from torch_sim.telemetry import configure_logging, get_logger from torch_sim.units import UnitConversion +MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" + + configure_logging(log_file="2_structural_optimization.log") log = get_logger(name="2_structural_optimization") @@ -41,9 +44,9 @@ # ============================================================================ # SECTION 1: Lennard-Jones FIRE Optimization # ============================================================================ -log.info("=" * 70) + log.info("SECTION 1: Lennard-Jones FIRE Optimization") -log.info("=" * 70) + # Set up the random number generator generator = torch.Generator(device=device) @@ -127,13 +130,13 @@ # ============================================================================ # SECTION 2: Batched MACE FIRE Optimization (Atomic Positions Only) # ============================================================================ -log.info("=" * 70) + log.info("SECTION 2: Batched MACE FIRE - Positions Only") -log.info("=" * 70) + # Load MACE model loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, + model=MACE_MPA_MEDIUM_URL, return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), @@ -189,9 +192,9 @@ # ============================================================================ # SECTION 3: Batched MACE Gradient Descent Optimization # ============================================================================ -log.info("=" * 70) + log.info("SECTION 3: Batched MACE Gradient Descent") -log.info("=" * 70) + # Reset structures with new perturbations si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) @@ -222,9 +225,9 @@ # ============================================================================ # SECTION 4: Unit Cell Filter with Gradient Descent # ============================================================================ -log.info("=" * 70) + log.info("SECTION 4: Unit Cell Filter with Gradient Descent") -log.info("=" * 70) + # Recreate structures with perturbations si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2)) @@ -278,9 +281,9 @@ # ============================================================================ # SECTION 5: Unit Cell Filter with FIRE # ============================================================================ -log.info("=" * 70) + log.info("SECTION 5: Unit Cell Filter with FIRE") -log.info("=" * 70) + # Recreate structures with perturbations si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2)) @@ -330,9 +333,9 @@ # ============================================================================ # SECTION 6: Frechet Cell Filter with FIRE # ============================================================================ -log.info("=" * 70) + log.info("SECTION 6: Frechet Cell Filter with FIRE") -log.info("=" * 70) + # Recreate structures with perturbations si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2)) @@ -392,9 +395,9 @@ # ============================================================================ # SECTION 7: Batched MACE L-BFGS # ============================================================================ -log.info("=" * 70) + log.info("SECTION 7: Batched MACE L-BFGS") -log.info("=" * 70) + # Recreate structures with perturbations si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) @@ -425,9 +428,9 @@ # ============================================================================ # SECTION 8: Batched MACE BFGS # ============================================================================ -log.info("=" * 70) + log.info("SECTION 8: Batched MACE BFGS") -log.info("=" * 70) + # Recreate structures with perturbations si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) @@ -455,6 +458,4 @@ log.info(f"Final energies: {[energy.item() for energy in state.energy]} eV") -log.info("=" * 70) log.info("Structural optimization examples completed!") -log.info("=" * 70) diff --git a/examples/scripts/3_dynamics.py b/examples/scripts/3_dynamics.py index 65952f06..f226f6f0 100644 --- a/examples/scripts/3_dynamics.py +++ b/examples/scripts/3_dynamics.py @@ -8,7 +8,7 @@ """ # /// script -# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"] +# dependencies = ["scipy>=1.15", "torch-sim-atomistic[mace] @ ."] # /// import itertools @@ -21,11 +21,14 @@ import torch_sim as ts from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.models.mace import MaceModel from torch_sim.telemetry import configure_logging, get_logger from torch_sim.units import MetalUnits as Units +MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" + + configure_logging(log_file="3_dynamics.log") log = get_logger(name="3_dynamics") @@ -52,9 +55,9 @@ # ============================================================================ # SECTION 1: Lennard-Jones NVE (Microcanonical Ensemble) # ============================================================================ -log.info("=" * 70) + log.info("SECTION 1: Lennard-Jones NVE Simulation") -log.info("=" * 70) + # Create face-centered cubic (FCC) Argon a_len = 5.26 # Lattice constant @@ -139,13 +142,13 @@ # ============================================================================ # SECTION 2: MACE NVE Simulation # ============================================================================ -log.info("=" * 70) + log.info("SECTION 2: MACE NVE Simulation") -log.info("=" * 70) + # Load MACE model loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, + model=MACE_MPA_MEDIUM_URL, return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), @@ -205,9 +208,9 @@ # ============================================================================ # SECTION 3: MACE NVT Langevin Simulation # ============================================================================ -log.info("=" * 70) + log.info("SECTION 3: MACE NVT Langevin Simulation") -log.info("=" * 70) + # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) @@ -252,9 +255,9 @@ # ============================================================================ # SECTION 4: MACE NVT Nose-Hoover Simulation # ============================================================================ -log.info("=" * 70) + log.info("SECTION 4: MACE NVT Nose-Hoover Simulation") -log.info("=" * 70) + # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) @@ -294,9 +297,9 @@ # ============================================================================ # SECTION 5: MACE NPT Nose-Hoover Simulation # ============================================================================ -log.info("=" * 70) + log.info("SECTION 5: MACE NPT Nose-Hoover Simulation") -log.info("=" * 70) + # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) @@ -404,6 +407,5 @@ ) log.info(f"Final pressure: {final_pressure.item():.4f} eV/ų") -log.info("=" * 70) + log.info("Molecular dynamics examples completed!") -log.info("=" * 70) diff --git a/examples/scripts/4_high_level_api.py b/examples/scripts/4_high_level_api.py index 565eb7c1..2de7f33e 100644 --- a/examples/scripts/4_high_level_api.py +++ b/examples/scripts/4_high_level_api.py @@ -9,7 +9,7 @@ """ # /// script -# dependencies = ["mace-torch>=0.3.12", "pymatgen>=2025.2.18"] +# dependencies = ["torch-sim-atomistic[mace] @ .", "pymatgen>=2025.2.18"] # /// import os @@ -39,9 +39,9 @@ # ============================================================================ # SECTION 1: Basic Integration with Lennard-Jones # ============================================================================ -log.info("=" * 70) + log.info("SECTION 1: Basic Integration with Lennard-Jones") -log.info("=" * 70) + lj_model = LennardJonesModel( sigma=2.0, # Å, typical for Si-Si interaction @@ -69,9 +69,9 @@ # ============================================================================ # SECTION 2: Integration with Trajectory Logging # ============================================================================ -log.info("=" * 70) + log.info("SECTION 2: Integration with Trajectory Logging") -log.info("=" * 70) + trajectory_file = "tmp/lj_trajectory.h5md" @@ -121,9 +121,9 @@ # ============================================================================ # SECTION 3: MACE Model with High-Level API # ============================================================================ -log.info("=" * 70) + log.info("SECTION 3: MACE Model with High-Level API") -log.info("=" * 70) + mace = mace_mp(model="small", return_raw_model=True) mace_model = MaceModel( @@ -156,9 +156,9 @@ # ============================================================================ # SECTION 4: Batched Integration # ============================================================================ -log.info("=" * 70) + log.info("SECTION 4: Batched Integration") -log.info("=" * 70) + fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True) fe_atoms_supercell = fe_atoms.repeat([2, 2, 2]) @@ -182,9 +182,9 @@ # ============================================================================ # SECTION 5: Batched Integration with Trajectory Reporting # ============================================================================ -log.info("=" * 70) + log.info("SECTION 5: Batched Integration with Trajectory Reporting") -log.info("=" * 70) + systems = [si_atoms, fe_atoms, si_atoms_supercell, fe_atoms_supercell] @@ -217,9 +217,9 @@ # ============================================================================ # SECTION 6: Structure Optimization # ============================================================================ -log.info("=" * 70) + log.info("SECTION 6: Structure Optimization") -log.info("=" * 70) + final_state = ts.optimize( system=systems, @@ -240,9 +240,9 @@ # ============================================================================ # SECTION 7: Optimization with Custom Convergence Criteria # ============================================================================ -log.info("=" * 70) + log.info("SECTION 7: Optimization with Custom Convergence") -log.info("=" * 70) + final_state = ts.optimize( system=systems, @@ -261,9 +261,9 @@ # ============================================================================ # SECTION 8: Pymatgen Structure Support # ============================================================================ -log.info("=" * 70) + log.info("SECTION 8: Pymatgen Structure Support") -log.info("=" * 70) + lattice = [[5.43, 0, 0], [0, 5.43, 0], [0, 0, 5.43]] species = ["Si"] * 8 @@ -292,6 +292,5 @@ log.info(f"Final structure type: {type(final_structure)}") log.info(f"Final energy: {final_state.energy.item():.4f} eV") -log.info("=" * 70) + log.info("High-level API examples completed!") -log.info("=" * 70) diff --git a/examples/scripts/5_workflow.py b/examples/scripts/5_workflow.py index 0ed65c48..b1d4af99 100644 --- a/examples/scripts/5_workflow.py +++ b/examples/scripts/5_workflow.py @@ -7,7 +7,7 @@ """ # /// script -# dependencies = ["mace-torch>=0.3.12", "matbench-discovery>=1.3.1"] +# dependencies = ["torch-sim-atomistic[mace] @ .", "matbench-discovery>=1.3.1"] # /// import os @@ -20,10 +20,13 @@ import torch_sim as ts from torch_sim.elastic import get_bravais_type -from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.models.mace import MaceModel from torch_sim.telemetry import configure_logging, get_logger +MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" + + configure_logging(log_file="5_workflow.log") log = get_logger(name="5_workflow") @@ -41,12 +44,12 @@ # ============================================================================ # SECTION 1: In-Flight Autobatching Workflow # ============================================================================ -log.info("=" * 70) + log.info("SECTION 1: In-Flight Autobatching Workflow") -log.info("=" * 70) + log.info("Loading MACE model...") -mace = mace_mp(model=MaceUrls.mace_mpa_medium, return_raw_model=True) +mace = mace_mp(model=MACE_MPA_MEDIUM_URL, return_raw_model=True) mace_model = MaceModel( model=mace, device=device, @@ -147,15 +150,15 @@ # ============================================================================ # SECTION 2: Elastic Constants Calculation # ============================================================================ -log.info("=" * 70) + log.info("SECTION 2: Elastic Constants Calculation") -log.info("=" * 70) + # Use higher precision for elastic constants dtype_elastic = torch.float64 loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, + model=MACE_MPA_MEDIUM_URL, enable_cueq=False, device=str(device), default_dtype=str(dtype_elastic).removeprefix("torch."), @@ -221,21 +224,25 @@ ) # Print results -log.info("Elastic tensor (GPa):") elastic_tensor_np = elastic_tensor.detach().cpu().numpy() -for row in elastic_tensor_np: - log.info(" ".join(f"{val:10.4f}" for val in row)) - -log.info("Elastic moduli:") -log.info(f" Bulk modulus (GPa): {bulk_modulus:.4f}") -log.info(f" Shear modulus (GPa): {shear_modulus:.4f}") -log.info(f" Poisson's ratio: {poisson_ratio:.4f}") -log.info(f" Pugh's ratio (K/G): {pugh_ratio:.4f}") +elastic_tensor_str = "\n".join( + " ".join(f"{val:10.4f}" for val in row) for row in elastic_tensor_np +) # Interpret Pugh's ratio material_type = "ductile" if pugh_ratio > 1.75 else "brittle" -log.info(f" Material behavior: {material_type}") -log.info("=" * 70) -log.info("Workflow examples completed!") -log.info("=" * 70) +final_summary = ( + "Elastic tensor (GPa):\n" + f"{elastic_tensor_str}\n" + "\nElastic moduli:\n" + f" Bulk modulus (GPa): {bulk_modulus:.4f}\n" + f" Shear modulus (GPa): {shear_modulus:.4f}\n" + f" Poisson's ratio: {poisson_ratio:.4f}\n" + f" Pugh's ratio (K/G): {pugh_ratio:.4f}\n" + f" Material behavior: {material_type}\n" + f"\n{'=' * 70}\n" + "Workflow examples completed!\n" + f"{'=' * 70}" +) +log.info(final_summary) diff --git a/examples/scripts/6_phonons.py b/examples/scripts/6_phonons.py index 9851da86..9cbc2762 100644 --- a/examples/scripts/6_phonons.py +++ b/examples/scripts/6_phonons.py @@ -12,7 +12,7 @@ # /// script # dependencies = [ -# "mace-torch>=0.3.12", +# "torch-sim-atomistic[mace] @ .", # "phonopy>=2.35", # "pymatviz>=0.17.1", # "plotly>=6.3.0", @@ -30,10 +30,13 @@ from phonopy import Phonopy import torch_sim as ts -from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.models.mace import MaceModel from torch_sim.telemetry import configure_logging, get_logger +MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" + + configure_logging(log_file="6_phonons.log") log = get_logger(name="6_phonons") @@ -54,13 +57,13 @@ def require_not_none[T](value: T | None, message: str) -> T: # ============================================================================ # SECTION 1: Structure Relaxation for Phonons # ============================================================================ -log.info("=" * 70) + log.info("SECTION 1: Structure Relaxation") -log.info("=" * 70) + # Load the MACE model loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, + model=MACE_MPA_MEDIUM_URL, return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), @@ -103,9 +106,9 @@ def require_not_none[T](value: T | None, message: str) -> T: # ============================================================================ # SECTION 2: Phonon DOS Calculation # ============================================================================ -log.info("=" * 70) + log.info("SECTION 2: Phonon DOS Calculation") -log.info("=" * 70) + # Convert state to Phonopy atoms atoms = ts.io.state_to_phonopy(final_state)[0] @@ -166,9 +169,9 @@ def require_not_none[T](value: T | None, message: str) -> T: # ============================================================================ # SECTION 3: Phonon Band Structure Calculation # ============================================================================ -log.info("=" * 70) + log.info("SECTION 3: Phonon Band Structure Calculation") -log.info("=" * 70) + try: import seekpath @@ -260,16 +263,17 @@ def require_not_none[T](value: T | None, message: str) -> T: # ============================================================================ # SECTION 4: Summary # ============================================================================ -log.info("=" * 70) -log.info("Summary") -log.info("=" * 70) -log.info("Structure: Silicon (diamond)") -log.info("Supercell: 2x2x2") -log.info(f"Number of displaced structures: {len(supercells)}") -log.info("Batched force calculation: Yes") -log.info("Phonon DOS calculated: Yes") -log.info(f"Frequency range: {freq_points.min():.3f} to {freq_points.max():.3f} THz") - -log.info("=" * 70) -log.info("Phonon calculation examples completed!") -log.info("=" * 70) +log.info( + f"{'=' * 70}\n" + "Summary\n" + f"{'=' * 70}\n" + "Structure: Silicon (diamond)\n" + "Supercell: 2x2x2\n" + f"Number of displaced structures: {len(supercells)}\n" + "Batched force calculation: Yes\n" + "Phonon DOS calculated: Yes\n" + f"Frequency range: {freq_points.min():.3f} to {freq_points.max():.3f} THz\n" + f"\n{'=' * 70}\n" + "Phonon calculation examples completed!\n" + f"{'=' * 70}" +) diff --git a/examples/scripts/7_others.py b/examples/scripts/7_others.py index edebfe51..93d47a9b 100644 --- a/examples/scripts/7_others.py +++ b/examples/scripts/7_others.py @@ -36,9 +36,9 @@ # ============================================================================ # SECTION 1: Batched Neighbor List Calculations # ============================================================================ -log.info("=" * 70) + log.info("SECTION 1: Batched Neighbor List Calculations") -log.info("=" * 70) + # Create multiple atomic systems atoms_list = [ @@ -102,9 +102,9 @@ # ============================================================================ # SECTION 2: Velocity Autocorrelation Function (VACF) # ============================================================================ -log.info("=" * 70) + log.info("SECTION 2: Velocity Autocorrelation Function") -log.info("=" * 70) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float64 @@ -221,24 +221,27 @@ # ============================================================================ # SECTION 3: Summary # ============================================================================ -log.info("=" * 70) -log.info("Summary") -log.info("=" * 70) -log.info("Demonstrated features:") -log.info(" 1. Batched neighbor list calculations") -log.info(" - Linked cell method (efficient)") -log.info(" - N^2 method (simple)") -log.info(" 2. Velocity autocorrelation function (VACF)") -log.info(" - NVE molecular dynamics") -log.info(" - Running average over time windows") -log.info(" - Normalized correlation decay") - -log.info("Key capabilities:") -log.info(" - Efficient batched computations") -log.info(" - Multiple neighbor list algorithms") -log.info(" - Advanced property calculations during MD") -log.info(" - Trajectory analysis and correlation functions") - -log.info("=" * 70) -log.info("Miscellaneous examples completed!") -log.info("=" * 70) +final_summary = ( + "=" * 70 + + "\nSummary\n" + + "=" * 70 + + "\nDemonstrated features:\n" + + " 1. Batched neighbor list calculations\n" + + " - Linked cell method (efficient)\n" + + " - N^2 method (simple)\n" + + " 2. Velocity autocorrelation function (VACF)\n" + + " - NVE molecular dynamics\n" + + " - Running average over time windows\n" + + " - Normalized correlation decay\n" + + "\nKey capabilities:\n" + + " - Efficient batched computations\n" + + " - Multiple neighbor list algorithms\n" + + " - Advanced property calculations during MD\n" + + " - Trajectory analysis and correlation functions\n" + + "\n" + + "=" * 70 + + "\nMiscellaneous examples completed!\n" + + "=" * 70 +) + +log.info(final_summary) diff --git a/examples/scripts/8_bechmarking.py b/examples/scripts/8_bechmarking.py index d7cc5a1d..c5be5b11 100644 --- a/examples/scripts/8_bechmarking.py +++ b/examples/scripts/8_bechmarking.py @@ -17,10 +17,13 @@ from pymatgen.io.ase import AseAtomsAdaptor import torch_sim as ts -from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.models.mace import MaceModel from torch_sim.telemetry import configure_logging, get_logger +MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" + + configure_logging(log_file="8_bechmarking.log") log = get_logger(name="8_bechmarking") @@ -50,7 +53,7 @@ def load_mace_model(device: torch.device) -> MaceModel: """Load MACE model for benchmarking.""" loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, + model=MACE_MPA_MEDIUM_URL, return_raw_model=True, default_dtype="float64", device=str(device), diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index ca350643..c6431c77 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -1,6 +1,6 @@ # %% # /// script -# dependencies = ["torch_sim_atomistic[mace]"] +# dependencies = ["torch-sim-atomistic[mace] @ ."] # /// diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index c7fc6362..f6abefc6 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -1,6 +1,6 @@ # %% # /// script -# dependencies = ["torch_sim_atomistic[mace, io]"] +# dependencies = ["torch-sim-atomistic[mace,io] @ ."] # /// diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index fe95d9e9..5783b2aa 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -1,6 +1,6 @@ # %% # /// script -# dependencies = ["torch_sim_atomistic[mace]"] +# dependencies = ["torch-sim-atomistic[mace] @ ."] # /// @@ -60,11 +60,13 @@ # %% from mace.calculators.foundations_models import mace_mp -from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.models.mace import MaceModel + +MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" # load mace_mp using the mace package loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, + model=MACE_MPA_MEDIUM_URL, return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), From c79a7bf98ec7c08adfb0560ea158ecbe463f8136 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 9 Mar 2026 17:47:11 -0400 Subject: [PATCH 4/4] wip? --- examples/scripts/1_introduction.py | 8 +++--- examples/scripts/2_structural_optimization.py | 10 +++---- examples/scripts/3_dynamics.py | 10 +++---- examples/scripts/4_high_level_api.py | 5 +++- examples/scripts/5_workflow.py | 13 ++++----- examples/scripts/6_phonons.py | 10 +++---- examples/scripts/7_others.py | 4 ++- examples/scripts/8_bechmarking.py | 8 +++--- examples/tutorials/autobatching_tutorial.py | 5 +++- examples/tutorials/diff_sim.py | 9 ++----- examples/tutorials/high_level_tutorial.py | 3 ++- examples/tutorials/hybrid_swap_tutorial.py | 5 +++- examples/tutorials/low_level_tutorial.py | 8 +++--- examples/tutorials/reporting_tutorial.py | 1 + examples/tutorials/state_tutorial.py | 3 ++- tests/models/test_mace.py | 27 +++++-------------- tests/test_optimizers_vs_ase.py | 8 +++--- 17 files changed, 64 insertions(+), 73 deletions(-) diff --git a/examples/scripts/1_introduction.py b/examples/scripts/1_introduction.py index 14a2d55e..c9d2cfef 100644 --- a/examples/scripts/1_introduction.py +++ b/examples/scripts/1_introduction.py @@ -6,7 +6,10 @@ """ # /// script -# dependencies = ["scipy>=1.15", "torch-sim-atomistic[mace] @ ."] +# dependencies = [ +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", +# ] # /// import itertools @@ -123,9 +126,8 @@ # Load the raw model from the downloaded model -MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" loaded_model = mace_mp( - model=MACE_MPA_MEDIUM_URL, + model="medium", return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), diff --git a/examples/scripts/2_structural_optimization.py b/examples/scripts/2_structural_optimization.py index 278c7cb2..873127ee 100644 --- a/examples/scripts/2_structural_optimization.py +++ b/examples/scripts/2_structural_optimization.py @@ -8,7 +8,10 @@ """ # /// script -# dependencies = ["scipy>=1.15", "torch-sim-atomistic[mace] @ ."] +# dependencies = [ +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", +# ] # /// import itertools @@ -26,9 +29,6 @@ from torch_sim.units import UnitConversion -MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" - - configure_logging(log_file="2_structural_optimization.log") log = get_logger(name="2_structural_optimization") @@ -136,7 +136,7 @@ # Load MACE model loaded_model = mace_mp( - model=MACE_MPA_MEDIUM_URL, + model="medium", return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), diff --git a/examples/scripts/3_dynamics.py b/examples/scripts/3_dynamics.py index f226f6f0..f811f50b 100644 --- a/examples/scripts/3_dynamics.py +++ b/examples/scripts/3_dynamics.py @@ -8,7 +8,10 @@ """ # /// script -# dependencies = ["scipy>=1.15", "torch-sim-atomistic[mace] @ ."] +# dependencies = [ +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", +# ] # /// import itertools @@ -26,9 +29,6 @@ from torch_sim.units import MetalUnits as Units -MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" - - configure_logging(log_file="3_dynamics.log") log = get_logger(name="3_dynamics") @@ -148,7 +148,7 @@ # Load MACE model loaded_model = mace_mp( - model=MACE_MPA_MEDIUM_URL, + model="medium", return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), diff --git a/examples/scripts/4_high_level_api.py b/examples/scripts/4_high_level_api.py index 2de7f33e..e942ee07 100644 --- a/examples/scripts/4_high_level_api.py +++ b/examples/scripts/4_high_level_api.py @@ -9,7 +9,10 @@ """ # /// script -# dependencies = ["torch-sim-atomistic[mace] @ .", "pymatgen>=2025.2.18"] +# dependencies = [ +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", +# ] # /// import os diff --git a/examples/scripts/5_workflow.py b/examples/scripts/5_workflow.py index b1d4af99..b3104648 100644 --- a/examples/scripts/5_workflow.py +++ b/examples/scripts/5_workflow.py @@ -7,7 +7,11 @@ """ # /// script -# dependencies = ["torch-sim-atomistic[mace] @ .", "matbench-discovery>=1.3.1"] +# dependencies = [ +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", +# "matbench-discovery>=1.3.1", +# ] # /// import os @@ -24,9 +28,6 @@ from torch_sim.telemetry import configure_logging, get_logger -MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" - - configure_logging(log_file="5_workflow.log") log = get_logger(name="5_workflow") @@ -49,7 +50,7 @@ log.info("Loading MACE model...") -mace = mace_mp(model=MACE_MPA_MEDIUM_URL, return_raw_model=True) +mace = mace_mp(model="medium", return_raw_model=True) mace_model = MaceModel( model=mace, device=device, @@ -158,7 +159,7 @@ dtype_elastic = torch.float64 loaded_model = mace_mp( - model=MACE_MPA_MEDIUM_URL, + model="medium", enable_cueq=False, device=str(device), default_dtype=str(dtype_elastic).removeprefix("torch."), diff --git a/examples/scripts/6_phonons.py b/examples/scripts/6_phonons.py index 9cbc2762..6ec3b356 100644 --- a/examples/scripts/6_phonons.py +++ b/examples/scripts/6_phonons.py @@ -12,12 +12,11 @@ # /// script # dependencies = [ -# "torch-sim-atomistic[mace] @ .", -# "phonopy>=2.35", +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", # "pymatviz>=0.17.1", # "plotly>=6.3.0", # "seekpath", -# "ase", # ] # /// @@ -34,9 +33,6 @@ from torch_sim.telemetry import configure_logging, get_logger -MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" - - configure_logging(log_file="6_phonons.log") log = get_logger(name="6_phonons") @@ -63,7 +59,7 @@ def require_not_none[T](value: T | None, message: str) -> T: # Load the MACE model loaded_model = mace_mp( - model=MACE_MPA_MEDIUM_URL, + model="medium", return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), diff --git a/examples/scripts/7_others.py b/examples/scripts/7_others.py index 93d47a9b..972a3edf 100644 --- a/examples/scripts/7_others.py +++ b/examples/scripts/7_others.py @@ -8,7 +8,9 @@ """ # /// script -# dependencies = ["ase>=3.26", "scipy>=1.15", "matplotlib", "numpy"] +# dependencies = [ +# "torch_sim_atomistic[io]", +# ] # /// import os diff --git a/examples/scripts/8_bechmarking.py b/examples/scripts/8_bechmarking.py index c5be5b11..f4c0d388 100644 --- a/examples/scripts/8_bechmarking.py +++ b/examples/scripts/8_bechmarking.py @@ -3,7 +3,8 @@ # %% # /// script # dependencies = [ -# "torch_sim_atomistic[mace,test]" +# "torch_sim_atomistic[mace, test]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", # ] # /// @@ -21,9 +22,6 @@ from torch_sim.telemetry import configure_logging, get_logger -MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" - - configure_logging(log_file="8_bechmarking.log") log = get_logger(name="8_bechmarking") @@ -53,7 +51,7 @@ def load_mace_model(device: torch.device) -> MaceModel: """Load MACE model for benchmarking.""" loaded_model = mace_mp( - model=MACE_MPA_MEDIUM_URL, + model="medium", return_raw_model=True, default_dtype="float64", device=str(device), diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index c6431c77..2cd04185 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -1,6 +1,9 @@ # %% # /// script -# dependencies = ["torch-sim-atomistic[mace] @ ."] +# dependencies = [ +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", +# ] # /// diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py index cf96c6e6..fcf8d310 100644 --- a/examples/tutorials/diff_sim.py +++ b/examples/tutorials/diff_sim.py @@ -1,12 +1,7 @@ -# %% [markdown] -#
-# Dependencies +# %% # /// script -# dependencies = [ -# "matplotlib", -# ] +# dependencies = ["matplotlib",] # /// -#
# %% import typing diff --git a/examples/tutorials/high_level_tutorial.py b/examples/tutorials/high_level_tutorial.py index 04591d69..2a816142 100644 --- a/examples/tutorials/high_level_tutorial.py +++ b/examples/tutorials/high_level_tutorial.py @@ -1,7 +1,8 @@ # %% # /// script # dependencies = [ -# "torch_sim_atomistic[mace, io]" +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", # ] # /// diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index f6abefc6..13b3977e 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -1,6 +1,9 @@ # %% # /// script -# dependencies = ["torch-sim-atomistic[mace,io] @ ."] +# dependencies = [ +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", +# ] # /// diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 5783b2aa..11402189 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -1,6 +1,9 @@ # %% # /// script -# dependencies = ["torch-sim-atomistic[mace] @ ."] +# dependencies = [ +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", +# ] # /// @@ -62,11 +65,10 @@ from mace.calculators.foundations_models import mace_mp from torch_sim.models.mace import MaceModel -MACE_MPA_MEDIUM_URL = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" # load mace_mp using the mace package loaded_model = mace_mp( - model=MACE_MPA_MEDIUM_URL, + model="medium", return_raw_model=True, default_dtype=str(dtype).removeprefix("torch."), device=str(device), diff --git a/examples/tutorials/reporting_tutorial.py b/examples/tutorials/reporting_tutorial.py index 55a25079..44430c79 100644 --- a/examples/tutorials/reporting_tutorial.py +++ b/examples/tutorials/reporting_tutorial.py @@ -2,6 +2,7 @@ # /// script # dependencies = [ # "torch_sim_atomistic[mace, io]" +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", # ] # /// diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index bccf82ca..acfac8a4 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -1,7 +1,8 @@ # %% # /// script # dependencies = [ -# "torch_sim_atomistic[mace, io]" +# "torch_sim_atomistic[mace, io]", +# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop", # ] # /// diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 0d18b196..5a6c608c 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -15,37 +15,23 @@ try: from mace.calculators import MACECalculator - from mace.calculators.foundations_models import mace_mp, mace_off + from mace.calculators.foundations_models import mace_mp, mace_off, mace_omol from torch_sim.models.mace import MaceModel except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] -# mace_omol is optional (added in newer MACE versions) -try: - from mace.calculators.foundations_models import mace_omol - - raw_mace_omol = mace_omol(model="extra_large", return_raw_model=True) - HAS_MACE_OMOL = True -except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - raw_mace_omol = None - HAS_MACE_OMOL = False - -MACE_MP_SMALL_URL = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model" -MACE_OFF_SMALL_URL = "https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true" - -raw_mace_mp = mace_mp(model=MACE_MP_SMALL_URL, return_raw_model=True) -raw_mace_off = mace_off(model=MACE_OFF_SMALL_URL, return_raw_model=True) +raw_mace_omol = mace_omol(model="extra_large", return_raw_model=True) +raw_mace_mp = mace_mp(model="small", return_raw_model=True) +raw_mace_off = mace_off(model="small", return_raw_model=True) DTYPE = torch.float64 @pytest.fixture def ase_mace_calculator() -> MACECalculator: dtype = str(DTYPE).removeprefix("torch.") - return mace_mp( - model=MACE_MP_SMALL_URL, device="cpu", default_dtype=dtype, dispersion=False - ) + return mace_mp(model="small", device="cpu", default_dtype=dtype, dispersion=False) @pytest.fixture @@ -84,7 +70,7 @@ def test_mace_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: @pytest.fixture def ase_mace_off_calculator() -> MACECalculator: return mace_off( - model=MACE_OFF_SMALL_URL, + model="small", device=str(DEVICE), default_dtype=str(DTYPE).removeprefix("torch."), dispersion=False, @@ -117,7 +103,6 @@ def test_mace_off_dtype_working( model.forward(benzene_sim_state.to(DEVICE, dtype)) -@pytest.mark.skipif(not HAS_MACE_OMOL, reason="mace_omol not available") @pytest.mark.parametrize( ("charge", "spin"), [ diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index 328507f4..921c089e 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -16,7 +16,7 @@ try: from mace.calculators.foundations_models import mace_mp - from torch_sim.models.mace import MaceModel, MaceUrls + from torch_sim.models.mace import MaceModel except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] @@ -30,9 +30,7 @@ def ts_mace_mpa() -> MaceModel: """Provides a MACE MP model instance for the optimizer tests.""" # Use float64 for potentially higher precision needed in optimization dtype = getattr(torch, dtype_str := "float64") - raw_mace = mace_mp( - model=MaceUrls.mace_mp_small, return_raw_model=True, default_dtype=dtype_str - ) + raw_mace = mace_mp(model="small", return_raw_model=True, default_dtype=dtype_str) return MaceModel( model=raw_mace, device=torch.device("cpu"), @@ -46,7 +44,7 @@ def ts_mace_mpa() -> MaceModel: def ase_mace_mpa() -> "MACECalculator": """Provides an ASE MACECalculator instance using mace_mp.""" # Ensure dtype matches the one used in the torch-sim fixture (float64) - return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64") + return mace_mp(model="small", default_dtype="float64") def _compare_ase_and_ts_states(