Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions src/biotite/structure/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
]


from functools import partial
import itertools
import numpy as np
from biotite.structure.atoms import array as atom_array
from biotite.structure.atoms import AtomArrayStack
from biotite.structure.info.groups import (
amino_acid_names,
carbohydrate_names,
Expand Down Expand Up @@ -390,14 +390,17 @@ def filter_polymer(array, min_size=2, pol_type="peptide"):
# Import `check_res_id_continuity` here to avoid circular imports
from biotite.structure.integrity import check_res_id_continuity

split_idx = check_res_id_continuity(array)

check_pol = partial(_is_polymer, min_size=min_size, pol_type=pol_type)
bool_idx = map(
lambda a: np.full(len(a), check_pol(atom_array(a)), dtype=bool),
np.split(array, split_idx),
)
return np.concatenate(list(bool_idx))
if isinstance(array, AtomArrayStack):
array = array[0]

mask = np.zeros(len(array), dtype=bool)
discontinuity_idx = check_res_id_continuity(array)
for start, stop in itertools.pairwise(
itertools.chain([0], discontinuity_idx, [len(array)])
):
segment = array[..., start:stop]
mask[start:stop] = _is_polymer(segment, min_size, pol_type)
return mask


def filter_intersection(array, intersect, categories=None):
Expand Down
25 changes: 16 additions & 9 deletions tests/structure/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,32 +150,39 @@ def test_linear_bond_continuity_filter(canonical_sample_protein):
assert len(pro[struc.filter_linear_bond_continuity(pro)]) == 6


def test_polymer_filter(canonical_sample_nucleotide, sample_carbohydrate):
@pytest.mark.parametrize("as_stack", [False, True])
def test_polymer_filter(canonical_sample_nucleotide, sample_carbohydrate, as_stack):
if as_stack:
canonical_sample_nucleotide = struc.stack([canonical_sample_nucleotide] * 2)
sample_carbohydrate = struc.stack([sample_carbohydrate] * 2)

a = canonical_sample_nucleotide

# Check for nucleotide filtering
a_nuc = a[struc.filter_polymer(a, pol_type="n")]
a_nuc = a[..., struc.filter_polymer(a, pol_type="n")]
# Take three nucleic acids chains and remove solvent => the result should
# encompass all nucleotide polymer atoms, which is exactly the output of the
# `filter_polymer()`. In the structure file, the filtered atoms are 1-651.
a_nuc_manual = a[np.isin(a.chain_id, ["D", "P", "T"]) & ~struc.filter_solvent(a)]
assert len(a_nuc) == len(a_nuc_manual) == 651
a_nuc_manual = a[
..., np.isin(a.chain_id, ["D", "P", "T"]) & ~struc.filter_solvent(a)
]
assert a_nuc.array_length() == a_nuc_manual.array_length() == 651
assert set(a_nuc.chain_id) == {"D", "P", "T"}
# chain D should be absent
a_nuc = a_nuc[struc.filter_polymer(a_nuc, min_size=6, pol_type="n")]
a_nuc = a_nuc[..., struc.filter_polymer(a_nuc, min_size=6, pol_type="n")]
assert set(a_nuc.chain_id) == {"P", "T"}

# Single protein chain A: residues 10-335
a_pep = a[struc.filter_polymer(a, pol_type="p")]
assert len(a_pep) == len(
a[(a.res_id >= 10) & (a.res_id <= 335) & (a.chain_id == "A")]
a_pep = a[..., struc.filter_polymer(a, pol_type="p")]
assert a_pep.array_length() == np.count_nonzero(
(a.res_id >= 10) & (a.res_id <= 335) & (a.chain_id == "A")
)

# Chain B has five carbohydrate residues
# Chain C has four
# => Only chain B is selected
a = sample_carbohydrate
a_carb = a[struc.filter_polymer(a, min_size=4, pol_type="carb")]
a_carb = a[..., struc.filter_polymer(a, min_size=4, pol_type="carb")]
assert set(a_carb.chain_id) == {"B"}
assert struc.get_residue_count(a_carb) == 5

Expand Down
Loading