Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions docs/user/codes/forcefields.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,28 @@ Support is provided for the following models, which can be selected using `atoma

## Using custom forcefields by dictionary

`force_field_name` also accepts a MSONable dictionary for specifying a custom ASE calculator class or function [^calculator-meta-type-annotation].
For example, a `Job` created with the following code snippet instantiates `chgnet.model.dynamics.CHGNetCalculator` as the ASE calculator:
`force_field_name` also accepts an import-like string, or MSONable dictionary to specify a custom ASE calculator class or function [^calculator-meta-type-annotation].
For example, a `Job` created with the either of the following two code snippets instantiates a `chgnet.model.dynamics.CHGNetCalculator` as the ASE calculator.
```python
# simple import string
job = ForceFieldStaticMaker(
force_field_name={
calculator_meta="chgnet.model.dynamics.CHGNetCalculator",
).make(structure)
```

or using `force_field_name` when

```python
# monty MSONable style
job = ForceFieldStaticMaker(
calculator_meta={
"@module": "chgnet.model.dynamics",
"@callable": "CHGNetCalculator",
}
).make(structure)
```
Note that one can also specify `force_field_name = {"@module": ...,"@class": ...}` in the second example for backwards compatibility.
However, this may not be preserved in future versions, and `calculator_meta` is preferred.

[^calculator-meta-type-annotation]: In this context, the type annotation of the decoded dict should be either `Type[Calculator]` or `Callable[..., Calculator]`, where `Calculator` is from `ase.calculators.calculator`.

Expand Down
4 changes: 2 additions & 2 deletions src/atomate2/forcefields/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def from_ase_compatible_result(
ase_calculator_name: str,
result: AseResult,
steps: int,
calculator_meta: MLFF | dict | None = None,
calculator_meta: str | MLFF | dict | None = None,
relax_kwargs: dict = None,
optimizer_kwargs: dict = None,
fix_symmetry: bool = False,
Expand Down Expand Up @@ -123,7 +123,7 @@ def from_ase_compatible_result(
Whether to fix the symmetry of the ions during relaxation.
symprec : float
Tolerance for symmetry finding in case of fix_symmetry.
calculator_meta : Optional, MLFF or dict or None
calculator_meta : Optional, str, MLFF, dict, or None
Metadata about the calculator used.
steps : int
Maximum number of ionic steps allowed during relaxation.
Expand Down
117 changes: 87 additions & 30 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import inspect
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand All @@ -13,6 +14,7 @@
from pathlib import Path
from typing import TYPE_CHECKING

from ase.calculators.calculator import Calculator
from ase.units import Bohr
from monty.json import MontyDecoder
from typing_extensions import assert_never, deprecated
Expand All @@ -26,8 +28,6 @@
except ImportError:
torch_dtype = str

from ase.calculators.calculator import Calculator

from atomate2.ase.schemas import AseResult

_FORCEFIELD_DATA_OBJECTS = ["trajectory", "ionic_steps"]
Expand Down Expand Up @@ -159,42 +159,80 @@ def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
class ForceFieldMixin:
"""Mix-in class for force-fields.

Attributes
----------
force_field_name : str or MLFF
Name of the forcefield which will be
correctly deserialized/standardized if the forcefield is
a known `MLFF`.
calculator_meta : MLFF or dict
Actual metadata to instantiate the ASE calculator.
calculator_kwargs : dict = field(default_factory=dict)
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs: dict = field(default_factory=dict)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()
or another final document schema.
All basic forcefield jobs should inherit from this class
to easily access `ase_calculator`.
"""

force_field_name: str | MLFF | dict = MLFF.Forcefield
calculator_meta: MLFF | dict = field(init=False)
calculator_kwargs: dict = field(default_factory=dict)
task_document_kwargs: dict = field(default_factory=dict)
calculator_meta: str | MLFF | dict | None = None
calculator_kwargs: dict[str, Any] = field(default_factory=dict)
task_document_kwargs: dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
"""Ensure that force_field_name is correctly assigned."""
"""Validate input data types.

Attributes
----------
force_field_name : str, MLFF, or dict
If a str or MLFF: Name of the forcefield which will be
correctly deserialized/standardized if the forcefield is
a known `MLFF`.
If a dict, a monty-style dict.

calculator_meta : MLFF, str, or dict
Actual metadata to instantiate the ASE calculator.
If a MLFF, that default interface in `ase_calculator` will be used.
If an import-style str or monty-style dict, the calculator will
be dynamically loaded.

calculator_kwargs : dict = {}
Keyword arguments that will get passed to the ASE calculator.

task_document_kwargs: dict = {}
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()
or another final document schema.
"""
if hasattr(super(), "__post_init__"):
super().__post_init__() # type: ignore[misc]

mlff: MLFF = MLFF.Forcefield # Fallback to placeholder
if isinstance(self.force_field_name, dict):
mlff = MLFF.Forcefield # Fallback to placeholder
self.calculator_meta = self.force_field_name.copy()
calculator_meta: str | dict[str, Any] | MLFF = self.force_field_name.copy()

elif (
(
inspect.isclass(self.force_field_name)
and issubclass(self.force_field_name, Calculator)
)
or isinstance(self.force_field_name, Calculator)
or inspect.isfunction(self.force_field_name) # for mace_mp specifically
):
# can happen with deserialization of legacy documents from JSON
calculator_meta = ".".join(
getattr(self.force_field_name, k) for k in ("__module__", "__name__")
)

else:
mlff = _get_standardized_mlff(self.force_field_name)
self.calculator_meta = mlff
# On round-trip deserialization, `calculator_meta` will be a dict
# of the calculator information
calculator_meta = self.calculator_meta or mlff

# avoids unintentional deserialization from monty on round-trip
if isinstance(calculator_meta, dict):
self.calculator_meta: str | MLFF = ".".join(
calculator_meta[k] for k in ("@module", "@callable")
)
else:
try:
self.calculator_meta = _get_standardized_mlff(calculator_meta)
except ValueError:
self.calculator_meta = calculator_meta

self.force_field_name: str = str(mlff) # Narrow-down type for mypy

# Pad calculator_kwargs with default values, but permit user to override them
self.calculator_kwargs = {
self.calculator_kwargs: dict[str, Any] = {
**_DEFAULT_CALCULATOR_KWARGS.get(mlff, {}),
**self.calculator_kwargs,
}
Expand Down Expand Up @@ -228,7 +266,7 @@ def ase_calculator_name(self) -> str:
"""The name of the ASE calculator for schemas."""
if isinstance(self.calculator_meta, MLFF):
return str(self.force_field_name)
if isinstance(self.calculator_meta, dict):
if isinstance(self.calculator_meta, str | dict):
calc_cls = _load_calc_cls(self.calculator_meta)
return calc_cls.__name__
assert_never(self.calculator_meta)
Expand Down Expand Up @@ -402,7 +440,9 @@ def ase_calculator(
**{k: v for k, v in kwargs.items() if k != "predict_unit"},
)

elif isinstance(calculator_meta, dict):
elif isinstance(calculator_meta, dict) or (
isinstance(calculator_meta, str) and calculator_meta.count(".") >= 1
):
calc_cls = _load_calc_cls(calculator_meta)
calculator = calc_cls(**kwargs)

Expand All @@ -413,8 +453,25 @@ def ase_calculator(


def _load_calc_cls(
calculator_meta: dict,
calculator_meta: str | dict,
) -> type[Calculator] | Callable[..., Calculator]:
"""Load an ASE calculator using monty or importlib.

Parameters
----------
calculator_meta : str or dict
If a str, should be a dot-separated import string:
"chgnet.model.dynamics.CHGNetCalculator"
If a dict, should be a monty-style JSONable dict:
{"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"}

Returns
-------
ase Calculator
"""
if isinstance(calculator_meta, str):
module, klass = calculator_meta.rsplit(".", 1)
return getattr(import_module(module), klass)
return MontyDecoder().process_decoded(calculator_meta)


Expand All @@ -434,12 +491,12 @@ def revert_default_dtype() -> Generator[None]:
torch.set_default_dtype(orig)


def _get_pkg_name(calculator_meta: MLFF | dict[str, Any]) -> str | None:
def _get_pkg_name(calculator_meta: MLFF | str | dict[str, Any]) -> str | None:
"""Get the package name for a given force field.

Parameters
----------
calculator_meta : MLFF or JSONable dict
calculator_meta : MLFF, import-style str, or JSONable dict
The calculator metadata used to load the calculator,
or an MLFF enum.

Expand Down Expand Up @@ -480,13 +537,13 @@ def _get_pkg_name(calculator_meta: MLFF | dict[str, Any]) -> str | None:
case _:
ff_pkg = None
return ff_pkg
if isinstance(calculator_meta, dict):
if isinstance(calculator_meta, str | dict):
calc_cls = _load_calc_cls(calculator_meta)
return calc_cls.__module__.split(".", 1)[0]
assert_never(calculator_meta)


def _get_pkg_version(calculator_meta: MLFF | dict[str, Any]) -> str | None:
def _get_pkg_version(calculator_meta: str | dict[str, Any] | MLFF) -> str | None:
"""Try to establish the imported version of a forcefield python package."""
if isinstance(pkg_name := _get_pkg_name(calculator_meta), str):
try:
Expand Down
85 changes: 85 additions & 0 deletions tests/forcefields/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,3 +893,88 @@ def test_ext_load_static_maker(si_structure: Structure):

assert output1.forcefield_name == "mace_mp"
assert output1.forcefield_version == get_imported_version("mace_torch")


@pytest.mark.skipif(not mlff_is_installed("MACE"), reason="MACE is not installed")
@pytest.mark.parametrize("as_str", [True, False])
def test_roundtrip(si_structure: Structure, as_str: bool):

import json

from ase.calculators.calculator import Calculator
from mace.calculators import MACECalculator
from monty.json import MontyDecoder, MontyEncoder

import_str = "mace.calculators.mace_mp"
module, klass = import_str.rsplit(".", 1)

# If using an import string, one must specify this through `calculator_meta`
# If using a monty-style dict, one can use either `calculator_meta` (preferred)
# or `force_field_name` (for backwards compatibility)
valid_kwargs = ["calculator_meta"] + ([] if as_str else ["force_field_name"])

for calc_kwarg in valid_kwargs:
job = ForceFieldRelaxMaker(
**{
calc_kwarg: (
import_str if as_str else {"@module": module, "@callable": klass}
)
},
calculator_kwargs={"model": "medium"},
).make(si_structure)

roundtrip_job = MontyDecoder().decode(json.dumps(job, cls=MontyEncoder))

for j in (job, roundtrip_job):
assert j.maker.calculator_meta == import_str
assert j.maker.force_field_name == str(MLFF.Forcefield)
assert j.maker.mlff == MLFF.Forcefield
assert isinstance(j.maker.calculator, MACECalculator)
assert isinstance(j.maker.calculator, Calculator)


@pytest.mark.skipif(not mlff_is_installed("MACE"), reason="MACE is not installed")
@pytest.mark.parametrize(
"import_str",
[
"mace.calculators.foundations_models.mace_mp",
"mace.calculators.mace.MACECalculator",
],
)
def test_roundtrip_legacy(si_structure: Structure, import_str: str):
# Test backwards compatibility. Legacy docs can contain dict for
# `force_field_name` which will be deserialized by monty into an ase
# `Calculator`. `ForceFieldMixin` needs to handle this and
# narrow types correctly

import json

from ase.calculators.calculator import Calculator
from mace.calculators import MACECalculator
from mace.calculators.foundations_models import download_mace_mp_checkpoint
from monty.json import MontyDecoder, MontyEncoder

module, klass = import_str.rsplit(".", 1)

job = ForceFieldRelaxMaker(
force_field_name={"@module": module, "@callable": klass},
calculator_kwargs=(
{"model": "medium"}
if klass == "mace_mp"
else {"model_paths": download_mace_mp_checkpoint("medium")}
),
).make(si_structure)

job_dct = json.loads(MontyEncoder().encode(job))
job_dct["function"]["@bound"]["force_field_name"] = {
"@module": module,
"@callable": klass,
}
job_dct["function"]["@bound"].pop("calculator_meta")

deser = MontyDecoder().process_decoded(job_dct)
assert deser.maker.calculator_meta == import_str
assert deser.maker.force_field_name == str(MLFF.Forcefield)
assert deser.maker.mlff == MLFF.Forcefield
assert isinstance(deser.maker.calculator, MACECalculator)
assert isinstance(deser.maker.calculator, Calculator)
Loading