From 520945e43cd6218c080c2f4f87385d2c74a9daa6 Mon Sep 17 00:00:00 2001 From: Chen-Jie7 Date: Thu, 22 Jan 2026 12:42:13 -0800 Subject: [PATCH 1/3] removed old files and wrap hbplus output in tempdir --- .../rfd3/src/rfd3/metrics/hbonds_metrics.py | 389 ----------------- models/rfd3/src/rfd3/transforms/hbonds.py | 407 ------------------ .../rfd3/src/rfd3/transforms/hbonds_hbplus.py | 100 ++--- 3 files changed, 51 insertions(+), 845 deletions(-) delete mode 100644 models/rfd3/src/rfd3/metrics/hbonds_metrics.py delete mode 100644 models/rfd3/src/rfd3/transforms/hbonds.py diff --git a/models/rfd3/src/rfd3/metrics/hbonds_metrics.py b/models/rfd3/src/rfd3/metrics/hbonds_metrics.py deleted file mode 100644 index f680bc05..00000000 --- a/models/rfd3/src/rfd3/metrics/hbonds_metrics.py +++ /dev/null @@ -1,389 +0,0 @@ -import logging -from typing import Literal - -import biotite.structure as struc -import numpy as np -from atomworks.enums import ChainType -from atomworks.io.transforms.atom_array import remove_hydrogens -from rfd3.constants import ( - ATOM14_ATOM_NAMES, - SELECTION_NONPROTEIN, - SELECTION_PROTEIN, - association_schemes_stripped, -) -from rfd3.transforms.hbonds import ( - add_hydrogen_atom_positions, - calculate_hbonds, -) - -from foundry.metrics.base import Metric -from foundry.utils.ddp import RankedLogger - -logging.basicConfig(level=logging.INFO) -global_logger = RankedLogger(__name__, rank_zero_only=False) - - -def simplified_processing_atom_array(atom_arrays, central_atom="CB", threshold=0.5): - """ - Allows for sequence extraction from cleaned up virtual atoms. Needed for hbond metrics. - """ - final_atom_array = [] - for atom_array in atom_arrays: - cur_atom_array_list = [] - - res_ids = atom_array.res_id - res_start_indices = np.concatenate( - [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1] - ) - res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]]) - - for start, end in zip(res_start_indices, res_end_indices): - cur_res_atom_array = atom_array[start:end] - - # Check if the current residue is after padding (seq unknown): - if_seq_known = not any( - atom_name.startswith("V") for atom_name in cur_res_atom_array.atom_name - ) - - if not if_seq_known: - # For Glycine: it doesn't have CB, so set the virtual atom as CA. - # The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA. - # There might be a better way to do this. - CA_coord = cur_res_atom_array.coord[ - cur_res_atom_array.atom_name == "CA" - ] - CB_coord = cur_res_atom_array.coord[ - cur_res_atom_array.atom_name == "CB" - ] - if np.linalg.norm(CA_coord - CB_coord) < threshold: - central_atom = "CA" - - central_mask = cur_res_atom_array.atom_name == central_atom - - # ... Calculate the distance to the central atom - central_coord = cur_res_atom_array.coord[central_mask][ - 0 - ] # Should only have one central atom anyway - dists = np.linalg.norm( - cur_res_atom_array.coord - central_coord, axis=-1 - ) - - # ... Select virtual atom by the distance. Shouldn't count the central atom itself. - is_virtual = (dists < threshold) & ~central_mask - - cur_res_atom_array = cur_res_atom_array[~is_virtual] - cur_pred_res_atom_names = ( - cur_res_atom_array.atom_name - ) # e.g. [N, CA, C, O, CB, V6, V2] - - has_restype_assigned = False - for restype, atom_names in association_schemes_stripped[ - "atom14" - ].items(): - atom_names = np.array(atom_names) - if restype in ["UNK", "MSK"]: - continue - - atom_name_idx_in_atom14_scheme = np.array( - [ - np.where(ATOM14_ATOM_NAMES == atom_name)[0][0] - for atom_name in cur_pred_res_atom_names - ] - ) # [0, 1, 2, 3, 4, 11, 7] - atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool) - atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True - if all( - x is not None for x in atom_names[atom14_scheme_mask] - ) and all(x is None for x in atom_names[~atom14_scheme_mask]): - cur_res_atom_array.res_name = np.array( - [restype] * len(cur_res_atom_array) - ) - cur_res_atom_array.atom_name = np.asarray( - atom_names[atom14_scheme_mask], dtype=str - ) - cur_atom_array_list.append(cur_res_atom_array) - has_restype_assigned = True - break - else: - cur_atom_array_list.append(cur_res_atom_array) - has_restype_assigned = True - - if not has_restype_assigned: - cur_res_atom_array.res_name = np.array( - ["UNK"] * len(cur_res_atom_array) - ) - cur_atom_array_list.append(cur_res_atom_array) - - cur_atom_array = struc.concatenate(cur_atom_array_list) - cur_atom_array.element = struc.infer_elements(cur_atom_array.atom_name) - - final_atom_array.append(cur_atom_array) - - return final_atom_array - - -# Training comparison -def calculate_hbond_stats( - input_atom_array_stack, - output_atom_array_stack, - selection1, - selection2, - selection1_type, - cutoff_dist, - cutoff_angle, - donor_elements, - acceptor_elements, - periodic, -): - """ - Compare the number of hbonds correctly recapitualted in the output atom array. - - Args: - input_atom_array_stack: Input atom array stack - output_atom_array_stack: Output atom array stack - selection1: Selection of atom types allowed to be donors (5,6) - selection2: Selection of atom types allowed to be acceptors (1,2,3...) - cutoff_dist: Cutoff distance for hbonds - cutoff_angle: Cutoff angle for hbonds - """ - # Used the latest function above, should check if it works correctly - output_atom_array_stack = simplified_processing_atom_array(output_atom_array_stack) - - assert len(input_atom_array_stack) == len( - output_atom_array_stack - ), "Input and output atom arrays must have the same length" - - total_correct_donors_percent = 0.0 - total_correct_acceptors_percent = 0.0 - total_number_hbonds = 0 - num_valid_samples = 0 - for i in range(len(input_atom_array_stack)): - correct_donors = 0 - correct_acceptors = 0 - - input_atom_array = input_atom_array_stack[i] - output_atom_array = output_atom_array_stack[i] - - if not ( - "active_donor" in input_atom_array.get_annotation_categories() - or "active_acceptor" in input_atom_array.get_annotation_categories() - ): - # print("active donor/acceptor not in annotation") - continue - if np.sum(input_atom_array.active_donor == 0) and np.sum( - input_atom_array.active_acceptor == 0 - ): - continue - - # Select possible donors and acceptors for the model output - if selection1 is None or selection2 is None: - continue - - # Hack: Temporarily use biotite to infer bonds, should be replaced with cifutils? - output_atom_array.bonds = struc.connect_via_distances( - output_atom_array, default_bond_type=1 - ) - - # Hack: delete coords_to_be_diffused (if exists) to temporarily solve a weird bug in create hydrogens. Anyway it will not be used. - if "coord_to_be_noised" in input_atom_array.get_annotation_categories(): - input_atom_array.del_annotation("coord_to_be_noised") - if "coord_to_be_noised" in output_atom_array.get_annotation_categories(): - output_atom_array.del_annotation("coord_to_be_noised") - - output_atom_array = add_hydrogen_atom_positions(output_atom_array) - - cur_selection1 = np.isin(output_atom_array.chain_type, selection1) - cur_selection2 = ( - np.isin(output_atom_array.chain_type, selection2) - | get_motif_features(output_atom_array)["is_motif_atom"] - ) - - hbonds, hbond_types, output_atom_array = calculate_hbonds( - output_atom_array, - cur_selection1, - cur_selection2, - selection1_type=selection1_type, - cutoff_dist=cutoff_dist, - cutoff_angle=cutoff_angle, - donor_elements=donor_elements, - acceptor_elements=acceptor_elements, - periodic=periodic, - ) - - output_atom_array.set_annotation("active_donor", hbond_types[:, 0]) - output_atom_array.set_annotation("active_acceptor", hbond_types[:, 1]) - - output_atom_array = remove_hydrogens(output_atom_array) - - given_hbond_donors = np.array(input_atom_array.active_donor, dtype=bool) - given_hbond_acceptors = np.array(input_atom_array.active_acceptor, dtype=bool) - given_hbond_donors_index = np.where(input_atom_array.active_donor == 1)[0] - given_hbond_acceptors_index = np.where(input_atom_array.active_acceptor == 1)[0] - - # Ensure the produced hbonds matches input hbond requirements: have the same atom type, residue name, and atom name - for idx in given_hbond_donors_index: - if bool( - output_atom_array[ - (output_atom_array.chain_id == input_atom_array.chain_id[idx]) - & (output_atom_array.res_id == input_atom_array.res_id[idx]) - & ( - output_atom_array.atom_name - == input_atom_array.gt_atom_name[idx] - ) - ].active_donor - ): - correct_donors += 1 - - for idx in given_hbond_acceptors_index: - if bool( - output_atom_array[ - (output_atom_array.chain_id == input_atom_array.chain_id[idx]) - & (output_atom_array.res_id == input_atom_array.res_id[idx]) - & ( - output_atom_array.atom_name - == input_atom_array.gt_atom_name[idx] - ) - ].active_acceptor - ): - correct_acceptors += 1 - - correct_hbond_donors_percent = ( - correct_donors / np.sum(given_hbond_donors) - if np.sum(given_hbond_donors) > 0 - else 1.0 - ) - correct_hbond_acceptors_percent = ( - correct_acceptors / np.sum(given_hbond_acceptors) - if np.sum(given_hbond_acceptors) > 0 - else 1.0 - ) - - total_correct_donors_percent += correct_hbond_donors_percent - total_correct_acceptors_percent += correct_hbond_acceptors_percent - total_number_hbonds += len(hbonds) - num_valid_samples += 1 - - if num_valid_samples == 0: - return 0, 0, 0 - return ( - total_correct_donors_percent / num_valid_samples, - total_correct_acceptors_percent / num_valid_samples, - total_number_hbonds / num_valid_samples, - ) - - -# Inference comparison -> tempportary fix to test out sm_hbonds, should be merged with hbond in transforms down the line -def get_hbond_metrics(atom_array=None): - if atom_array is None: - print("WARNING: atom_array is None") - return None # Or raise a more descriptive error - - curr_copy = atom_array.copy() - o = {} - selection1 = np.array([ChainType.as_enum(item).value for item in SELECTION_PROTEIN]) - selection2 = np.array( - [ChainType.as_enum(item).value for item in SELECTION_NONPROTEIN] - ) - # Hack: Temporarily use biotite to infer bonds, should be replaced with cifutils? - curr_copy.bonds = struc.connect_via_distances(curr_copy, default_bond_type=1) - # Hack: delete coords_to_be_diffused (if exists) to temporarily solve a weird bug in create hydrogens. Anyway it will not be used. - if "coord_to_be_noised" in curr_copy.get_annotation_categories(): - curr_copy.del_annotation("coord_to_be_noised") - - try: - curr_copy = add_hydrogen_atom_positions(curr_copy) - except Exception as e: - print("WARNING: problem adding hydrogen", e) - - if selection1 is not None: - selection1 = np.isin(curr_copy.chain_type, selection1) - else: - selection1 = selection1 - if selection2 is not None: - selection2 = np.isin(curr_copy.chain_type, selection2) - else: - selection2 = selection2 - - # Always include fixed motif atoms for hbond calculations - selection2 |= np.array(curr_copy.is_motif_atom, dtype=bool) - selection1 = ~selection2 - - hbonds, hbond_types, curr_copy = calculate_hbonds( - curr_copy, - selection1=selection1, - selection2=selection2, - ) - - o["num_hbonds"] = int(len(hbonds)) - o["num_donors"] = int(np.sum(hbond_types[:, 0])) - o["num_acceptors"] = int(np.sum(hbond_types[:, 1])) - - return o - - -class HbondMetrics(Metric): - def __init__( - self, - selection1: list[str] = SELECTION_PROTEIN, - selection2: list[str] = SELECTION_NONPROTEIN, - selection1_type: Literal["acceptor", "donor", "both"] = "both", - cutoff_dist: float = 3.0, - cutoff_angle: float = 120.0, - donor_elements: list[str] = ["N", "O", "S", "F"], - acceptor_elements: list[str] = ["N", "O", "S", "F"], - periodic: bool = False, - ): - super().__init__() - - self.selection1 = np.array( - [ChainType.as_enum(item).value for item in selection1] - ) - self.selection2 = np.array( - [ChainType.as_enum(item).value for item in selection2] - ) - - self.selection1_type = selection1_type - self.cutoff_dist = cutoff_dist - self.cutoff_angle = cutoff_angle - self.donor_elements = donor_elements - self.acceptor_elements = acceptor_elements - self.periodic = periodic - - @property - def kwargs_to_compute_args(self): - return { - "ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",), - "predicted_atom_array_stack": ("predicted_atom_array_stack",), - } - - def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack): - try: - ( - mean_correct_donors_percent, - mean_correct_acceptors_percent, - mean_num_hbonds, - ) = calculate_hbond_stats( - input_atom_array_stack=ground_truth_atom_array_stack, - output_atom_array_stack=predicted_atom_array_stack, - selection1=self.selection1, - selection2=self.selection2, - selection1_type=self.selection1_type, - cutoff_dist=self.cutoff_dist, - cutoff_angle=self.cutoff_angle, - donor_elements=self.donor_elements, - acceptor_elements=self.acceptor_elements, - periodic=self.periodic, - ) - except Exception as e: - global_logger.error( - f"Error calculating hydrogen bond metrics: {e} | Skipping" - ) - return {} - - # Aggregate output for batch-level metrics - o = { - "mean_correct_donors_percent": float(mean_correct_donors_percent), - "mean_correct_acceptors_percent": float(mean_correct_acceptors_percent), - "mean_num_hbonds": float(mean_num_hbonds), - } - return o diff --git a/models/rfd3/src/rfd3/transforms/hbonds.py b/models/rfd3/src/rfd3/transforms/hbonds.py deleted file mode 100644 index d4ad3e08..00000000 --- a/models/rfd3/src/rfd3/transforms/hbonds.py +++ /dev/null @@ -1,407 +0,0 @@ -from typing import Any, Literal, Tuple - -import biotite.structure as struc -import hydride -import numpy as np -from atomworks.io.transforms.atom_array import remove_hydrogens -from atomworks.io.utils.ccd import atom_array_from_ccd_code -from atomworks.ml.transforms._checks import ( - check_atom_array_annotation, - check_contains_keys, - check_is_instance, -) -from atomworks.ml.transforms.base import Transform -from biotite.structure import AtomArray, AtomArrayStack -from rfd3.constants import SELECTION_NONPROTEIN, SELECTION_PROTEIN - -from foundry.utils.ddp import RankedLogger - -ranked_logger = RankedLogger() - -HYDROGEN_LIKE_SYMBOLS = ("H", "H2", "D", "T") - - -# TODO: Once the cifutils submodule is bumped, we can use the built-in add_hydrogen_atom_positions function -def add_hydrogen_atom_positions( - atom_array: AtomArray | AtomArrayStack, -) -> AtomArray | AtomArrayStack: - """Add hydrogens using biotite supported hydride library - - Args: - atom_array (AtomArray | AtomArrayStack): The atom array containing the chain information. - - Returns: - AtomArray: The updated atom array with hydrogens added. - """ - - def _get_charge_from_ccd_code(atom): - try: - ccd_array = atom_array_from_ccd_code(atom.res_name) - charge = ccd_array[ - ccd_array.atom_name.tolist().index(atom.atom_name) - ].charge - except Exception: - ## res_name not found in ccd or atom_name not found in ccd_array - charge = 0 - return charge - - if "charge" not in atom_array.get_annotation_categories(): - charges = np.vectorize(_get_charge_from_ccd_code)(atom_array) - atom_array.set_annotation("charge", charges) - - # Add as a custom annotation - - array = remove_hydrogens(atom_array) - - fields_to_copy_from_residue_if_present = [ - "auth_seq_id", - "label_entity_id", - "is_can_prot", - "is_can_nucl", - "is_sm", - "chain_type", - ] - fields_to_copy_from_residue_if_present = list( - set(fields_to_copy_from_residue_if_present).intersection( - set(atom_array.get_annotation_categories()) - ) - ) - - def _copy_missing_annotations_residue_wise( - arr_to_copy_from: AtomArray, - arr_to_update: AtomArray, - fields_to_copy_from_residue_if_present: list[str], - ) -> AtomArray: - """Copy specified annotations residue-wise from one AtomArray to another. Updates annotations in-place.""" - residue_starts = struc.get_residue_starts(arr_to_copy_from) - residue_starts_atom_array = arr_to_copy_from[residue_starts] - annot = { - item: getattr(residue_starts_atom_array, item) - for item in fields_to_copy_from_residue_if_present - } - for field in fields_to_copy_from_residue_if_present: - updated_field = struc.spread_residue_wise(arr_to_update, annot[field]) - arr_to_update.set_annotation(field, updated_field) - return arr_to_update - - def _handle_nan_coords(atom_array, noise_level=1e-3): - coords = atom_array.coord - - # Find NaNs - nan_mask = np.isnan(coords) - - # Replace NaNs with 0 + small random offset - coords[nan_mask] = np.random.uniform( - -noise_level, noise_level, size=nan_mask.sum() - ) - - # Update atom_array in-place - atom_array.coord = coords - return atom_array, nan_mask - - if isinstance(array, AtomArrayStack): - updated_arrays = [] - for old_arr in array: - if old_arr.bonds is None: - old_arr.bonds = struc.connect_via_distances(old_arr) - - ## give some values to nan - old_arr, nan_mask = _handle_nan_coords(old_arr) - arr, mask = hydride.add_hydrogen(old_arr) - ## put back nans - arr.coord[mask, :][nan_mask] = np.nan - arr = _copy_missing_annotations_residue_wise( - old_arr, arr, fields_to_copy_from_residue_if_present - ) - updated_arrays.append(arr) - - ret_array = struc.stack(updated_arrays) - - elif isinstance(array, AtomArray): - if array.bonds is None: - array.bonds = struc.connect_via_distances(array) - ## give some values to nan - array, nan_mask = _handle_nan_coords(array) - arr, mask = hydride.add_hydrogen(array) - ## put back nans - arr.coord[mask, :][nan_mask] = np.nan - ret_array = _copy_missing_annotations_residue_wise( - array, arr, fields_to_copy_from_residue_if_present - ) - return ret_array - - -def check_atom_array_has_hydrogen(data: dict[str, Any]): - """Check if `atom_array` key has bonds.""" - import numpy as np - - if not np.any(data["atom_array"].element == "H"): - raise ValueError("Key `atom_array` in data has no hydrogens.") - - -def calculate_hbonds( - atom_array: AtomArray, - selection1: np.ndarray = None, - selection2: np.ndarray = None, - selection1_type: Literal["acceptor", "donor", "both"] = "both", - cutoff_dist: float = 3, - cutoff_angle: float = 120, - donor_elements: Tuple[str] = ("O", "N", "S", "F"), - acceptor_elements: Tuple[str] = ("O", "N", "S", "F"), - periodic: bool = False, -) -> Tuple[np.ndarray, np.ndarray, AtomArray]: - """ - Calculates Hbonds with biotite.struc.Hbond. - Assigns donor, acceptor annotation for each heavy atom involved. - Args: - atom_array (AtomArray):Expects the atom_array that contains hydrogens. - - selection1 and selection2 (np.ndarray, optional): (Boolean mask for atoms to limit the hydrogen bond search to specific sections of the model. - The shape must match the shape of the atoms argument. If None is given, the whole atoms stack is used instead. (Default: None)) - - selection1_type (Literal, optional): Determines the type of selection1. The type of selection2 is chosen accordingly (‘both’ or the opposite). - (Default: 'both') - cutoff_dist (float, optional): The maximal distance between the hydrogen and acceptor to be considered a hydrogen bond. (Default: 2.5) - cutoff_angle (float, optional): The angle cutoff in degree between Donor-H..Acceptor to be considered a hydrogen bond. (Default: 120) - donor_elements, acceptor_elements (tuple of str): Elements to be considered as possible donors or acceptors. (Default: O, N, S) - periodic (bool, optional): If true, hydrogen bonds can also be detected in periodic boundary conditions. The box attribute of atoms is required in this case. (Default: False) - - - """ - # Remove NaN coordinates - has_resolved_coordinates = ~np.isnan(atom_array.coord).any(axis=-1) - nonNaN_array = atom_array[has_resolved_coordinates] - - # update selections if any - if selection1 is not None: - selection1 = selection1[has_resolved_coordinates] - if selection2 is not None: - selection2 = selection2[has_resolved_coordinates] - - ## index map from nonNaN_array to original - index_map = { - counter: i for counter, i in enumerate(has_resolved_coordinates.nonzero()[0]) - } - - if selection1.sum() == 0 or selection2.sum() == 0: - # no ligand, or ligand is of same type as selection1 (e.g. 6) (peptide) - triplets = np.array([]) - else: - # Compute H bonds - triplets = struc.hbond( ## assuming AtomArray, not AtomArrayStack (returns an extra masks in that case) - nonNaN_array, - selection1=selection1, - selection2=selection2, - selection1_type=selection1_type, - cutoff_dist=cutoff_dist, - cutoff_angle=cutoff_angle, - donor_elements=donor_elements, - acceptor_elements=acceptor_elements, - periodic=periodic, - ) - - ## map back triplet indices, nonNaN indices to original indices - flattened = triplets.flatten() - triplets = np.array([index_map[i] for i in flattened]).reshape(-1, 3) - - ## add back NaNs - - donor_array = np.array([[0.0] * len(atom_array)]) - acceptor_array = np.array([[0.0] * len(atom_array)]) - - if len(triplets) > 0: - donor_array[:, triplets[:, 0]] = 1.0 - acceptor_array[:, triplets[:, 2]] = 1.0 - - ## [is_active_donor, is_active_acceptor] per atom - types = np.vstack((donor_array, acceptor_array)).T - - return triplets, types, atom_array - - -class CalculateHbonds(Transform): - """Transform for calculating Hbonds, expects an AtomArray containing hydrogens.""" - - def __init__( - self, - selection1_type: Literal["acceptor", "donor", "both"] = "both", - cutoff_dist: float = 3, - cutoff_angle: float = 120, - donor_elements: Tuple[str] = ("O", "N", "S", "F"), - acceptor_elements: Tuple[str] = ("O", "N", "S", "F"), - periodic: bool = False, - make2d: bool = False, - ): - """ - Initialize the Hbonds transform. - - Args: - - selection1 and selection2 (list[str], optional): Specify a list of ChainTypes as in atomworks.enums. e.g. selectoin1 = ['POLYPEPTIDE(L)'], selection2 = ['NON-POLYMER', 'POLYRIBONUCLEOTIDE'] - Allowed values: {'PEPTIDE NUCLEIC ACID', 'BRANCHED', 'POLYDEOXYRIBONUCLEOTIDE', 'POLYRIBONUCLEOTIDE', 'CYCLIC-PSEUDO-PEPTIDE', 'MACROLIDE', 'POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID', 'OTHER', 'POLYPEPTIDE(L)', 'NON-POLYMER', 'POLYPEPTIDE(D)', 'WATER'} - - selection1_type (Literal, optional): Determines the type of selection1. The type of selection2 is chosen accordingly (‘both’ or the opposite). - (Default: 'both') - cutoff_dist (float, optional): The maximal distance between the hydrogen and acceptor to be considered a hydrogen bond. (Default: 2.5) - cutoff_angle (float, optional): The angle cutoff in degree between Donor-H..Acceptor to be considered a hydrogen bond. (Default: 120) - donor_elements, acceptor_elements (tuple of str): Elements to be considered as possible donors or acceptors. (Default: O, N, S) - periodic (bool, optional): If true, hydrogen bonds can also be detected in periodic boundary conditions. The box attribute of atoms is required in this case. (Default: False) - """ - self.selection1_type = selection1_type - self.cutoff_dist = cutoff_dist - self.cutoff_angle = cutoff_angle - self.donor_elements = donor_elements - self.acceptor_elements = acceptor_elements - self.periodic = periodic - self.make2d = make2d - - def check_input(self, data: dict[str, Any]) -> None: - check_contains_keys(data, ["atom_array"]) - check_is_instance(data, "atom_array", AtomArray) - check_atom_array_annotation(data, ["res_name"]) - - ## turn off cause H addition debug ongoing - # check_atom_array_has_hydrogen(data) - - def forward(self, data: dict) -> dict: - """ - Calculates Hbonds and adds it to the data dictionary under the key `hbonds`. - - Args: - data: dict - A dictionary containing the input data atomarray. - Expects the atom_array in data["atom_array"] contains hydrogens. - - - Returns: - dict: The data dictionary with hbonds added. - Sets hbond_type = [Donor, Acceptor] annotation to each atom. Donor, Acceptor can be both 0 or 1 (float). size: Lx2 (L: length of AtomArray) - """ - - atom_array: AtomArray = data["atom_array"] - - try: - atom_array = add_hydrogen_atom_positions(atom_array) - - except Exception as e: - print( - f"WARNING: problem adding hydrogens: {e}.\nThis example will get no hydrogen bond annotations." - ) - atom_array.set_annotation( - "active_donor", np.zeros(atom_array.array_length(), dtype=bool) - ) - atom_array.set_annotation( - "active_acceptor", np.zeros(atom_array.array_length(), dtype=bool) - ) - data["atom_array"] = atom_array - return data - - ## These are the only two use-cases we have so far. Can be extended as needed - - if data["sampled_condition_name"] == "ppi": - selection1_chain_types = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"] - selection2_chain_types = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"] - separate_selections_for_motif_and_diffused = True - else: - selection1_chain_types = SELECTION_PROTEIN - selection2_chain_types = SELECTION_NONPROTEIN - separate_selections_for_motif_and_diffused = False - - selection1 = np.isin(atom_array.chain_type, selection1_chain_types) - selection2 = np.isin(atom_array.chain_type, selection2_chain_types) - - # Optionally restrict to Hbonds between motif and diffused regions - if separate_selections_for_motif_and_diffused: - selection1 = selection1 & atom_array.is_motif_atom - selection2 = selection2 & ~atom_array.is_motif_atom - else: - # Include fixed motif atoms for hbond calculations - selection2 |= np.array(atom_array.is_motif_atom, dtype=bool) - selection1 = ~selection2 - - hbonds, hbond_types, atom_array = calculate_hbonds( - atom_array, - selection1=selection1, - selection2=selection2, - selection1_type=self.selection1_type, - cutoff_dist=self.cutoff_dist, - cutoff_angle=self.cutoff_angle, - donor_elements=self.donor_elements, - acceptor_elements=self.acceptor_elements, - periodic=self.periodic, - ) - - # Initialize log_dict if not present - data.setdefault("log_dict", {}) - log_dict = data["log_dict"] - - # Log hbond statistics - log_dict["hbond_total_count"] = len(hbonds) - log_dict["hbond_total_atoms"] = hbond_types.sum() - - # Subsample if hbond_subsample is set and number of atoms is bigger than 3 - final_hbond_types = hbond_types - final_hbond_types[:, 0] = final_hbond_types[:, 0] * np.array( - atom_array.is_motif_atom - ) - final_hbond_types[:, 1] = final_hbond_types[:, 1] * np.array( - atom_array.is_motif_atom - ) - - if data["conditions"]["hbond_subsample"] and np.sum(hbond_types) > 3: - # Linear correlation: fewer hbonds = higher fraction - base_fraction = 0.1 # minimum fraction (when many hbonds) - max_fraction = 0.9 # maximum fraction (when few hbonds) - n_hbonds = len(hbonds) - max_hbonds = 50 # Expected maximum number of hbonds for scaling - - # Linear interpolation: fraction decreases linearly with number of hbonds - fraction = max_fraction - (max_fraction - base_fraction) * min( - n_hbonds / max_hbonds, 1.0 - ) - final_hbond_types = subsample_one_hot_np(hbond_types, fraction) - - # Set annotations and log subsample atoms - atom_array.set_annotation("active_donor", final_hbond_types[:, 0]) - atom_array.set_annotation("active_acceptor", final_hbond_types[:, 1]) - log_dict["hbond_subsample_atoms"] = final_hbond_types.sum() - - # Remove hydrogens after processing - atom_array = remove_hydrogens(atom_array) - data["log_dict"] = log_dict - data["atom_array"] = atom_array - return data - - -def subsample_one_hot_np(array, fraction): - """ - Subsamples a one-hot encoded NumPy array by randomly keeping a given fraction of the 1s. - - Args: - array (np.ndarray): One-hot array of 0s and 1s. - fraction (float): Fraction of 1s to keep (0 < fraction <= 1). - - Returns: - np.ndarray: Subsampled array with same shape. - """ - if not (0 < fraction <= 1): - raise ValueError("Fraction must be in the range (0, 1].") - - array = array.copy() # Don't modify original - one_indices = np.argwhere(array == 1) - num_ones = len(one_indices) - - keep_count = int(num_ones * fraction) - - # Shuffle and choose a subset of indices to keep - np.random.shuffle(one_indices) - keep_indices = one_indices[:keep_count] - - # Create new zero array - new_array = np.zeros_like(array) - - # Set selected indices to 1 - for i, j in keep_indices: - new_array[i, j] = 1 - - return new_array diff --git a/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py b/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py index 24231c35..56053120 100644 --- a/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py +++ b/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py @@ -1,6 +1,7 @@ import os import string import subprocess +import tempfile from datetime import datetime from typing import Any, Tuple @@ -66,10 +67,6 @@ def calculate_hbonds( cutoff_HA_dist: float = 3, cutoff_DA_distance: float = 3.5, ) -> Tuple[np.ndarray, np.ndarray, AtomArray]: - dtstr = datetime.now().strftime("%Y%m%d%H%M%S") - pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb" - atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path) - hbplus_exe = os.environ.get("HBPLUS_PATH") if hbplus_exe is None or hbplus_exe == "": @@ -78,49 +75,56 @@ def calculate_hbonds( "Please set it to the path of the hbplus executable in order to calculate hydrogen bonds." ) - subprocess.call( - [ - hbplus_exe, - "-h", - str(cutoff_HA_dist), - "-d", - str(cutoff_DA_distance), - pdb_path, - pdb_path, - ], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - - HB = open(pdb_path.replace("pdb", "hb2"), "r").readlines() - hbonds = [] - for i in range(8, len(HB)): - d_chain = HB[i][0] - d_resi = str(int(HB[i][1:5].strip())) - d_resn = HB[i][6:9].strip() - d_ins = HB[i][5].replace("-", " ") - d_atom = HB[i][9:13].strip() - a_chain = HB[i][14] - a_resi = str(int(HB[i][15:19].strip())) - a_ins = HB[i][19].replace("-", " ") - a_resn = HB[i][20:23].strip() - a_atom = HB[i][23:27].strip() - dist = float(HB[i][27:32].strip()) - - items = { - "d_chain": chain_map[d_chain], - "d_resi": d_resi, - "d_resn": d_resn, - "d_ins": d_ins, - "d_atom": d_atom, - "a_chain": chain_map[a_chain], - "a_resi": a_resi, - "a_resn": a_resn, - "a_ins": a_ins, - "a_atom": a_atom, - "dist": dist, - } - hbonds.append(items) + with tempfile.TemporaryDirectory() as tmpdir: + dtstr = datetime.now().strftime("%Y%m%d%H%M%S") + pdb_filename = f"{dtstr}_{np.random.randint(10000)}.pdb" + pdb_path = os.path.join(tmpdir, pdb_filename) + atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path) + + subprocess.call( + [ + hbplus_exe, + "-h", + str(cutoff_HA_dist), + "-d", + str(cutoff_DA_distance), + pdb_path, + pdb_path, + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + hb2_path = pdb_path.replace(".pdb", ".hb2") + HB = open(hb2_path, "r").readlines() + hbonds = [] + for i in range(8, len(HB)): + d_chain = HB[i][0] + d_resi = str(int(HB[i][1:5].strip())) + d_resn = HB[i][6:9].strip() + d_ins = HB[i][5].replace("-", " ") + d_atom = HB[i][9:13].strip() + a_chain = HB[i][14] + a_resi = str(int(HB[i][15:19].strip())) + a_ins = HB[i][19].replace("-", " ") + a_resn = HB[i][20:23].strip() + a_atom = HB[i][23:27].strip() + dist = float(HB[i][27:32].strip()) + + items = { + "d_chain": chain_map[d_chain], + "d_resi": d_resi, + "d_resn": d_resn, + "d_ins": d_ins, + "d_atom": d_atom, + "a_chain": chain_map[a_chain], + "a_resi": a_resi, + "a_resn": a_resn, + "a_ins": a_ins, + "a_atom": a_atom, + "dist": dist, + } + hbonds.append(items) donor_array = np.zeros(len(atom_array)) acceptor_array = np.zeros(len(atom_array)) @@ -162,8 +166,6 @@ def calculate_hbonds( donor_array[donor_mask] = 1 acceptor_array[acceptor_mask] = 1 - os.remove(pdb_path) - os.remove(pdb_path.replace("pdb", "hb2")) atom_array.set_annotation("active_donor", donor_array) atom_array.set_annotation("active_acceptor", acceptor_array) From d8d3f8ebbf5eb3df4033238b2e0c372109fc478b Mon Sep 17 00:00:00 2001 From: Chen-Jie7 Date: Tue, 17 Feb 2026 20:57:51 -0800 Subject: [PATCH 2/3] Fix ligand fragmentation when partially fixing ligand atoms in rfd3 When a user specifies select_fixed_atoms for a ligand with only a subset of its atoms, the unfixed atoms were ending up physically disconnected from the fixed ones (e.g., C11-C12 bond at 5-11A instead of ~1.5A). Three root causes were addressed: 1. _set_origin() zeroed coordinates for ALL unfixed atoms including ligand atoms, placing them at the origin. Now excludes ligand atoms (identified via atom_array.hetero) from zeroing. 2. centre_random_augment_around_motif() applied centering asymmetrically when centering_affects_motif=False, pulling unfixed ligand atoms away from fixed ones. Now accepts an is_ligand_atom mask and excludes ligand atoms from asymmetric centering. 3. Both initial noise (_get_initial_structure) and per-step noise (epsilon_L) were applied at full scale to unfixed ligand atoms, placing them ~160A away. Now zeroes noise for unfixed ligand atoms so they stay near their original positions while the model can still softly adjust them through denoising predictions. To support this, is_ligand_atom is now extracted as an atom-level feature in AddIsXFeats and propagated through the diffusion pipeline to the inference sampler. Verified with biotin_2.json partial fix (16/21 atoms fixed): C11-C12 bond improved from 5-11A to ~2.5A. Co-Authored-By: Claude Opus 4.6 --- .../rfd3/src/rfd3/inference/input_parsing.py | 20 +++++-- models/rfd3/src/rfd3/inference/parsing.py | 1 - .../rfd3/src/rfd3/model/inference_sampler.py | 37 +++++++++++- .../src/rfd3/transforms/design_transforms.py | 4 ++ models/rfd3/src/rfd3/transforms/pipelines.py | 2 + models/rfd3/tests/test_ligand_partial_fix.py | 57 +++++++++++++++++++ 6 files changed, 114 insertions(+), 7 deletions(-) create mode 100644 models/rfd3/tests/test_ligand_partial_fix.py diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py index 777387d9..c82131b3 100644 --- a/models/rfd3/src/rfd3/inference/input_parsing.py +++ b/models/rfd3/src/rfd3/inference/input_parsing.py @@ -661,10 +661,16 @@ def break_unindexed(unindex: InputSelection): def _append_ligand(self, atom_array, atom_array_input_annotated): """Append ligand if specified.""" if exists(self.ligand): + fixed_atoms = {} + if exists(self.select_fixed_atoms): + fixed_atoms = { + component: atoms + for component, atoms in self.select_fixed_atoms.data.items() + } ligand_array = extract_ligand_array( atom_array_input_annotated, self.ligand, - fixed_atoms={}, + fixed_atoms=fixed_atoms, set_defaults=False, additional_annotations=set( list(atom_array.get_annotation_categories()) @@ -714,9 +720,15 @@ def _set_origin(self, atom_array): infer_ori_strategy=self.infer_ori_strategy, ) # Diffused atoms are always initialized at origin during regular diffusion (all information removed) - atom_array.coord[ - ~atom_array.is_motif_atom_with_fixed_coord.astype(bool) - ] = 0.0 + unfixed_mask = ~atom_array.is_motif_atom_with_fixed_coord.astype(bool) + # Don't zero out unfixed ligand atoms - they must keep their original + # coordinates to maintain molecular connectivity with fixed ligand atoms. + # They will still receive noise during diffusion (since is_motif_atom_with_fixed_coord=False) + # but start near their true position rather than at the origin. + if exists(self.ligand): + is_ligand = atom_array.hetero.astype(bool) + unfixed_mask = unfixed_mask & ~is_ligand + atom_array.coord[unfixed_mask] = 0.0 return atom_array def _apply_globals(self, atom_array): diff --git a/models/rfd3/src/rfd3/inference/parsing.py b/models/rfd3/src/rfd3/inference/parsing.py index 65d5eb41..9e200eb7 100644 --- a/models/rfd3/src/rfd3/inference/parsing.py +++ b/models/rfd3/src/rfd3/inference/parsing.py @@ -117,7 +117,6 @@ def from_any_(v: Any, atom_array: AtomArray): # Split to atom names data_split[idx] = token.atom_name[comp_mask_subset].tolist() - # TODO: there is a bug where when you select specifc atoms within a ligand, output ligand is fragmented # Update mask & token dictionary mask[comp_mask] = comp_mask_subset diff --git a/models/rfd3/src/rfd3/model/inference_sampler.py b/models/rfd3/src/rfd3/model/inference_sampler.py index ac7fe9ec..0c70bced 100644 --- a/models/rfd3/src/rfd3/model/inference_sampler.py +++ b/models/rfd3/src/rfd3/model/inference_sampler.py @@ -132,9 +132,17 @@ def _get_initial_structure( L: int, coord_atom_lvl_to_be_noised: torch.Tensor, is_motif_atom_with_fixed_coord, + is_ligand_atom=None, ) -> torch.Tensor: noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device) noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in + # Zero out noise for unfixed ligand atoms so they start near their + # original positions, maintaining connectivity with fixed ligand atoms. + # They are NOT reinserted (unlike fixed atoms), so the model can still + # adjust their positions through denoising predictions. + if is_ligand_atom is not None: + unfixed_ligand = is_ligand_atom & ~is_motif_atom_with_fixed_coord + noise[..., unfixed_ligand, :] = 0 X_L = noise + coord_atom_lvl_to_be_noised return X_L @@ -151,6 +159,7 @@ def sample_diffusion_like_af3( ) -> dict[str, Any]: # Motif setup to recenter the motif at every step is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"] + is_ligand_atom = f.get("is_ligand_atom", None) # Book-keeping noise_schedule = self._construct_inference_noise_schedule( @@ -167,6 +176,7 @@ def sample_diffusion_like_af3( L=L, coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised.clone(), is_motif_atom_with_fixed_coord=is_motif_atom_with_fixed_coord, + is_ligand_atom=is_ligand_atom, ) # (D, L, 3) if self.s_jitter_origin > 0.0: @@ -203,6 +213,7 @@ def sample_diffusion_like_af3( # If keeping the motif position wrt the origin fixed, we can't do translational augmentation # We want to keep this position fixed in the interval where the model is not allowed to change it s_trans=self.s_trans if step_num >= threshold_step else 0.0, + is_ligand_atom=is_ligand_atom, ) # Update gamma & step scale @@ -221,6 +232,11 @@ def sample_diffusion_like_af3( epsilon_L[..., is_motif_atom_with_fixed_coord, :] = ( 0 # No noise injection for fixed atoms ) + # No noise injection for unfixed ligand atoms either - they must + # stay near their original positions to maintain molecular connectivity. + if is_ligand_atom is not None: + unfixed_ligand = is_ligand_atom & ~is_motif_atom_with_fixed_coord + epsilon_L[..., unfixed_ligand, :] = 0 X_noisy_L = X_L + epsilon_L # Denoise the coordinates @@ -324,6 +340,7 @@ def sample_diffusion_like_af3( coord_atom_lvl_to_be_noised, is_motif_atom_with_fixed_coord, reinsert_motif=self.insert_motif_at_end, + is_ligand_atom=is_ligand_atom, ) # Align prediction to original motif @@ -386,6 +403,7 @@ def sample_diffusion_like_af3( ) -> dict[str, Any]: # Motif setup to recenter the motif at every step is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"] + is_ligand_atom = f.get("is_ligand_atom", None) # Book-keeping noise_schedule = self._construct_inference_noise_schedule( device=coord_atom_lvl_to_be_noised.device, @@ -400,6 +418,7 @@ def sample_diffusion_like_af3( L=L, coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised.clone(), is_motif_atom_with_fixed_coord=is_motif_atom_with_fixed_coord, + is_ligand_atom=is_ligand_atom, ) # (D, L, 3) X_noisy_L_traj = [] @@ -428,6 +447,7 @@ def sample_diffusion_like_af3( X_L, coord_atom_lvl_to_be_noised, is_motif_atom_with_fixed_coord, + is_ligand_atom=is_ligand_atom, ) # Update gamma & step scale @@ -446,6 +466,11 @@ def sample_diffusion_like_af3( epsilon_L[..., is_motif_atom_with_fixed_coord, :] = ( 0 # No noise injection for fixed atoms ) + # No noise injection for unfixed ligand atoms either - they must + # stay near their original positions to maintain molecular connectivity. + if is_ligand_atom is not None: + unfixed_ligand = is_ligand_atom & ~is_motif_atom_with_fixed_coord + epsilon_L[..., unfixed_ligand, :] = 0 # NOTE: no symmetry applied to the noisy structure X_noisy_L = X_L + epsilon_L @@ -526,6 +551,7 @@ def sample_diffusion_like_af3( coord_atom_lvl_to_be_noised, is_motif_atom_with_fixed_coord, reinsert_motif=self.insert_motif_at_end, + is_ligand_atom=is_ligand_atom, ) # apply symmetry frame shift to X_L @@ -600,6 +626,7 @@ def centre_random_augment_around_motif( center_option: str = "all", centering_affects_motif: bool = True, reinsert_motif=True, + is_ligand_atom: torch.Tensor | None = None, # (L,) mask for ligand atoms ): D, L, _ = X_L.shape @@ -633,8 +660,14 @@ def centre_random_augment_around_motif( if centering_affects_motif: X_L = X_L - center else: - X_L[..., ~is_motif_atom_with_fixed_coord, :] = ( - X_L[..., ~is_motif_atom_with_fixed_coord, :] - center + # When centering only affects non-motif atoms, exclude ligand atoms + # from centering to prevent unfixed ligand atoms from drifting away + # from their fixed counterparts (which would fragment the molecule). + atoms_to_center = ~is_motif_atom_with_fixed_coord + if is_ligand_atom is not None: + atoms_to_center = atoms_to_center & ~is_ligand_atom + X_L[..., atoms_to_center, :] = ( + X_L[..., atoms_to_center, :] - center ) # ... Random augmentation diff --git a/models/rfd3/src/rfd3/transforms/design_transforms.py b/models/rfd3/src/rfd3/transforms/design_transforms.py index 38c79979..c75860eb 100644 --- a/models/rfd3/src/rfd3/transforms/design_transforms.py +++ b/models/rfd3/src/rfd3/transforms/design_transforms.py @@ -475,6 +475,10 @@ def forward(self, data: dict) -> dict: mask = token_level_array.get_annotation(x).copy().astype(bool) data["feats"][x] = mask + # Atom-level ligand mask for partial-ligand-fix handling in the sampler + if "is_ligand_atom" in self.X: + data["feats"]["is_ligand_atom"] = atom_array.is_ligand.copy().astype(bool) + if "is_motif_token_with_fully_fixed_coord" in self.X: mask = apply_token_wise( atom_array, diff --git a/models/rfd3/src/rfd3/transforms/pipelines.py b/models/rfd3/src/rfd3/transforms/pipelines.py index 89bc7289..187da66a 100644 --- a/models/rfd3/src/rfd3/transforms/pipelines.py +++ b/models/rfd3/src/rfd3/transforms/pipelines.py @@ -509,6 +509,8 @@ def build_atom14_base_pipeline_( "is_motif_atom_unindexed", "is_motif_atom_with_fixed_seq", "is_motif_token_with_fully_fixed_coord", + # Ligand mask (atom-level) for partial-ligand-fix handling + "is_ligand_atom", ], central_atom=central_atom, ), diff --git a/models/rfd3/tests/test_ligand_partial_fix.py b/models/rfd3/tests/test_ligand_partial_fix.py new file mode 100644 index 00000000..56cfd478 --- /dev/null +++ b/models/rfd3/tests/test_ligand_partial_fix.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import numpy as np +from beartype import beartype +from jaxtyping import Bool, Shaped +from rfd3.inference.input_parsing import DesignInputSpecification + +PDB_CONTENT = """\ +ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 10.00 N +ATOM 2 CA ALA A 1 1.500 0.000 0.000 1.00 10.00 C +ATOM 3 C ALA A 1 2.000 1.500 0.000 1.00 10.00 C +ATOM 4 O ALA A 1 2.000 2.500 0.000 1.00 10.00 O +HETATM 5 C1 LIG B 1 5.000 0.000 0.000 1.00 10.00 C +HETATM 6 O1 LIG B 1 6.200 0.000 0.000 1.00 10.00 O +HETATM 7 N1 LIG B 1 5.000 1.200 0.000 1.00 10.00 N +HETATM 8 C2 LIG B 1 6.200 1.200 0.000 1.00 10.00 C +TER +END +""" + + +@beartype +def _ligand_fixed_lookup( + atom_names: Shaped[np.ndarray, "n"], + fixed_mask: Bool[np.ndarray, "n"], +) -> dict[str, bool]: + """Map ligand atom names to fixed-coordinate flags.""" + return { + str(name): bool(is_fixed) + for name, is_fixed in zip(atom_names.tolist(), fixed_mask.tolist()) + } + + +def test_partial_ligand_fixed_atoms_respected(tmp_path: Path) -> None: + pdb_path = tmp_path / "ligand.pdb" + pdb_path.write_text(PDB_CONTENT) + + spec = DesignInputSpecification.safe_init( + input=pdb_path, + length=1, + ligand="LIG", + select_fixed_atoms={"LIG": "C1,O1"}, + ) + atom_array = spec.build(return_metadata=False) + + ligand_mask = atom_array.res_name == "LIG" + assert ligand_mask.any(), "Expected ligand atoms in output atom array." + + ligand_names = atom_array.atom_name[ligand_mask] + fixed_mask = atom_array.is_motif_atom_with_fixed_coord[ligand_mask].astype(bool) + fixed_lookup = _ligand_fixed_lookup(ligand_names, fixed_mask) + + assert set(fixed_lookup.keys()) == {"C1", "O1", "N1", "C2"} + assert fixed_lookup["C1"] + assert fixed_lookup["O1"] + assert not fixed_lookup["N1"] + assert not fixed_lookup["C2"] From 342387713040f67142de3c6ef0a8d15dedfe16d5 Mon Sep 17 00:00:00 2001 From: Chen-Jie7 Date: Tue, 24 Feb 2026 23:18:52 -0800 Subject: [PATCH 3/3] explicit temp dir --- models/rfd3/src/rfd3/transforms/hbonds_hbplus.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py b/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py index 56053120..c043c089 100644 --- a/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py +++ b/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py @@ -75,7 +75,9 @@ def calculate_hbonds( "Please set it to the path of the hbplus executable in order to calculate hydrogen bonds." ) - with tempfile.TemporaryDirectory() as tmpdir: + # Explicitly use /tmp to avoid writing temp files to the working directory + temp_base_dir = os.environ.get("TMPDIR", "/tmp") + with tempfile.TemporaryDirectory(dir=temp_base_dir) as tmpdir: dtstr = datetime.now().strftime("%Y%m%d%H%M%S") pdb_filename = f"{dtstr}_{np.random.randint(10000)}.pdb" pdb_path = os.path.join(tmpdir, pdb_filename)