Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions examples/scripts/1_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
"""

# /// script
# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"]
# dependencies = [
# "torch_sim_atomistic[mace, io]",
# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop",
# ]
# ///

import itertools
Expand All @@ -18,7 +21,7 @@

import torch_sim as ts
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.models.mace import MaceModel
from torch_sim.telemetry import configure_logging, get_logger


Expand All @@ -33,9 +36,9 @@
# ============================================================================
# SECTION 1: Lennard-Jones Model - Simple Classical Potential
# ============================================================================
log.info("=" * 70)

log.info("SECTION 1: Lennard-Jones Model")
log.info("=" * 70)


# Create face-centered cubic (FCC) Argon
# 5.26 Å is a typical lattice constant for Ar
Expand Down Expand Up @@ -118,13 +121,13 @@
# ============================================================================
# SECTION 2: MACE Model - Machine Learning Potential (Batched)
# ============================================================================
log.info("=" * 70)

log.info("SECTION 2: MACE Model with Batched Input")
log.info("=" * 70)


# Load the raw model from the downloaded model
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
model="medium",
return_raw_model=True,
default_dtype=str(dtype).removeprefix("torch."),
device=str(device),
Expand Down Expand Up @@ -213,6 +216,5 @@
log.info(f"Max forces difference: {forces_diff}")
log.info(f"Max stress difference: {stress_diff}")

log.info("=" * 70)

log.info("Introduction examples completed!")
log.info("=" * 70)
43 changes: 22 additions & 21 deletions examples/scripts/2_structural_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
"""

# /// script
# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"]
# dependencies = [
# "torch_sim_atomistic[mace, io]",
# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop",
# ]
# ///

import itertools
Expand All @@ -21,7 +24,7 @@

import torch_sim as ts
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.models.mace import MaceModel
from torch_sim.telemetry import configure_logging, get_logger
from torch_sim.units import UnitConversion

Expand All @@ -41,9 +44,9 @@
# ============================================================================
# SECTION 1: Lennard-Jones FIRE Optimization
# ============================================================================
log.info("=" * 70)

log.info("SECTION 1: Lennard-Jones FIRE Optimization")
log.info("=" * 70)


# Set up the random number generator
generator = torch.Generator(device=device)
Expand Down Expand Up @@ -127,13 +130,13 @@
# ============================================================================
# SECTION 2: Batched MACE FIRE Optimization (Atomic Positions Only)
# ============================================================================
log.info("=" * 70)

log.info("SECTION 2: Batched MACE FIRE - Positions Only")
log.info("=" * 70)


# Load MACE model
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
model="medium",
return_raw_model=True,
default_dtype=str(dtype).removeprefix("torch."),
device=str(device),
Expand Down Expand Up @@ -189,9 +192,9 @@
# ============================================================================
# SECTION 3: Batched MACE Gradient Descent Optimization
# ============================================================================
log.info("=" * 70)

log.info("SECTION 3: Batched MACE Gradient Descent")
log.info("=" * 70)


# Reset structures with new perturbations
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -222,9 +225,9 @@
# ============================================================================
# SECTION 4: Unit Cell Filter with Gradient Descent
# ============================================================================
log.info("=" * 70)

log.info("SECTION 4: Unit Cell Filter with Gradient Descent")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -278,9 +281,9 @@
# ============================================================================
# SECTION 5: Unit Cell Filter with FIRE
# ============================================================================
log.info("=" * 70)

log.info("SECTION 5: Unit Cell Filter with FIRE")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -330,9 +333,9 @@
# ============================================================================
# SECTION 6: Frechet Cell Filter with FIRE
# ============================================================================
log.info("=" * 70)

log.info("SECTION 6: Frechet Cell Filter with FIRE")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -392,9 +395,9 @@
# ============================================================================
# SECTION 7: Batched MACE L-BFGS
# ============================================================================
log.info("=" * 70)

log.info("SECTION 7: Batched MACE L-BFGS")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
Expand Down Expand Up @@ -425,9 +428,9 @@
# ============================================================================
# SECTION 8: Batched MACE BFGS
# ============================================================================
log.info("=" * 70)

log.info("SECTION 8: Batched MACE BFGS")
log.info("=" * 70)


# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
Expand Down Expand Up @@ -455,6 +458,4 @@
log.info(f"Final energies: {[energy.item() for energy in state.energy]} eV")


