diff --git a/deepmd_gnn/mace.py b/deepmd_gnn/mace.py index 03bee6b..5a98e56 100644 --- a/deepmd_gnn/mace.py +++ b/deepmd_gnn/mace.py @@ -556,6 +556,7 @@ def forward( extended_atype, nlist, mapping=mapping, + box=box, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, @@ -571,9 +572,12 @@ def forward( model_predict["atom_energy"] = model_ret["energy"] model_predict["energy"] = model_ret["energy_redu"] model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) - model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + model_predict["virial"] = model_ret_lower["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["atom_virial"] = model_ret_lower["energy_derv_c"][ + :, + :nloc, + ].squeeze(-3) return model_predict @torch.jit.export @@ -636,6 +640,7 @@ def forward_lower( extended_atype, nlist, mapping, + None, fparam, aparam, do_atomic_virial, @@ -657,9 +662,10 @@ def forward_lower_common( extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, + box: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, - do_atomic_virial: bool = False, # noqa: ARG002 + do_atomic_virial: bool = False, comm_dict: Optional[dict[str, torch.Tensor]] = None, ) -> dict[str, torch.Tensor]: """Forward lower common pass of the model. @@ -699,7 +705,6 @@ def forward_lower_common( extended_atype = extended_atype.to(torch.int64) nall = extended_coord.shape[1] - # fake as one frame extended_coord_ff = extended_coord.view(nf * nall, 3) extended_atype_ff = extended_atype.view(nf * nall) edge_index = torch.ops.deepmd_gnn.edge_index( @@ -708,24 +713,26 @@ def forward_lower_common( torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"), ) edge_index = edge_index.T - # to one hot indices = extended_atype_ff.unsqueeze(-1) oh = torch.zeros( (nf * nall, self.ntypes), device=extended_atype.device, dtype=torch.float64, ) - # scatter_ is the in-place version of scatter oh.scatter_(dim=-1, index=indices, value=1) one_hot = oh.view((nf * nall, self.ntypes)) - # cast to float32 default_dtype = self.model.atomic_energies_fn.atomic_energies.dtype - extended_coord_ff = extended_coord_ff.to(default_dtype) - extended_coord_ff.requires_grad_(True) # noqa: FBT003 + extended_coord_grad = extended_coord.to(default_dtype) + extended_coord_grad.requires_grad_(True) # noqa: FBT003 + extended_coord_ff = extended_coord_grad.view(nf * nall, 3) nedge = edge_index.shape[1] + shifts = torch.zeros( + (nedge, 3), + dtype=default_dtype, + device=extended_coord_ff.device, + ) if self.num_interactions > 1 and mapping is not None and nloc < nall: - # shift the edges for ghost atoms, and map the ghost atoms to real atoms mapping_ff = mapping.view(nf * nall) + torch.arange( 0, nf * nall, @@ -736,47 +743,86 @@ def forward_lower_common( shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff] shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]] edge_index = mapping_ff[edge_index] - else: - shifts = torch.zeros( - (nedge, 3), - dtype=torch.float64, - device=extended_coord_.device, - ) shifts = shifts.to(default_dtype) one_hot = one_hot.to(default_dtype) - # it seems None is not allowed for data - box = ( - torch.eye( - 3, - dtype=extended_coord_ff.dtype, + + batch = ( + torch.arange( + nf, + dtype=torch.int64, device=extended_coord_ff.device, ) - * 1000.0 + .unsqueeze(-1) + .expand(nf, nall) + .reshape(-1) + ) + ptr = torch.arange( + 0, + (nf + 1) * nall, + nall, + dtype=torch.int64, + device=extended_coord_ff.device, + ) + weight = torch.ones( + [nf], + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, ) - ret = self.model.forward( - { - "positions": extended_coord_ff, - "shifts": shifts, - "cell": box, - "edge_index": edge_index, - "batch": torch.zeros( - [nf * nall], - dtype=torch.int64, - device=extended_coord_ff.device, - ), - "node_attrs": one_hot, - "ptr": torch.tensor( - [0, nf * nall], - dtype=torch.int64, - device=extended_coord_ff.device, - ), - "weight": torch.tensor( - [1.0], + compute_displacement = box is not None + input_dict: dict[str, torch.Tensor] = { + "edge_index": edge_index, + "batch": batch, + "node_attrs": one_hot.to(default_dtype), + "ptr": ptr, + "weight": weight, + } + displacement = torch.jit.annotate(Optional[torch.Tensor], None) + if box is not None: + box_tensor = ( + box.view(nf, 3, 3).to(default_dtype).to(extended_coord_ff.device) + ) + edge_batch = torch.div(edge_index[0], nall, rounding_mode="floor") + inv_box = torch.linalg.inv(box_tensor) + unit_shifts = torch.einsum("ec,ecb->eb", shifts, inv_box[edge_batch]) + displacement = torch.zeros( + (nf, 3, 3), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + displacement.requires_grad_(True) + symmetric_displacement = 0.5 * ( + displacement + displacement.transpose(-1, -2) + ) + positions = extended_coord_ff + torch.einsum( + "be,bec->bc", + extended_coord_ff, + symmetric_displacement[batch], + ) + cell = box_tensor + torch.matmul(box_tensor, symmetric_displacement) + input_dict["positions"] = positions + input_dict["cell"] = cell + input_dict["shifts"] = torch.einsum( + "be,bec->bc", + torch.round(unit_shifts).to(default_dtype), + cell[edge_batch], + ) + else: + input_dict["positions"] = extended_coord_ff + input_dict["cell"] = ( + torch.eye( + 3, dtype=extended_coord_ff.dtype, device=extended_coord_ff.device, - ), - }, + ) + .unsqueeze(0) + .expand(nf, 3, 3) + * 1000.0 + ) + input_dict["shifts"] = shifts + + ret = self.model.forward( + input_dict, compute_force=False, compute_virials=False, compute_stress=False, @@ -784,43 +830,103 @@ def forward_lower_common( training=self.training, ) - atom_energy = ret["node_energy"] - if atom_energy is None: + atom_energy_all = ret["node_energy"] + if atom_energy_all is None: msg = "atom_energy is None" raise ValueError(msg) - atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc] - energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype) - grad_outputs: list[Optional[torch.Tensor]] = [ - torch.ones_like(energy), - ] - force = torch.autograd.grad( - outputs=[energy], - inputs=[extended_coord_ff], - grad_outputs=grad_outputs, - retain_graph=True, - create_graph=self.training, - )[0] - if force is None: - msg = "force is None" - raise ValueError(msg) - force = -force - atomic_virial = force.unsqueeze(-1).to( - extended_coord_.dtype, - ) @ extended_coord_ff.unsqueeze(-2).to( - extended_coord_.dtype, + atom_energy_all = atom_energy_all.view(nf, nall) + atom_energy = atom_energy_all[:, :nloc] + energy = torch.sum(atom_energy, dim=1) + grad_outputs = torch.jit.annotate( + list[Optional[torch.Tensor]], + [torch.ones_like(energy)], + ) + retain_graph = self.training or do_atomic_virial + + atomic_virial_fallback = torch.zeros( + (nf, nall, 3, 3), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + if compute_displacement and displacement is not None: + grads = torch.autograd.grad( + outputs=[energy], + inputs=[extended_coord_ff, displacement], + grad_outputs=grad_outputs, + retain_graph=retain_graph, + create_graph=self.training, + allow_unused=True, + ) + force_ff = grads[0] + virial_tensor = grads[1] + if force_ff is None: + force_ff = torch.zeros_like(extended_coord_ff) + if virial_tensor is None: + virial_tensor = torch.zeros( + (nf, 3, 3), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + force = -force_ff.view(nf, nall, 3) + virial = -virial_tensor.view(nf, 1, 9) + else: + force_ff = torch.autograd.grad( + outputs=[energy], + inputs=[extended_coord_ff], + grad_outputs=grad_outputs, + retain_graph=retain_graph, + create_graph=self.training, + allow_unused=True, + )[0] + if force_ff is None: + force_ff = torch.zeros_like(extended_coord_ff) + force = -force_ff.view(nf, nall, 3) + atomic_virial_fallback = force.unsqueeze(-1) @ extended_coord_ff.view( + nf, + nall, + 3, + ).unsqueeze(-2) + virial = torch.sum(atomic_virial_fallback, dim=1).view(nf, 1, 9) + + atomic_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, ) - force = force.view(nf, nall, 3).to(extended_coord_.dtype) - atomic_virial = atomic_virial.view(nf, nall, 1, 9) - virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype) + if do_atomic_virial: + if compute_displacement and displacement is not None: + atomic_virial_local = torch.zeros( + (nf, nloc, 9), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + for ii in range(nloc): + atom_energy_ii = atom_energy[:, ii] + atom_grad_outputs = torch.jit.annotate( + list[Optional[torch.Tensor]], + [torch.ones_like(atom_energy_ii)], + ) + atom_virial_ii = torch.autograd.grad( + outputs=[atom_energy_ii], + inputs=[displacement], + grad_outputs=atom_grad_outputs, + retain_graph=True, + create_graph=self.training, + allow_unused=True, + )[0] + if atom_virial_ii is None: + atom_virial_ii = torch.zeros_like(displacement) + atomic_virial_local[:, ii, :] = (-atom_virial_ii).view(nf, 9) + atomic_virial[:, :nloc, 0, :] = atomic_virial_local + else: + atomic_virial[:, :, 0, :] = atomic_virial_fallback.view(nf, nall, 9) return { - "energy_redu": energy.view(nf, 1), - "energy_derv_r": force.view(nf, nall, 1, 3), - "energy_derv_c_redu": virial.view(nf, 1, 9), - # take the first nloc atoms to match other models - "energy": atom_energy.view(nf, nloc, 1), - # fake atom_virial - "energy_derv_c": atomic_virial.view(nf, nall, 1, 9), + "energy_redu": energy.view(nf, 1).to(extended_coord_.dtype), + "energy_derv_r": force.view(nf, nall, 1, 3).to(extended_coord_.dtype), + "energy_derv_c_redu": virial.to(extended_coord_.dtype), + "energy": atom_energy.view(nf, nloc, 1).to(extended_coord_.dtype), + "energy_derv_c": atomic_virial.to(extended_coord_.dtype), } def serialize(self) -> dict: diff --git a/tests/test_mace_off.py b/tests/test_mace_off.py index bd887a9..a6e7101 100644 --- a/tests/test_mace_off.py +++ b/tests/test_mace_off.py @@ -61,6 +61,11 @@ def _native_mace_reference_outputs( torch.tensor([], dtype=torch.int64, device="cpu"), ).T + shifts = torch.zeros( + (edge_index.shape[1], 3), + dtype=default_dtype, + device=extended_coord_ff.device, + ) if mapping is not None and nloc < nall: mapping_ff = mapping.view(nf * nall) + torch.arange( 0, @@ -72,12 +77,6 @@ def _native_mace_reference_outputs( shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff] shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]] edge_index = mapping_ff[edge_index] - else: - shifts = torch.zeros( - (edge_index.shape[1], 3), - dtype=default_dtype, - device=extended_coord_ff.device, - ) one_hot = torch.zeros( (nf * nall, ntypes), @@ -90,33 +89,60 @@ def _native_mace_reference_outputs( value=1, ) + batch = ( + torch.arange( + nf, + dtype=torch.int64, + device=extended_coord_ff.device, + ) + .unsqueeze(-1) + .expand(nf, nall) + .reshape(-1) + ) + ptr = torch.arange( + 0, + (nf + 1) * nall, + nall, + dtype=torch.int64, + device=extended_coord_ff.device, + ) + weight = torch.ones( + [nf], + dtype=default_dtype, + device=extended_coord_ff.device, + ) + cell = box.view(nf, 3, 3).to(default_dtype).to(extended_coord_ff.device) + edge_batch = torch.div(edge_index[0], nall, rounding_mode="floor") + inv_box = torch.linalg.inv(cell) + unit_shifts = torch.einsum("ec,ecb->eb", shifts, inv_box[edge_batch]) + displacement = torch.zeros( + (nf, 3, 3), + dtype=extended_coord_ff.dtype, + device=extended_coord_ff.device, + ) + displacement.requires_grad_(True) + symmetric_displacement = 0.5 * (displacement + displacement.transpose(-1, -2)) + positions = extended_coord_ff + torch.einsum( + "be,bec->bc", + extended_coord_ff, + symmetric_displacement[batch], + ) + cell_def = cell + torch.matmul(cell, symmetric_displacement) + ret = mace_model.forward( { - "positions": extended_coord_ff, - "shifts": shifts, - "cell": torch.eye( - 3, - dtype=default_dtype, - device=extended_coord_ff.device, - ) - * 1000.0, - "edge_index": edge_index, - "batch": torch.zeros( - [nf * nall], - dtype=torch.int64, - device=extended_coord_ff.device, + "positions": positions, + "cell": cell_def, + "shifts": torch.einsum( + "be,bec->bc", + torch.round(unit_shifts).to(default_dtype), + cell_def[edge_batch], ), + "edge_index": edge_index, + "batch": batch, "node_attrs": one_hot, - "ptr": torch.tensor( - [0, nf * nall], - dtype=torch.int64, - device=extended_coord_ff.device, - ), - "weight": torch.tensor( - [1.0], - dtype=default_dtype, - device=extended_coord_ff.device, - ), + "ptr": ptr, + "weight": weight, }, compute_force=False, compute_virials=False, @@ -130,24 +156,26 @@ def _native_mace_reference_outputs( msg = "Native MACE model returned no node_energy" raise ValueError(msg) atom_energy = atom_energy.view(nf, nall)[:, :nloc] - energy = atom_energy.sum(dim=1).view(nf, 1) + energy = atom_energy.sum(dim=1).view(nf) - force = torch.autograd.grad( + grads = torch.autograd.grad( outputs=[energy], - inputs=[extended_coord_ff], + inputs=[extended_coord_ff, displacement], grad_outputs=[torch.ones_like(energy)], - retain_graph=True, + retain_graph=False, create_graph=False, - )[0] + allow_unused=True, + ) + force = grads[0] + virial = grads[1] if force is None: msg = "Native MACE model returned no force gradient" raise ValueError(msg) force = -force.view(nf, nall, 3) - atomic_virial = force.to(coord.dtype).unsqueeze(-1) @ extended_coord_ff.view( - nf, - nall, - 3, - ).to(coord.dtype).unsqueeze(-2) + if virial is None: + virial_out = torch.zeros((nf, 9), dtype=coord.dtype, device=force.device) + else: + virial_out = (-virial).view(nf, 9).to(coord.dtype) if mapping is not None: force_local = torch.scatter_reduce( @@ -157,21 +185,12 @@ def _native_mace_reference_outputs( force.to(coord.dtype), reduce="sum", ) - atomic_virial_local = torch.scatter_reduce( - torch.zeros((nf, nloc, 9), dtype=coord.dtype, device=force.device), - 1, - mapping.unsqueeze(-1).expand(nf, nall, 9), - atomic_virial.view(nf, nall, 9), - reduce="sum", - ) force_out = force_local - virial_out = atomic_virial_local.sum(dim=1).view(nf, 9) else: force_out = force[:, :nloc, :].to(coord.dtype) - virial_out = atomic_virial.sum(dim=1).view(nf, 9) return { - "energy": energy.to(coord.dtype), + "energy": energy.view(nf, 1).to(coord.dtype), "force": force_out, "virial": virial_out, } diff --git a/tests/test_model.py b/tests/test_model.py index b6aa6dc..cb4e876 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -277,6 +277,8 @@ def test_forward(self) -> None: "aparam": aparam, "fparam": fparam, } + if "atom_virial" in self.output_def: + input_dict["do_atomic_virial"] = True if test_spin: input_dict["spin"] = spin ret.append(module(**input_dict)) @@ -289,6 +291,8 @@ def test_forward(self) -> None: "fparam": fparam, "mapping": mapping_large, } + if "atom_virial" in self.output_def: + input_dict_lower["do_atomic_virial"] = True if test_spin: input_dict_lower["extended_spin"] = spin_ext @@ -303,6 +307,8 @@ def test_forward(self) -> None: "aparam": aparam, "fparam": fparam, } + if "atom_virial" in self.output_def: + input_dict_lower["do_atomic_virial"] = True if test_spin: input_dict_lower["extended_spin"] = spin_ext @@ -953,6 +959,16 @@ def ff_cell(bb): } return module(**input_dict)["energy"] + def ff_cell_atom(bb): + input_dict = { + "coord": stretch_box(coord, cell, bb), + "atype": atype, + "box": bb, + "aparam": aparam, + "fparam": fparam, + } + return module(**input_dict)["atom_energy"] + fdv = ( -( finite_difference(ff_cell, cell, delta=delta) @@ -970,12 +986,32 @@ def ff_cell(bb): "aparam": aparam, "fparam": fparam, } - rfv = module(**input_dict)["virial"] + if "atom_virial" in self.output_def: + input_dict["do_atomic_virial"] = True + ret = module(**input_dict) + rfv = ret["virial"] np.testing.assert_almost_equal( fdv.reshape(-1, 9), rfv.reshape(-1, 9), decimal=places, ) + if "atom_virial" in self.output_def: + fdav = -( + finite_difference(ff_cell_atom, cell, delta=delta) + .reshape(-1, 3, 3) + .transpose(0, 2, 1) + @ cell.reshape(-1, 3, 3) + ).reshape(-1, 9) + np.testing.assert_almost_equal( + fdav, + ret["atom_virial"].reshape(-1, 9), + decimal=places, + ) + np.testing.assert_almost_equal( + ret["atom_virial"].sum(axis=1).reshape(-1, 9), + rfv.reshape(-1, 9), + decimal=places, + ) else: # not support virial by far pass