From ffd767c7a9c43367ae79a11f187a51c1178ffb8a Mon Sep 17 00:00:00 2001 From: Simon Mathis Date: Sun, 20 Jul 2025 13:45:24 +0100 Subject: [PATCH 1/3] Add pandas-style query functionality to `AtomArray` and `AtomArrayStack` objects --- src/biotite/structure/atoms.py | 67 +++++ src/biotite/structure/query.py | 516 +++++++++++++++++++++++++++++++++ tests/structure/test_query.py | 447 ++++++++++++++++++++++++++++ 3 files changed, 1030 insertions(+) create mode 100644 src/biotite/structure/query.py create mode 100644 tests/structure/test_query.py diff --git a/src/biotite/structure/atoms.py b/src/biotite/structure/atoms.py index b5f6c395a..e98a9e51a 100644 --- a/src/biotite/structure/atoms.py +++ b/src/biotite/structure/atoms.py @@ -26,6 +26,7 @@ from collections.abc import Sequence import numpy as np from biotite.copyable import Copyable +from biotite.structure import query from biotite.structure.bonds import BondList @@ -440,6 +441,72 @@ def _copy_annotations(self, clone): if self._bonds is not None: clone._bonds = self._bonds.copy() + def query(self, expr): + """ + Query the AtomArray using a pandas-like expression string. + + Parameters + ---------- + expr : str + Query expression in pandas-like syntax. + + Returns + ------- + filtered_array : AtomArray or AtomArrayStack + Filtered atom array containing only atoms that match the query expression. + + Examples + -------- + Select all CA atoms in chain A: + + >>> ca_atoms = atom_array.query("(chain_id == 'A') & (atom_name == 'CA')") + + Select atoms without NaN coordinates: + + >>> valid_atoms = atom_array.query("~has_nan_coord()") + """ + return query.query(self, expr) + + def mask(self, expr): + """ + Query the AtomArray using a pandas-like expression string and return a boolean mask. + + Parameters + ---------- + expr : str + Query expression in pandas-like syntax. + + Returns + ------- + mask : ndarray, dtype=bool + Boolean numpy array indicating which atoms match the query. + + Examples + -------- + >>> mask = atom_array.mask("atom_name == 'CA'") + """ + return query.mask(self, expr) + + def idxs(self, expr): + """ + Query the AtomArray using a pandas-like expression string and return the indices of matching atoms. + + Parameters + ---------- + expr : str + Query expression in pandas-like syntax. + + Returns + ------- + indices : ndarray, dtype=int + Numpy array of indices for atoms that match the query expression. + + Examples + -------- + >>> idxs = atom_array.idxs("atom_name == 'CA'") + """ + return query.idxs(self, expr) + class Atom(Copyable): """ diff --git a/src/biotite/structure/query.py b/src/biotite/structure/query.py new file mode 100644 index 000000000..7f96d6fae --- /dev/null +++ b/src/biotite/structure/query.py @@ -0,0 +1,516 @@ +""" +This module provides pandas-like query functionality for AtomArray and AtomArrayStack objects. + +The main class :class:`QueryExpression` allows filtering atom arrays using +string-based expressions similar to pandas query syntax. Functions are also +provided for direct querying without creating a QueryExpression object. +""" + +__name__ = "biotite.structure.query" +__author__ = "The Biotite contributors" +__all__ = ["QueryExpression", "query", "mask", "idxs"] + +import ast +import operator +from types import MappingProxyType +import numpy as np + + +class QueryExpression: + """ + Query evaluator for biotite AtomArrays using pandas-like syntax. + + This class provides a way to filter atom arrays using string expressions + that are parsed and evaluated as boolean masks. The expressions support + comparison operators, logical operators, and built-in functions. + + Parameters + ---------- + expr : str + The query expression string to parse and evaluate. + + Examples + -------- + Select all CA atoms in chain A: + + >>> expr = QueryExpression("(chain_id == 'A') & (atom_name == 'CA')") + >>> ca_atoms = expr.query(atom_array) + + Select atoms without NaN coordinates: + + >>> expr = QueryExpression("~has_nan_coord()") + >>> valid_atoms = expr.query(atom_array) + + Select bonded atoms in specific residues: + + >>> expr = QueryExpression("has_bonds() & (res_name in ['ALA', 'GLY', 'VAL'])") + >>> bonded = expr.query(atom_array) + """ + + # Map string operators to functions + OPS = MappingProxyType( + { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, + # Special handling for In/NotIn will be done in _eval_node + ast.In: None, + ast.NotIn: None, + # Logical operators + ast.And: np.logical_and, + ast.Or: np.logical_or, + ast.Not: np.logical_not, + # Bitwise operators (which act as logical for boolean arrays) + ast.BitAnd: np.bitwise_and, + ast.BitOr: np.bitwise_or, + ast.Invert: np.invert, + ast.UAdd: operator.pos, + ast.USub: operator.neg, + } + ) + + def __init__(self, expr): + self.expr = expr + # Parse once during initialization for efficiency + self.tree = ast.parse(expr, mode="eval") + + def mask(self, atom_array): + """ + Apply the query expression to an AtomArray and return a boolean mask. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to query. + + Returns + ------- + mask : ndarray, dtype=bool + Boolean numpy array indicating which atoms match the query. + """ + namespace = self._build_namespace(atom_array) + functions = self._build_functions(atom_array) + mask = self._eval_node(self.tree.body, namespace, functions, atom_array) + + # Ensure result is boolean array of correct length + mask = self._ensure_bool_array(mask, atom_array.array_length()) + + return mask + + def query(self, atom_array): + """ + Apply the query expression to an AtomArray and return a filtered AtomArray. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to query. + + Returns + ------- + filtered_array : AtomArray or AtomArrayStack + Filtered atom array containing only atoms that match the query expression. + """ + mask = self.mask(atom_array) + return atom_array[..., mask] + + def idxs(self, atom_array): + """ + Apply the query expression to an AtomArray and return the indices of matching atoms. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to query. + + Returns + ------- + indices : ndarray, dtype=int + Numpy array of indices for atoms that match the query expression. + """ + mask = self.mask(atom_array) + return np.where(mask)[0] + + @staticmethod + def _build_namespace(atom_array): + """ + Build namespace of queryable attributes. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to build namespace from. + + Returns + ------- + namespace : dict + Dictionary mapping attribute names to their values. + """ + namespace = {} + + # Add all annotation arrays as queryable attributes + for attr in atom_array.get_annotation_categories(): + namespace[attr] = getattr(atom_array, attr) + + return namespace + + @staticmethod + def _build_functions(atom_array): + """ + Build available functions that can be called in queries. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to build functions for. + + Returns + ------- + functions : dict + Dictionary mapping function names to callable functions. + """ + functions = { + "has_nan_coord": lambda: QueryExpression._has_nan_coord(atom_array), + "has_bonds": lambda: QueryExpression._has_bonds(atom_array), + } + return functions + + @staticmethod + def _has_nan_coord(atom_array): + """ + Check if atom has NaN coordinates. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to check. + + Returns + ------- + has_nan : ndarray, dtype=bool + Boolean numpy array indicating which atoms have NaN coordinates. + """ + # To handle both AtomArray & AtomArrayStack, we need to reduce + # over all axes except the second to last (i.e. the atom dimension) + reduce_axes = tuple( + i for i in range(atom_array.coord.ndim) if i != atom_array.coord.ndim - 2 + ) + return np.isnan(atom_array.coord).any(axis=reduce_axes) + + @staticmethod + def _has_bonds(atom_array): + """ + Check if atom is involved in a bond. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to check. + + Returns + ------- + has_bonds : ndarray, dtype=bool + Boolean numpy array indicating which atoms are involved in bonds. + """ + if not hasattr(atom_array, "bonds") or atom_array.bonds is None: + return np.zeros(atom_array.array_length(), dtype=bool) + _bonded_idxs = np.unique(atom_array.bonds.as_array()[:, :2]) + return np.isin(np.arange(atom_array.array_length()), _bonded_idxs) + + @staticmethod + def _ensure_bool_array(mask, expected_length): + """ + Ensure mask is a boolean numpy array of the correct length. + + Parameters + ---------- + mask : array-like + The mask to ensure is a boolean array. + expected_length : int + The expected length of the array. + + Returns + ------- + mask : ndarray, dtype=bool + Boolean numpy array of the correct length. + + Raises + ------ + ValueError + If the mask length doesn't match the expected length. + """ + # Convert to numpy array if needed + if not isinstance(mask, np.ndarray): + mask = np.array(mask, dtype=bool) + + # Handle scalar boolean result + if mask.shape == () or mask.ndim == 0: + mask = np.full(expected_length, bool(mask), dtype=bool) + + # Ensure boolean dtype + if mask.dtype != bool: + mask = mask.astype(bool) + + # Check length + if len(mask) != expected_length: + raise ValueError( + f"Query resulted in mask of length {len(mask)}, " + f"but AtomArray has length {expected_length}" + ) + + return mask + + def _handle_in_operator(self, left, right, invert=False): + """ + Handle 'in' and 'not in' operators with numpy arrays. + + Parameters + ---------- + left : array-like or scalar + Left operand of the in/not in operation. + right : array-like + Right operand of the in/not in operation. + invert : bool, optional + Whether to invert the result (for 'not in'). Default is False. + + Returns + ------- + result : ndarray, dtype=bool + Boolean numpy array result of the in/not in operation. + + Raises + ------ + TypeError + If the right operand is not iterable. + """ + # Convert right to list/array if needed + if isinstance(right, (list, tuple, np.ndarray)): + # Use numpy's isin for array operations + if isinstance(left, np.ndarray): + return np.isin(left, right, invert=invert) + else: + # Single value + return (left not in right) if invert else (left in right) + else: + raise TypeError(f"Argument of type '{type(right)}' is not iterable") + + def _eval_node(self, node, namespace, functions, atom_array): + """ + Recursively evaluate an AST node. + + Parameters + ---------- + node : ast.AST + The AST node to evaluate. + namespace : dict + Dictionary of available variables. + functions : dict + Dictionary of available functions. + atom_array : AtomArray or AtomArrayStack + The atom array being queried. + + Returns + ------- + result : any + The result of evaluating the AST node. + + Raises + ------ + ValueError + If an unsupported operation or node type is encountered. + NameError + If a name or function is not defined. + """ + if isinstance(node, ast.Compare): + left = self._eval_node(node.left, namespace, functions, atom_array) + results = [] + + for op, comparator in zip(node.ops, node.comparators, strict=False): + right = self._eval_node(comparator, namespace, functions, atom_array) + + # Special handling for In/NotIn operators + if isinstance(op, ast.In): + results.append(self._handle_in_operator(left, right, invert=False)) + elif isinstance(op, ast.NotIn): + results.append(self._handle_in_operator(left, right, invert=True)) + else: + op_func = self.OPS[type(op)] + results.append(op_func(left, right)) + + left = right + + # Chain multiple comparisons with AND + if len(results) > 1: + result = results[0] + for r in results[1:]: + result = np.logical_and(result, r) + return result + else: + return results[0] + + elif isinstance(node, ast.BoolOp): + op_func = self.OPS[type(node.op)] + values = [ + self._eval_node(value, namespace, functions, atom_array) + for value in node.values + ] + + # Ensure all values are boolean arrays of correct length + values = [ + self._ensure_bool_array(v, atom_array.array_length()) for v in values + ] + + # Use numpy operations for boolean arrays + result = values[0] + for val in values[1:]: + result = op_func(result, val) + return result + + elif isinstance(node, ast.BinOp): + # Handle bitwise operations (& and |) + if type(node.op) in [ast.BitAnd, ast.BitOr]: + left = self._eval_node(node.left, namespace, functions, atom_array) + right = self._eval_node(node.right, namespace, functions, atom_array) + + # Ensure boolean arrays + left = self._ensure_bool_array(left, atom_array.array_length()) + right = self._ensure_bool_array(right, atom_array.array_length()) + + op_func = self.OPS[type(node.op)] + return op_func(left, right) + else: + raise ValueError(f"Unsupported binary operation: {type(node.op)}") + + elif isinstance(node, ast.UnaryOp): + op_func = self.OPS[type(node.op)] + operand = self._eval_node(node.operand, namespace, functions, atom_array) + + # Ensure boolean array for logical operations + if type(node.op) in [ast.Not, ast.Invert]: + operand = self._ensure_bool_array(operand, atom_array.array_length()) + + return op_func(operand) + + elif isinstance(node, ast.Call): + # Handle function calls + if isinstance(node.func, ast.Name): + func_name = node.func.id + if func_name in functions: + # Call the function (no arguments supported for now) + if node.args or node.keywords: + raise ValueError( + f"Function '{func_name}' does not accept arguments" + ) + result = functions[func_name]() + # Ensure it returns a boolean array of correct length + return self._ensure_bool_array(result, atom_array.array_length()) + else: + raise NameError(f"Function '{func_name}' is not defined") + else: + raise ValueError("Complex function calls not supported") + + elif isinstance(node, ast.Name): + if node.id in namespace: + return namespace[node.id] + raise NameError(f"Name '{node.id}' is not defined") + + elif isinstance(node, ast.Constant): + return node.value + + elif isinstance(node, ast.List): + return [ + self._eval_node(elt, namespace, functions, atom_array) + for elt in node.elts + ] + + elif isinstance(node, ast.Tuple): + return tuple( + self._eval_node(elt, namespace, functions, atom_array) + for elt in node.elts + ) + + else: + raise ValueError(f"Unsupported node type: {type(node)}") + + def __str__(self): + return self.expr + + def __repr__(self): + return f"QueryExpression('{self.expr}')" + + +def query(atom_array, expr): + """ + Query the AtomArray using pandas-like syntax. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to query. + expr : str + Query expression in pandas-like syntax. + + Returns + ------- + filtered_array : AtomArray or AtomArrayStack + Filtered atom array containing only atoms that match the query expression. + + Examples + -------- + Select all CA atoms in chain A: + + >>> ca_atoms = query(atom_array, "(chain_id == 'A') & (atom_name == 'CA')") + + Select atoms without NaN coordinates: + + >>> valid_atoms = query(atom_array, "~has_nan_coord()") + + Select bonded atoms in specific residues: + + >>> bonded = query(atom_array, "has_bonds() & (res_name in ['ALA', 'GLY', 'VAL'])") + """ + querier = QueryExpression(expr) + return querier.query(atom_array) + + +def mask(atom_array, expr): + """ + Query the AtomArray using pandas-like syntax and return a boolean mask. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to query. + expr : str + Query expression in pandas-like syntax. + + Returns + ------- + mask : ndarray, dtype=bool + Boolean numpy array indicating which atoms match the query. + """ + querier = QueryExpression(expr) + return querier.mask(atom_array) + + +def idxs(atom_array, expr): + """ + Query the AtomArray using pandas-like syntax and return the indices of matching atoms. + + Parameters + ---------- + atom_array : AtomArray or AtomArrayStack + The atom array to query. + expr : str + Query expression in pandas-like syntax. + + Returns + ------- + indices : ndarray, dtype=int + Numpy array of indices for atoms that match the query expression. + """ + querier = QueryExpression(expr) + return querier.idxs(atom_array) diff --git a/tests/structure/test_query.py b/tests/structure/test_query.py new file mode 100644 index 000000000..b34d1ec71 --- /dev/null +++ b/tests/structure/test_query.py @@ -0,0 +1,447 @@ +import numpy as np +import pytest +import biotite.structure as struc +from biotite.structure import query +from biotite.structure.bonds import BondList + + +@pytest.fixture +def atom_array(): + """Create a test AtomArray with diverse annotations.""" + chain_id = ["A", "A", "B", "B", "B", "C", "C"] + res_id = [1, 1, 2, 2, 3, 3, 3] + ins_code = ["", "", "", "A", "", "", ""] + res_name = ["ALA", "ALA", "GLY", "GLY", "PRO", "PRO", "PRO"] + hetero = [False, False, False, False, False, True, False] + atom_name = ["N", "CA", "N", "CA", "N", "CA", "C"] + element = ["N", "C", "N", "C", "N", "C", "C"] + + atom_list = [] + for i in range(7): + atom_list.append( + struc.Atom( + [i, i + 1, i + 2], + chain_id=chain_id[i], + res_id=res_id[i], + ins_code=ins_code[i], + res_name=res_name[i], + hetero=hetero[i], + atom_name=atom_name[i], + element=element[i], + ) + ) + array = struc.array(atom_list) + + # Add some NaN coordinates to test has_nan_coord + array.coord[3] = [np.nan, np.nan, np.nan] + + # Add bonds to test has_bonds function + bonds = BondList(array.array_length()) + bonds.add_bond(0, 1) # N-CA bond in residue 1 + bonds.add_bond(2, 3) # N-CA bond in residue 2 + bonds.add_bond(4, 5) # N-CA bond in residue 3 + array.bonds = bonds + + return array + + +@pytest.fixture +def atom_array_stack(atom_array): + """Create a test AtomArrayStack.""" + array2 = atom_array.copy() + array2.coord += 10 + array3 = atom_array.copy() + array3.coord += 20 + return struc.stack([atom_array, array2, array3]) + + +class TestQueryExpression: + """Test QueryExpression class functionality.""" + + def test_initialization(self): + """Test QueryExpression initialization.""" + expr = query.QueryExpression("chain_id == 'A'") + assert expr.expr == "chain_id == 'A'" + assert expr.tree is not None + + def test_simple_equality(self, atom_array): + """Test simple equality queries.""" + expr = query.QueryExpression("chain_id == 'A'") + result = expr.query(atom_array) + expected_indices = [0, 1] # First two atoms are in chain A + assert result.array_length() == len(expected_indices) + assert (result.chain_id == "A").all() + + def test_inequality(self, atom_array): + """Test inequality queries.""" + expr = query.QueryExpression("res_id != 1") + result = expr.query(atom_array) + assert (result.res_id != 1).all() + assert result.array_length() == 5 # 7 total - 2 from res_id=1 + + def test_comparison_operators(self, atom_array): + """Test comparison operators (<, <=, >, >=).""" + # Test greater than + expr = query.QueryExpression("res_id > 1") + result = expr.query(atom_array) + assert (result.res_id > 1).all() + + # Test less than or equal + expr = query.QueryExpression("res_id <= 2") + result = expr.query(atom_array) + assert (result.res_id <= 2).all() + + def test_logical_operators(self, atom_array): + """Test logical AND (&) and OR (|) operators.""" + # Test AND with & + expr = query.QueryExpression("(chain_id == 'A') & (atom_name == 'CA')") + result = expr.query(atom_array) + assert result.array_length() == 1 + assert result.chain_id[0] == "A" + assert result.atom_name[0] == "CA" + + # Test OR with | + expr = query.QueryExpression("(chain_id == 'A') | (chain_id == 'C')") + result = expr.query(atom_array) + assert result.array_length() == 4 # 2 from A + 2 from C + + # Test NOT with ~ + expr = query.QueryExpression("~(chain_id == 'A')") + result = expr.query(atom_array) + assert result.array_length() == 5 # 7 total - 2 from A + assert not (result.chain_id == "A").any() + + def test_in_operator(self, atom_array): + """Test 'in' and 'not in' operators.""" + # Test 'in' with list + expr = query.QueryExpression("chain_id in ['A', 'C']") + result = expr.query(atom_array) + assert result.array_length() == 4 + assert set(result.chain_id.tolist()) == {"A", "C"} + + # Test 'not in' + expr = query.QueryExpression("res_name not in ['ALA', 'GLY']") + result = expr.query(atom_array) + assert (result.res_name == "PRO").all() + + def test_functions(self, atom_array): + """Test built-in functions like has_nan_coord and has_bonds.""" + # Test has_nan_coord + expr = query.QueryExpression("has_nan_coord()") + result = expr.query(atom_array) + assert result.array_length() == 1 # Only one atom has NaN coords + + expr = query.QueryExpression("~has_nan_coord()") + result = expr.query(atom_array) + assert result.array_length() == 6 # 7 total - 1 with NaN + + # Test has_bonds + expr = query.QueryExpression("has_bonds()") + result = expr.query(atom_array) + # Atoms 0,1,2,3,4,5 are involved in bonds (but atom 3 has NaN coords) + expected_bonded = 6 + assert result.array_length() == expected_bonded + + def test_complex_queries(self, atom_array): + """Test complex combined queries.""" + expr = query.QueryExpression( + "(chain_id == 'A') & (atom_name == 'CA') & ~has_nan_coord()" + ) + result = expr.query(atom_array) + assert result.array_length() == 1 + assert result.chain_id[0] == "A" + assert result.atom_name[0] == "CA" + + # Complex query with functions and operators + expr = query.QueryExpression( + "has_bonds() & (res_name in ['ALA', 'GLY']) & ~has_nan_coord()" + ) + result = expr.query(atom_array) + # Should get atoms 0, 1, 2 (atom 3 has NaN coords) + assert result.array_length() == 3 + + def test_mask_method(self, atom_array): + """Test the mask method returns correct boolean array.""" + expr = query.QueryExpression("chain_id == 'A'") + mask = expr.mask(atom_array) + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert mask.shape == (atom_array.array_length(),) + assert mask.sum() == 2 # Two atoms in chain A + assert mask[0] and mask[1] and not mask[2] + + def test_idxs_method(self, atom_array): + """Test the idxs method returns correct indices.""" + expr = query.QueryExpression("chain_id == 'A'") + indices = expr.idxs(atom_array) + assert isinstance(indices, np.ndarray) + assert indices.dtype == np.int64 + assert indices.tolist() == [0, 1] + + def test_atom_array_stack(self, atom_array_stack): + """Test queries work with AtomArrayStack.""" + expr = query.QueryExpression("chain_id == 'A'") + result = expr.query(atom_array_stack) + assert isinstance(result, struc.AtomArrayStack) + assert result.stack_depth() == atom_array_stack.stack_depth() + assert result.array_length() == 2 # Two atoms in chain A + + def test_error_handling(self, atom_array): + """Test error handling for invalid queries.""" + # Test undefined name + expr = query.QueryExpression("undefined_attr == 'test'") + with pytest.raises(NameError, match="Name 'undefined_attr' is not defined"): + expr.query(atom_array) + + # Test undefined function + expr = query.QueryExpression("undefined_func()") + with pytest.raises(NameError, match="Function 'undefined_func' is not defined"): + expr.query(atom_array) + + # Test function with arguments (not supported) + expr = query.QueryExpression("has_nan_coord(True)") + with pytest.raises(ValueError, match="does not accept arguments"): + expr.query(atom_array) + + # Test invalid 'in' operand + expr = query.QueryExpression("chain_id in 42") + with pytest.raises(TypeError, match="is not iterable"): + expr.query(atom_array) + + +class TestStandaloneFunctions: + """Test standalone query functions.""" + + def test_query_function(self, atom_array): + """Test the standalone query function.""" + result = query.query(atom_array, "chain_id == 'A'") + assert isinstance(result, struc.AtomArray) + assert result.array_length() == 2 + assert (result.chain_id == "A").all() + + def test_mask_function(self, atom_array): + """Test the standalone mask function.""" + mask = query.mask(atom_array, "chain_id == 'A'") + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert mask.sum() == 2 + + def test_idxs_function(self, atom_array): + """Test the standalone idxs function.""" + indices = query.idxs(atom_array, "chain_id == 'A'") + assert isinstance(indices, np.ndarray) + assert indices.tolist() == [0, 1] + + +class TestAtomArrayMethods: + """Test query methods added to AtomArray and AtomArrayStack.""" + + def test_atom_array_query_method(self, atom_array): + """Test AtomArray.query method.""" + result = atom_array.query("chain_id == 'A'") + assert isinstance(result, struc.AtomArray) + assert result.array_length() == 2 + assert (result.chain_id == "A").all() + + def test_atom_array_mask_method(self, atom_array): + """Test AtomArray.mask method.""" + mask = atom_array.mask("chain_id == 'A'") + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert mask.sum() == 2 + + def test_atom_array_idxs_method(self, atom_array): + """Test AtomArray.idxs method.""" + indices = atom_array.idxs("chain_id == 'A'") + assert isinstance(indices, np.ndarray) + assert indices.tolist() == [0, 1] + + def test_atom_array_stack_methods(self, atom_array_stack): + """Test AtomArrayStack query methods.""" + result = atom_array_stack.query("chain_id == 'A'") + assert isinstance(result, struc.AtomArrayStack) + assert result.array_length() == 2 + + mask = atom_array_stack.mask("chain_id == 'A'") + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert mask.sum() == 2 + + indices = atom_array_stack.idxs("chain_id == 'A'") + assert isinstance(indices, np.ndarray) + assert indices.tolist() == [0, 1] + + +class TestEdgeCases: + """Test edge cases and special scenarios.""" + + def test_empty_results(self, atom_array): + """Test queries that return empty results.""" + expr = query.QueryExpression("chain_id == 'Z'") # Non-existent chain + result = expr.query(atom_array) + assert result.array_length() == 0 + + mask = expr.mask(atom_array) + assert not mask.any() + + indices = expr.idxs(atom_array) + assert len(indices) == 0 + + def test_all_atoms_match(self, atom_array): + """Test queries where all atoms match.""" + expr = query.QueryExpression("res_id >= 1") + result = expr.query(atom_array) + assert result.array_length() == atom_array.array_length() + + mask = expr.mask(atom_array) + assert mask.all() + + def test_no_bonds(self): + """Test has_bonds function on array without bonds.""" + atom = struc.Atom([0, 0, 0], chain_id="A") + array = struc.array([atom]) + + expr = query.QueryExpression("has_bonds()") + result = expr.query(array) + assert result.array_length() == 0 + + expr = query.QueryExpression("~has_bonds()") + result = expr.query(array) + assert result.array_length() == 1 + + def test_chained_comparisons(self, atom_array): + """Test chained comparison operators.""" + expr = query.QueryExpression("1 <= res_id <= 2") + result = expr.query(atom_array) + expected_count = sum(1 <= rid <= 2 for rid in atom_array.res_id) + assert result.array_length() == expected_count + + def test_scalar_to_array_broadcasting(self, atom_array): + """Test scalar boolean results getting broadcast to array.""" + # This should create a scalar True result that gets broadcast + expr = query.QueryExpression("True") + mask = expr.mask(atom_array) + assert mask.all() + assert len(mask) == atom_array.array_length() + + expr = query.QueryExpression("False") + mask = expr.mask(atom_array) + assert not mask.any() + assert len(mask) == atom_array.array_length() + + +class TestBuiltinFunctions: + """Test built-in query functions in detail.""" + + def test_has_nan_coord_edge_cases(self): + """Test has_nan_coord with various coordinate patterns.""" + # Array with partial NaN + array = struc.AtomArray(3) + array.coord = np.array( + [ + [1.0, 2.0, 3.0], # No NaN + [np.nan, 2.0, 3.0], # Partial NaN + [np.nan, np.nan, np.nan], # All NaN + ] + ) + + expr = query.QueryExpression("has_nan_coord()") + result = expr.query(array) + assert result.array_length() == 2 # Two atoms have NaN coords + + # Test with AtomArrayStack + stack = struc.stack([array, array]) + result = expr.query(stack) + assert result.array_length() == 2 + + def test_has_bonds_edge_cases(self): + """Test has_bonds with various bond configurations.""" + array = struc.AtomArray(4) + array.coord = np.random.rand(4, 3) + + # Add bonds only to some atoms + bonds = BondList(4) + bonds.add_bond(0, 1) # Atoms 0 and 1 are bonded + # Atoms 2 and 3 have no bonds + array.bonds = bonds + + expr = query.QueryExpression("has_bonds()") + result = expr.query(array) + assert result.array_length() == 2 # Only atoms 0 and 1 + + expr = query.QueryExpression("~has_bonds()") + result = expr.query(array) + assert result.array_length() == 2 # Only atoms 2 and 3 + + +@pytest.mark.parametrize( + "query_str,expected_count", + [ + ("chain_id == 'A'", 2), + ("res_name == 'PRO'", 3), + ("hetero == True", 1), + ("atom_name == 'CA'", 3), + ("element == 'N'", 3), + ("ins_code == 'A'", 1), + ("res_id == 3", 3), + ], +) +def test_parametrized_queries(atom_array, query_str, expected_count): + """Test various queries with parametrized inputs.""" + expr = query.QueryExpression(query_str) + result = expr.query(atom_array) + assert result.array_length() == expected_count + + +class TestStringRepresentation: + """Test string representations of QueryExpression.""" + + def test_str_and_repr(self): + """Test __str__ and __repr__ methods.""" + expr_str = "chain_id == 'A'" + expr = query.QueryExpression(expr_str) + + assert str(expr) == expr_str + assert repr(expr) == f"QueryExpression('{expr_str}')" + + +class TestCompatibilityWithExistingCode: + """Test that query functionality works well with existing Biotite patterns.""" + + def test_with_filtering_and_slicing(self, atom_array): + """Test query combined with traditional filtering.""" + # First filter with query, then with traditional indexing + ca_atoms = atom_array.query("atom_name == 'CA'") + chain_a_ca = ca_atoms[ca_atoms.chain_id == "A"] + assert chain_a_ca.array_length() == 1 + + # Combine with slicing + first_two = atom_array[:2] + result = first_two.query("chain_id == 'A'") + assert result.array_length() == 2 + + def test_with_structure_operations(self, atom_array): + """Test query with structure operations like concatenation.""" + chain_a = atom_array.query("chain_id == 'A'") + chain_b = atom_array.query("chain_id == 'B'") + + combined = chain_a + chain_b + assert combined.array_length() == 5 + + # Query the combined structure + result = combined.query("atom_name == 'CA'") + assert result.array_length() == 2 + + def test_compound_queries(self, atom_array): + """Test compound queries.""" + result = atom_array.query( + "(chain_id == 'A') & (atom_name == 'CA') & ~has_nan_coord()" + ) + assert result.array_length() == 1 + assert result.chain_id[0] == "A" + assert result.atom_name[0] == "CA" + + idxs = atom_array.idxs( + "(chain_id == 'A') & (atom_name == 'CA') & ~has_nan_coord()" + ) + assert idxs.tolist() == [1] From 3f4d38026158499e27f097f303ff6f3d5cdaf883 Mon Sep 17 00:00:00 2001 From: Simon Mathis Date: Sun, 20 Jul 2025 14:00:13 +0100 Subject: [PATCH 2/3] tests: add further tests --- tests/structure/test_query.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/structure/test_query.py b/tests/structure/test_query.py index b34d1ec71..3f796c9ca 100644 --- a/tests/structure/test_query.py +++ b/tests/structure/test_query.py @@ -445,3 +445,12 @@ def test_compound_queries(self, atom_array): "(chain_id == 'A') & (atom_name == 'CA') & ~has_nan_coord()" ) assert idxs.tolist() == [1] + + def test_compound_queries2(self, atom_array): + atom_array.set_annotation("b_factor", np.array([10, 51, 93, 40, 50, 60, 70], dtype=np.float32)) + result = atom_array.query("(res_name in ['ALA', 'GLY']) & (b_factor > 50)") + assert result.array_length() == 2 + assert result.res_name[0] == "ALA" + assert result.res_name[1] == "GLY" + assert result.b_factor[0] > 50 + assert result.b_factor[1] > 50 \ No newline at end of file From 6ede054e2124c2219b7f0e1921cceafc668092bf Mon Sep 17 00:00:00 2001 From: Simon Mathis Date: Sun, 20 Jul 2025 14:00:28 +0100 Subject: [PATCH 3/3] chore: ruff --- tests/structure/test_query.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/structure/test_query.py b/tests/structure/test_query.py index 3f796c9ca..d178dbc6c 100644 --- a/tests/structure/test_query.py +++ b/tests/structure/test_query.py @@ -447,10 +447,12 @@ def test_compound_queries(self, atom_array): assert idxs.tolist() == [1] def test_compound_queries2(self, atom_array): - atom_array.set_annotation("b_factor", np.array([10, 51, 93, 40, 50, 60, 70], dtype=np.float32)) + atom_array.set_annotation( + "b_factor", np.array([10, 51, 93, 40, 50, 60, 70], dtype=np.float32) + ) result = atom_array.query("(res_name in ['ALA', 'GLY']) & (b_factor > 50)") assert result.array_length() == 2 assert result.res_name[0] == "ALA" assert result.res_name[1] == "GLY" assert result.b_factor[0] > 50 - assert result.b_factor[1] > 50 \ No newline at end of file + assert result.b_factor[1] > 50