Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 181 additions & 75 deletions deepmd_gnn/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def forward(
extended_atype,
nlist,
mapping=mapping,
box=box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
Expand All @@ -571,9 +572,12 @@ def forward(
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
model_predict["virial"] = model_ret_lower["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
model_predict["atom_virial"] = model_ret_lower["energy_derv_c"][
:,
:nloc,
].squeeze(-3)
return model_predict

@torch.jit.export
Expand Down Expand Up @@ -636,6 +640,7 @@ def forward_lower(
extended_atype,
nlist,
mapping,
None,
fparam,
aparam,
do_atomic_virial,
Expand All @@ -657,9 +662,10 @@ def forward_lower_common(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False, # noqa: ARG002
do_atomic_virial: bool = False,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
) -> dict[str, torch.Tensor]:
"""Forward lower common pass of the model.
Expand Down Expand Up @@ -699,7 +705,6 @@ def forward_lower_common(
extended_atype = extended_atype.to(torch.int64)
nall = extended_coord.shape[1]

# fake as one frame
extended_coord_ff = extended_coord.view(nf * nall, 3)
extended_atype_ff = extended_atype.view(nf * nall)
edge_index = torch.ops.deepmd_gnn.edge_index(
Expand All @@ -708,24 +713,26 @@ def forward_lower_common(
torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
)
edge_index = edge_index.T
# to one hot
indices = extended_atype_ff.unsqueeze(-1)
oh = torch.zeros(
(nf * nall, self.ntypes),
device=extended_atype.device,
dtype=torch.float64,
)
# scatter_ is the in-place version of scatter
oh.scatter_(dim=-1, index=indices, value=1)
one_hot = oh.view((nf * nall, self.ntypes))

# cast to float32
default_dtype = self.model.atomic_energies_fn.atomic_energies.dtype
extended_coord_ff = extended_coord_ff.to(default_dtype)
extended_coord_ff.requires_grad_(True) # noqa: FBT003
extended_coord_grad = extended_coord.to(default_dtype)
extended_coord_grad.requires_grad_(True) # noqa: FBT003
extended_coord_ff = extended_coord_grad.view(nf * nall, 3)
nedge = edge_index.shape[1]
shifts = torch.zeros(
(nedge, 3),
dtype=default_dtype,
device=extended_coord_ff.device,
)
if self.num_interactions > 1 and mapping is not None and nloc < nall:
# shift the edges for ghost atoms, and map the ghost atoms to real atoms
mapping_ff = mapping.view(nf * nall) + torch.arange(
0,
nf * nall,
Expand All @@ -736,91 +743,190 @@ def forward_lower_common(
shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff]
shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]]
edge_index = mapping_ff[edge_index]
else:
shifts = torch.zeros(
(nedge, 3),
dtype=torch.float64,
device=extended_coord_.device,
)
shifts = shifts.to(default_dtype)
one_hot = one_hot.to(default_dtype)
# it seems None is not allowed for data
box = (
torch.eye(
3,
dtype=extended_coord_ff.dtype,

batch = (
torch.arange(
nf,
dtype=torch.int64,
device=extended_coord_ff.device,
)
* 1000.0
.unsqueeze(-1)
.expand(nf, nall)
.reshape(-1)
)
ptr = torch.arange(
0,
(nf + 1) * nall,
nall,
dtype=torch.int64,
device=extended_coord_ff.device,
)
weight = torch.ones(
[nf],
dtype=extended_coord_ff.dtype,
device=extended_coord_ff.device,
)

ret = self.model.forward(
{
"positions": extended_coord_ff,
"shifts": shifts,
"cell": box,
"edge_index": edge_index,
"batch": torch.zeros(
[nf * nall],
dtype=torch.int64,
device=extended_coord_ff.device,
),
"node_attrs": one_hot,
"ptr": torch.tensor(
[0, nf * nall],
dtype=torch.int64,
device=extended_coord_ff.device,
),
"weight": torch.tensor(
[1.0],
compute_displacement = box is not None
input_dict: dict[str, torch.Tensor] = {
"edge_index": edge_index,
"batch": batch,
"node_attrs": one_hot.to(default_dtype),
"ptr": ptr,
"weight": weight,
}
displacement = torch.jit.annotate(Optional[torch.Tensor], None)
if box is not None:
box_tensor = (
box.view(nf, 3, 3).to(default_dtype).to(extended_coord_ff.device)
)
edge_batch = torch.div(edge_index[0], nall, rounding_mode="floor")
inv_box = torch.linalg.inv(box_tensor)
unit_shifts = torch.einsum("ec,ecb->eb", shifts, inv_box[edge_batch])
displacement = torch.zeros(
(nf, 3, 3),
dtype=extended_coord_ff.dtype,
device=extended_coord_ff.device,
)
displacement.requires_grad_(True)
symmetric_displacement = 0.5 * (
displacement + displacement.transpose(-1, -2)
)
positions = extended_coord_ff + torch.einsum(
"be,bec->bc",
extended_coord_ff,
symmetric_displacement[batch],
)
cell = box_tensor + torch.matmul(box_tensor, symmetric_displacement)
input_dict["positions"] = positions
input_dict["cell"] = cell
input_dict["shifts"] = torch.einsum(
"be,bec->bc",
torch.round(unit_shifts).to(default_dtype),
cell[edge_batch],
)
else:
input_dict["positions"] = extended_coord_ff
input_dict["cell"] = (
torch.eye(
3,
dtype=extended_coord_ff.dtype,
device=extended_coord_ff.device,
),
},
)
.unsqueeze(0)
.expand(nf, 3, 3)
* 1000.0
)
input_dict["shifts"] = shifts

ret = self.model.forward(
input_dict,
compute_force=False,
compute_virials=False,
compute_stress=False,
compute_displacement=False,
training=self.training,
)

atom_energy = ret["node_energy"]
if atom_energy is None:
atom_energy_all = ret["node_energy"]
if atom_energy_all is None:
msg = "atom_energy is None"
raise ValueError(msg)
atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc]
energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype)
grad_outputs: list[Optional[torch.Tensor]] = [
torch.ones_like(energy),
]
force = torch.autograd.grad(
outputs=[energy],
inputs=[extended_coord_ff],
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=self.training,
)[0]
if force is None:
msg = "force is None"
raise ValueError(msg)
force = -force
atomic_virial = force.unsqueeze(-1).to(
extended_coord_.dtype,
) @ extended_coord_ff.unsqueeze(-2).to(
extended_coord_.dtype,
atom_energy_all = atom_energy_all.view(nf, nall)
atom_energy = atom_energy_all[:, :nloc]
energy = torch.sum(atom_energy, dim=1)
grad_outputs = torch.jit.annotate(
list[Optional[torch.Tensor]],
[torch.ones_like(energy)],
)
retain_graph = self.training or do_atomic_virial

atomic_virial_fallback = torch.zeros(
(nf, nall, 3, 3),
dtype=extended_coord_ff.dtype,
device=extended_coord_ff.device,
)
if compute_displacement and displacement is not None:
grads = torch.autograd.grad(
outputs=[energy],
inputs=[extended_coord_ff, displacement],
grad_outputs=grad_outputs,
retain_graph=retain_graph,
create_graph=self.training,
allow_unused=True,
)
force_ff = grads[0]
virial_tensor = grads[1]
if force_ff is None:
force_ff = torch.zeros_like(extended_coord_ff)
if virial_tensor is None:
virial_tensor = torch.zeros(
(nf, 3, 3),
dtype=extended_coord_ff.dtype,
device=extended_coord_ff.device,
)
force = -force_ff.view(nf, nall, 3)
virial = -virial_tensor.view(nf, 1, 9)
else:
Comment on lines +772 to +872
force_ff = torch.autograd.grad(
outputs=[energy],
inputs=[extended_coord_ff],
grad_outputs=grad_outputs,
retain_graph=retain_graph,
create_graph=self.training,
allow_unused=True,
)[0]
if force_ff is None:
force_ff = torch.zeros_like(extended_coord_ff)
force = -force_ff.view(nf, nall, 3)
Comment on lines +851 to +883
atomic_virial_fallback = force.unsqueeze(-1) @ extended_coord_ff.view(
nf,
nall,
3,
).unsqueeze(-2)
virial = torch.sum(atomic_virial_fallback, dim=1).view(nf, 1, 9)

atomic_virial = torch.zeros(
(nf, nall, 1, 9),
dtype=extended_coord_ff.dtype,
device=extended_coord_ff.device,
)
force = force.view(nf, nall, 3).to(extended_coord_.dtype)
atomic_virial = atomic_virial.view(nf, nall, 1, 9)
virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype)
if do_atomic_virial:
if compute_displacement and displacement is not None:
atomic_virial_local = torch.zeros(
(nf, nloc, 9),
dtype=extended_coord_ff.dtype,
device=extended_coord_ff.device,
)
for ii in range(nloc):
atom_energy_ii = atom_energy[:, ii]
atom_grad_outputs = torch.jit.annotate(
list[Optional[torch.Tensor]],
[torch.ones_like(atom_energy_ii)],
)
atom_virial_ii = torch.autograd.grad(
outputs=[atom_energy_ii],
inputs=[displacement],
grad_outputs=atom_grad_outputs,
retain_graph=True,
create_graph=self.training,
allow_unused=True,
)[0]
if atom_virial_ii is None:
atom_virial_ii = torch.zeros_like(displacement)
atomic_virial_local[:, ii, :] = (-atom_virial_ii).view(nf, 9)
Comment on lines +898 to +919
atomic_virial[:, :nloc, 0, :] = atomic_virial_local
else:
atomic_virial[:, :, 0, :] = atomic_virial_fallback.view(nf, nall, 9)

return {
"energy_redu": energy.view(nf, 1),
"energy_derv_r": force.view(nf, nall, 1, 3),
"energy_derv_c_redu": virial.view(nf, 1, 9),
# take the first nloc atoms to match other models
"energy": atom_energy.view(nf, nloc, 1),
# fake atom_virial
"energy_derv_c": atomic_virial.view(nf, nall, 1, 9),
"energy_redu": energy.view(nf, 1).to(extended_coord_.dtype),
"energy_derv_r": force.view(nf, nall, 1, 3).to(extended_coord_.dtype),
"energy_derv_c_redu": virial.to(extended_coord_.dtype),
"energy": atom_energy.view(nf, nloc, 1).to(extended_coord_.dtype),
"energy_derv_c": atomic_virial.to(extended_coord_.dtype),
}

def serialize(self) -> dict:
Expand Down
Loading
Loading