From 07e26606bb89e1435a3f30ab9a6e16a360b56052 Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Thu, 12 Feb 2026 17:22:53 +0100 Subject: [PATCH] Fix #865 --- src/biotite/structure/filter.py | 23 +++++++++++++---------- tests/structure/test_filter.py | 25 ++++++++++++++++--------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/biotite/structure/filter.py b/src/biotite/structure/filter.py index dd0467725..ea8ac84cd 100644 --- a/src/biotite/structure/filter.py +++ b/src/biotite/structure/filter.py @@ -28,9 +28,9 @@ ] -from functools import partial +import itertools import numpy as np -from biotite.structure.atoms import array as atom_array +from biotite.structure.atoms import AtomArrayStack from biotite.structure.info.groups import ( amino_acid_names, carbohydrate_names, @@ -390,14 +390,17 @@ def filter_polymer(array, min_size=2, pol_type="peptide"): # Import `check_res_id_continuity` here to avoid circular imports from biotite.structure.integrity import check_res_id_continuity - split_idx = check_res_id_continuity(array) - - check_pol = partial(_is_polymer, min_size=min_size, pol_type=pol_type) - bool_idx = map( - lambda a: np.full(len(a), check_pol(atom_array(a)), dtype=bool), - np.split(array, split_idx), - ) - return np.concatenate(list(bool_idx)) + if isinstance(array, AtomArrayStack): + array = array[0] + + mask = np.zeros(len(array), dtype=bool) + discontinuity_idx = check_res_id_continuity(array) + for start, stop in itertools.pairwise( + itertools.chain([0], discontinuity_idx, [len(array)]) + ): + segment = array[..., start:stop] + mask[start:stop] = _is_polymer(segment, min_size, pol_type) + return mask def filter_intersection(array, intersect, categories=None): diff --git a/tests/structure/test_filter.py b/tests/structure/test_filter.py index f0322c51c..047e2cba3 100644 --- a/tests/structure/test_filter.py +++ b/tests/structure/test_filter.py @@ -150,32 +150,39 @@ def test_linear_bond_continuity_filter(canonical_sample_protein): assert len(pro[struc.filter_linear_bond_continuity(pro)]) == 6 -def test_polymer_filter(canonical_sample_nucleotide, sample_carbohydrate): +@pytest.mark.parametrize("as_stack", [False, True]) +def test_polymer_filter(canonical_sample_nucleotide, sample_carbohydrate, as_stack): + if as_stack: + canonical_sample_nucleotide = struc.stack([canonical_sample_nucleotide] * 2) + sample_carbohydrate = struc.stack([sample_carbohydrate] * 2) + a = canonical_sample_nucleotide # Check for nucleotide filtering - a_nuc = a[struc.filter_polymer(a, pol_type="n")] + a_nuc = a[..., struc.filter_polymer(a, pol_type="n")] # Take three nucleic acids chains and remove solvent => the result should # encompass all nucleotide polymer atoms, which is exactly the output of the # `filter_polymer()`. In the structure file, the filtered atoms are 1-651. - a_nuc_manual = a[np.isin(a.chain_id, ["D", "P", "T"]) & ~struc.filter_solvent(a)] - assert len(a_nuc) == len(a_nuc_manual) == 651 + a_nuc_manual = a[ + ..., np.isin(a.chain_id, ["D", "P", "T"]) & ~struc.filter_solvent(a) + ] + assert a_nuc.array_length() == a_nuc_manual.array_length() == 651 assert set(a_nuc.chain_id) == {"D", "P", "T"} # chain D should be absent - a_nuc = a_nuc[struc.filter_polymer(a_nuc, min_size=6, pol_type="n")] + a_nuc = a_nuc[..., struc.filter_polymer(a_nuc, min_size=6, pol_type="n")] assert set(a_nuc.chain_id) == {"P", "T"} # Single protein chain A: residues 10-335 - a_pep = a[struc.filter_polymer(a, pol_type="p")] - assert len(a_pep) == len( - a[(a.res_id >= 10) & (a.res_id <= 335) & (a.chain_id == "A")] + a_pep = a[..., struc.filter_polymer(a, pol_type="p")] + assert a_pep.array_length() == np.count_nonzero( + (a.res_id >= 10) & (a.res_id <= 335) & (a.chain_id == "A") ) # Chain B has five carbohydrate residues # Chain C has four # => Only chain B is selected a = sample_carbohydrate - a_carb = a[struc.filter_polymer(a, min_size=4, pol_type="carb")] + a_carb = a[..., struc.filter_polymer(a, min_size=4, pol_type="carb")] assert set(a_carb.chain_id) == {"B"} assert struc.get_residue_count(a_carb) == 5