fix: derive atomic virial via DeePMD output transform#108
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
10b3c9e to
07c60d9
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #108 +/- ##
==========================================
+ Coverage 81.46% 81.98% +0.52%
==========================================
Files 9 9
Lines 863 866 +3
==========================================
+ Hits 703 710 +7
+ Misses 160 156 -4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
Authored by OpenClaw (model: gpt-5.4)
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Pull request overview
This PR updates the MACE wrapper’s virial / atomic-virial derivation to follow a displacement/strain-gradient path (instead of constructing atomic virials from (F \otimes r)), and adjusts tests and the MACE-OFF native reference to match the new semantics—aiming to improve correctness around ghost/extended atoms.
Changes:
- Enable
do_atomic_virialin model forward tests whenatom_virialis part of the expected outputs. - Update MACE-OFF native reference evaluation to compute virial via a displacement-gradient path.
- Refactor
deepmd_gnn/mace.pyto compute virial (and optionally atomic virial) via autograd against a symmetric displacement whenboxis provided.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
tests/test_model.py |
Turns on atomic-virial computation in test inputs and adds finite-difference checks for atom_virial. |
tests/test_mace_off.py |
Updates the native MACE reference to compute virial from displacement gradients (and adjusts batching/cell handling). |
deepmd_gnn/mace.py |
Introduces optional box handling into forward_lower_common() and derives virial/atomic-virial using displacement/strain gradients. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| compute_displacement = box is not None | ||
| input_dict: dict[str, torch.Tensor] = { | ||
| "edge_index": edge_index, | ||
| "batch": batch, | ||
| "node_attrs": one_hot.to(default_dtype), | ||
| "ptr": ptr, | ||
| "weight": weight, | ||
| } | ||
| displacement = torch.jit.annotate(Optional[torch.Tensor], None) | ||
| if box is not None: | ||
| box_tensor = ( | ||
| box.view(nf, 3, 3).to(default_dtype).to(extended_coord_ff.device) | ||
| ) | ||
| edge_batch = torch.div(edge_index[0], nall, rounding_mode="floor") | ||
| inv_box = torch.linalg.inv(box_tensor) | ||
| unit_shifts = torch.einsum("ec,ecb->eb", shifts, inv_box[edge_batch]) | ||
| displacement = torch.zeros( | ||
| (nf, 3, 3), | ||
| dtype=extended_coord_ff.dtype, | ||
| device=extended_coord_ff.device, | ||
| ) | ||
| displacement.requires_grad_(True) | ||
| symmetric_displacement = 0.5 * ( | ||
| displacement + displacement.transpose(-1, -2) | ||
| ) | ||
| positions = extended_coord_ff + torch.einsum( | ||
| "be,bec->bc", | ||
| extended_coord_ff, | ||
| symmetric_displacement[batch], | ||
| ) | ||
| cell = box_tensor + torch.matmul(box_tensor, symmetric_displacement) | ||
| input_dict["positions"] = positions | ||
| input_dict["cell"] = cell | ||
| input_dict["shifts"] = torch.einsum( | ||
| "be,bec->bc", | ||
| torch.round(unit_shifts).to(default_dtype), | ||
| cell[edge_batch], | ||
| ) | ||
| else: | ||
| input_dict["positions"] = extended_coord_ff | ||
| input_dict["cell"] = ( | ||
| torch.eye( | ||
| 3, | ||
| dtype=extended_coord_ff.dtype, | ||
| device=extended_coord_ff.device, | ||
| ), | ||
| }, | ||
| ) | ||
| .unsqueeze(0) | ||
| .expand(nf, 3, 3) | ||
| * 1000.0 | ||
| ) | ||
| input_dict["shifts"] = shifts | ||
|
|
||
| ret = self.model.forward( | ||
| input_dict, | ||
| compute_force=False, | ||
| compute_virials=False, | ||
| compute_stress=False, | ||
| compute_displacement=False, | ||
| training=self.training, | ||
| ) | ||
|
|
||
| atom_energy = ret["node_energy"] | ||
| if atom_energy is None: | ||
| atom_energy_all = ret["node_energy"] | ||
| if atom_energy_all is None: | ||
| msg = "atom_energy is None" | ||
| raise ValueError(msg) | ||
| atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc] | ||
| energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype) | ||
| grad_outputs: list[Optional[torch.Tensor]] = [ | ||
| torch.ones_like(energy), | ||
| ] | ||
| force = torch.autograd.grad( | ||
| outputs=[energy], | ||
| inputs=[extended_coord_ff], | ||
| grad_outputs=grad_outputs, | ||
| retain_graph=True, | ||
| create_graph=self.training, | ||
| )[0] | ||
| if force is None: | ||
| msg = "force is None" | ||
| raise ValueError(msg) | ||
| force = -force | ||
| atomic_virial = force.unsqueeze(-1).to( | ||
| extended_coord_.dtype, | ||
| ) @ extended_coord_ff.unsqueeze(-2).to( | ||
| extended_coord_.dtype, | ||
| atom_energy_all = atom_energy_all.view(nf, nall) | ||
| atom_energy = atom_energy_all[:, :nloc] | ||
| energy = torch.sum(atom_energy, dim=1) | ||
| grad_outputs = torch.jit.annotate( | ||
| list[Optional[torch.Tensor]], | ||
| [torch.ones_like(energy)], | ||
| ) | ||
| retain_graph = self.training or do_atomic_virial | ||
|
|
||
| atomic_virial_fallback = torch.zeros( | ||
| (nf, nall, 3, 3), | ||
| dtype=extended_coord_ff.dtype, | ||
| device=extended_coord_ff.device, | ||
| ) | ||
| if compute_displacement and displacement is not None: | ||
| grads = torch.autograd.grad( | ||
| outputs=[energy], | ||
| inputs=[extended_coord_ff, displacement], | ||
| grad_outputs=grad_outputs, | ||
| retain_graph=retain_graph, | ||
| create_graph=self.training, | ||
| allow_unused=True, | ||
| ) | ||
| force_ff = grads[0] | ||
| virial_tensor = grads[1] | ||
| if force_ff is None: | ||
| force_ff = torch.zeros_like(extended_coord_ff) | ||
| if virial_tensor is None: | ||
| virial_tensor = torch.zeros( | ||
| (nf, 3, 3), | ||
| dtype=extended_coord_ff.dtype, | ||
| device=extended_coord_ff.device, | ||
| ) | ||
| force = -force_ff.view(nf, nall, 3) | ||
| virial = -virial_tensor.view(nf, 1, 9) | ||
| else: |
| @@ -289,6 +291,8 @@ def test_forward(self) -> None: | |||
| "fparam": fparam, | |||
| "mapping": mapping_large, | |||
| } | |||
| if "atom_virial" in self.output_def: | |||
| input_dict_lower["do_atomic_virial"] = True | |||
| if test_spin: | |||
| model_predict = {} | ||
| model_predict["atom_energy"] = model_ret["energy"] | ||
| model_predict["energy"] = model_ret["energy_redu"] | ||
| model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) | ||
| model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) | ||
| model_predict["virial"] = model_ret_lower["energy_derv_c_redu"].squeeze(-2) | ||
| if do_atomic_virial: | ||
| model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) | ||
| model_predict["atom_virial"] = model_ret_lower["energy_derv_c"][ | ||
| :, | ||
| :nloc, | ||
| ].squeeze(-3) |
| if compute_displacement and displacement is not None: | ||
| grads = torch.autograd.grad( | ||
| outputs=[energy], | ||
| inputs=[extended_coord_ff, displacement], | ||
| grad_outputs=grad_outputs, | ||
| retain_graph=retain_graph, | ||
| create_graph=self.training, | ||
| allow_unused=True, | ||
| ) | ||
| force_ff = grads[0] | ||
| virial_tensor = grads[1] | ||
| if force_ff is None: | ||
| force_ff = torch.zeros_like(extended_coord_ff) | ||
| if virial_tensor is None: | ||
| virial_tensor = torch.zeros( | ||
| (nf, 3, 3), | ||
| dtype=extended_coord_ff.dtype, | ||
| device=extended_coord_ff.device, | ||
| ) | ||
| force = -force_ff.view(nf, nall, 3) | ||
| virial = -virial_tensor.view(nf, 1, 9) | ||
| else: | ||
| force_ff = torch.autograd.grad( | ||
| outputs=[energy], | ||
| inputs=[extended_coord_ff], | ||
| grad_outputs=grad_outputs, | ||
| retain_graph=retain_graph, | ||
| create_graph=self.training, | ||
| allow_unused=True, | ||
| )[0] | ||
| if force_ff is None: | ||
| force_ff = torch.zeros_like(extended_coord_ff) | ||
| force = -force_ff.view(nf, nall, 3) |
| atomic_virial_local = torch.zeros( | ||
| (nf, nloc, 9), | ||
| dtype=extended_coord_ff.dtype, | ||
| device=extended_coord_ff.device, | ||
| ) | ||
| for ii in range(nloc): | ||
| atom_energy_ii = atom_energy[:, ii] | ||
| atom_grad_outputs = torch.jit.annotate( | ||
| list[Optional[torch.Tensor]], | ||
| [torch.ones_like(atom_energy_ii)], | ||
| ) | ||
| atom_virial_ii = torch.autograd.grad( | ||
| outputs=[atom_energy_ii], | ||
| inputs=[displacement], | ||
| grad_outputs=atom_grad_outputs, | ||
| retain_graph=True, | ||
| create_graph=self.training, | ||
| allow_unused=True, | ||
| )[0] | ||
| if atom_virial_ii is None: | ||
| atom_virial_ii = torch.zeros_like(displacement) | ||
| atomic_virial_local[:, ii, :] = (-atom_virial_ii).view(nf, 9) |
Summary
MaceModel.forward_lower_common()fit_output_to_model_output()to derive force, virial, and atom virial from atomic energiesWhy
The current implementation in PR #14 builds
atom_virialmanually fromF ⊗ rand then applies an extra correction. That path is easy to get subtly wrong for ghost atoms / extended atoms.This patch removes the custom atomic-virial path and lets DeePMD-kit handle:
Validation
python3 -m py_compile deepmd_mace/mace.pypytestvalidation in a local nox environment, but test collection is currently blocked by a dynamic-library loading issue in the environment (libcudart.so.12during extension loading), so I could not complete a trustworthy local numerical pass hereAuthored by OpenClaw (model: gpt-5.4)