Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions models/rfd3/src/rfd3/inference/input_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,16 @@ def break_unindexed(unindex: InputSelection):
def _append_ligand(self, atom_array, atom_array_input_annotated):
"""Append ligand if specified."""
if exists(self.ligand):
fixed_atoms = {}
if exists(self.select_fixed_atoms):
fixed_atoms = {
component: atoms
for component, atoms in self.select_fixed_atoms.data.items()
}
ligand_array = extract_ligand_array(
atom_array_input_annotated,
self.ligand,
fixed_atoms={},
fixed_atoms=fixed_atoms,
set_defaults=False,
additional_annotations=set(
list(atom_array.get_annotation_categories())
Expand Down Expand Up @@ -734,9 +740,15 @@ def _set_origin(self, atom_array):
infer_ori_strategy=self.infer_ori_strategy,
)
# Diffused atoms are always initialized at origin during regular diffusion (all information removed)
atom_array.coord[
~atom_array.is_motif_atom_with_fixed_coord.astype(bool)
] = 0.0
unfixed_mask = ~atom_array.is_motif_atom_with_fixed_coord.astype(bool)
# Don't zero out unfixed ligand atoms - they must keep their original
# coordinates to maintain molecular connectivity with fixed ligand atoms.
# They will still receive noise during diffusion (since is_motif_atom_with_fixed_coord=False)
# but start near their true position rather than at the origin.
if exists(self.ligand):
is_ligand = atom_array.hetero.astype(bool)
unfixed_mask = unfixed_mask & ~is_ligand
atom_array.coord[unfixed_mask] = 0.0
return atom_array

def _apply_globals(self, atom_array):
Expand Down
1 change: 0 additions & 1 deletion models/rfd3/src/rfd3/inference/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def from_any_(v: Any, atom_array: AtomArray):

# Split to atom names
data_split[idx] = token.atom_name[comp_mask_subset].tolist()
# TODO: there is a bug where when you select specifc atoms within a ligand, output ligand is fragmented

# Update mask & token dictionary
mask[comp_mask] = comp_mask_subset
Expand Down
Loading
Loading