diff --git a/docs/user/codes/forcefields.md b/docs/user/codes/forcefields.md index 5b70893e76..17ad64eac7 100644 --- a/docs/user/codes/forcefields.md +++ b/docs/user/codes/forcefields.md @@ -48,16 +48,28 @@ Support is provided for the following models, which can be selected using `atoma ## Using custom forcefields by dictionary -`force_field_name` also accepts a MSONable dictionary for specifying a custom ASE calculator class or function [^calculator-meta-type-annotation]. -For example, a `Job` created with the following code snippet instantiates `chgnet.model.dynamics.CHGNetCalculator` as the ASE calculator: +`force_field_name` also accepts an import-like string, or MSONable dictionary to specify a custom ASE calculator class or function [^calculator-meta-type-annotation]. +For example, a `Job` created with the either of the following two code snippets instantiates a `chgnet.model.dynamics.CHGNetCalculator` as the ASE calculator. ```python +# simple import string job = ForceFieldStaticMaker( - force_field_name={ + calculator_meta="chgnet.model.dynamics.CHGNetCalculator", +).make(structure) +``` + +or using `force_field_name` when + +```python +# monty MSONable style +job = ForceFieldStaticMaker( + calculator_meta={ "@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator", } ).make(structure) ``` +Note that one can also specify `force_field_name = {"@module": ...,"@class": ...}` in the second example for backwards compatibility. +However, this may not be preserved in future versions, and `calculator_meta` is preferred. [^calculator-meta-type-annotation]: In this context, the type annotation of the decoded dict should be either `Type[Calculator]` or `Callable[..., Calculator]`, where `Calculator` is from `ase.calculators.calculator`. diff --git a/src/atomate2/forcefields/schemas.py b/src/atomate2/forcefields/schemas.py index 255ca634e7..51a0e67308 100644 --- a/src/atomate2/forcefields/schemas.py +++ b/src/atomate2/forcefields/schemas.py @@ -95,7 +95,7 @@ def from_ase_compatible_result( ase_calculator_name: str, result: AseResult, steps: int, - calculator_meta: MLFF | dict | None = None, + calculator_meta: str | MLFF | dict | None = None, relax_kwargs: dict = None, optimizer_kwargs: dict = None, fix_symmetry: bool = False, @@ -123,7 +123,7 @@ def from_ase_compatible_result( Whether to fix the symmetry of the ions during relaxation. symprec : float Tolerance for symmetry finding in case of fix_symmetry. - calculator_meta : Optional, MLFF or dict or None + calculator_meta : Optional, str, MLFF, dict, or None Metadata about the calculator used. steps : int Maximum number of ionic steps allowed during relaxation. diff --git a/src/atomate2/forcefields/utils.py b/src/atomate2/forcefields/utils.py index fc1f0d6250..f9aba0e739 100644 --- a/src/atomate2/forcefields/utils.py +++ b/src/atomate2/forcefields/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import inspect import warnings from contextlib import contextmanager from dataclasses import dataclass, field @@ -13,6 +14,7 @@ from pathlib import Path from typing import TYPE_CHECKING +from ase.calculators.calculator import Calculator from ase.units import Bohr from monty.json import MontyDecoder from typing_extensions import assert_never, deprecated @@ -26,8 +28,6 @@ except ImportError: torch_dtype = str - from ase.calculators.calculator import Calculator - from atomate2.ase.schemas import AseResult _FORCEFIELD_DATA_OBJECTS = ["trajectory", "ionic_steps"] @@ -159,42 +159,80 @@ def _get_formatted_ff_name(force_field_name: str | MLFF) -> str: class ForceFieldMixin: """Mix-in class for force-fields. - Attributes - ---------- - force_field_name : str or MLFF - Name of the forcefield which will be - correctly deserialized/standardized if the forcefield is - a known `MLFF`. - calculator_meta : MLFF or dict - Actual metadata to instantiate the ASE calculator. - calculator_kwargs : dict = field(default_factory=dict) - Keyword arguments that will get passed to the ASE calculator. - task_document_kwargs: dict = field(default_factory=dict) - Additional keyword args passed to :obj:`.ForceFieldTaskDocument() - or another final document schema. + All basic forcefield jobs should inherit from this class + to easily access `ase_calculator`. """ force_field_name: str | MLFF | dict = MLFF.Forcefield - calculator_meta: MLFF | dict = field(init=False) - calculator_kwargs: dict = field(default_factory=dict) - task_document_kwargs: dict = field(default_factory=dict) + calculator_meta: str | MLFF | dict | None = None + calculator_kwargs: dict[str, Any] = field(default_factory=dict) + task_document_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: - """Ensure that force_field_name is correctly assigned.""" + """Validate input data types. + + Attributes + ---------- + force_field_name : str, MLFF, or dict + If a str or MLFF: Name of the forcefield which will be + correctly deserialized/standardized if the forcefield is + a known `MLFF`. + If a dict, a monty-style dict. + + calculator_meta : MLFF, str, or dict + Actual metadata to instantiate the ASE calculator. + If a MLFF, that default interface in `ase_calculator` will be used. + If an import-style str or monty-style dict, the calculator will + be dynamically loaded. + + calculator_kwargs : dict = {} + Keyword arguments that will get passed to the ASE calculator. + + task_document_kwargs: dict = {} + Additional keyword args passed to :obj:`.ForceFieldTaskDocument() + or another final document schema. + """ if hasattr(super(), "__post_init__"): super().__post_init__() # type: ignore[misc] + mlff: MLFF = MLFF.Forcefield # Fallback to placeholder if isinstance(self.force_field_name, dict): - mlff = MLFF.Forcefield # Fallback to placeholder - self.calculator_meta = self.force_field_name.copy() + calculator_meta: str | dict[str, Any] | MLFF = self.force_field_name.copy() + + elif ( + ( + inspect.isclass(self.force_field_name) + and issubclass(self.force_field_name, Calculator) + ) + or isinstance(self.force_field_name, Calculator) + or inspect.isfunction(self.force_field_name) # for mace_mp specifically + ): + # can happen with deserialization of legacy documents from JSON + calculator_meta = ".".join( + getattr(self.force_field_name, k) for k in ("__module__", "__name__") + ) + else: mlff = _get_standardized_mlff(self.force_field_name) - self.calculator_meta = mlff + # On round-trip deserialization, `calculator_meta` will be a dict + # of the calculator information + calculator_meta = self.calculator_meta or mlff + + # avoids unintentional deserialization from monty on round-trip + if isinstance(calculator_meta, dict): + self.calculator_meta: str | MLFF = ".".join( + calculator_meta[k] for k in ("@module", "@callable") + ) + else: + try: + self.calculator_meta = _get_standardized_mlff(calculator_meta) + except ValueError: + self.calculator_meta = calculator_meta self.force_field_name: str = str(mlff) # Narrow-down type for mypy # Pad calculator_kwargs with default values, but permit user to override them - self.calculator_kwargs = { + self.calculator_kwargs: dict[str, Any] = { **_DEFAULT_CALCULATOR_KWARGS.get(mlff, {}), **self.calculator_kwargs, } @@ -228,7 +266,7 @@ def ase_calculator_name(self) -> str: """The name of the ASE calculator for schemas.""" if isinstance(self.calculator_meta, MLFF): return str(self.force_field_name) - if isinstance(self.calculator_meta, dict): + if isinstance(self.calculator_meta, str | dict): calc_cls = _load_calc_cls(self.calculator_meta) return calc_cls.__name__ assert_never(self.calculator_meta) @@ -402,7 +440,9 @@ def ase_calculator( **{k: v for k, v in kwargs.items() if k != "predict_unit"}, ) - elif isinstance(calculator_meta, dict): + elif isinstance(calculator_meta, dict) or ( + isinstance(calculator_meta, str) and calculator_meta.count(".") >= 1 + ): calc_cls = _load_calc_cls(calculator_meta) calculator = calc_cls(**kwargs) @@ -413,8 +453,25 @@ def ase_calculator( def _load_calc_cls( - calculator_meta: dict, + calculator_meta: str | dict, ) -> type[Calculator] | Callable[..., Calculator]: + """Load an ASE calculator using monty or importlib. + + Parameters + ---------- + calculator_meta : str or dict + If a str, should be a dot-separated import string: + "chgnet.model.dynamics.CHGNetCalculator" + If a dict, should be a monty-style JSONable dict: + {"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"} + + Returns + ------- + ase Calculator + """ + if isinstance(calculator_meta, str): + module, klass = calculator_meta.rsplit(".", 1) + return getattr(import_module(module), klass) return MontyDecoder().process_decoded(calculator_meta) @@ -434,12 +491,12 @@ def revert_default_dtype() -> Generator[None]: torch.set_default_dtype(orig) -def _get_pkg_name(calculator_meta: MLFF | dict[str, Any]) -> str | None: +def _get_pkg_name(calculator_meta: MLFF | str | dict[str, Any]) -> str | None: """Get the package name for a given force field. Parameters ---------- - calculator_meta : MLFF or JSONable dict + calculator_meta : MLFF, import-style str, or JSONable dict The calculator metadata used to load the calculator, or an MLFF enum. @@ -480,13 +537,13 @@ def _get_pkg_name(calculator_meta: MLFF | dict[str, Any]) -> str | None: case _: ff_pkg = None return ff_pkg - if isinstance(calculator_meta, dict): + if isinstance(calculator_meta, str | dict): calc_cls = _load_calc_cls(calculator_meta) return calc_cls.__module__.split(".", 1)[0] assert_never(calculator_meta) -def _get_pkg_version(calculator_meta: MLFF | dict[str, Any]) -> str | None: +def _get_pkg_version(calculator_meta: str | dict[str, Any] | MLFF) -> str | None: """Try to establish the imported version of a forcefield python package.""" if isinstance(pkg_name := _get_pkg_name(calculator_meta), str): try: diff --git a/tests/forcefields/test_jobs.py b/tests/forcefields/test_jobs.py index 7cbe644e60..36dfe2ab5c 100644 --- a/tests/forcefields/test_jobs.py +++ b/tests/forcefields/test_jobs.py @@ -893,3 +893,88 @@ def test_ext_load_static_maker(si_structure: Structure): assert output1.forcefield_name == "mace_mp" assert output1.forcefield_version == get_imported_version("mace_torch") + + +@pytest.mark.skipif(not mlff_is_installed("MACE"), reason="MACE is not installed") +@pytest.mark.parametrize("as_str", [True, False]) +def test_roundtrip(si_structure: Structure, as_str: bool): + + import json + + from ase.calculators.calculator import Calculator + from mace.calculators import MACECalculator + from monty.json import MontyDecoder, MontyEncoder + + import_str = "mace.calculators.mace_mp" + module, klass = import_str.rsplit(".", 1) + + # If using an import string, one must specify this through `calculator_meta` + # If using a monty-style dict, one can use either `calculator_meta` (preferred) + # or `force_field_name` (for backwards compatibility) + valid_kwargs = ["calculator_meta"] + ([] if as_str else ["force_field_name"]) + + for calc_kwarg in valid_kwargs: + job = ForceFieldRelaxMaker( + **{ + calc_kwarg: ( + import_str if as_str else {"@module": module, "@callable": klass} + ) + }, + calculator_kwargs={"model": "medium"}, + ).make(si_structure) + + roundtrip_job = MontyDecoder().decode(json.dumps(job, cls=MontyEncoder)) + + for j in (job, roundtrip_job): + assert j.maker.calculator_meta == import_str + assert j.maker.force_field_name == str(MLFF.Forcefield) + assert j.maker.mlff == MLFF.Forcefield + assert isinstance(j.maker.calculator, MACECalculator) + assert isinstance(j.maker.calculator, Calculator) + + +@pytest.mark.skipif(not mlff_is_installed("MACE"), reason="MACE is not installed") +@pytest.mark.parametrize( + "import_str", + [ + "mace.calculators.foundations_models.mace_mp", + "mace.calculators.mace.MACECalculator", + ], +) +def test_roundtrip_legacy(si_structure: Structure, import_str: str): + # Test backwards compatibility. Legacy docs can contain dict for + # `force_field_name` which will be deserialized by monty into an ase + # `Calculator`. `ForceFieldMixin` needs to handle this and + # narrow types correctly + + import json + + from ase.calculators.calculator import Calculator + from mace.calculators import MACECalculator + from mace.calculators.foundations_models import download_mace_mp_checkpoint + from monty.json import MontyDecoder, MontyEncoder + + module, klass = import_str.rsplit(".", 1) + + job = ForceFieldRelaxMaker( + force_field_name={"@module": module, "@callable": klass}, + calculator_kwargs=( + {"model": "medium"} + if klass == "mace_mp" + else {"model_paths": download_mace_mp_checkpoint("medium")} + ), + ).make(si_structure) + + job_dct = json.loads(MontyEncoder().encode(job)) + job_dct["function"]["@bound"]["force_field_name"] = { + "@module": module, + "@callable": klass, + } + job_dct["function"]["@bound"].pop("calculator_meta") + + deser = MontyDecoder().process_decoded(job_dct) + assert deser.maker.calculator_meta == import_str + assert deser.maker.force_field_name == str(MLFF.Forcefield) + assert deser.maker.mlff == MLFF.Forcefield + assert isinstance(deser.maker.calculator, MACECalculator) + assert isinstance(deser.maker.calculator, Calculator)