log.info("=" * 70)
log.info("Structural optimization examples completed!")
log.info("=" * 70)
32 changes: 17 additions & 15 deletions examples/scripts/3_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
"""

# /// script
# dependencies = ["scipy>=1.15", "mace-torch>=0.3.12"]
# dependencies = [
# "torch_sim_atomistic[mace, io]",
# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop",
# ]
# ///

import itertools
Expand All @@ -21,7 +24,7 @@

import torch_sim as ts
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.models.mace import MaceModel
from torch_sim.telemetry import configure_logging, get_logger
from torch_sim.units import MetalUnits as Units

Expand Down Expand Up @@ -52,9 +55,9 @@
# ============================================================================
# SECTION 1: Lennard-Jones NVE (Microcanonical Ensemble)
# ============================================================================
log.info("=" * 70)

log.info("SECTION 1: Lennard-Jones NVE Simulation")
log.info("=" * 70)


# Create face-centered cubic (FCC) Argon
a_len = 5.26 # Lattice constant
Expand Down Expand Up @@ -139,13 +142,13 @@
# ============================================================================
# SECTION 2: MACE NVE Simulation
# ============================================================================
log.info("=" * 70)

log.info("SECTION 2: MACE NVE Simulation")
log.info("=" * 70)


# Load MACE model
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
model="medium",
return_raw_model=True,
default_dtype=str(dtype).removeprefix("torch."),
device=str(device),
Expand Down Expand Up @@ -205,9 +208,9 @@
# ============================================================================
# SECTION 3: MACE NVT Langevin Simulation
# ============================================================================
log.info("=" * 70)

log.info("SECTION 3: MACE NVT Langevin Simulation")
log.info("=" * 70)


# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -252,9 +255,9 @@
# ============================================================================
# SECTION 4: MACE NVT Nose-Hoover Simulation
# ============================================================================
log.info("=" * 70)

log.info("SECTION 4: MACE NVT Nose-Hoover Simulation")
log.info("=" * 70)


# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -294,9 +297,9 @@
# ============================================================================
# SECTION 5: MACE NPT Nose-Hoover Simulation
# ============================================================================
log.info("=" * 70)

log.info("SECTION 5: MACE NPT Nose-Hoover Simulation")
log.info("=" * 70)


# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
Expand Down Expand Up @@ -404,6 +407,5 @@
)
log.info(f"Final pressure: {final_pressure.item():.4f} eV/ų")

log.info("=" * 70)

log.info("Molecular dynamics examples completed!")
log.info("=" * 70)
40 changes: 21 additions & 19 deletions examples/scripts/4_high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
"""

# /// script
# dependencies = ["mace-torch>=0.3.12", "pymatgen>=2025.2.18"]
# dependencies = [
# "torch_sim_atomistic[mace, io]",
# "mace-torch @ git+https://github.com/ACEsuit/mace.git@develop",
# ]
# ///

import os
Expand Down Expand Up @@ -39,9 +42,9 @@
# ============================================================================
# SECTION 1: Basic Integration with Lennard-Jones
# ============================================================================
log.info("=" * 70)

log.info("SECTION 1: Basic Integration with Lennard-Jones")
log.info("=" * 70)


lj_model = LennardJonesModel(
sigma=2.0, # Å, typical for Si-Si interaction
Expand Down Expand Up @@ -69,9 +72,9 @@
# ============================================================================
# SECTION 2: Integration with Trajectory Logging
# ============================================================================
log.info("=" * 70)

log.info("SECTION 2: Integration with Trajectory Logging")
log.info("=" * 70)


trajectory_file = "tmp/lj_trajectory.h5md"

Expand Down Expand Up @@ -121,9 +124,9 @@
# ============================================================================
# SECTION 3: MACE Model with High-Level API
# ============================================================================
log.info("=" * 70)

log.info("SECTION 3: MACE Model with High-Level API")
log.info("=" * 70)


mace = mace_mp(model="small", return_raw_model=True)
mace_model = MaceModel(
Expand Down Expand Up @@ -156,9 +159,9 @@
# ============================================================================
# SECTION 4: Batched Integration
# ============================================================================
log.info("=" * 70)

log.info("SECTION 4: Batched Integration")
log.info("=" * 70)


fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True)
fe_atoms_supercell = fe_atoms.repeat([2, 2, 2])
Expand All @@ -182,9 +185,9 @@
# ============================================================================
# SECTION 5: Batched Integration with Trajectory Reporting
# ============================================================================
log.info("=" * 70)

log.info("SECTION 5: Batched Integration with Trajectory Reporting")
log.info("=" * 70)


systems = [si_atoms, fe_atoms, si_atoms_supercell, fe_atoms_supercell]

Expand Down Expand Up @@ -217,9 +220,9 @@
# ============================================================================
# SECTION 6: Structure Optimization
# ============================================================================
log.info("=" * 70)

log.info("SECTION 6: Structure Optimization")
log.info("=" * 70)


final_state = ts.optimize(
system=systems,
Expand All @@ -240,9 +243,9 @@
# ============================================================================
# SECTION 7: Optimization with Custom Convergence Criteria
# ============================================================================
log.info("=" * 70)

log.info("SECTION 7: Optimization with Custom Convergence")
log.info("=" * 70)


final_state = ts.optimize(
system=systems,
Expand All @@ -261,9 +264,9 @@
# ============================================================================
# SECTION 8: Pymatgen Structure Support
# ============================================================================
log.info("=" * 70)

log.info("SECTION 8: Pymatgen Structure Support")
log.info("=" * 70)


lattice = [[5.43, 0, 0], [0, 5.43, 0], [0, 0, 5.43]]
species = ["Si"] * 8
Expand Down Expand Up @@ -292,6 +295,5 @@
log.info(f"Final structure type: {type(final_structure)}")
log.info(f"Final energy: {final_state.energy.item():.4f} eV")

log.info("=" * 70)

log.info("High-level API examples completed!")
log.info("=" * 70)
Loading
Loading