From e72c0420d64dc324677d7fdbf64d8093d722b66b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 19 Sep 2024 16:46:02 -0400 Subject: [PATCH 01/11] feat: support atomic virials Signed-off-by: Jinzhe Zeng --- deepmd_gnn/mace.py | 40 ++++++++++------------------------------ tests/test_model.py | 2 ++ 2 files changed, 12 insertions(+), 30 deletions(-) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index 03bee6b..d86cd54 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -17,6 +17,7 @@ ) from deepmd.pt.model.model.transform_output import ( communicate_extended_output, + fit_output_to_model_output, ) from deepmd.pt.utils import env from deepmd.pt.utils.nlist import ( @@ -659,7 +660,7 @@ def forward_lower_common( mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, - do_atomic_virial: bool = False, # noqa: ARG002 + do_atomic_virial: bool = False, comm_dict: Optional[dict[str, torch.Tensor]] = None, ) -> dict[str, torch.Tensor]: """Forward lower common pass of the model. @@ -788,39 +789,18 @@ def forward_lower_common( if atom_energy is None: msg = "atom_energy is None" raise ValueError(msg) - atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc] - energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype) - grad_outputs: list[Optional[torch.Tensor]] = [ - torch.ones_like(energy), - ] - force = torch.autograd.grad( - outputs=[energy], - inputs=[extended_coord_ff], - grad_outputs=grad_outputs, - retain_graph=True, + atom_energy = atom_energy.view(nf, nall, 1) + atom_energy = atom_energy[:, :nloc, :] + model_ret = fit_output_to_model_output( + {"energy": atom_energy.to(extended_coord_.dtype)}, + self.fitting_output_def(), + extended_coord_ff.view(nf, nall, 3), + do_atomic_virial=do_atomic_virial, create_graph=self.training, - )[0] - if force is None: - msg = "force is None" - raise ValueError(msg) - force = -force - atomic_virial = force.unsqueeze(-1).to( - extended_coord_.dtype, - ) @ extended_coord_ff.unsqueeze(-2).to( - extended_coord_.dtype, ) - force = force.view(nf, nall, 3).to(extended_coord_.dtype) - atomic_virial = atomic_virial.view(nf, nall, 1, 9) - virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype) return { - "energy_redu": energy.view(nf, 1), - "energy_derv_r": force.view(nf, nall, 1, 3), - "energy_derv_c_redu": virial.view(nf, 1, 9), - # take the first nloc atoms to match other models - "energy": atom_energy.view(nf, nloc, 1), - # fake atom_virial - "energy_derv_c": atomic_virial.view(nf, nall, 1, 9), + kk: vv.to(extended_coord_.dtype) for kk, vv in model_ret.items() } def serialize(self) -> dict: diff --git a/tests/test_model.py b/tests/test_model.py index b6aa6dc..c285abf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -276,6 +276,7 @@ def test_forward(self) -> None: "box": cell, "aparam": aparam, "fparam": fparam, + "do_atomic_virial": True, } if test_spin: input_dict["spin"] = spin @@ -288,6 +289,7 @@ def test_forward(self) -> None: "aparam": aparam, "fparam": fparam, "mapping": mapping_large, + "do_atomic_virial": True, } if test_spin: input_dict_lower["extended_spin"] = spin_ext From 07c60d9a34532c0aadcf491c593ca262c77f87d0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 19 Sep 2024 17:07:38 -0400 Subject: [PATCH 02/11] fix shape issue Signed-off-by: Jinzhe Zeng --- deepmd_gnn/mace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index d86cd54..e97af6b 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -799,6 +799,7 @@ def forward_lower_common( create_graph=self.training, ) + return { kk: vv.to(extended_coord_.dtype) for kk, vv in model_ret.items() } From 577225efdd0f8f960af2cb2d013904441705ec4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 26 Apr 2026 19:01:35 +0000 Subject: [PATCH 03/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd_gnn/mace.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index e97af6b..ef6db33 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -799,10 +799,7 @@ def forward_lower_common( create_graph=self.training, ) - - return { - kk: vv.to(extended_coord_.dtype) for kk, vv in model_ret.items() - } + return {kk: vv.to(extended_coord_.dtype) for kk, vv in model_ret.items()} def serialize(self) -> dict: """Serialize the model.""" From 699a154772a346e6ddf865d9cdd85b4ddc221f77 Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Sun, 26 Apr 2026 19:40:47 +0000 Subject: [PATCH 04/11] fix(test): scope atomic virial path to MACE --- deepmd_gnn/mace.py | 49 +++++++++++++++++++++++++++++++++++++-------- tests/test_model.py | 6 ++++-- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index ef6db33..b36e4ce 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -789,17 +789,50 @@ def forward_lower_common( if atom_energy is None: msg = "atom_energy is None" raise ValueError(msg) - atom_energy = atom_energy.view(nf, nall, 1) - atom_energy = atom_energy[:, :nloc, :] - model_ret = fit_output_to_model_output( - {"energy": atom_energy.to(extended_coord_.dtype)}, - self.fitting_output_def(), - extended_coord_ff.view(nf, nall, 3), - do_atomic_virial=do_atomic_virial, + atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc] + energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype) + grad_outputs: list[Optional[torch.Tensor]] = [ + torch.ones_like(energy), + ] + force = torch.autograd.grad( + outputs=[energy], + inputs=[extended_coord_ff], + grad_outputs=grad_outputs, + retain_graph=True, create_graph=self.training, + )[0] + if force is None: + msg = "force is None" + raise ValueError(msg) + force = -force + atomic_virial = force.unsqueeze(-1).to( + extended_coord_.dtype, + ) @ extended_coord_ff.unsqueeze(-2).to( + extended_coord_.dtype, ) + force = force.view(nf, nall, 3).to(extended_coord_.dtype) + atomic_virial = atomic_virial.view(nf, nall, 1, 9) + virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype) + if do_atomic_virial: + model_ret = fit_output_to_model_output( + {"energy": atom_energy.view(nf, nloc, 1)}, + self.fitting_output_def(), + extended_coord_ff.view(nf, nall, 3), + do_atomic_virial=True, + create_graph=self.training, + ) + atomic_virial = model_ret["energy_derv_c"].to(extended_coord_.dtype) - return {kk: vv.to(extended_coord_.dtype) for kk, vv in model_ret.items()} + return { + "energy_redu": energy.view(nf, 1), + "energy_derv_r": force.view(nf, nall, 1, 3), + "energy_derv_c_redu": virial.view(nf, 1, 9), + # take the first nloc atoms to match other models + "energy": atom_energy.view(nf, nloc, 1), + # fake atom_virial when do_atomic_virial is False; + # corrected one when do_atomic_virial is True. + "energy_derv_c": atomic_virial.view(nf, nall, 1, 9), + } def serialize(self) -> dict: """Serialize the model.""" diff --git a/tests/test_model.py b/tests/test_model.py index c285abf..f3108f2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -276,8 +276,9 @@ def test_forward(self) -> None: "box": cell, "aparam": aparam, "fparam": fparam, - "do_atomic_virial": True, } + if "atom_virial" in self.output_def: + input_dict["do_atomic_virial"] = True if test_spin: input_dict["spin"] = spin ret.append(module(**input_dict)) @@ -289,8 +290,9 @@ def test_forward(self) -> None: "aparam": aparam, "fparam": fparam, "mapping": mapping_large, - "do_atomic_virial": True, } + if "atom_virial" in self.output_def: + input_dict_lower["do_atomic_virial"] = True if test_spin: input_dict_lower["extended_spin"] = spin_ext From 318c219d2a84751357d159708a901953b41192b2 Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Sun, 26 Apr 2026 20:02:21 +0000 Subject: [PATCH 05/11] fix(model): align atomic virial gradients with runtime paths --- deepmd_gnn/mace.py | 7 ++++--- tests/test_model.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index b36e4ce..c4523a3 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -722,8 +722,9 @@ def forward_lower_common( # cast to float32 default_dtype = self.model.atomic_energies_fn.atomic_energies.dtype - extended_coord_ff = extended_coord_ff.to(default_dtype) - extended_coord_ff.requires_grad_(True) # noqa: FBT003 + extended_coord_grad = extended_coord.to(default_dtype) + extended_coord_grad.requires_grad_(True) # noqa: FBT003 + extended_coord_ff = extended_coord_grad.view(nf * nall, 3) nedge = edge_index.shape[1] if self.num_interactions > 1 and mapping is not None and nloc < nall: # shift the edges for ghost atoms, and map the ghost atoms to real atoms @@ -817,7 +818,7 @@ def forward_lower_common( model_ret = fit_output_to_model_output( {"energy": atom_energy.view(nf, nloc, 1)}, self.fitting_output_def(), - extended_coord_ff.view(nf, nall, 3), + extended_coord_grad, do_atomic_virial=True, create_graph=self.training, ) diff --git a/tests/test_model.py b/tests/test_model.py index f3108f2..00260d8 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -307,6 +307,8 @@ def test_forward(self) -> None: "aparam": aparam, "fparam": fparam, } + if "atom_virial" in self.output_def: + input_dict_lower["do_atomic_virial"] = True if test_spin: input_dict_lower["extended_spin"] = spin_ext From 3c1863c3b63a82373376b9c8e8a5676d83dbbb6b Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Mon, 27 Apr 2026 02:00:54 +0000 Subject: [PATCH 06/11] refactor(model): avoid duplicate autodiff in atomic virial path --- deepmd_gnn/mace.py | 14 +++++--------- tests/test_model.py | 11 ++++++++++- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index c4523a3..87d18dc 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -16,8 +16,8 @@ BaseModel, ) from deepmd.pt.model.model.transform_output import ( + atomic_virial_corr, communicate_extended_output, - fit_output_to_model_output, ) from deepmd.pt.utils import env from deepmd.pt.utils.nlist import ( @@ -813,16 +813,12 @@ def forward_lower_common( ) force = force.view(nf, nall, 3).to(extended_coord_.dtype) atomic_virial = atomic_virial.view(nf, nall, 1, 9) - virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype) if do_atomic_virial: - model_ret = fit_output_to_model_output( - {"energy": atom_energy.view(nf, nloc, 1)}, - self.fitting_output_def(), + atomic_virial = atomic_virial + atomic_virial_corr( extended_coord_grad, - do_atomic_virial=True, - create_graph=self.training, - ) - atomic_virial = model_ret["energy_derv_c"].to(extended_coord_.dtype) + atom_energy.view(nf, nloc, 1), + ).view(nf, nall, 1, 9).to(extended_coord_.dtype) + virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype) return { "energy_redu": energy.view(nf, 1), diff --git a/tests/test_model.py b/tests/test_model.py index 00260d8..d872531 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -976,12 +976,21 @@ def ff_cell(bb): "aparam": aparam, "fparam": fparam, } - rfv = module(**input_dict)["virial"] + if "atom_virial" in self.output_def: + input_dict["do_atomic_virial"] = True + ret = module(**input_dict) + rfv = ret["virial"] np.testing.assert_almost_equal( fdv.reshape(-1, 9), rfv.reshape(-1, 9), decimal=places, ) + if "atom_virial" in self.output_def: + np.testing.assert_almost_equal( + ret["atom_virial"].sum(axis=1).reshape(-1, 9), + rfv.reshape(-1, 9), + decimal=places, + ) else: # not support virial by far pass From 65cb82d64a12d972fb79c10b903ec32f62bb2b4b Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Mon, 27 Apr 2026 16:27:56 +0000 Subject: [PATCH 07/11] test(model): check atomic virial by finite difference --- tests/test_model.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index d872531..2588184 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -959,6 +959,16 @@ def ff_cell(bb): } return module(**input_dict)["energy"] + def ff_cell_atom(bb): + input_dict = { + "coord": stretch_box(coord, cell, bb), + "atype": atype, + "box": bb, + "aparam": aparam, + "fparam": fparam, + } + return module(**input_dict)["atom_energy"] + fdv = ( -( finite_difference(ff_cell, cell, delta=delta) @@ -986,6 +996,20 @@ def ff_cell(bb): decimal=places, ) if "atom_virial" in self.output_def: + fdav = ( + -( + finite_difference(ff_cell_atom, cell, delta=delta) + .reshape(-1, 3, 3) + .transpose(0, 2, 1) + @ cell.reshape(-1, 3, 3) + ) + .reshape(-1, 9) + ) + np.testing.assert_almost_equal( + fdav, + ret["atom_virial"].reshape(-1, 9), + decimal=places, + ) np.testing.assert_almost_equal( ret["atom_virial"].sum(axis=1).reshape(-1, 9), rfv.reshape(-1, 9), From 0feefd05734e7e1a284f23ba2120eb3ca8bd3fc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Apr 2026 16:30:42 +0000 Subject: [PATCH 08/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 2588184..cb4e876 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -996,15 +996,12 @@ def ff_cell_atom(bb): decimal=places, ) if "atom_virial" in self.output_def: - fdav = ( - -( - finite_difference(ff_cell_atom, cell, delta=delta) - .reshape(-1, 3, 3) - .transpose(0, 2, 1) - @ cell.reshape(-1, 3, 3) - ) - .reshape(-1, 9) - ) + fdav = -( + finite_difference(ff_cell_atom, cell, delta=delta) + .reshape(-1, 3, 3) + .transpose(0, 2, 1) + @ cell.reshape(-1, 3, 3) + ).reshape(-1, 9) np.testing.assert_almost_equal( fdav, ret["atom_virial"].reshape(-1, 9), From 1b161b5f7f43a95b79bfd3bac5513b9f72fe2fab Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Mon, 27 Apr 2026 17:13:55 +0000 Subject: [PATCH 09/11] fix(model): derive atomic virial via output transform --- deepmd_gnn/mace.py | 45 +++++++++------------------------------------ 1 file changed, 9 insertions(+), 36 deletions(-) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index 87d18dc..ea3f99f 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -16,8 +16,8 @@ BaseModel, ) from deepmd.pt.model.model.transform_output import ( - atomic_virial_corr, communicate_extended_output, + fit_output_to_model_output, ) from deepmd.pt.utils import env from deepmd.pt.utils.nlist import ( @@ -791,44 +791,17 @@ def forward_lower_common( msg = "atom_energy is None" raise ValueError(msg) atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc] - energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype) - grad_outputs: list[Optional[torch.Tensor]] = [ - torch.ones_like(energy), - ] - force = torch.autograd.grad( - outputs=[energy], - inputs=[extended_coord_ff], - grad_outputs=grad_outputs, - retain_graph=True, + + model_ret = fit_output_to_model_output( + {"energy": atom_energy.view(nf, nloc, 1)}, + self.fitting_output_def(), + extended_coord_grad, + do_atomic_virial=do_atomic_virial, create_graph=self.training, - )[0] - if force is None: - msg = "force is None" - raise ValueError(msg) - force = -force - atomic_virial = force.unsqueeze(-1).to( - extended_coord_.dtype, - ) @ extended_coord_ff.unsqueeze(-2).to( - extended_coord_.dtype, ) - force = force.view(nf, nall, 3).to(extended_coord_.dtype) - atomic_virial = atomic_virial.view(nf, nall, 1, 9) - if do_atomic_virial: - atomic_virial = atomic_virial + atomic_virial_corr( - extended_coord_grad, - atom_energy.view(nf, nloc, 1), - ).view(nf, nall, 1, 9).to(extended_coord_.dtype) - virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype) - return { - "energy_redu": energy.view(nf, 1), - "energy_derv_r": force.view(nf, nall, 1, 3), - "energy_derv_c_redu": virial.view(nf, 1, 9), - # take the first nloc atoms to match other models - "energy": atom_energy.view(nf, nloc, 1), - # fake atom_virial when do_atomic_virial is False; - # corrected one when do_atomic_virial is True. - "energy_derv_c": atomic_virial.view(nf, nall, 1, 9), + kk: vv.to(extended_coord_.dtype) if vv is not None else vv + for kk, vv in model_ret.items() } def serialize(self) -> dict: From 47c938b6d24246b709071091bf0188b64b664d09 Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:11:36 +0000 Subject: [PATCH 10/11] fix(mace): derive virial from box deformation Authored by OpenClaw (model: gpt-5.4) --- deepmd_gnn/mace.py | 222 +++++++++++++++++++++++++++++++---------- tests/test_mace_off.py | 114 ++++++++++++--------- 2 files changed, 235 insertions(+), 101 deletions(-) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index ea3f99f..7f83f8b 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -557,6 +557,7 @@ def forward( extended_atype, nlist, mapping=mapping, + box=box, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, @@ -572,9 +573,11 @@ def forward( model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) - model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + model_predict["virial"] = model_ret_lower["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["atom_virial"] = model_ret_lower["energy_derv_c"][ + :, :nloc + ].squeeze(-3) return model_predict @torch.jit.export @@ -637,6 +640,7 @@ def forward_lower( extended_atype, nlist, mapping, + None, fparam, aparam, do_atomic_virial, @@ -658,6 +662,7 @@ def forward_lower_common( extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, + box: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, @@ -700,7 +705,6 @@ def forward_lower_common( extended_atype = extended_atype.to(torch.int64) nall = extended_coord.shape[1] - # fake as one frame extended_coord_ff = extended_coord.view(nf * nall, 3) extended_atype_ff = extended_atype.view(nf * nall) edge_index = torch.ops.deepmd_gnn.edge_index( @@ -709,25 +713,26 @@ def forward_lower_common( torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"), ) edge_index = edge_index.T - # to one hot indices = extended_atype_ff.unsqueeze(-1) oh = torch.zeros( (nf * nall, self.ntypes), device=extended_atype.device, dtype=torch.float64, ) - # scatter_ is the in-place version of scatter oh.scatter_(dim=-1, index=indices, value=1) one_hot = oh.view((nf * nall, self.ntypes)) - # cast to float32 default_dtype = self.model.atomic_energies_fn.atomic_energies.dtype extended_coord_grad = extended_coord.to(default_dtype) extended_coord_grad.requires_grad_(True) # noqa: FBT003 extended_coord_ff = extended_coord_grad.view(nf * nall, 3) nedge = edge_index.shape[1] + shifts = torch.zeros( + (nedge, 3), + dtype=default_dtype, + device=extended_coord_ff.device, + ) if self.num_interactions > 1 and mapping is not None and nloc < nall: - # shift the edges for ghost atoms, and map the ghost atoms to real atoms mapping_ff = mapping.view(nf * nall) + torch.arange( 0, nf * nall, @@ -738,47 +743,79 @@ def forward_lower_common( shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff] shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]] edge_index = mapping_ff[edge_index] - else: - shifts = torch.zeros( - (nedge, 3), - dtype=torch.float64, - device=extended_coord_.device, - ) shifts = shifts.to(default_dtype) one_hot = one_hot.to(default_dtype) - # it seems None is not allowed for data - box = ( - torch.eye( - 3, + + batch = torch.arange( + nf, + dtype=torch.int64, + device=extended_coord_ff.device, + ).unsqueeze(-1).expand(nf, nall).reshape(-1) + ptr = torch.arange( + 0, + (nf + 1) * nall, + nall, + dtype=torch.int64, + device=extended_coord_ff.device, + ) + weight = torch.ones( + [nf], + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + + compute_displacement = box is not None + input_dict: dict[str, torch.Tensor] = { + "edge_index": edge_index, + "batch": batch, + "node_attrs": one_hot.to(default_dtype), + "ptr": ptr, + "weight": weight, + } + displacement = torch.jit.annotate(Optional[torch.Tensor], None) + if box is not None: + box_tensor = box.view(nf, 3, 3).to(default_dtype).to(extended_coord_ff.device) + edge_batch = torch.div(edge_index[0], nall, rounding_mode="floor") + inv_box = torch.linalg.inv(box_tensor) + unit_shifts = torch.einsum("ec,ecb->eb", shifts, inv_box[edge_batch]) + displacement = torch.zeros( + (nf, 3, 3), dtype=extended_coord_ff.dtype, device=extended_coord_ff.device, ) - * 1000.0 - ) - - ret = self.model.forward( - { - "positions": extended_coord_ff, - "shifts": shifts, - "cell": box, - "edge_index": edge_index, - "batch": torch.zeros( - [nf * nall], - dtype=torch.int64, - device=extended_coord_ff.device, - ), - "node_attrs": one_hot, - "ptr": torch.tensor( - [0, nf * nall], - dtype=torch.int64, - device=extended_coord_ff.device, - ), - "weight": torch.tensor( - [1.0], + displacement.requires_grad_(True) + symmetric_displacement = 0.5 * ( + displacement + displacement.transpose(-1, -2) + ) + positions = extended_coord_ff + torch.einsum( + "be,bec->bc", + extended_coord_ff, + symmetric_displacement[batch], + ) + cell = box_tensor + torch.matmul(box_tensor, symmetric_displacement) + input_dict["positions"] = positions + input_dict["cell"] = cell + input_dict["shifts"] = torch.einsum( + "be,bec->bc", + torch.round(unit_shifts).to(default_dtype), + cell[edge_batch], + ) + else: + input_dict["positions"] = extended_coord_ff + input_dict["cell"] = ( + torch.eye( + 3, dtype=extended_coord_ff.dtype, device=extended_coord_ff.device, - ), - }, + ) + .unsqueeze(0) + .expand(nf, 3, 3) + * 1000.0 + ) + input_dict["shifts"] = shifts + + ret = self.model.forward( + input_dict, compute_force=False, compute_virials=False, compute_stress=False, @@ -786,22 +823,103 @@ def forward_lower_common( training=self.training, ) - atom_energy = ret["node_energy"] - if atom_energy is None: + atom_energy_all = ret["node_energy"] + if atom_energy_all is None: msg = "atom_energy is None" raise ValueError(msg) - atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc] + atom_energy_all = atom_energy_all.view(nf, nall) + atom_energy = atom_energy_all[:, :nloc] + energy = torch.sum(atom_energy, dim=1) + grad_outputs = torch.jit.annotate( + list[Optional[torch.Tensor]], + [torch.ones_like(energy)], + ) + retain_graph = self.training or do_atomic_virial - model_ret = fit_output_to_model_output( - {"energy": atom_energy.view(nf, nloc, 1)}, - self.fitting_output_def(), - extended_coord_grad, - do_atomic_virial=do_atomic_virial, - create_graph=self.training, + atomic_virial_fallback = torch.zeros( + (nf, nall, 3, 3), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + if compute_displacement and displacement is not None: + grads = torch.autograd.grad( + outputs=[energy], + inputs=[extended_coord_ff, displacement], + grad_outputs=grad_outputs, + retain_graph=retain_graph, + create_graph=self.training, + allow_unused=True, + ) + force_ff = grads[0] + virial_tensor = grads[1] + if force_ff is None: + force_ff = torch.zeros_like(extended_coord_ff) + if virial_tensor is None: + virial_tensor = torch.zeros( + (nf, 3, 3), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + force = -force_ff.view(nf, nall, 3) + virial = -virial_tensor.view(nf, 1, 9) + else: + force_ff = torch.autograd.grad( + outputs=[energy], + inputs=[extended_coord_ff], + grad_outputs=grad_outputs, + retain_graph=retain_graph, + create_graph=self.training, + allow_unused=True, + )[0] + if force_ff is None: + force_ff = torch.zeros_like(extended_coord_ff) + force = -force_ff.view(nf, nall, 3) + atomic_virial_fallback = force.unsqueeze(-1) @ extended_coord_ff.view( + nf, + nall, + 3, + ).unsqueeze(-2) + virial = torch.sum(atomic_virial_fallback, dim=1).view(nf, 1, 9) + + atomic_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, ) + if do_atomic_virial: + if compute_displacement and displacement is not None: + atomic_virial_local = torch.zeros( + (nf, nloc, 9), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + for ii in range(nloc): + atom_energy_ii = atom_energy[:, ii] + atom_grad_outputs = torch.jit.annotate( + list[Optional[torch.Tensor]], + [torch.ones_like(atom_energy_ii)], + ) + atom_virial_ii = torch.autograd.grad( + outputs=[atom_energy_ii], + inputs=[displacement], + grad_outputs=atom_grad_outputs, + retain_graph=True, + create_graph=self.training, + allow_unused=True, + )[0] + if atom_virial_ii is None: + atom_virial_ii = torch.zeros_like(displacement) + atomic_virial_local[:, ii, :] = (-atom_virial_ii).view(nf, 9) + atomic_virial[:, :nloc, 0, :] = atomic_virial_local + else: + atomic_virial[:, :, 0, :] = atomic_virial_fallback.view(nf, nall, 9) + return { - kk: vv.to(extended_coord_.dtype) if vv is not None else vv - for kk, vv in model_ret.items() + "energy_redu": energy.view(nf, 1).to(extended_coord_.dtype), + "energy_derv_r": force.view(nf, nall, 1, 3).to(extended_coord_.dtype), + "energy_derv_c_redu": virial.to(extended_coord_.dtype), + "energy": atom_energy.view(nf, nloc, 1).to(extended_coord_.dtype), + "energy_derv_c": atomic_virial.to(extended_coord_.dtype), } def serialize(self) -> dict: diff --git a/tests/test_mace_off.py b/tests/test_mace_off.py index bd887a9..ac41969 100644 --- a/tests/test_mace_off.py +++ b/tests/test_mace_off.py @@ -61,6 +61,11 @@ def _native_mace_reference_outputs( torch.tensor([], dtype=torch.int64, device="cpu"), ).T + shifts = torch.zeros( + (edge_index.shape[1], 3), + dtype=default_dtype, + device=extended_coord_ff.device, + ) if mapping is not None and nloc < nall: mapping_ff = mapping.view(nf * nall) + torch.arange( 0, @@ -72,12 +77,6 @@ def _native_mace_reference_outputs( shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff] shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]] edge_index = mapping_ff[edge_index] - else: - shifts = torch.zeros( - (edge_index.shape[1], 3), - dtype=default_dtype, - device=extended_coord_ff.device, - ) one_hot = torch.zeros( (nf * nall, ntypes), @@ -90,33 +89,57 @@ def _native_mace_reference_outputs( value=1, ) + batch = torch.arange( + nf, + dtype=torch.int64, + device=extended_coord_ff.device, + ).unsqueeze(-1).expand(nf, nall).reshape(-1) + ptr = torch.arange( + 0, + (nf + 1) * nall, + nall, + dtype=torch.int64, + device=extended_coord_ff.device, + ) + weight = torch.ones( + [nf], + dtype=default_dtype, + device=extended_coord_ff.device, + ) + cell = box.view(nf, 3, 3).to(default_dtype).to(extended_coord_ff.device) + edge_batch = torch.div(edge_index[0], nall, rounding_mode="floor") + inv_box = torch.linalg.inv(cell) + unit_shifts = torch.einsum("ec,ecb->eb", shifts, inv_box[edge_batch]) + displacement = torch.zeros( + (nf, 3, 3), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + displacement.requires_grad_(True) + symmetric_displacement = 0.5 * ( + displacement + displacement.transpose(-1, -2) + ) + positions = extended_coord_ff + torch.einsum( + "be,bec->bc", + extended_coord_ff, + symmetric_displacement[batch], + ) + cell_def = cell + torch.matmul(cell, symmetric_displacement) + ret = mace_model.forward( { - "positions": extended_coord_ff, - "shifts": shifts, - "cell": torch.eye( - 3, - dtype=default_dtype, - device=extended_coord_ff.device, - ) - * 1000.0, - "edge_index": edge_index, - "batch": torch.zeros( - [nf * nall], - dtype=torch.int64, - device=extended_coord_ff.device, + "positions": positions, + "cell": cell_def, + "shifts": torch.einsum( + "be,bec->bc", + torch.round(unit_shifts).to(default_dtype), + cell_def[edge_batch], ), + "edge_index": edge_index, + "batch": batch, "node_attrs": one_hot, - "ptr": torch.tensor( - [0, nf * nall], - dtype=torch.int64, - device=extended_coord_ff.device, - ), - "weight": torch.tensor( - [1.0], - dtype=default_dtype, - device=extended_coord_ff.device, - ), + "ptr": ptr, + "weight": weight, }, compute_force=False, compute_virials=False, @@ -130,24 +153,26 @@ def _native_mace_reference_outputs( msg = "Native MACE model returned no node_energy" raise ValueError(msg) atom_energy = atom_energy.view(nf, nall)[:, :nloc] - energy = atom_energy.sum(dim=1).view(nf, 1) + energy = atom_energy.sum(dim=1).view(nf) - force = torch.autograd.grad( + grads = torch.autograd.grad( outputs=[energy], - inputs=[extended_coord_ff], + inputs=[extended_coord_ff, displacement], grad_outputs=[torch.ones_like(energy)], - retain_graph=True, + retain_graph=False, create_graph=False, - )[0] + allow_unused=True, + ) + force = grads[0] + virial = grads[1] if force is None: msg = "Native MACE model returned no force gradient" raise ValueError(msg) force = -force.view(nf, nall, 3) - atomic_virial = force.to(coord.dtype).unsqueeze(-1) @ extended_coord_ff.view( - nf, - nall, - 3, - ).to(coord.dtype).unsqueeze(-2) + if virial is None: + virial_out = torch.zeros((nf, 9), dtype=coord.dtype, device=force.device) + else: + virial_out = (-virial).view(nf, 9).to(coord.dtype) if mapping is not None: force_local = torch.scatter_reduce( @@ -157,21 +182,12 @@ def _native_mace_reference_outputs( force.to(coord.dtype), reduce="sum", ) - atomic_virial_local = torch.scatter_reduce( - torch.zeros((nf, nloc, 9), dtype=coord.dtype, device=force.device), - 1, - mapping.unsqueeze(-1).expand(nf, nall, 9), - atomic_virial.view(nf, nall, 9), - reduce="sum", - ) force_out = force_local - virial_out = atomic_virial_local.sum(dim=1).view(nf, 9) else: force_out = force[:, :nloc, :].to(coord.dtype) - virial_out = atomic_virial.sum(dim=1).view(nf, 9) return { - "energy": energy.to(coord.dtype), + "energy": energy.view(nf, 1).to(coord.dtype), "force": force_out, "virial": virial_out, } From eb19073dde50fa8036d0b411804a011799b20707 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:12:40 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd_gnn/mace.py | 23 +++++++++++++++-------- tests/test_mace_off.py | 19 +++++++++++-------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index 7f83f8b..5a98e56 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -17,7 +17,6 @@ ) from deepmd.pt.model.model.transform_output import ( communicate_extended_output, - fit_output_to_model_output, ) from deepmd.pt.utils import env from deepmd.pt.utils.nlist import ( @@ -576,7 +575,8 @@ def forward( model_predict["virial"] = model_ret_lower["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret_lower["energy_derv_c"][ - :, :nloc + :, + :nloc, ].squeeze(-3) return model_predict @@ -746,11 +746,16 @@ def forward_lower_common( shifts = shifts.to(default_dtype) one_hot = one_hot.to(default_dtype) - batch = torch.arange( - nf, - dtype=torch.int64, - device=extended_coord_ff.device, - ).unsqueeze(-1).expand(nf, nall).reshape(-1) + batch = ( + torch.arange( + nf, + dtype=torch.int64, + device=extended_coord_ff.device, + ) + .unsqueeze(-1) + .expand(nf, nall) + .reshape(-1) + ) ptr = torch.arange( 0, (nf + 1) * nall, @@ -774,7 +779,9 @@ def forward_lower_common( } displacement = torch.jit.annotate(Optional[torch.Tensor], None) if box is not None: - box_tensor = box.view(nf, 3, 3).to(default_dtype).to(extended_coord_ff.device) + box_tensor = ( + box.view(nf, 3, 3).to(default_dtype).to(extended_coord_ff.device) + ) edge_batch = torch.div(edge_index[0], nall, rounding_mode="floor") inv_box = torch.linalg.inv(box_tensor) unit_shifts = torch.einsum("ec,ecb->eb", shifts, inv_box[edge_batch]) diff --git a/tests/test_mace_off.py b/tests/test_mace_off.py index ac41969..a6e7101 100644 --- a/tests/test_mace_off.py +++ b/tests/test_mace_off.py @@ -89,11 +89,16 @@ def _native_mace_reference_outputs( value=1, ) - batch = torch.arange( - nf, - dtype=torch.int64, - device=extended_coord_ff.device, - ).unsqueeze(-1).expand(nf, nall).reshape(-1) + batch = ( + torch.arange( + nf, + dtype=torch.int64, + device=extended_coord_ff.device, + ) + .unsqueeze(-1) + .expand(nf, nall) + .reshape(-1) + ) ptr = torch.arange( 0, (nf + 1) * nall, @@ -116,9 +121,7 @@ def _native_mace_reference_outputs( device=extended_coord_ff.device, ) displacement.requires_grad_(True) - symmetric_displacement = 0.5 * ( - displacement + displacement.transpose(-1, -2) - ) + symmetric_displacement = 0.5 * (displacement + displacement.transpose(-1, -2)) positions = extended_coord_ff + torch.einsum( "be,bec->bc", extended_coord_ff,