From 989630de30c00b6be03e27d4fce371bd74b4e8d7 Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Tue, 3 Feb 2026 17:13:12 +0100 Subject: [PATCH 1/4] Implement `CellList` in Rust --- Cargo.toml | 4 + doc/apidoc.py | 42 +- doc/contribution/development.rst | 4 +- src/biotite/structure/__init__.py | 6 +- src/biotite/structure/basepairs.py | 2 +- src/biotite/structure/bonds.pyx | 20 +- src/biotite/structure/box.py | 12 +- src/biotite/structure/celllist.pyx | 864 ------------------ src/biotite/structure/compare.py | 53 +- src/biotite/structure/hbond.py | 6 +- src/biotite/structure/rdf.py | 6 +- src/biotite/structure/sasa.pyx | 4 +- src/biotite/structure/sse.py | 2 +- src/rust/structure/celllist.rs | 1333 ++++++++++++++++++++++++++++ src/rust/structure/io/pdb/file.rs | 2 +- src/rust/structure/mod.rs | 5 + tests/structure/test_celllist.py | 212 ++++- tests/test_modname.py | 5 +- 18 files changed, 1589 insertions(+), 993 deletions(-) delete mode 100644 src/biotite/structure/celllist.pyx create mode 100644 src/rust/structure/celllist.rs diff --git a/Cargo.toml b/Cargo.toml index 6c3a82a9c..c82070384 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,7 @@ features = ["extension-module"] name = "rust" crate-type = ["cdylib"] path = "src/rust/lib.rs" + +# Enable optimizations in dev mode for better performance during development +[profile.dev] +opt-level = 3 diff --git a/doc/apidoc.py b/doc/apidoc.py index fc9940be5..b5255f0e2 100644 --- a/doc/apidoc.py +++ b/doc/apidoc.py @@ -237,27 +237,51 @@ def skip_nonrelevant(app, what, name, obj, skip, options): return True if not _is_relevant_type(obj): return True - if obj.__module__ is None: - # Some built-in functions have '__module__' set to None + module = _get_module(obj) + if module is None: return True - package_name = obj.__module__.split(".")[0] + package_name = module.split(".")[0] if package_name != "biotite": return True return False +def _get_module(obj): + """ + Get the module name of an object. + + For most objects, this is simply ``obj.__module__``. + However, some extension types (e.g. PyO3 method descriptors) + don't have a ``__module__`` attribute, so we fall back to + ``obj.__objclass__.__module__``. + """ + module = getattr(obj, "__module__", None) + if module is not None: + return module + # Fallback for PyO3 method descriptors that don't have __module__ + objclass = getattr(obj, "__objclass__", None) + if objclass is not None: + return getattr(objclass, "__module__", None) + return None + + def _is_relevant_type(obj): - if type(obj).__name__ == "method_descriptor": - # These are some special built-in Python methods - return False return ( ( # Functions - type(obj) - in [types.FunctionType, types.BuiltinFunctionType, types.MethodType] + isinstance( + obj, + ( + types.FunctionType, + types.BuiltinFunctionType, + types.MethodType, + types.BuiltinMethodType, + types.MethodDescriptorType, + ), + ) ) | ( - # Functions from C-extensions and wrapped functions + # Functions from Cython extensions and wrapped functions type(obj).__name__ in [ "cython_function_or_method", diff --git a/doc/contribution/development.rst b/doc/contribution/development.rst index bb56d630d..879f9e368 100644 --- a/doc/contribution/development.rst +++ b/doc/contribution/development.rst @@ -90,9 +90,9 @@ Therefore, the code should be vectorized as much as possible using *NumPy*. In cases the problem cannot be reasonably or conveniently solved this way, writing modules in *Rust* using `PyO3 ` is the preferred way to go. -The *Rust* part of the codebase is located in ``src/biotite/rust/`` and mirrors the +The *Rust* part of the codebase is located in ``src/rust/`` and mirrors the structure of the *Python* side. -For example *Rust* functionalities for ``biotite.structure.io.pdb`` residue in +For example *Rust* functionalities for ``biotite.structure.io.pdb`` reside in ``biotite.rust.structure.io.pdb``. ``biotite.rust`` itself is not publicly exposed, but its functionalities are instead internally used or reexported in ``biotite`` subpackages. diff --git a/src/biotite/structure/__init__.py b/src/biotite/structure/__init__.py index 82168deec..217f42320 100644 --- a/src/biotite/structure/__init__.py +++ b/src/biotite/structure/__init__.py @@ -94,6 +94,10 @@ If no :class:`BondList` is associated, the ``bonds`` attribute is ``None``. +A unit cell or simulation box can be associated by setting the ``box`` attribute with a +``(3, 3)``-shaped :class:`ndarray` for an :class:`AtomArray` or a +``(m, 3, 3)``-shaped :class:`ndarray` for an :class:`AtomArrayStack`. + Based on the implementation in *NumPy* arrays, this package furthermore contains a comprehensive set of functions for structure analysis, manipulation and visualization. @@ -104,11 +108,11 @@ __name__ = "biotite.structure" __author__ = "Patrick Kunzmann" +from biotite.rust.structure import * from .atoms import * from .basepairs import * from .bonds import * from .box import * -from .celllist import * from .chains import * from .charges import * from .compare import * diff --git a/src/biotite/structure/basepairs.py b/src/biotite/structure/basepairs.py index 8ab54e265..c2c8a687e 100644 --- a/src/biotite/structure/basepairs.py +++ b/src/biotite/structure/basepairs.py @@ -21,8 +21,8 @@ import warnings from enum import IntEnum import numpy as np +from biotite.rust.structure import CellList from biotite.structure.atoms import Atom, array -from biotite.structure.celllist import CellList from biotite.structure.compare import rmsd from biotite.structure.error import ( BadStructureError, diff --git a/src/biotite/structure/bonds.pyx b/src/biotite/structure/bonds.pyx index 9d0673e92..02b04c797 100644 --- a/src/biotite/structure/bonds.pyx +++ b/src/biotite/structure/bonds.pyx @@ -52,16 +52,16 @@ class BondType(IntEnum): """ This enum type represents the type of a chemical bond. - - `ANY` - Used if the actual type is unknown - - `SINGLE` - Single bond - - `DOUBLE` - Double bond - - `TRIPLE` - Triple bond - - `QUADRUPLE` - A quadruple bond - - `AROMATIC_SINGLE` - Aromatic bond with a single formal bond - - `AROMATIC_DOUBLE` - Aromatic bond with a double formal bond - - `AROMATIC_TRIPLE` - Aromatic bond with a triple formal bond - - `AROMATIC` - Aromatic bond without specification of the formal bond - - `COORDINATION` - Coordination complex involving a metal atom + - ``ANY`` - Used if the actual type is unknown + - ``SINGLE`` - Single bond + - ``DOUBLE`` - Double bond + - ``TRIPLE`` - Triple bond + - ``QUADRUPLE`` - A quadruple bond + - ``AROMATIC_SINGLE`` - Aromatic bond with a single formal bond + - ``AROMATIC_DOUBLE`` - Aromatic bond with a double formal bond + - ``AROMATIC_TRIPLE`` - Aromatic bond with a triple formal bond + - ``AROMATIC`` - Aromatic bond without specification of the formal bond + - ``COORDINATION`` - Coordination complex involving a metal atom """ ANY = 0 SINGLE = 1 diff --git a/src/biotite/structure/box.py b/src/biotite/structure/box.py index d0f460afe..4c6d116c0 100644 --- a/src/biotite/structure/box.py +++ b/src/biotite/structure/box.py @@ -453,7 +453,7 @@ def move_inside_box(coord, box): Parameters ---------- - coord : ndarray, dtype=float, shape=(n,3) or shape=(m,n,3) + coord : ndarray, dtype=float, shape=(3,) or shape=(n,3) or shape=(m,n,3) The coordinates for one or multiple models. box : ndarray, dtype=float, shape=(3,3) or shape=(m,3,3) The box(es) for one or multiple models. @@ -462,7 +462,7 @@ def move_inside_box(coord, box): Returns ------- - moved_coord : ndarray, dtype=float, shape=(n,3) or shape=(m,n,3) + moved_coord : ndarray, dtype=float, shape=(3,) or shape=(n,3) or shape=(m,n,3) The moved coordinates. Has the same shape is the input `coord`. @@ -619,7 +619,7 @@ def coord_to_fraction(coord, box): Parameters ---------- - coord : ndarray, dtype=float, shape=(n,3) or shape=(m,n,3) + coord : ndarray, dtype=float, shape=(3,) or shape=(n,3) or shape=(m,n,3) The coordinates for one or multiple models. box : ndarray, dtype=float, shape=(3,3) or shape=(m,3,3) The box(es) for one or multiple models. @@ -628,7 +628,7 @@ def coord_to_fraction(coord, box): Returns ------- - fraction : ndarray, dtype=float, shape=(n,3) or shape=(m,n,3) + fraction : ndarray, dtype=float, shape=(3,) or shape=(n,3) or shape=(m,n,3) The fractions of the box vectors. See Also @@ -666,7 +666,7 @@ def fraction_to_coord(fraction, box): Parameters ---------- - fraction : ndarray, dtype=float, shape=(n,3) or shape=(m,n,3) + fraction : ndarray, dtype=float, shape=(3,) or shape=(n,3) or shape=(m,n,3) The fractions of the box vectors for one or multiple models. box : ndarray, dtype=float, shape=(3,3) or shape=(m,3,3) The box(es) for one or multiple models. @@ -675,7 +675,7 @@ def fraction_to_coord(fraction, box): Returns ------- - coord : ndarray, dtype=float, shape=(n,3) or shape=(m,n,3) + coord : ndarray, dtype=float, shape=(3,) or shape=(n,3) or shape=(m,n,3) The coordinates. See Also diff --git a/src/biotite/structure/celllist.pyx b/src/biotite/structure/celllist.pyx deleted file mode 100644 index b2ec3f09b..000000000 --- a/src/biotite/structure/celllist.pyx +++ /dev/null @@ -1,864 +0,0 @@ -# This source code is part of the Biotite package and is distributed -# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further -# information. - -""" -This module allows efficient search of atoms in a defined radius around -a location. -""" - -__name__ = "biotite.structure" -__author__ = "Patrick Kunzmann" -__all__ = ["CellList"] - -cimport cython -cimport numpy as np -from libc.stdlib cimport realloc, malloc, free - -import numpy as np -from .atoms import coord as to_coord -from .atoms import AtomArrayStack -from .box import repeat_box_coord, move_inside_box - -ctypedef np.uint64_t ptr -ctypedef np.float32_t float32 -ctypedef np.uint8_t uint8 - - -cdef class CellList: - """ - __init__(atom_array, cell_size, periodic=False, box=None, selection=None) - - This class enables the efficient search of atoms in vicinity of a - defined location. - - This class stores the indices of an atom array in virtual "cells", - each corresponding to a specific coordinate interval. - If the atoms in vicinity of a specific location are searched, only - the atoms in the relevant cells are checked. - Effectively this decreases the operation time for finding atoms - with a maximum distance to given coordinates from *O(n)* to *O(1)*, - after the :class:`CellList` has been created. - Therefore a :class:`CellList` saves calculation time in those - cases, where vicinity is checked for multiple locations. - - Parameters - ---------- - atom_array : AtomArray or ndarray, dtype=float, shape=(n,3) - The :class:`AtomArray` to create the :class:`CellList` for. - Alternatively the atom coordinates are accepted directly. - In this case `box` must be set, if `periodic` is true. - cell_size : float - The coordinate interval each cell has for x, y and z axis. - The amount of cells depends on the range of coordinates in the - `atom_array` and the `cell_size`. - periodic : bool, optional - If true, the cell list considers periodic copies of atoms. - The periodicity is based on the `box` attribute of `atom_array`. - box : ndarray, dtype=float, shape=(3,3), optional - If provided, the periodicity is based on this parameter instead - of the :attr:`box` attribute of `atom_array`. - Only has an effect, if `periodic` is ``True``. - selection : ndarray, dtype=bool, shape=(n,), optional - If provided, only the atoms masked by this array are stored in - the cell list. However, the indices stored in the cell list - will still refer to the original unfiltered `atom_array`. - - Examples - -------- - - >>> cell_list = CellList(atom_array, cell_size=5) - >>> near_atoms = atom_array[cell_list.get_atoms(np.array([1,2,3]), radius=7.0)] - """ - - # The atom coordinates - cdef float32[:,:] _coord - # A boolean mask that covers the selected atoms - cdef uint8[:] _selection - cdef bint _has_selection - # The cells to store the coordinates in; an ndarray of pointers - cdef ptr[:,:,:] _cells - # The amount elements in each C-array in '_cells' - cdef int[:,:,:] _cell_length - # The maximum value of '_cell_length' over all cells, - # required for worst case assumption on size of output arrays - cdef int _max_cell_length - # The length of the cell in each direction (x,y,z) - cdef float _cellsize - # The minimum and maximum coordinates for all atoms - # Used as origin ('_min_coord' is at _cells[0,0,0]) - # and for bound checks - cdef float32[:] _min_coord - cdef float32[:] _max_coord - # Indicates whether the cell list takes periodicity into account - cdef bint _periodic - cdef np.ndarray _box - # The length of the array before appending periodic copies - # if 'periodic' is true - cdef int _orig_length - cdef float32[:] _orig_min_coord - cdef float32[:] _orig_max_coord - - - @cython.initializedcheck(False) - @cython.boundscheck(False) - @cython.wraparound(False) - def __cinit__(self, atom_array not None, float cell_size, - bint periodic=False, box=None, np.ndarray selection=None): - cdef float32 x, y, z - cdef int i, j, k - cdef int atom_array_i - cdef int* cell_ptr = NULL - cdef int length - - if isinstance(atom_array, AtomArrayStack): - raise TypeError("Expected 'AtomArray' but got 'AtomArrayStack'") - coord = to_coord(atom_array) - # the length of the array before appending periodic copies - # if 'periodic' is true - self._orig_length = coord.shape[0] - self._box = None - if selection is None: - _check_coord(coord) - else: - _check_coord(coord[selection]) - - if periodic: - if box is not None: - self._box = box - elif atom_array.box is not None: - if atom_array.box.shape != (3,3): - raise ValueError( - "Box has invalid shape" - ) - self._box = atom_array.box - else: - raise ValueError( - "AtomArray must have a box to enable periodicity" - ) - if np.isnan(self._box).any(): - raise ValueError("Box contains NaN values") - coord = move_inside_box(coord, self._box) - coord, indices = repeat_box_coord(coord, self._box) - - if self._has_initialized_cells(): - raise Exception("Duplicate call of constructor") - self._cells = None - if cell_size <= 0: - raise ValueError("Cell size must be greater than 0") - self._periodic = periodic - self._coord = coord.astype(np.float32, copy=False) - self._cellsize = cell_size - # calculate how many cells are required for each dimension - min_coord = np.nanmin(coord, axis=0).astype(np.float32) - max_coord = np.nanmax(coord, axis=0).astype(np.float32) - self._min_coord = min_coord - self._max_coord = max_coord - cell_count = (((max_coord - min_coord) / cell_size) +1).astype(int) - if self._periodic: - self._orig_min_coord = np.nanmin(coord[:self._orig_length], axis=0) \ - .astype(np.float32) - self._orig_max_coord = np.nanmax(coord[:self._orig_length], axis=0) \ - .astype(np.float32) - - # ndarray of pointers to C-arrays - # containing indices to atom array - self._cells = np.zeros(cell_count, dtype=np.uint64) - # Stores the length of the C-arrays - self._cell_length = np.zeros(cell_count, dtype=np.int32) - - # Prepare selection - if selection is not None: - self._has_selection = True - self._selection = np.frombuffer(selection, dtype=np.uint8) - if self._selection.shape[0] != self._orig_length: - raise IndexError( - f"Atom array has length {self._orig_length}, " - f"but selection has length {self._selection.shape[0]}" - ) - else: - self._has_selection = False - - # Fill cells - for atom_array_i in range(self._coord.shape[0]): - # Only put selected atoms into cell list - if not self._has_selection \ - or self._selection[atom_array_i % self._orig_length]: - x = self._coord[atom_array_i, 0] - y = self._coord[atom_array_i, 1] - z = self._coord[atom_array_i, 2] - # Get cell indices for coordinates - self._get_cell_index(x, y, z, &i, &j, &k) - # Increment cell length and reallocate - length = self._cell_length[i,j,k] + 1 - cell_ptr = self._cells[i,j,k] - cell_ptr = realloc(cell_ptr, length * sizeof(int)) - if not cell_ptr: - raise MemoryError() - # Potentially increase max cell length - if length > self._max_cell_length: - self._max_cell_length = length - # Store atom array index in respective cell - cell_ptr[length-1] = atom_array_i - # Store new cell pointer and length - self._cell_length[i,j,k] = length - self._cells[i,j,k] = cell_ptr - - - def __dealloc__(self): - if self._has_initialized_cells(): - deallocate_ptrs(self._cells) - - - @cython.initializedcheck(False) - @cython.boundscheck(False) - @cython.wraparound(False) - def create_adjacency_matrix(self, float32 threshold_distance): - """ - create_adjacency_matrix(threshold_distance) - - Create an adjacency matrix for the atoms in this cell list. - - An adjacency matrix depicts which atoms *i* and *j* have - a distance lower than a given threshold distance. - The values in the adjacency matrix ``m`` are - ``m[i,j] = 1 if distance(i,j) <= threshold else 0`` - - Parameters - ---------- - threshold_distance : float - The threshold distance. All atom pairs that have a distance - lower than this value are indicated by ``True`` values in - the resulting matrix. - - Returns - ------- - matrix : ndarray, dtype=bool, shape=(n,n) - An *n x n* adjacency matrix. - If a `selection` was given to the constructor of the - :class:`CellList`, the rows and columns corresponding to - atoms, that are not masked by the selection, have all - elements set to ``False``. - - Notes - ----- - The highest performance is achieved when the the cell size is - equal to the threshold distance. However, this is purely - optinal: The resulting adjacency matrix is the same for every - cell size. - - Although the adjacency matrix should be symmetric in most cases, - it may occur that ``m[i,j] != m[j,i]``, when ``distance(i,j)`` - is very close to the `threshold_distance` due to numerical - errors. - The matrix can be symmetrized with ``numpy.maximum(a, a.T)``. - - Examples - -------- - Create adjacency matrix for CA atoms in a structure: - - >>> atom_array = atom_array[atom_array.atom_name == "CA"] - >>> cell_list = CellList(atom_array, 5) - >>> matrix = cell_list.create_adjacency_matrix(5) - """ - if threshold_distance < 0: - raise ValueError("Threshold must be a positive value") - cdef int i=0 - - # Get atom position for all original positions - # (no periodic copies) - coord = np.asarray(self._coord[:self._orig_length]) - - if self._has_selection: - selection = np.asarray(self._selection, dtype=bool) - # Create matrix with all elements set to False - matrix = np.zeros( - (self._orig_length, self._orig_length), dtype=bool - ) - # Set only those rows that belong to masked atoms - matrix[selection, :] = self.get_atoms( - coord[selection], threshold_distance, as_mask=True - ) - return matrix - else: - return self.get_atoms(coord, threshold_distance, as_mask=True) - - - @cython.initializedcheck(False) - @cython.boundscheck(False) - @cython.wraparound(False) - def get_atoms(self, np.ndarray coord, radius, bint as_mask=False): - """ - get_atoms(coord, radius, as_mask=False) - - Find atoms with a maximum distance from given coordinates. - - Parameters - ---------- - coord : ndarray, dtype=float, shape=(3,) or shape=(m,3) - The central coordinates, around which the atoms are - searched. - If a single position is given, the indices of atoms in its - radius are returned. - Multiple positions (2-D :class:`ndarray`) have a vectorized - behavior: - Each row in the resulting :class:`ndarray` contains the - indices for the corresponding position. - Since the positions may have different amounts of adjacent - atoms, trailing `-1` values are used to indicate nonexisting - indices. - radius : float or ndarray, shape=(n,), dtype=float, optional - The radius around `coord`, in which the atoms are searched, - i.e. all atoms in `radius` distance to `coord` are returned. - Either a single radius can be given as scalar, or individual - radii for each position in `coord` can be provided as - :class:`ndarray`. - as_mask : bool, optional - If true, the result is returned as boolean mask, instead - of an index array. - - Returns - ------- - indices : ndarray, dtype=int32, shape=(p,) or shape=(m,p) - The indices of the atom array, where the atoms are in the - defined `radius` around `coord`. - If `coord` contains multiple positions, this return value is - two-dimensional with trailing `-1` values for empty values. - Only returned with `as_mask` set to false. - mask : ndarray, dtype=bool, shape=(m,n) or shape=(n,) - Same as `indices`, but as boolean mask. - The values are true for atoms in the atom array, - that are in the defined vicinity. - Only returned with `as_mask` set to true. - - See Also - -------- - get_atoms_in_cells - - Notes - ----- - In case of a :class:`CellList` with `periodic` set to `True`: - If more than one periodic copy of an atom is within the - threshold radius, the returned `indices` array contains the - corresponding index multiple times. - Please use ``numpy.unique()``, if this is undesireable. - - Examples - -------- - Get adjacent atoms for a single position: - - >>> cell_list = CellList(atom_array, 3) - >>> pos = np.array([1.0, 2.0, 3.0]) - >>> indices = cell_list.get_atoms(pos, radius=2.0) - >>> print(indices) - [102 104 112] - >>> print(atom_array[indices]) - A 6 TRP CE3 C 0.779 0.524 2.812 - A 6 TRP CZ3 C 1.439 0.433 4.053 - A 6 TRP HE3 H -0.299 0.571 2.773 - >>> indices = cell_list.get_atoms(pos, radius=3.0) - >>> print(atom_array[indices]) - A 6 TRP CD2 C 1.508 0.564 1.606 - A 6 TRP CE3 C 0.779 0.524 2.812 - A 6 TRP CZ3 C 1.439 0.433 4.053 - A 6 TRP HE3 H -0.299 0.571 2.773 - A 6 TRP HZ3 H 0.862 0.400 4.966 - A 3 TYR CZ C -0.639 3.053 5.043 - A 3 TYR HH H 1.187 3.395 5.567 - A 19 PRO HD2 H 0.470 3.937 1.260 - A 6 TRP CE2 C 2.928 0.515 1.710 - A 6 TRP CH2 C 2.842 0.407 4.120 - A 18 PRO HA H 2.719 3.181 1.316 - A 18 PRO HB3 H 2.781 3.223 3.618 - A 18 PRO CB C 3.035 4.190 3.187 - - Get adjacent atoms for multiple positions: - - >>> cell_list = CellList(atom_array, 3) - >>> pos = np.array([[1.0,2.0,3.0], [2.0,3.0,4.0], [3.0,4.0,5.0]]) - >>> indices = cell_list.get_atoms(pos, radius=3.0) - >>> print(indices) - [[ 99 102 104 112 114 45 55 290 101 105 271 273 268 -1 -1] - [104 114 45 46 55 44 54 105 271 273 265 268 269 272 275] - [ 46 55 273 268 269 272 274 275 -1 -1 -1 -1 -1 -1 -1]] - >>> # Convert to list of arrays and remove trailing -1 - >>> indices = [row[row != -1] for row in indices] - >>> for row in indices: - ... print(row) - [ 99 102 104 112 114 45 55 290 101 105 271 273 268] - [104 114 45 46 55 44 54 105 271 273 265 268 269 272 275] - [ 46 55 273 268 269 272 274 275] - """ - cdef int i=0, j=0 - cdef int array_i = 0 - cdef int max_array_length = 0 - cdef int coord_index - cdef float32 x1, y1, z1, x2, y2, z2 - cdef float32 sq_dist - cdef float32 sq_radius - cdef float32[:] sq_radii - cdef np.ndarray cell_radii - - cdef int[:,:] all_indices - cdef int[:,:] indices - cdef float32[:,:] coord_v - - if len(coord) == 0: - return _empty_result(as_mask) - - # Handle periodicity for the input coordinates - if self._periodic: - coord = move_inside_box(coord, self._box) - # Convert input parameters into a uniform format - coord, radius, is_multi_coord, is_multi_radius \ - = _prepare_vectorization(coord, radius, np.float32) - if is_multi_radius: - sq_radii = radius * radius - cell_radii = np.ceil(radius / self._cellsize).astype(np.int32) - else: - # All radii are equal - sq_radii = np.full( - len(coord), radius[0]*radius[0], dtype=np.float32 - ) - cell_radii = np.full( - len(coord), - int(np.ceil(radius[0] / self._cellsize)), - dtype=np.int32 - ) - - # Get indices for adjacent atoms, based on a cell radius - all_indices = self._get_atoms_in_cells( - coord, cell_radii, is_multi_radius - ) - # These have to be narrowed down in the next step - # using the Euclidian distance - - # Filter all indices from all_indices - # where squared distance is smaller than squared radius - # Using the squared distance is computationally cheaper than - # calculating the sqaure root for every distance - indices = np.full( - (all_indices.shape[0], all_indices.shape[1]), -1, dtype=np.int32 - ) - coord_v = coord - for i in range(all_indices.shape[0]): - sq_radius = sq_radii[i] - x1 = coord_v[i,0] - y1 = coord_v[i,1] - z1 = coord_v[i,2] - array_i = 0 - for j in range(all_indices.shape[1]): - coord_index = all_indices[i,j] - if coord_index != -1: - x2 = self._coord[coord_index, 0] - y2 = self._coord[coord_index, 1] - z2 = self._coord[coord_index, 2] - sq_dist = squared_distance(x1, y1, z1, x2, y2, z2) - if sq_dist <= sq_radius: - indices[i, array_i] = coord_index - array_i += 1 - if array_i > max_array_length: - max_array_length = array_i - - return self._post_process( - np.asarray(indices)[:, :max_array_length], - as_mask, is_multi_coord - ) - - - @cython.boundscheck(False) - @cython.wraparound(False) - def get_atoms_in_cells(self, np.ndarray coord, - cell_radius=1, bint as_mask=False): - """ - get_atoms_in_cells(coord, cell_radius=1, as_mask=False) - - Find atoms with a maximum cell distance from given - coordinates. - - Instead of using the radius as maximum euclidian distance to the - given coordinates, - the radius is measured as the amount of cells: - A radius of 0 means, that only the atoms in the same cell - as the given coordinates are considered. A radius of 1 means, - that the atoms indices from this cell and the 8 surrounding - cells are returned and so forth. - This is more efficient than `get_atoms()`. - - Parameters - ---------- - coord : ndarray, dtype=float, shape=(3,) or shape=(m,3) - The central coordinates, around which the atoms are - searched. - If a single position is given, the indices of atoms in its - cell radius are returned. - Multiple positions (2-D :class:`ndarray`) have a vectorized - behavior: - Each row in the resulting :class:`ndarray` contains the - indices for the corresponding position. - Since the positions may have different amounts of adjacent - atoms, trailing `-1` values are used to indicate nonexisting - indices. - cell_radius : int or ndarray, shape=(n,), dtype=int, optional - The radius around `coord` (in amount of cells), in which - the atoms are searched. This does not correspond to the - Euclidian distance used in `get_atoms()`. In this case, all - atoms in the cell corresponding to `coord` and in adjacent - cells are returned. - Either a single radius can be given as scalar, or individual - radii for each position in `coord` can be provided as - :class:`ndarray`. - By default atoms are searched in the cell of `coord` - and directly adjacent cells (cell_radius = 1). - as_mask : bool, optional - If true, the result is returned as boolean mask, instead - of an index array. - - Returns - ------- - indices : ndarray, dtype=int32, shape=(p,) or shape=(m,p) - The indices of the atom array, where the atoms are in the - defined `radius` around `coord`. - If `coord` contains multiple positions, this return value is - two-dimensional with trailing `-1` values for empty values. - Only returned with `as_mask` set to false. - mask : ndarray, dtype=bool, shape=(m,n) or shape=(n,) - Same as `indices`, but as boolean mask. - The values are true for atoms in the atom array, - that are in the defined vicinity. - Only returned with `as_mask` set to true. - - See Also - -------- - get_atoms - - Notes - ----- - In case of a :class:`CellList` with `periodic` set to `True`: - If more than one periodic copy of an atom is within the - threshold radius, the returned `indices` array contains the - corresponding index multiple times. - Please use ``numpy.unique()``, if this is undesireable. - """ - # This function is a thin wrapper around the private method - # with the same name, with addition of handling periodicty - # and the ability to return a mask instead of indices - - if len(coord) == 0: - return _empty_result(as_mask) - - # Handle periodicity for the input coordinates - if self._periodic: - coord = move_inside_box(coord, self._box) - # Convert input parameters into a uniform format - coord, cell_radius, is_multi_coord, is_multi_radius \ - = _prepare_vectorization(coord, cell_radius, np.int32) - # Get adjacent atom indices - array_indices = self._get_atoms_in_cells( - coord, cell_radius, is_multi_radius - ) - return self._post_process(array_indices, as_mask, is_multi_coord) - - - @cython.boundscheck(False) - @cython.wraparound(False) - def _get_atoms_in_cells(self, - np.ndarray coord, - np.ndarray cell_radii, - bint is_multi_radius): - """ - Get the indices of atoms in `cell_radii` adjacency of `coord`. - - Parameters - ---------- - coord : ndarray, dtype=float32, shape=(n,3) - The position to find adjacent atoms for. - cell_radii : ndarray, dtype=int32, shape=(n) - The radius for each position. - is_multi_radius : bool - True indicates, that all values in `cell_radii` are the - same. - - Returns - ------- - array_indices : ndarray, dtype=int32, shape=(m,p) - Indices of adjancent atoms. - """ - - cdef int max_cell_radius - if is_multi_radius: - max_cell_radius = np.max(cell_radii) - else: - # All radii are equal - max_cell_radius = cell_radii[0] - # Worst case assumption on index array length requirement: - # At maximum, the amount of adjacent atoms can only be the - # maximum amount of atoms per cell times the amount of cells - # Since the cells extend in 3 dimensions the amount of cells is - # (2*r + 1)**3 - cdef int length = (2*max_cell_radius + 1)**3 * self._max_cell_length - array_indices = np.full((len(coord), length), -1, dtype=np.int32) - # Fill index array - cdef int max_array_length \ - = self._find_adjacent_atoms(coord, array_indices, cell_radii) - return array_indices[:, :max_array_length] - - - @cython.boundscheck(False) - @cython.wraparound(False) - cdef int _find_adjacent_atoms(self, - float32[:,:] coord, - int[:,:] indices, - int[:] cell_radius): - """ - This method fills the given empty index array - with actual indices of adjacent atoms. - - Since the length of 'indices' (second dimension) is - the worst case assumption, this method returns the actual - required length, i.e. the highest length of all arrays - in this 'array of arrays'. - """ - cdef int length - cdef int* list_ptr - cdef float32 x, y,z - cdef int i=0, j=0, k=0 - cdef int adj_i, adj_j, adj_k - cdef int pos_i, array_i, cell_i - cdef int max_array_length = 0 - cdef int cell_r - - cdef ptr[:,:,:] cells = self._cells - cdef int[:,:,:] cell_length = self._cell_length - cdef uint8[:] finite_mask = ( - np.isfinite(np.asarray(coord)).all(axis=-1).astype(np.uint8, copy=False) - ) - - for pos_i in range(coord.shape[0]): - if not finite_mask[pos_i]: - # For non-finite coordinates, there are no adjacent atoms - continue - array_i = 0 - cell_r = cell_radius[pos_i] - x = coord[pos_i, 0] - y = coord[pos_i, 1] - z = coord[pos_i, 2] - self._get_cell_index(x, y, z, &i, &j, &k) - # Look into cells of the indices and adjacent cells - # in all 3 dimensions - - for adj_i in range(i-cell_r, i+cell_r+1): - if (adj_i >= 0 and adj_i < cells.shape[0]): - for adj_j in range(j-cell_r, j+cell_r+1): - if (adj_j >= 0 and adj_j < cells.shape[1]): - for adj_k in range(k-cell_r, k+cell_r+1): - if (adj_k >= 0 and adj_k < cells.shape[2]): - # Fill index array - # with indices in cell - list_ptr = cells[adj_i, adj_j, adj_k] - length = cell_length[adj_i, adj_j, adj_k] - for cell_i in range(length): - indices[pos_i, array_i] = \ - list_ptr[cell_i] - array_i += 1 - if array_i > max_array_length: - max_array_length = array_i - return max_array_length - - - @cython.boundscheck(False) - @cython.wraparound(False) - def _post_process(self, - np.ndarray indices, - bint as_mask, - bint is_multi_coord): - """ - Post process the resulting indices of adjacent atoms, - including periodicity handling and optional conversion into a - boolean matrix. - """ - # Handle periodicity for the output indices - if self._periodic: - # Map indices of repeated coordinates to original - # coordinates, i.e. the coordinates in the central box - # -> Remainder of dividing index by original array length - # Furthermore this ensures, that the indices have valid - # values for '_as_mask()' - indices[indices != -1] %= self._orig_length - if as_mask: - matrix = self._as_mask(indices) - if is_multi_coord: - return matrix - else: - return matrix[0] - else: - if is_multi_coord: - return indices - else: - return indices[0] - - - @cython.initializedcheck(False) - @cython.boundscheck(False) - @cython.wraparound(False) - @cython.cdivision(True) - cdef inline void _get_cell_index(self, float32 x, float32 y, float32 z, - int* i, int* j, int* k): - i[0] = ((x - self._min_coord[0]) / self._cellsize) - j[0] = ((y - self._min_coord[1]) / self._cellsize) - k[0] = ((z - self._min_coord[2]) / self._cellsize) - - @cython.initializedcheck(False) - @cython.boundscheck(False) - @cython.wraparound(False) - cdef inline bint _check_coord(self, float32 x, float32 y, float32 z): - if x < self._min_coord[0] or x > self._max_coord[0]: - return False - if y < self._min_coord[1] or y > self._max_coord[1]: - return False - if z < self._min_coord[2] or z > self._max_coord[2]: - return False - return True - - @cython.initializedcheck(False) - @cython.boundscheck(False) - @cython.wraparound(False) - cdef np.ndarray _as_mask(self, int[:,:] indices): - cdef int i,j - cdef int index - cdef uint8[:,:] matrix = np.zeros( - (indices.shape[0], self._orig_length), dtype=np.uint8 - ) - # Fill matrix - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - index = indices[i,j] - if index == -1: - # End of list -> jump to next position - break - matrix[i, index] = True - return np.asarray(matrix, dtype=bool) - - cdef inline bint _has_initialized_cells(self): - # Memoryviews are not initialized on class creation - # This method checks if the _cells memoryview was initialized - # and is not None - try: - if self._cells is not None: - return True - else: - return False - except AttributeError: - return False - - -def _check_coord(coord): - """ - Perform checks on validity of coordinates. - """ - if coord.ndim != 2: - raise ValueError("Coordinates must have shape (n,3)") - if coord.shape[0] == 0: - raise ValueError("Coordinates must not be empty") - if coord.shape[1] != 3: - raise ValueError("Coordinates must have form (x,y,z)") - if not np.isfinite(coord).all(): - raise ValueError("Coordinates contain non-finite values") - - -def _empty_result(as_mask): - """ - Create return value for :func:`get_atoms()` and - :func:`get_atoms_in_cells()`, if no coordinates are given. - """ - if as_mask: - return np.array([], dtype=bool) - else: - return np.array([], dtype=np.int32) - - -def _prepare_vectorization(np.ndarray coord, radius, radius_dtype): - """ - Since `get_atoms()` and `get_atoms_in_cells()`, may take different - amount of dimensions for the coordinates and the radius to enable - vectorized compuation, each of these functions would need to handle - the different cases. - - This function converts the input radius and coordinates into a - uniform format and also return, whether single/multiple - radii/coordinates were given. - - The shapes before and after conversion are: - - - coord: (3, ), radius: scalar -> coord: (1,3), radius: (1,) - - coord: (n,3), radius: scalar -> coord: (n,3), radius: (n,) - - coord: (n,3), radius: (n, ) -> coord: (n,3), radius: (n,) - - Thes resulting values have the same dimensionality for all cases and - can be handeled uniformly by `get_atoms()` and - `get_atoms_in_cells()`. - """ - cdef bint is_multi_coord - cdef bint is_multi_radius - - if coord.ndim == 1 and coord.shape[0] == 3: - # Single position - coord = coord[np.newaxis, :].astype(np.float32, copy=False) - is_multi_coord = False - elif coord.ndim == 2 and coord.shape[1] == 3: - # Multiple positions - coord = coord.astype(np.float32, copy=False) - is_multi_coord = True - else: - raise ValueError( - f"Invalid shape for input coordinates" - ) - - if isinstance(radius, np.ndarray): - # Multiple radii - # Check whether amount of coordinates match amount of radii - if not is_multi_coord: - raise ValueError( - "Cannot accept array of radii, if a single position is given" - ) - if radius.ndim != 1: - raise ValueError("Array of radii must be one-dimensional") - if radius.shape[0] != coord.shape[0]: - raise ValueError( - f"Amount of radii ({radius.shape[0]}) " - f"and coordinates ({coord.shape[0]}) are not equal" - ) - if (radius < 0).any(): - raise ValueError("Radii must be a positive values") - radius = radius.astype(radius_dtype, copy=False) - is_multi_radius = True - else: - # Single radius - if radius < 0: - raise ValueError("Radius must be a positive value") - # If only a single integer is given, - # create numpy array filled with identical values - # with the same length as the coordinates - radius = np.full(coord.shape[0], radius, dtype=radius_dtype) - is_multi_radius = False - - return coord, radius, is_multi_coord, is_multi_radius - - -cdef inline void deallocate_ptrs(ptr[:,:,:] ptrs): - cdef int i, j, k - cdef int* cell_ptr - # Free cell pointers - for i in range(ptrs.shape[0]): - for j in range(ptrs.shape[1]): - for k in range(ptrs.shape[2]): - cell_ptr = ptrs[i,j,k] - free(cell_ptr) - - -cdef inline float32 squared_distance(float32 x1, float32 y1, float32 z1, - float32 x2, float32 y2, float32 z2): - cdef float32 diff_x = x2 - x1 - cdef float32 diff_y = y2 - y1 - cdef float32 diff_z = z2 - z1 - return diff_x*diff_x + diff_y*diff_y + diff_z*diff_z diff --git a/src/biotite/structure/compare.py b/src/biotite/structure/compare.py index 69f99a227..5da956d35 100644 --- a/src/biotite/structure/compare.py +++ b/src/biotite/structure/compare.py @@ -14,8 +14,8 @@ import collections.abc import warnings import numpy as np +from biotite.rust.structure import CellList, CellListResult from biotite.structure.atoms import AtomArray, AtomArrayStack, coord -from biotite.structure.celllist import CellList from biotite.structure.chains import get_chain_count, get_chain_positions from biotite.structure.geometry import index_distance from biotite.structure.residues import get_residue_count, get_residue_positions @@ -514,37 +514,6 @@ def _sq_euclidian(reference, subject): return vector_dot(dif, dif) -def _to_sparse_indices(all_contacts): - """ - Create tuples of contact indices from the :meth:`CellList.get_atoms()` return value. - - In other words, they would mark the non-zero elements in a dense contact matrix. - - Parameters - ---------- - all_contacts : ndarray, dtype=int, shape=(m,n) - The contact indices as returned by :meth:`CellList.get_atoms()`. - Padded with -1, in the second dimension. - Dimension *m* marks the query atoms, dimension *n* marks the contact atoms. - - Returns - ------- - combined_indices : ndarray, dtype=int, shape=(l,2) - The contact indices. - Each column contains the query and contact atom index. - """ - # Find rows where a query atom has at least one contact - non_empty_indices = np.where(np.any(all_contacts != -1, axis=1))[0] - # Take those rows and flatten them - contact_indices = all_contacts[non_empty_indices].flatten() - # For each row the corresponding query atom is the same - # Hence in the flattened form the query atom index is simply repeated - query_indices = np.repeat(non_empty_indices, all_contacts.shape[1]) - combined_indices = np.stack([query_indices, contact_indices], axis=1) - # Remove the padding values - return combined_indices[contact_indices != -1] - - def _find_contacts( atoms=None, atom_mask=None, @@ -600,20 +569,16 @@ def _find_contacts( cell_list = CellList(coords, inclusion_radius, selection=selection) # Pairs of indices for atoms within the inclusion radius if atom_mask is None: - all_contacts = cell_list.get_atoms(coords, inclusion_radius) + contacts = cell_list.get_atoms( + coords, inclusion_radius, result_format=CellListResult.PAIRS + ) else: - filtered_contacts = cell_list.get_atoms(coords[atom_mask], inclusion_radius) - # Map the contacts for the masked atoms to the original coordinates - # Rows that were filtered out by the mask are fully padded with -1 - # consistent with the padding of `get_atoms()` - all_contacts = np.full( - (coords.shape[0], filtered_contacts.shape[-1]), - -1, - dtype=filtered_contacts.dtype, + contacts = cell_list.get_atoms( + coords[atom_mask], inclusion_radius, result_format=CellListResult.PAIRS ) - all_contacts[atom_mask] = filtered_contacts - # Convert into pairs of indices - contacts = _to_sparse_indices(all_contacts) + # Map indices from masked indices back to original indices + mapping = np.nonzero(atom_mask)[0] + contacts[:, 0] = mapping[contacts[:, 0]] if exclude_same_chain: # Do the same for the chain level diff --git a/src/biotite/structure/hbond.py b/src/biotite/structure/hbond.py index 75b5c75a7..f3bf39963 100644 --- a/src/biotite/structure/hbond.py +++ b/src/biotite/structure/hbond.py @@ -12,8 +12,8 @@ import warnings import numpy as np +from biotite.rust.structure import CellList, CellListResult from biotite.structure.atoms import AtomArrayStack, stack -from biotite.structure.celllist import CellList from biotite.structure.filter import filter_heavy from biotite.structure.geometry import angle, distance @@ -295,7 +295,9 @@ def _hbond( cell_list = CellList( donor_h_coord, cell_size=cutoff_dist, periodic=periodic, box=box_for_model ) - possible_bonds |= cell_list.get_atoms_in_cells(acceptor_coord, as_mask=True) + possible_bonds |= cell_list.get_atoms_in_cells( + acceptor_coord, result_format=CellListResult.MASK + ) possible_bonds_i = np.where(possible_bonds) # Narrow down acceptor_i = acceptor_i[possible_bonds_i[0]] diff --git a/src/biotite/structure/rdf.py b/src/biotite/structure/rdf.py index 448a81ffa..ff1c436d3 100644 --- a/src/biotite/structure/rdf.py +++ b/src/biotite/structure/rdf.py @@ -12,9 +12,9 @@ from numbers import Integral import numpy as np +from biotite.rust.structure import CellList, CellListResult from biotite.structure.atoms import AtomArray, coord, stack from biotite.structure.box import box_volume -from biotite.structure.celllist import CellList from biotite.structure.geometry import displacement from biotite.structure.util import vector_dot @@ -202,7 +202,9 @@ def rdf( # This is enough to find all atoms that are in the given # interval (and more), since the size of each cell is as large # as the last edge of the bins - near_atom_mask = cell_list.get_atoms_in_cells(center[i], as_mask=True) + near_atom_mask = cell_list.get_atoms_in_cells( + center[i], result_format=CellListResult.MASK + ) # Calculate distances of each center to preselected atoms # for each center for j in range(center.shape[1]): diff --git a/src/biotite/structure/sasa.pyx b/src/biotite/structure/sasa.pyx index e9c478bc2..a8561a6e3 100644 --- a/src/biotite/structure/sasa.pyx +++ b/src/biotite/structure/sasa.pyx @@ -16,7 +16,7 @@ cimport numpy as np from libc.stdlib cimport malloc, free import numpy as np -from .celllist import CellList +from biotite.rust.structure import CellList from .filter import filter_solvent, filter_monoatomic_ions, filter_heavy from .info.radii import vdw_radius_protor, vdw_radius_single @@ -218,7 +218,7 @@ def sasa(array, float probe_radius=1.4, np.ndarray atom_filter=None, # Therefore intersecting atoms are always in the same or adjacent cell. cell_list = CellList(occl_array, np.max(radii[occl_filter])*2) cdef np.ndarray cell_indices - cdef int[:,:] cell_indices_view + cdef int64[:,:] cell_indices_view cdef int length cdef int max_adj_list_length = 0 cdef int array_length = array.array_length() diff --git a/src/biotite/structure/sse.py b/src/biotite/structure/sse.py index c04b8fcd6..0fcdcc3de 100644 --- a/src/biotite/structure/sse.py +++ b/src/biotite/structure/sse.py @@ -12,7 +12,7 @@ __all__ = ["annotate_sse"] import numpy as np -from biotite.structure.celllist import CellList +from biotite.rust.structure import CellList from biotite.structure.filter import filter_amino_acids from biotite.structure.geometry import angle, dihedral, distance from biotite.structure.integrity import check_res_id_continuity diff --git a/src/rust/structure/celllist.rs b/src/rust/structure/celllist.rs new file mode 100644 index 000000000..f45bd1896 --- /dev/null +++ b/src/rust/structure/celllist.rs @@ -0,0 +1,1333 @@ +use numpy::ndarray::Array2; +use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2}; +use pyo3::exceptions; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyInt, PyTuple}; +use std::convert::TryInto; +use std::ops::{Index, IndexMut}; + +// Label as a separate module to indicate that this exception comes +// from biotite +mod biotite { + pyo3::import_exception!(biotite.structure, BadStructureError); +} + +/// The desired type of result to be returned by some :class:`CellList` methods. +/// +/// - ``MAPPING`` - +/// An ``(q, k)`` array mapping each query coordinate ``q`` to the indices of its +/// neighbors. +/// Since the coordinates may have different amounts of adjacent +/// atoms, trailing ``-1`` values are used to indicate nonexisting +/// indices. +/// If only a single query coordinate is provided, the return value has shape +/// ``(k,)``. +/// - ``MASK`` - +/// A boolean mask of shape ``(q, n)`` where the value at ``(i, j)`` is ``True`` +/// if the query coordinate ``i`` and atom ``j`` are neighbors. +/// If only a single query coordinate is provided, the return value has shape +/// ``(n,)``. +/// - ``PAIRS`` - +/// An array of shape ``(k, 2)`` where each row contains the ``(query_idx, atom_idx)`` +/// tuple of a found neighboring atom. +/// This is basically a sparse representation of the ``MASK``. +/// If only a single query coordinate is provided, the return value has shape +/// ``(k,)``, equivalent to ``MAPPING``. +/// +/// Examples +/// -------- +/// +/// >>> # Create a CellList for a small molecule +/// >>> from biotite.structure.info import residue +/// >>> atoms = residue("ALA") +/// >>> cell_list = CellList(atoms, cell_size=2) +/// >>> # Demonstrate results for both, single and multiple query coordinates +/// >>> single_coord = atoms.coord[0] +/// >>> multiple_coords = atoms.coord[:2] +/// >>> # MAPPING: indices of neighboring atoms, -1 indicates padding values to be ignored +/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellListResult.MAPPING)) +/// [6 1 0 7] +/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellListResult.MAPPING)) +/// [[ 6 1 0 7 -1] +/// [ 2 1 0 4 8]] +/// >>> # MASK: boolean mask indicating neighbors +/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellListResult.MASK)) +/// [ True True False False False False True True False False False False +/// False] +/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellListResult.MASK)) +/// [[ True True False False False False True True False False False False +/// False] +/// [ True True True False True False False False True False False False +/// False]] +/// >>> # PAIRS: (query_idx, atom_idx) tuples +/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellListResult.PAIRS)) +/// [6 1 0 7] +/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellListResult.PAIRS)) +/// [[0 6] +/// [0 1] +/// [0 0] +/// [0 7] +/// [1 2] +/// [1 1] +/// [1 0] +/// [1 4] +/// [1 8]] +#[allow(clippy::upper_case_acronyms)] +#[derive(Clone, Copy, PartialEq)] +#[pyclass] +pub enum CellListResult { + MAPPING, + MASK, + PAIRS, +} + +/// Internal enum for efficient support of both single and multiple radii +/// given in :meth:`CellList.get_atoms` and :meth:`CellList.get_atoms_in_cells`. +/// +/// This avoids the overhead of always allocating a ``Vec`` when a single radius +/// is provided, which is the common case. +#[derive(Clone, Debug)] +enum Radius { + /// A single radius value applied to all coordinates. + Single(T), + /// Individual radius values for each coordinate. + Multiple(Vec), +} + +/// A three-dimensional grid of cells, where each cell can store a variable number of elements. +/// +/// All elements are stored in a single contiguous ``Vec``, with +/// separate offsets tracking which elements belong to each cell. +/// The offset layout uses interleaved *(start, end)* pairs: for cell *[i, j, k]*, +/// the start offset is at ``offsets[2 * linear_idx]`` and the end offset is at +/// ``offsets[2 * linear_idx + 1]``, where ``linear_idx = i * stride_i + j * stride_j + k``. +struct CellGrid { + /// All elements stored contiguously. Elements for each cell are grouped together. + data: Vec, + /// Flattened array of *(start, end)* offset pairs for each cell. + /// Layout: for cell *[i, j, k]*, the offset pair is at index + /// ``2 * (i * stride_i + j * stride_j + k)``. + offsets: Vec, + /// Dimensions of the grid *[dim_i, dim_j, dim_k]*. + dimensions: [usize; 3], + /// Precomputed stride for the first dimension: ``dimensions[1] * dimensions[2]``. + stride_i: usize, + /// Precomputed stride for the second dimension: ``dimensions[2]``. + stride_j: usize, +} + +impl CellGrid { + /// Create a :class:`CellGrid` directly from raw data and offsets. + /// + /// This constructor is primarily used for unpickling, where the internal + /// state has been serialized and needs to be reconstructed without + /// recomputing the cell assignments. + /// + /// Parameters + /// ---------- + /// data + /// The flattened element data. + /// offsets + /// The interleaved *(start, end)* offset pairs for each cell. + /// dimensions + /// The number of cells in each dimension *[dim_i, dim_j, dim_k]*. + pub fn new(data: Vec, offsets: Vec, dimensions: [usize; 3]) -> Self { + let stride_i = dimensions[1] * dimensions[2]; + let stride_j = dimensions[2]; + CellGrid { + data, + offsets, + dimensions, + stride_i, + stride_j, + } + } + + /// Create a new :class:`CellGrid` by assigning elements to cells based on their positions. + /// + /// This is the primary constructor used during :class:`CellList` creation. + /// It performs three passes over the data: + /// + /// 1. Count elements per cell to determine offsets. + /// 2. Allocate the data vector and compute offset ranges. + /// 3. Place elements into their respective cells. + /// + /// Parameters + /// ---------- + /// dimensions : [usize; 3] + /// The number of cells in each dimension *[dim_i, dim_j, dim_k]*. + /// positions : Vec<[usize; 3]> + /// The cell position *[i, j, k]* for each element. + /// elements : Vec + /// The elements to store. + /// + pub fn from_elements( + dimensions: [usize; 3], + positions: Vec<[usize; 3]>, + elements: Vec, + ) -> Self { + if positions.len() != elements.len() { + panic!("Positions and elements must have the same length"); + } + + let n_cells = dimensions[0] * dimensions[1] * dimensions[2]; + let stride_i = dimensions[1] * dimensions[2]; + let stride_j = dimensions[2]; + + // First iteration: Count the number of elements in each cell + let mut counts: Vec = vec![0; n_cells]; + for position in positions.iter() { + let linear_idx = position[0] * stride_i + position[1] * stride_j + position[2]; + counts[linear_idx] += 1; + } + + // Second iteration: Calculate offsets + // Each cell has a (start, end) pair, stored as offsets[2*idx] and offsets[2*idx + 1] + let mut offsets: Vec = vec![0; n_cells * 2]; + let mut current_offset = 0; + for (idx, count) in counts.iter().enumerate() { + offsets[idx * 2] = current_offset; // start + offsets[idx * 2 + 1] = current_offset; // end (will be incremented when filling) + current_offset += count; + } + + // Pre-allocate the data vector with default values + let total_count = current_offset; + let mut data: Vec = vec![T::default(); total_count]; + + // Third iteration: Fill the data vector + for (position, element) in positions.iter().zip(elements.into_iter()) { + let linear_idx = position[0] * stride_i + position[1] * stride_j + position[2]; + let end_idx = linear_idx * 2 + 1; + let current_index = offsets[end_idx]; + data[current_index] = element; + offsets[end_idx] += 1; + } + + CellGrid { + data, + offsets, + dimensions, + stride_i, + stride_j, + } + } + + /// Check if a cell position is within the valid bounds of the grid. + /// + /// Returns ``False`` if any index is negative or exceeds the corresponding dimension. + #[inline(always)] + fn is_valid_cell_position(&self, cell_position: [isize; 3]) -> bool { + for (i, &element) in cell_position.iter().enumerate() { + if element < 0 || element as usize >= self.dimensions[i] { + return false; + } + } + true + } + + /// Convert a 3D cell position to a linear index into the offsets array. + /// + /// Notes + /// ----- + /// The caller must ensure that ``cell_position`` is within bounds before + /// using the returned index to access ``self.offsets``. + #[inline(always)] + fn get_linear_index(&self, cell_position: [isize; 3]) -> usize { + (cell_position[0] as usize) * self.stride_i + + (cell_position[1] as usize) * self.stride_j + + (cell_position[2] as usize) + } +} + +/// Immutable indexing into a :class:`CellGrid`. +/// +/// Returns a slice of all elements in the cell at the given position. +/// If the position is out of bounds (negative or exceeding dimensions), +/// an empty slice is returned instead of panicking. This allows safe +/// iteration over adjacent cells without explicit bounds checking. +impl Index<[isize; 3]> for CellGrid { + type Output = [T]; + + #[inline(always)] + fn index(&self, cell_position: [isize; 3]) -> &Self::Output { + if !self.is_valid_cell_position(cell_position) { + return &[]; + } + + let linear_idx = self.get_linear_index(cell_position); + // SAFETY: bounds are checked above, and offsets was constructed with 2 * n_cells elements + unsafe { + let start = *self.offsets.get_unchecked(linear_idx * 2); + let end = *self.offsets.get_unchecked(linear_idx * 2 + 1); + self.data.get_unchecked(start..end) + } + } +} + +/// Mutable indexing into a :class:`CellGrid`. +/// +/// Returns a mutable slice of all elements in the cell at the given position. +/// +/// Notes +/// ----- +/// Panics if the position is out of bounds. +/// Unlike the immutable ``Index``, mutable access requires valid bounds since returning +/// an empty mutable slice would be semantically incorrect. +impl IndexMut<[isize; 3]> for CellGrid { + #[inline(always)] + fn index_mut(&mut self, cell_position: [isize; 3]) -> &mut Self::Output { + if !self.is_valid_cell_position(cell_position) { + panic!( + "Cell position ({},{},{}) is out of bounds", + cell_position[0], cell_position[1], cell_position[2], + ); + } + + let linear_idx = self.get_linear_index(cell_position); + // SAFETY: bounds are checked above + unsafe { + let start = *self.offsets.get_unchecked(linear_idx * 2); + let end = *self.offsets.get_unchecked(linear_idx * 2 + 1); + self.data.get_unchecked_mut(start..end) + } + } +} + +/// __init__(atom_array, cell_size, periodic=False, box=None, selection=None) +/// +/// This class enables the efficient search of atoms in vicinity of a +/// defined location. +/// +/// This class stores the indices of an atom array in virtual "cells", +/// each corresponding to a specific coordinate interval. +/// If the atoms in vicinity of a specific location are searched, only +/// the atoms in the relevant cells are checked. +/// Effectively this decreases the operation time for finding atoms +/// with a maximum distance to given coordinates from *O(n)* to *O(1)*, +/// after the :class:`CellList` has been created. +/// Therefore a :class:`CellList` saves calculation time in those +/// cases, where vicinity is checked for multiple locations. +/// +/// Parameters +/// ---------- +/// atom_array : AtomArray or ndarray, dtype=float, shape=(n,3) +/// The :class:`AtomArray` to create the :class:`CellList` for. +/// Alternatively the atom coordinates are accepted directly. +/// In this case `box` must be set, if `periodic` is true. +/// cell_size : float +/// The coordinate interval each cell has for x, y and z axis. +/// The amount of cells depends on the range of coordinates in the +/// `atom_array` and the `cell_size`. +/// periodic : bool, optional +/// If true, the cell list considers periodic copies of atoms. +/// The periodicity is based on the `box` attribute of `atom_array`. +/// box : ndarray, dtype=float, shape=(3,3), optional +/// If provided, the periodicity is based on this parameter instead +/// of the :attr:`box` attribute of `atom_array`. +/// Only has an effect, if `periodic` is ``True``. +/// selection : ndarray, dtype=bool, shape=(n,), optional +/// If provided, only the atoms masked by this array are stored in +/// the cell list. However, the indices stored in the cell list +/// will still refer to the original unfiltered `atom_array`. +/// +/// Examples +/// -------- +/// +/// >>> cell_list = CellList(atom_array, cell_size=5) +/// >>> near_atoms = atom_array[cell_list.get_atoms(np.array([1,2,3]), radius=7.0)] +#[pyclass(module = "biotite.structure")] +pub struct CellList { + coord: Vec<[f32; 3]>, + // A boolean mask that covers the selected atoms + selection: Option>, + cells: CellGrid, + cell_size: f32, + // The minimum and maximum coordinates for all atoms + // Used as origin ('coord_range[0]' is at 'cells[[0,0,0]]') + // and for bound checks + coord_range: [[f32; 3]; 2], + // The box dimensions if periodicity is taken into account + periodic_box: Option<[[f32; 3]; 3]>, + // The length of the array before appending periodic copies + orig_length: usize, +} + +#[pymethods] +impl CellList { + #[new] + #[pyo3(signature = (atom_array, cell_size, periodic=false, r#box=None, selection=None))] + fn new<'py>( + py: Python<'py>, + atom_array: Bound<'py, PyAny>, + cell_size: f32, + periodic: bool, + r#box: Option>, + selection: Option>, + ) -> PyResult { + let struc = PyModule::import(py, "biotite.structure")?; + let np = PyModule::import(py, "numpy")?; + + // Input validation + if !atom_array.is_instance(&struc.getattr("AtomArray")?)? + && !atom_array.is_instance(&np.getattr("ndarray")?)? + { + return Err(exceptions::PyTypeError::new_err(format!( + "Expected 'AtomArray' but got '{}'", + atom_array + .getattr("__class__")? + .getattr("__name__")? + .extract::<&str>()? + ))); + } + if cell_size <= 0.0 { + return Err(exceptions::PyValueError::new_err( + "Cell size must be greater than 0", + )); + } + + let orig_length: usize; + let mut coord_object: Bound<'py, PyAny> = struc.call_method1("coord", (&atom_array,))?; + let box_object: Bound<'py, PyAny>; + let periodic_box: Option<[[f32; 3]; 3]>; + if periodic { + match r#box { + Some(wrapped_box) => { + box_object = wrapped_box; + } + None => { + // Use `box` attribute of `AtomArray` + match atom_array + .getattr("box")? + .extract::>>()? + { + Some(atoms_box) => { + box_object = atoms_box; + } + None => { + return Err(exceptions::PyValueError::new_err( + "AtomArray must have a box to enable periodicity", + )); + } + } + } + } + + orig_length = coord_object.call_method0("__len__")?.extract()?; + coord_object = struc + .getattr("move_inside_box")? + .call1((&coord_object, &box_object))?; + coord_object = struc + .getattr("repeat_box_coord")? + .call1((&coord_object, &box_object))? + .extract::<[Bound<'py, PyAny>; 2]>()? + .first() + .ok_or_else(|| { + exceptions::PyRuntimeError::new_err("Could not get repeated coords") + })? + .clone(); + periodic_box = Some( + extract_coord(box_object.extract::>()?)? + .try_into() + .map_err(|_| { + exceptions::PyTypeError::new_err("Box must be a 3x3 float32 matrix") + })?, + ); + } else { + periodic_box = None; + orig_length = coord_object.call_method0("__len__")?.extract()?; + } + let coord: Vec<[f32; 3]> = extract_coord(coord_object.extract::>()?)?; + let coord_range: [[f32; 3]; 2] = calculate_coord_range(&coord)?; + + if let Some(ref sel) = selection { + if sel.len() != orig_length { + return Err(exceptions::PyIndexError::new_err(format!( + "Atom array has length {}, but selection has length {}", + orig_length, + sel.len() + ))); + } + } + + let mut positions: Vec<[usize; 3]> = Vec::with_capacity(coord.len()); + let mut elements: Vec = Vec::with_capacity(coord.len()); + for (i, coord) in coord.iter().enumerate() { + if let Some(ref selection) = selection { + if !selection[i % orig_length] { + continue; + } + } + positions + .push(as_usize(Self::get_cell_position(*coord, &coord_range, cell_size)).unwrap()); + elements.push(i); + } + // To get the number of cells in each dimension, use the position of the maximum coordinate + let dimensions: [usize; 3] = + Self::get_cell_position(coord_range[1], &coord_range, cell_size) + .iter() + .map(|x| (x + 1) as usize) + .collect::>() + .try_into() + .unwrap(); + let cells = CellGrid::from_elements(dimensions, positions, elements); + + Ok(CellList { + coord, + selection, + cells, + cell_size, + coord_range, + periodic_box, + orig_length, + }) + } + + /// Reconstruct a :class:`CellList` from its serialized state. + /// + /// This is used for unpickling. It reconstructs the :class:`CellList` + /// from the internal state that was serialized by :meth:`__reduce__`. + #[staticmethod] + #[pyo3(name = "_from_state")] + #[allow(clippy::too_many_arguments)] + fn from_state( + coord: Vec<[f32; 3]>, + selection: Option>, + cells_data: Vec, + cells_offsets: Vec, + cells_dimensions: [usize; 3], + cell_size: f32, + coord_range: [[f32; 3]; 2], + periodic_box: Option<[[f32; 3]; 3]>, + orig_length: usize, + ) -> PyResult { + let cells = CellGrid::new(cells_data, cells_offsets, cells_dimensions); + + Ok(CellList { + coord, + selection, + cells, + cell_size, + coord_range, + periodic_box, + orig_length, + }) + } + + fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult> { + // Get the class from the module to ensure pickle can find it + let struc = PyModule::import(py, "biotite.structure")?; + let cls = struc.getattr("CellList")?; + let from_state = cls.getattr("_from_state")?; + + // Build the args tuple + let args = PyTuple::new( + py, + [ + self.coord.clone().into_pyobject(py)?.into_any().unbind(), + self.selection + .clone() + .into_pyobject(py)? + .into_any() + .unbind(), + self.cells + .data + .clone() + .into_pyobject(py)? + .into_any() + .unbind(), + self.cells + .offsets + .clone() + .into_pyobject(py)? + .into_any() + .unbind(), + self.cells.dimensions.into_pyobject(py)?.into_any().unbind(), + self.cell_size.into_pyobject(py)?.into_any().unbind(), + self.coord_range.into_pyobject(py)?.into_any().unbind(), + self.periodic_box.into_pyobject(py)?.into_any().unbind(), + self.orig_length.into_pyobject(py)?.into_any().unbind(), + ], + )?; + + // Return (callable, args) tuple + PyTuple::new(py, [from_state.unbind(), args.into_any().unbind()]) + } + + /// Create an adjacency matrix for the atoms in this cell list. + /// + /// An adjacency matrix depicts which atoms *i* and *j* have a distance + /// lower than a given threshold distance. + /// The values in the adjacency matrix ``m`` are + /// ``m[i,j] = True if distance(i,j) <= threshold else False``. + /// + /// Parameters + /// ---------- + /// threshold_distance : float + /// The threshold distance. + /// All atom pairs that have a distance lower than or equal to this value are + /// indicated by ``True`` values in the resulting matrix. + /// + /// Returns + /// ------- + /// matrix : ndarray, dtype=bool, shape=(n,n) + /// An *n x n* adjacency matrix. + /// If a `selection` was given to the constructor of the + /// :class:`CellList`, the rows and columns corresponding to + /// atoms that are not masked by the selection have all + /// elements set to ``False``. + /// + /// Notes + /// ----- + /// The highest performance is achieved when the `cell_size` is + /// equal to the `threshold_distance`. + /// However, this is purely optional: the resulting adjacency matrix is the same for + /// every `cell_size`. + /// + /// Although the adjacency matrix should be symmetric in most cases, + /// it may occur that ``m[i,j] != m[j,i]`` when ``distance(i,j)`` + /// is very close to the `threshold_distance` due to numerical + /// errors. The matrix can be symmetrized with ``numpy.maximum(m, m.T)``. + /// + /// Examples + /// -------- + /// Create adjacency matrix for CA atoms in a structure: + /// + /// >>> atom_array = atom_array[atom_array.atom_name == "CA"] + /// >>> cell_list = CellList(atom_array, 5) + /// >>> matrix = cell_list.create_adjacency_matrix(5) + fn create_adjacency_matrix<'py>( + &self, + py: Python<'py>, + threshold_distance: f32, + ) -> PyResult>> { + if threshold_distance < 0.0 { + return Err(exceptions::PyValueError::new_err( + "Threshold distance must be a positive value", + )); + } + + match self.selection { + Some(ref selection) => { + let coord: Vec<[f32; 3]> = self + .coord + .iter() + .take(self.orig_length) + .enumerate() + .filter(|(i, _)| selection[*i]) + .map(|(_, c)| *c) + .collect(); + let mut sub_matrix: Array2 = Array2::default((coord.len(), self.orig_length)); + for matrix_index in + self.get_atoms_from_slice(py, &coord, &Radius::Single(threshold_distance))? + { + // Map index to potentially periodic copy to original atom index + sub_matrix[[matrix_index[0], matrix_index[1] % self.orig_length]] = true; + } + let mut matrix: Array2 = + Array2::default((self.orig_length, self.orig_length)); + let mut sub_row_idx = 0; + for (i, is_selected) in selection.iter().enumerate().take(self.orig_length) { + if *is_selected { + matrix.row_mut(i).assign(&sub_matrix.row(sub_row_idx)); + sub_row_idx += 1; + } + } + Ok(matrix.into_pyarray(py)) + } + None => { + let mut matrix: Array2 = + Array2::default((self.orig_length, self.orig_length)); + for matrix_index in self.get_atoms_from_slice( + py, + &self.coord[..self.orig_length], + &Radius::Single(threshold_distance), + )? { + // Map index to potentially periodic copy to original atom index + matrix[[matrix_index[0], matrix_index[1] % self.orig_length]] = true; + } + Ok(matrix.into_pyarray(py)) + } + } + } + + /// get_atoms(coord, radius, as_mask=False, result_format=CellListResult.MAPPING) + /// + /// Find atoms with a maximum distance from given coordinates. + /// + /// Parameters + /// ---------- + /// coord : ndarray, dtype=float, shape=(3,) or shape=(m,3) + /// One or more coordinates around which the atoms are searched. + /// radius : float or ndarray, shape=(n,), dtype=float + /// The radius around `coord`, in which the atoms are searched, + /// i.e. all atoms within `radius` distance to `coord` are returned. + /// Either a single radius can be given as scalar, or individual + /// radii for each position in `coord` can be provided as + /// :class:`ndarray`. + /// as_mask : bool, optional + /// **Deprecated:** Use ``result_format=CellListResult.MASK`` instead. + /// If true, the result is returned as boolean mask instead + /// of an index array. + /// result_format : CellListResult, optional + /// The format of the result. See :class:`CellListResult` for options. + /// Default is ``CellListResult.MAPPING``. + /// + /// Returns + /// ------- + /// result : ndarray + /// The result format depends on `result_format`. + /// See :class:`CellListResult` for details. + /// + /// See Also + /// -------- + /// get_atoms_in_cells + /// + /// Notes + /// ----- + /// In case of a :class:`CellList` with `periodic` set to ``True``: + /// If more than one periodic copy of an atom is within the + /// threshold radius, the returned indices array may contain the + /// corresponding index multiple times. + /// Use ``numpy.unique()`` if this is undesirable. + /// + /// Examples + /// -------- + /// Get adjacent atoms for a single position: + /// + /// >>> cell_list = CellList(atom_array, 3) + /// >>> pos = np.array([1.0, 2.0, 3.0]) + /// >>> indices = cell_list.get_atoms(pos, radius=2.0) + /// >>> print(indices) + /// [102 104 112] + #[pyo3(signature = (coord, radius, as_mask=false, result_format=CellListResult::MAPPING))] + fn get_atoms<'py>( + &self, + py: Python<'py>, + coord: &Bound<'py, PyAny>, + radius: &Bound<'py, PyAny>, + as_mask: bool, + result_format: CellListResult, + ) -> PyResult> { + let (converted_coord, is_multi_coord) = self.prepare_coord_from_python(py, coord)?; + let converted_radius = Self::prepare_radius_from_python::(radius)?; + let pairs = self.get_atoms_from_slice(py, &converted_coord, &converted_radius)?; + format_result( + py, + converted_coord.len(), + self.orig_length, + pairs, + result_format, + as_mask, + is_multi_coord, + ) + } + + /// get_atoms_in_cells(coord, cell_radius=1, as_mask=False, result_format=CellListResult.MAPPING) + /// + /// Find atoms with a maximum cell distance from given coordinates. + /// + /// Instead of using the radius as maximum Euclidean distance to the + /// given coordinates, the radius is measured as the number of cells: + /// A radius of ``0`` means that only the atoms in the same cell + /// as the given coordinates are considered. A radius of ``1`` means + /// that the atom indices from this cell and the 26 surrounding + /// cells are returned, and so forth. + /// This is more efficient than :meth:`get_atoms`. + /// + /// Parameters + /// ---------- + /// coord : ndarray, dtype=float, shape=(3,) or shape=(m,3) + /// One or more coordinates around which the atoms are searched. + /// cell_radius : int or ndarray, shape=(n,), dtype=int, optional + /// The radius around `coord` (in number of cells), in which + /// the atoms are searched. This does not correspond to the + /// Euclidean distance used in :meth:`get_atoms`. In this case, all + /// atoms in the cell corresponding to `coord` and in adjacent + /// cells are returned. + /// Either a single radius can be given as scalar, or individual + /// radii for each position in `coord` can be provided as + /// :class:`ndarray`. + /// By default, atoms are searched in the cell of `coord` + /// and directly adjacent cells (``cell_radius=1``). + /// as_mask : bool, optional + /// **Deprecated:** Use ``result_format=CellListResult.MASK`` instead. + /// If true, the result is returned as boolean mask instead + /// of an index array. + /// result_format : CellListResult, optional + /// The format of the result. See :class:`CellListResult` for options. + /// Default is ``CellListResult.MAPPING``. + /// + /// Returns + /// ------- + /// result : ndarray + /// The result format depends on `result_format`. + /// See :class:`CellListResult` for details. + /// + /// See Also + /// -------- + /// get_atoms + /// + /// Notes + /// ----- + /// In case of a :class:`CellList` with `periodic` set to ``True``: + /// If more than one periodic copy of an atom is within the + /// threshold radius, the returned indices array may contain the + /// corresponding index multiple times. + /// Use ``numpy.unique()`` if this is undesirable. + #[pyo3(signature = (coord, cell_radius=None, as_mask=false, result_format=CellListResult::MAPPING))] + fn get_atoms_in_cells<'py>( + &self, + py: Python<'py>, + coord: &Bound<'py, PyAny>, + cell_radius: Option<&Bound<'py, PyAny>>, + as_mask: bool, + result_format: CellListResult, + ) -> PyResult> { + let (converted_coord, is_multi_coord) = self.prepare_coord_from_python(py, coord)?; + + let converted_radius = Self::prepare_radius_from_python::( + cell_radius.unwrap_or(&PyInt::new(py, 1).into_any()), + )?; + let pairs = self.get_atoms_in_cells_from_slice(py, &converted_coord, &converted_radius)?; + format_result( + py, + converted_coord.len(), + self.orig_length, + pairs, + result_format, + as_mask, + is_multi_coord, + ) + } +} + +impl CellList { + /// Convert a 3D coordinate to the corresponding cell position in the grid. + /// + /// Given a coordinate *[x, y, z]*, this computes the cell indices *[i, j, k]* + /// based on the coordinate range and cell size. + /// + /// Returns + /// ------- + /// The cell position *[i, j, k]*. Note that negative indices or indices + /// exceeding the grid dimensions can be returned if the coordinate is + /// outside the original coordinate range. The :class:`CellGrid` handles + /// out-of-bounds access gracefully by returning empty slices. + #[inline(always)] + fn get_cell_position( + coord: [f32; 3], + coord_range: &[[f32; 3]; 2], + cellsize: f32, + ) -> [isize; 3] { + [ + // Conversion to 'isize' automatically floors the result + ((coord[0] - coord_range[0][0]) / cellsize) as isize, + ((coord[1] - coord_range[0][1]) / cellsize) as isize, + ((coord[2] - coord_range[0][2]) / cellsize) as isize, + ] + } + + /// Find atoms within a Euclidean distance from given coordinates. + /// + /// This method first uses :meth:`get_atoms_in_cells_from_slice` to find candidate + /// atoms based on cell adjacency, then filters them by actual Euclidean distance. + /// + /// Returns + /// ------- + /// atoms: + /// A vector of *[query_idx, atom_idx]* pairs where ``atom_idx`` may refer to + /// periodic copies (``indices >= orig_length``). The caller is responsible for + /// mapping these back to original indices using ``atom_idx % orig_length``. + fn get_atoms_from_slice( + &self, + py: Python<'_>, + coord: &[[f32; 3]], + radius: &Radius, + ) -> PyResult> { + match radius { + Radius::Single(r) => { + let sq_r = r.powi(2); + let pairs = self + .get_atoms_in_cells_from_slice( + py, + coord, + &Radius::Single((r / self.cell_size).ceil() as i32), + )? + .iter() + .filter(|pair| distance_squared(coord[pair[0]], self.coord[pair[1]]) <= sq_r) + .copied() + .collect(); + Ok(pairs) + } + Radius::Multiple(rs) => { + let cell_radii = rs + .iter() + .map(|r| (*r / self.cell_size).ceil() as i32) + .collect(); + let sq_rs: Vec = rs.iter().map(|r| r.powi(2)).collect(); + let pairs = self + .get_atoms_in_cells_from_slice(py, coord, &Radius::Multiple(cell_radii))? + .iter() + .filter(|pair| { + distance_squared(coord[pair[0]], self.coord[pair[1]]) <= sq_rs[pair[0]] + }) + .copied() + .collect(); + Ok(pairs) + } + } + } + + /// Find atoms in cells adjacent to given coordinates. + /// + /// For each input coordinate, this method identifies the corresponding cell + /// and iterates over all cells within `cell_radius` distance (in cells, not + /// Euclidean distance). All atom indices in those cells are collected. + /// + /// Parameters + /// ---------- + /// coord + /// The query coordinates. + /// cell_radii + /// Either a single cell radius for all coordinates, or + /// individual radii for each coordinate. + /// + /// Returns + /// ------- + /// atoms: + /// A vector of *[query_idx, atom_idx]* pairs. The ``atom_idx`` values refer to + /// the full coordinate array, which may include periodic copies when + /// ``periodic=True``. These indices can exceed ``orig_length`` and should be + /// mapped back using ``atom_idx % orig_length`` when needed. + fn get_atoms_in_cells_from_slice( + &self, + py: Python<'_>, + coord: &[[f32; 3]], + cell_radii: &Radius, + ) -> PyResult> { + if let Radius::Multiple(ref rs) = cell_radii { + if rs.len() != coord.len() { + return Err(exceptions::PyValueError::new_err(format!( + "{} coordinates were provided, but {} radii", + coord.len(), + rs.len() + ))); + } + } + + let mut adjacent_atoms = Vec::new(); + for (coord_idx, c) in coord.iter().enumerate() { + // Skip non-finite coordinates + if !c[0].is_finite() || !c[1].is_finite() || !c[2].is_finite() { + continue; + } + + let cell_pos = Self::get_cell_position(*c, &self.coord_range, self.cell_size); + + // Iterate over all adjacent cells within the cell_radius + let cell_radius = match cell_radii { + Radius::Single(r) => *r as isize, + Radius::Multiple(ref rs) => rs[coord_idx] as isize, + }; + for adj_i in (cell_pos[0] - cell_radius)..=(cell_pos[0] + cell_radius) { + for adj_j in (cell_pos[1] - cell_radius)..=(cell_pos[1] + cell_radius) { + for adj_k in (cell_pos[2] - cell_radius)..=(cell_pos[2] + cell_radius) { + // Get all atoms in this cell (returns empty slice if out of bounds) + let cell_atoms = &self.cells[[adj_i, adj_j, adj_k]]; + adjacent_atoms.reserve(cell_atoms.len()); + for &atom_idx in cell_atoms { + adjacent_atoms.push([coord_idx, atom_idx]); + } + } + } + } + // The size of `coord` is arbitrary + // -> the function may take a long time to complete + // Check for interrupts periodically (every 256 coords) + if coord_idx & 0xFF == 0 { + py.check_signals()?; + } + } + Ok(adjacent_atoms) + } + + /// Convert Python coordinate input to a uniform Rust representation. + /// + /// This method handles the various input formats accepted by the Python API: + /// + /// - Single coordinate: ``shape=(3,)`` -> converted to ``Vec`` with one element + /// - Multiple coordinates: ``shape=(n, 3)`` -> converted to ``Vec`` with n elements + /// + /// It also handles: + /// + /// - Data type conversion to ``float32`` if necessary + /// - Periodic boundary conditions (moving coordinates inside the box) + /// + /// Returns + /// ------- + /// coordinates: + /// A vector of converted coordinates. + /// is_multi_coord: + /// A boolean indicating whether the input was a 2D array (multiple coordinates) + /// or a 1D array (single coordinate). + fn prepare_coord_from_python<'py>( + &self, + py: Python<'py>, + coord: &Bound<'py, PyAny>, + ) -> PyResult<(Vec<[f32; 3]>, bool)> { + // Slow path: need to convert dtype or handle periodicity + let struc = PyModule::import(py, "biotite.structure")?; + + // Convert into expected f32 array + let kwargs = PyDict::new(py); + kwargs.set_item("dtype", "float32")?; + kwargs.set_item("copy", false)?; + let coord = coord.call_method("astype", (), Some(&kwargs))?; + + // Handle periodicity if needed + let coord = if let Some(box_matrix) = self.periodic_box { + let box_array = Array2::from_shape_vec((3, 3), box_matrix.concat()) + .map_err(|_| exceptions::PyRuntimeError::new_err("Failed to create box array"))?; + struc + .getattr("move_inside_box")? + .call1((coord, box_array.into_pyarray(py)))? + } else { + coord.clone() + }; + + // Consistently use 2 dimensions + // -> if only one coordinate is given, convert it to a 2D array + if let Ok(single_coord) = coord.extract::<[f32; 3]>() { + if !single_coord.iter().all(|c| c.is_finite()) { + return Err(exceptions::PyValueError::new_err( + "Coordinates contain non-finite values", + )); + } + return Ok((vec![single_coord], false)); + } + if let Ok(array) = coord.extract::>() { + return Ok((extract_coord(array)?, true)); + } + Err(exceptions::PyTypeError::new_err( + "Coordinates must be a single coordinate or a 2D array of coordinates", + )) + } + + /// Convert Python radius input to the internal :class:`Radius` enum. + /// + /// This method handles two input formats: + /// + /// - Single radius: a scalar value -> ``Radius::Single(value)`` + /// - Multiple radii: an array of values -> ``Radius::Multiple(vec)`` + /// + /// The generic type ``T`` allows this to work for both ``f32`` (Euclidean radius) + /// and ``i32`` (cell radius). + fn prepare_radius_from_python<'py, T>(radius: &Bound<'py, PyAny>) -> PyResult> + where + T: for<'a> FromPyObject<'a, 'py>, + Vec: for<'a> FromPyObject<'a, 'py>, + { + if let Ok(r) = radius.extract::() { + return Ok(Radius::Single(r)); + } + if let Ok(rs) = radius.extract::>() { + return Ok(Radius::Multiple(rs)); + } + Err(exceptions::PyTypeError::new_err( + "Radius must be a single value or an array of values", + )) + } +} + +/// Calculate the bounding box (min and max coordinates) from a coordinate list. +/// +/// This is used to determine the spatial extent of the atoms, which in turn +/// determines the origin of the cell grid and the number of cells needed. +/// +/// Returns +/// ------- +/// coord_range: +/// A vector containing the minimum and maximum coordinates. +/// +/// Raises +/// ------ +/// PyValueError +/// If the coordinate list is empty. +fn calculate_coord_range(coord: &[[f32; 3]]) -> PyResult<[[f32; 3]; 2]> { + if coord.is_empty() { + return Err(exceptions::PyValueError::new_err( + "Coordinates must not be empty", + )); + } + + let mut min_coord = [f32::INFINITY, f32::INFINITY, f32::INFINITY]; + let mut max_coord = [f32::NEG_INFINITY, f32::NEG_INFINITY, f32::NEG_INFINITY]; + + for c in coord.iter() { + for i in 0..3 { + if c[i] < min_coord[i] { + min_coord[i] = c[i]; + } + if c[i] > max_coord[i] { + max_coord[i] = c[i]; + } + } + } + + Ok([min_coord, max_coord]) +} + +/// Extract coordinates from a 2D NumPy array into a ``Vec<[f32; 3]>``. +/// +/// Raises +/// ------ +/// PyValueError +/// If the array does not have shape *(n, 3)* or contains non-finite values (NaN or Inf). +fn extract_coord(coord_array: PyReadonlyArray2) -> PyResult> { + let coord_ndarray = coord_array.as_array(); + let shape = coord_ndarray.shape(); + if shape.len() != 2 { + return Err(exceptions::PyValueError::new_err( + "Coordinates must have shape (n,3)", + )); + } + if shape[1] != 3 { + return Err(exceptions::PyValueError::new_err( + "Coordinates must have form (x,y,z)", + )); + } + if !coord_ndarray.is_all_infinite() { + return Err(exceptions::PyValueError::new_err( + "Coordinates contain non-finite values", + )); + } + + // Try to access as contiguous slice for maximum efficiency + if let Some(slice) = coord_ndarray.as_slice() { + // Data is contiguous in memory, we can reinterpret directly + let n_coords = shape[0]; + let mut result: Vec<[f32; 3]> = Vec::with_capacity(n_coords); + for i in 0..n_coords { + let base = i * 3; + result.push([slice[base], slice[base + 1], slice[base + 2]]); + } + Ok(result) + } else { + // Fallback for non-contiguous arrays (e.g., transposed or sliced) + let mut result: Vec<[f32; 3]> = Vec::with_capacity(shape[0]); + for i in 0..shape[0] { + result.push([ + coord_ndarray[[i, 0]], + coord_ndarray[[i, 1]], + coord_ndarray[[i, 2]], + ]); + } + Ok(result) + } +} + +/// Convert an ``[isize; 3]`` to ``[usize; 3]`` if all elements are non-negative. +/// +/// Returns ``None`` if any element is negative. +#[inline(always)] +fn as_usize(x: [isize; 3]) -> Option<[usize; 3]> { + for e in &x { + if *e < 0 { + return None; + } + } + Some([x[0] as usize, x[1] as usize, x[2] as usize]) +} + +/// Compute the squared Euclidean distance between two 3D points. +/// +/// Using squared distance avoids the expensive square root operation, +/// which is sufficient for distance comparisons. +#[inline(always)] +fn distance_squared(a: [f32; 3], b: [f32; 3]) -> f32 { + (a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2) +} + +/// Format query results as a mapping array (the default format). +/// +/// Creates a 2D array where each row corresponds to a query coordinate, +/// and the columns contain the indices of neighboring atoms. Since different +/// queries may have different numbers of neighbors, the array is padded with +/// ``-1`` values to indicate empty slots. +/// +/// Parameters +/// ---------- +/// n_query +/// Number of query coordinates. +/// n_atoms +/// Number of atoms in the original (non-periodic) structure, +/// used for mapping periodic copy indices back to original indices. +/// pairs +/// Vector of *[query_idx, atom_idx]* pairs. +/// is_multi_coord +/// If ``False``, return a 1D array for the single query. +fn format_as_mapping( + py: Python<'_>, + n_query: usize, + n_atoms: usize, + pairs: Vec<[usize; 2]>, + is_multi_coord: bool, +) -> PyResult> { + // First pass: count elements per query to find max length + let mut counts: Vec = vec![0; n_query]; + for pair in pairs.iter() { + counts[pair[0]] += 1; + } + let max_length = *counts.iter().max().unwrap_or(&0); + + // Allocate final matrix directly + let mut mapping_matrix = Array2::from_elem((n_query, max_length), -1i64); + + // Reset counts to use as current position tracker + counts.fill(0); + + // Second pass: fill matrix directly + for pair in pairs.iter() { + let row = pair[0]; + let col = counts[row]; + // Map index to potentially periodic copy to original atom index + mapping_matrix[[row, col]] = (pair[1] % n_atoms) as i64; + counts[row] += 1; + } + + if is_multi_coord { + Ok(mapping_matrix.into_pyarray(py).into_any()) + } else { + Ok(mapping_matrix + .row(0) + .into_owned() + .into_pyarray(py) + .into_any()) + } +} + +/// Format query results as a boolean mask. +/// +/// Creates a 2D boolean array of shape *(n_query, n_atoms)* where +/// ``mask[i, j] = True`` indicates that atom *j* is a neighbor of query *i*. +/// +/// Parameters +/// ---------- +/// n_query +/// Number of query coordinates. +/// n_atoms +/// Number of atoms in the original (non-periodic) structure. +/// pairs +/// Vector of *[query_idx, atom_idx]* pairs. +/// is_multi_coord +/// If ``False``, return a 1D boolean array for the single query. +fn format_as_mask( + py: Python<'_>, + n_query: usize, + n_atoms: usize, + pairs: Vec<[usize; 2]>, + is_multi_coord: bool, +) -> PyResult> { + let mut mask = Array2::from_elem((n_query, n_atoms), false); + for pair in pairs.iter() { + // Map index to potentially periodic copy to original atom index + mask[[pair[0], pair[1] % n_atoms]] = true; + } + + if is_multi_coord { + Ok(mask.into_pyarray(py).into_any()) + } else { + Ok(mask.row(0).into_owned().into_pyarray(py).into_any()) + } +} + +/// Format query results as an array of index pairs. +/// +/// Creates a 2D array of shape *(p, 2)* where each row contains +/// *[query_idx, atom_idx]* for a neighboring pair. +/// +/// Parameters +/// ---------- +/// n_atoms +/// Number of atoms in the original (non-periodic) structure, +/// used for mapping periodic copy indices back to original indices. +/// pairs +/// Vector of *[query_idx, atom_idx]* pairs. +/// is_multi_coord +/// If ``False``, return just the atom indices as a 1D array. +fn format_as_pairs( + py: Python<'_>, + n_atoms: usize, + pairs: Vec<[usize; 2]>, + is_multi_coord: bool, +) -> PyResult> { + let n_pairs = pairs.len(); + + if is_multi_coord { + // Create a flat vector of i64 directly + let mut flat_data: Vec = Vec::with_capacity(n_pairs * 2); + for pair in pairs.iter() { + flat_data.push(pair[0] as i64); + // Map index to potentially periodic copy to original atom index + flat_data.push((pair[1] % n_atoms) as i64); + } + let pairs_array = Array2::from_shape_vec((n_pairs, 2), flat_data).expect("Shape mismatch"); + Ok(pairs_array.into_pyarray(py).into_any()) + } else { + // For single coord, return just the second column (atom indices) + let atoms: Vec = pairs + .iter() + .map(|pair| (pair[1] % n_atoms) as i64) + .collect(); + Ok(atoms.into_pyarray(py).into_any()) + } +} + +/// Format query results according to the requested output format. +/// +/// This is the main dispatch function that routes to the appropriate +/// formatting function based on the `result_format` parameter. +/// +/// Parameters +/// ---------- +/// n_query +/// Number of query coordinates. +/// n_atoms +/// Number of atoms in the original (non-periodic) structure. +/// pairs +/// Vector of *[query_idx, atom_idx]* pairs. +/// result_format +/// The desired output format. +/// as_mask +/// If ``True``, overrides `result_format` to ``MASK`` +/// and emits a deprecation warning. +/// is_multi_coord +/// Whether multiple query coordinates were provided. +fn format_result( + py: Python<'_>, + n_query: usize, + n_atoms: usize, + pairs: Vec<[usize; 2]>, + mut result_format: CellListResult, + as_mask: bool, + is_multi_coord: bool, +) -> PyResult> { + if as_mask { + // Raise DeprecationWarning when `as_mask`` is used + let warnings = py.import("warnings")?; + let _ = warnings.call_method1( + "warn", + ( + "The 'as_mask' parameter is deprecated, use 'result_format' instead.", + py.import("builtins")?.getattr("DeprecationWarning")?, + ), + ); + result_format = CellListResult::MASK; + } + match result_format { + CellListResult::MAPPING => format_as_mapping(py, n_query, n_atoms, pairs, is_multi_coord), + CellListResult::MASK => format_as_mask(py, n_query, n_atoms, pairs, is_multi_coord), + CellListResult::PAIRS => format_as_pairs(py, n_atoms, pairs, is_multi_coord), + } +} diff --git a/src/rust/structure/io/pdb/file.rs b/src/rust/structure/io/pdb/file.rs index 0cecb1a6c..b67708472 100644 --- a/src/rust/structure/io/pdb/file.rs +++ b/src/rust/structure/io/pdb/file.rs @@ -55,7 +55,7 @@ mod biotite { /// /// It contains efficient Rust implementation of the methods that would otherwise /// become major bottlenecks -#[pyclass(subclass)] +#[pyclass(subclass, module = "biotite.structure")] pub struct PDBFile { /// Lines of text from the PDB file. #[pyo3(get)] diff --git a/src/rust/structure/mod.rs b/src/rust/structure/mod.rs index 62dd4f515..9b6180efc 100644 --- a/src/rust/structure/mod.rs +++ b/src/rust/structure/mod.rs @@ -1,10 +1,15 @@ use crate::add_subpackage; use pyo3::prelude::*; +mod celllist; mod io; +use celllist::*; + pub fn module<'py>(parent_module: &Bound<'py, PyModule>) -> PyResult> { let module = PyModule::new(parent_module.py(), "structure")?; + module.add_class::()?; + module.add_class::()?; add_subpackage(&module, &io::module(&module)?, "biotite.rust.structure.io")?; Ok(module) } diff --git a/tests/structure/test_celllist.py b/tests/structure/test_celllist.py index 8ddef530b..29d2a825f 100644 --- a/tests/structure/test_celllist.py +++ b/tests/structure/test_celllist.py @@ -2,15 +2,23 @@ # under the 3-Clause BSD License. Please see 'LICENSE.rst' for further # information. -import itertools -from os.path import join +import pickle +from pathlib import Path import numpy as np import pytest import biotite.structure as struc -import biotite.structure.io as strucio +import biotite.structure.io.pdbx as pdbx from tests.util import data_dir +@pytest.fixture +def atoms(): + pdbx_file = pdbx.BinaryCIFFile.read(Path(data_dir("structure")) / "1l2y.bcif") + atoms = pdbx.get_structure(pdbx_file, model=1) + atoms = atoms[struc.filter_heavy(atoms)] + return atoms + + # Result should be independent of cell size @pytest.mark.parametrize("cell_size", [0.5, 1, 2, 5, 10]) def test_get_atoms(cell_size): @@ -37,41 +45,34 @@ def test_get_atoms(cell_size): assert indices[indices != -1].tolist() == expected_indices -@pytest.mark.parametrize( - "cell_size, threshold, periodic, use_selection", - itertools.product( - [0.5, 1, 2, 5, 10], - [2, 5, 10], - [False, True], - [False, True], - ), -) -def test_adjacency_matrix(cell_size, threshold, periodic, use_selection): +@pytest.mark.parametrize("use_selection", [False, True]) +@pytest.mark.parametrize("periodic", [False, True]) +@pytest.mark.parametrize("threshold", [2, 5, 10]) +@pytest.mark.parametrize("cell_size", [0.5, 1, 2, 5, 10]) +def test_adjacency_matrix(atoms, cell_size, threshold, periodic, use_selection): """ Compare the construction of an adjacency matrix using a cell list and using a computationally expensive but simpler distance matrix. """ - array = strucio.load_structure(join(data_dir("structure"), "3o5r.bcif")) - if periodic: # Create an orthorhombic box # with the outer coordinates as bounds - array.box = np.diag(np.max(array.coord, axis=-2) - np.min(array.coord, axis=-2)) + atoms.box = np.diag(np.max(atoms.coord, axis=-2) - np.min(atoms.coord, axis=-2)) if use_selection: np.random.seed(0) - selection = np.random.choice((False, True), array.array_length()) + selection = np.random.choice((False, True), atoms.array_length()) else: selection = None cell_list = struc.CellList( - array, cell_size=cell_size, periodic=periodic, selection=selection + atoms, cell_size=cell_size, periodic=periodic, selection=selection ) test_matrix = cell_list.create_adjacency_matrix(threshold) - length = array.array_length() + length = atoms.array_length() distance = struc.index_distance( - array, + atoms, np.stack( [np.repeat(np.arange(length), length), np.tile(np.arange(length), length)], axis=-1, @@ -80,62 +81,179 @@ def test_adjacency_matrix(cell_size, threshold, periodic, use_selection): ) distance = np.reshape(distance, (length, length)) # Create adjacency matrix from distance matrix - exp_matrix = distance <= threshold + ref_matrix = distance <= threshold if use_selection: # Set rows and columns to False for filtered out atoms - exp_matrix[~selection, :] = False - exp_matrix[:, ~selection] = False + ref_matrix[~selection, :] = False + ref_matrix[:, ~selection] = False # Both ways to create an adjacency matrix # should give the same result - assert np.array_equal(test_matrix, exp_matrix) + assert np.array_equal(test_matrix, ref_matrix) + + +@pytest.mark.parametrize( + "result_format", + [ + struc.CellListResult.MAPPING, + struc.CellListResult.MASK, + struc.CellListResult.PAIRS, + ], + ids=lambda format: str(format).split(".")[-1], +) +def test_result_format_consistency(atoms, result_format): + """ + Test that independent of the result format, the matrix constructed from it is + consistent. + """ + CELL_SIZE = 5 + + cell_list = struc.CellList(atoms, cell_size=CELL_SIZE) + ref_matrix = cell_list.create_adjacency_matrix(threshold_distance=CELL_SIZE) + + result = cell_list.get_atoms(atoms.coord, CELL_SIZE, result_format=result_format) + match result_format: + case struc.CellListResult.MAPPING: + mapping = result + test_matrix = np.zeros( + (atoms.array_length(), atoms.array_length()), dtype=bool + ) + for i, adjacent_atoms in enumerate(mapping): + for j in adjacent_atoms[adjacent_atoms != -1]: + test_matrix[i, j] = True + case struc.CellListResult.MASK: + test_matrix = result + case struc.CellListResult.PAIRS: + pairs = result + test_matrix = np.zeros( + (atoms.array_length(), atoms.array_length()), dtype=bool + ) + test_matrix[pairs[:, 0], pairs[:, 1]] = True + + assert np.array_equal(test_matrix, ref_matrix) -def test_outside_location(): +@pytest.mark.parametrize("displacement", [-1000, 1000]) +def test_outside_location(atoms, displacement): """ Test result for location outside any cell. """ - array = strucio.load_structure(join(data_dir("structure"), "3o5r.bcif")) - array = array[struc.filter_amino_acids(array)] - cell_list = struc.CellList(array, cell_size=5) - outside_coord = np.min(array.coord, axis=0) - 100 + CELL_SIZE = 5 + + cell_list = struc.CellList(atoms, cell_size=CELL_SIZE) + outside_coord = np.mean(atoms.coord, axis=0) + displacement # Expect empty array - assert len(cell_list.get_atoms(outside_coord, 5)) == 0 + assert len(cell_list.get_atoms(outside_coord, CELL_SIZE)) == 0 + + +@pytest.mark.parametrize( + "method", [struc.CellList.get_atoms, struc.CellList.get_atoms_in_cells] +) +def test_single_and_multiple_radii(atoms, method): + """ + Check if getting neighbors with multiple radii results in the same result as getting neighbors + with a single radius for each coordinate. + """ + CELL_SIZE = 5 + RADIUS_RANGE = (2, 10) + N_SAMPLES = 100 + + cell_list = struc.CellList(atoms, cell_size=CELL_SIZE) + + # Pick random radii + rng = np.random.default_rng(0) + radii = rng.integers(RADIUS_RANGE[0], RADIUS_RANGE[1], size=N_SAMPLES) + + mutliple_radius_result = method( + cell_list, + atoms.coord[:N_SAMPLES], + radii, + result_format=struc.CellListResult.MAPPING, + ) + max_neighbors = mutliple_radius_result.shape[-1] + + multiple_radius_result = method( + cell_list, + atoms.coord[:N_SAMPLES], + radii, + result_format=struc.CellListResult.MAPPING, + ) + + single_radius_result = np.full((N_SAMPLES, max_neighbors), -1, dtype=int) + for i in range(N_SAMPLES): + result = method( + cell_list, + atoms.coord[i], + radii[i], + result_format=struc.CellListResult.MAPPING, + ) + single_radius_result[i, : len(result)] = result + + assert np.array_equal(multiple_radius_result, single_radius_result) -def test_selection(): +def test_selection(atoms): """ Test whether the `selection` parameter in the constructor works. This is tested by comparing the selection done prior to cell list creation with the selection done in the cell list construction. """ - array = strucio.load_structure(join(data_dir("structure"), "3o5r.bcif")) - selection = np.array([False, True] * (array.array_length() // 2)) + selection = np.array([False, True] * (atoms.array_length() // 2)) # Selection prior to cell list creation - selected = array[selection] + selected = atoms[selection] cell_list = struc.CellList(selected, cell_size=10) - ref_near_atoms = selected[cell_list.get_atoms(array.coord[0], 20.0)] + ref_near_atoms = selected[cell_list.get_atoms(atoms.coord[0], 20.0)] # Selection in cell list creation - cell_list = struc.CellList(array, cell_size=10, selection=selection) - test_near_atoms = array[cell_list.get_atoms(array.coord[0], 20.0)] + cell_list = struc.CellList(atoms, cell_size=10, selection=selection) + test_near_atoms = atoms[cell_list.get_atoms(atoms.coord[0], 20.0)] assert test_near_atoms == ref_near_atoms -def test_empty_coordinates(): +@pytest.mark.parametrize( + "method", [struc.CellList.get_atoms, struc.CellList.get_atoms_in_cells] +) +@pytest.mark.parametrize( + "result_format", + [ + struc.CellListResult.MAPPING, + struc.CellListResult.MASK, + struc.CellListResult.PAIRS, + ], + ids=lambda format: str(format).split(".")[-1], +) +def test_empty_coordinates(atoms, method, result_format): """ Test whether empty input coordinates result in an empty output array/mask. """ - array = strucio.load_structure(join(data_dir("structure"), "3o5r.bcif")) - cell_list = struc.CellList(array, cell_size=10) + n_atoms = atoms.array_length() + cell_list = struc.CellList(atoms, cell_size=10) + + result = method(cell_list, np.zeros((0, 3)), 1, result_format=result_format) + match result_format: + case struc.CellListResult.MAPPING: + assert len(result) == 0 + case struc.CellListResult.MASK: + assert result.shape == (0, n_atoms) + case struc.CellListResult.PAIRS: + assert result.shape == (0, 2) + + +def test_pickle(atoms): + """ + Test whether the CellList can be pickled and unpickled. + The unpickled cell list should give the same results as the original cell list. + """ + N_SAMPLES = 10 + CELL_SIZE = 5 - for method in (struc.CellList.get_atoms, struc.CellList.get_atoms_in_cells): - indices = method(cell_list, np.array([]), 1, as_mask=False) - mask = method(cell_list, np.array([]), 1, as_mask=True) - assert len(indices) == 0 - assert len(mask) == 0 - assert indices.dtype == np.int32 - assert mask.dtype == bool + original_cell_list = struc.CellList(atoms, cell_size=CELL_SIZE) + pickled = pickle.dumps(original_cell_list) + unpickled_cell_list = pickle.loads(pickled) + assert np.array_equal( + original_cell_list.get_atoms(atoms.coord[:N_SAMPLES], CELL_SIZE), + unpickled_cell_list.get_atoms(atoms.coord[:N_SAMPLES], CELL_SIZE), + ) diff --git a/tests/test_modname.py b/tests/test_modname.py index 131c533bf..67f83d952 100644 --- a/tests/test_modname.py +++ b/tests/test_modname.py @@ -17,7 +17,10 @@ def find_all_modules(package_name, src_dir): module_names = [] for _, module_name, is_package in pkgutil.iter_modules([src_dir]): if module_name == "setup_ccd": - # This module is not intended to be imported + # This script is not intended to be imported + continue + if module_name == "rust": + # The Rust extension module is not directly user-facing continue full_module_name = f"{package_name}.{module_name}" if is_package: From 758b3464086a689b7558817b0934ca8a8173b8ac Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Wed, 11 Feb 2026 19:13:02 +0100 Subject: [PATCH 2/4] Move `CellListResult` into `CellList.Result` --- benchmarks/structure/benchmark_celllist.py | 13 ++++++-- src/biotite/structure/__init__.py | 2 +- src/biotite/structure/celllist.py | 14 ++++++++ src/biotite/structure/compare.py | 6 ++-- src/biotite/structure/hbond.py | 4 +-- src/biotite/structure/rdf.py | 4 +-- src/rust/structure/celllist.rs | 38 +++++++++++----------- tests/structure/test_celllist.py | 30 ++++++++--------- 8 files changed, 67 insertions(+), 44 deletions(-) create mode 100644 src/biotite/structure/celllist.py diff --git a/benchmarks/structure/benchmark_celllist.py b/benchmarks/structure/benchmark_celllist.py index 1d41a4405..3c57896f2 100644 --- a/benchmarks/structure/benchmark_celllist.py +++ b/benchmarks/structure/benchmark_celllist.py @@ -24,9 +24,18 @@ def benchmark_cell_list_creation(atoms): struc.CellList(atoms, 5.0) +@pytest.mark.parametrize( + "result_format", + [ + struc.CellList.Result.MAPPING, + struc.CellList.Result.MASK, + struc.CellList.Result.PAIRS, + ], + ids=lambda format: str(format).split(".")[-1], +) @pytest.mark.benchmark -def benchmark_cell_list_compute_contacts(cell_list, atoms): +def benchmark_cell_list_compute_contacts(cell_list, atoms, result_format): """ Find all contacts in a structure using an existing cell list. """ - cell_list.get_atoms(atoms.coord, 5.0) + cell_list.get_atoms(atoms.coord, 5.0, result_format=result_format) diff --git a/src/biotite/structure/__init__.py b/src/biotite/structure/__init__.py index 217f42320..2fce5fc40 100644 --- a/src/biotite/structure/__init__.py +++ b/src/biotite/structure/__init__.py @@ -108,11 +108,11 @@ __name__ = "biotite.structure" __author__ = "Patrick Kunzmann" -from biotite.rust.structure import * from .atoms import * from .basepairs import * from .bonds import * from .box import * +from .celllist import * from .chains import * from .charges import * from .compare import * diff --git a/src/biotite/structure/celllist.py b/src/biotite/structure/celllist.py new file mode 100644 index 000000000..63129721d --- /dev/null +++ b/src/biotite/structure/celllist.py @@ -0,0 +1,14 @@ +# This source code is part of the Biotite package and is distributed +# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further +# information. + +__name__ = "biotite.structure" +__author__ = "Patrick Kunzmann" +__all__ = ["CellList"] + +from biotite.rust.structure import CellList, CellListResult + +# Expose the `CellListResult` enum as more ergonomic `CellList.Result` to the user +CellListResult.__name__ = "Result" +CellListResult.__qualname__ = "CellList.Result" +CellList.Result = CellListResult diff --git a/src/biotite/structure/compare.py b/src/biotite/structure/compare.py index 5da956d35..83b066cf2 100644 --- a/src/biotite/structure/compare.py +++ b/src/biotite/structure/compare.py @@ -14,7 +14,7 @@ import collections.abc import warnings import numpy as np -from biotite.rust.structure import CellList, CellListResult +from biotite.rust.structure import CellList from biotite.structure.atoms import AtomArray, AtomArrayStack, coord from biotite.structure.chains import get_chain_count, get_chain_positions from biotite.structure.geometry import index_distance @@ -570,11 +570,11 @@ def _find_contacts( # Pairs of indices for atoms within the inclusion radius if atom_mask is None: contacts = cell_list.get_atoms( - coords, inclusion_radius, result_format=CellListResult.PAIRS + coords, inclusion_radius, result_format=CellList.Result.PAIRS ) else: contacts = cell_list.get_atoms( - coords[atom_mask], inclusion_radius, result_format=CellListResult.PAIRS + coords[atom_mask], inclusion_radius, result_format=CellList.Result.PAIRS ) # Map indices from masked indices back to original indices mapping = np.nonzero(atom_mask)[0] diff --git a/src/biotite/structure/hbond.py b/src/biotite/structure/hbond.py index f3bf39963..60ace1e02 100644 --- a/src/biotite/structure/hbond.py +++ b/src/biotite/structure/hbond.py @@ -12,7 +12,7 @@ import warnings import numpy as np -from biotite.rust.structure import CellList, CellListResult +from biotite.rust.structure import CellList from biotite.structure.atoms import AtomArrayStack, stack from biotite.structure.filter import filter_heavy from biotite.structure.geometry import angle, distance @@ -296,7 +296,7 @@ def _hbond( donor_h_coord, cell_size=cutoff_dist, periodic=periodic, box=box_for_model ) possible_bonds |= cell_list.get_atoms_in_cells( - acceptor_coord, result_format=CellListResult.MASK + acceptor_coord, result_format=CellList.Result.MASK ) possible_bonds_i = np.where(possible_bonds) # Narrow down diff --git a/src/biotite/structure/rdf.py b/src/biotite/structure/rdf.py index ff1c436d3..aa6790cbb 100644 --- a/src/biotite/structure/rdf.py +++ b/src/biotite/structure/rdf.py @@ -12,7 +12,7 @@ from numbers import Integral import numpy as np -from biotite.rust.structure import CellList, CellListResult +from biotite.rust.structure import CellList from biotite.structure.atoms import AtomArray, coord, stack from biotite.structure.box import box_volume from biotite.structure.geometry import displacement @@ -203,7 +203,7 @@ def rdf( # interval (and more), since the size of each cell is as large # as the last edge of the bins near_atom_mask = cell_list.get_atoms_in_cells( - center[i], result_format=CellListResult.MASK + center[i], result_format=CellList.Result.MASK ) # Calculate distances of each center to preselected atoms # for each center diff --git a/src/rust/structure/celllist.rs b/src/rust/structure/celllist.rs index f45bd1896..f5a97f6b7 100644 --- a/src/rust/structure/celllist.rs +++ b/src/rust/structure/celllist.rs @@ -45,24 +45,24 @@ mod biotite { /// >>> single_coord = atoms.coord[0] /// >>> multiple_coords = atoms.coord[:2] /// >>> # MAPPING: indices of neighboring atoms, -1 indicates padding values to be ignored -/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellListResult.MAPPING)) +/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellList.Result.MAPPING)) /// [6 1 0 7] -/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellListResult.MAPPING)) +/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellList.Result.MAPPING)) /// [[ 6 1 0 7 -1] /// [ 2 1 0 4 8]] /// >>> # MASK: boolean mask indicating neighbors -/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellListResult.MASK)) +/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellList.Result.MASK)) /// [ True True False False False False True True False False False False /// False] -/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellListResult.MASK)) +/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellList.Result.MASK)) /// [[ True True False False False False True True False False False False /// False] /// [ True True True False True False False False True False False False /// False]] /// >>> # PAIRS: (query_idx, atom_idx) tuples -/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellListResult.PAIRS)) +/// >>> print(cell_list.get_atoms(single_coord, radius=2, result_format=CellList.Result.PAIRS)) /// [6 1 0 7] -/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellListResult.PAIRS)) +/// >>> print(cell_list.get_atoms(multiple_coords, radius=2, result_format=CellList.Result.PAIRS)) /// [[0 6] /// [0 1] /// [0 0] @@ -74,7 +74,7 @@ mod biotite { /// [1 8]] #[allow(clippy::upper_case_acronyms)] #[derive(Clone, Copy, PartialEq)] -#[pyclass] +#[pyclass(module = "biotite.structure")] pub enum CellListResult { MAPPING, MASK, @@ -651,7 +651,7 @@ impl CellList { } } - /// get_atoms(coord, radius, as_mask=False, result_format=CellListResult.MAPPING) + /// get_atoms(coord, radius, as_mask=False, result_format=CellList.Result.MAPPING) /// /// Find atoms with a maximum distance from given coordinates. /// @@ -666,18 +666,18 @@ impl CellList { /// radii for each position in `coord` can be provided as /// :class:`ndarray`. /// as_mask : bool, optional - /// **Deprecated:** Use ``result_format=CellListResult.MASK`` instead. + /// **Deprecated:** Use ``result_format=CellList.Result.MASK`` instead. /// If true, the result is returned as boolean mask instead /// of an index array. - /// result_format : CellListResult, optional - /// The format of the result. See :class:`CellListResult` for options. - /// Default is ``CellListResult.MAPPING``. + /// result_format : CellList.Result, optional + /// The format of the result. See :class:`CellList.Result` for options. + /// Default is ``CellList.Result.MAPPING``. /// /// Returns /// ------- /// result : ndarray /// The result format depends on `result_format`. - /// See :class:`CellListResult` for details. + /// See :class:`CellList.Result` for details. /// /// See Also /// -------- @@ -723,7 +723,7 @@ impl CellList { ) } - /// get_atoms_in_cells(coord, cell_radius=1, as_mask=False, result_format=CellListResult.MAPPING) + /// get_atoms_in_cells(coord, cell_radius=1, as_mask=False, result_format=CellList.Result.MAPPING) /// /// Find atoms with a maximum cell distance from given coordinates. /// @@ -751,18 +751,18 @@ impl CellList { /// By default, atoms are searched in the cell of `coord` /// and directly adjacent cells (``cell_radius=1``). /// as_mask : bool, optional - /// **Deprecated:** Use ``result_format=CellListResult.MASK`` instead. + /// **Deprecated:** Use ``result_format=CellList.Result.MASK`` instead. /// If true, the result is returned as boolean mask instead /// of an index array. - /// result_format : CellListResult, optional - /// The format of the result. See :class:`CellListResult` for options. - /// Default is ``CellListResult.MAPPING``. + /// result_format : CellList.Result, optional + /// The format of the result. See :class:`CellList.Result` for options. + /// Default is ``CellList.Result.MAPPING``. /// /// Returns /// ------- /// result : ndarray /// The result format depends on `result_format`. - /// See :class:`CellListResult` for details. + /// See :class:`CellList.Result` for details. /// /// See Also /// -------- diff --git a/tests/structure/test_celllist.py b/tests/structure/test_celllist.py index 29d2a825f..1b7a5b781 100644 --- a/tests/structure/test_celllist.py +++ b/tests/structure/test_celllist.py @@ -95,9 +95,9 @@ def test_adjacency_matrix(atoms, cell_size, threshold, periodic, use_selection): @pytest.mark.parametrize( "result_format", [ - struc.CellListResult.MAPPING, - struc.CellListResult.MASK, - struc.CellListResult.PAIRS, + struc.CellList.Result.MAPPING, + struc.CellList.Result.MASK, + struc.CellList.Result.PAIRS, ], ids=lambda format: str(format).split(".")[-1], ) @@ -113,7 +113,7 @@ def test_result_format_consistency(atoms, result_format): result = cell_list.get_atoms(atoms.coord, CELL_SIZE, result_format=result_format) match result_format: - case struc.CellListResult.MAPPING: + case struc.CellList.Result.MAPPING: mapping = result test_matrix = np.zeros( (atoms.array_length(), atoms.array_length()), dtype=bool @@ -121,9 +121,9 @@ def test_result_format_consistency(atoms, result_format): for i, adjacent_atoms in enumerate(mapping): for j in adjacent_atoms[adjacent_atoms != -1]: test_matrix[i, j] = True - case struc.CellListResult.MASK: + case struc.CellList.Result.MASK: test_matrix = result - case struc.CellListResult.PAIRS: + case struc.CellList.Result.PAIRS: pairs = result test_matrix = np.zeros( (atoms.array_length(), atoms.array_length()), dtype=bool @@ -168,7 +168,7 @@ def test_single_and_multiple_radii(atoms, method): cell_list, atoms.coord[:N_SAMPLES], radii, - result_format=struc.CellListResult.MAPPING, + result_format=struc.CellList.Result.MAPPING, ) max_neighbors = mutliple_radius_result.shape[-1] @@ -176,7 +176,7 @@ def test_single_and_multiple_radii(atoms, method): cell_list, atoms.coord[:N_SAMPLES], radii, - result_format=struc.CellListResult.MAPPING, + result_format=struc.CellList.Result.MAPPING, ) single_radius_result = np.full((N_SAMPLES, max_neighbors), -1, dtype=int) @@ -185,7 +185,7 @@ def test_single_and_multiple_radii(atoms, method): cell_list, atoms.coord[i], radii[i], - result_format=struc.CellListResult.MAPPING, + result_format=struc.CellList.Result.MAPPING, ) single_radius_result[i, : len(result)] = result @@ -218,9 +218,9 @@ def test_selection(atoms): @pytest.mark.parametrize( "result_format", [ - struc.CellListResult.MAPPING, - struc.CellListResult.MASK, - struc.CellListResult.PAIRS, + struc.CellList.Result.MAPPING, + struc.CellList.Result.MASK, + struc.CellList.Result.PAIRS, ], ids=lambda format: str(format).split(".")[-1], ) @@ -234,11 +234,11 @@ def test_empty_coordinates(atoms, method, result_format): result = method(cell_list, np.zeros((0, 3)), 1, result_format=result_format) match result_format: - case struc.CellListResult.MAPPING: + case struc.CellList.Result.MAPPING: assert len(result) == 0 - case struc.CellListResult.MASK: + case struc.CellList.Result.MASK: assert result.shape == (0, n_atoms) - case struc.CellListResult.PAIRS: + case struc.CellList.Result.PAIRS: assert result.shape == (0, 2) From 4995653e9b9da3a7098d8085976f6b8ad53f41bd Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Wed, 11 Feb 2026 19:29:42 +0100 Subject: [PATCH 3/4] Add links to Rust source --- doc/viewcode.py | 317 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 246 insertions(+), 71 deletions(-) diff --git a/doc/viewcode.py b/doc/viewcode.py index ec0b28974..2ebece2f1 100644 --- a/doc/viewcode.py +++ b/doc/viewcode.py @@ -10,14 +10,161 @@ __author__ = "Patrick Kunzmann" __all__ = ["linkcode_resolve"] +import ast import inspect +import re +from enum import Enum, auto from importlib import import_module -from os import listdir -from os.path import dirname, isdir, join, splitext +from pathlib import Path import biotite -def _index_attributes(package_name, src_path): +class Source(Enum): + """Type of source file for an attribute.""" + + PYTHON = auto() + CYTHON = auto() + RUST = auto() + + +def _index_rust_code(code_lines): + """ + Find the line position of structs and enums in *Rust* files. + + This analyzer looks for `pub struct` and `pub enum` definitions + decorated with `#[pyclass]`. + + Parameters + ---------- + code_lines : list of str + The *Rust* source code split into lines. + + Returns + ------- + line_index : dict (str -> tuple(int, int)) + Maps an attribute name to its first and last line in a Rust + module. + """ + line_index = {} + + # Track pyclass decorator lines + pyclass_line = None + + for i, line in enumerate(code_lines): + stripped_line = line.strip() + + # Skip empty and comment lines + if len(stripped_line) == 0 or stripped_line.startswith("//"): + continue + + # Check for #[pyclass] decorator + if stripped_line.startswith("#[pyclass"): + pyclass_line = i + continue + + # Check for pub struct or pub enum after pyclass + if pyclass_line is not None: + match = re.match(r"pub\s+(struct|enum)\s+(\w+)", stripped_line) + if match: + attr_name = match.group(2) + attr_line_start = pyclass_line + + # Find the end of the struct/enum by matching braces + brace_count = 0 + started = False + attr_line_stop = i + 1 + + for j in range(i, len(code_lines)): + for char in code_lines[j]: + if char == "{": + brace_count += 1 + started = True + elif char == "}": + brace_count -= 1 + if started and brace_count == 0: + attr_line_stop = j + 1 + break + + line_index[attr_name] = ( + # 'One' based indexing + attr_line_start + 1, + # 'One' based indexing and inclusive stop + attr_line_stop, + ) + pyclass_line = None + + return line_index + + +def _index_rust_files(rust_src_path): + """ + Index all Rust source files and their pyclass-decorated attributes. + + Parameters + ---------- + rust_src_path : Path + Path to the Rust source directory (src/rust). + + Returns + ------- + rust_attribute_index : dict(str -> str) + Maps attribute names to their Rust file paths (relative to src/). + rust_line_index : dict(str -> tuple(int, int)) + Maps attribute names to their first and last line in the Rust file. + """ + rust_attribute_index = {} + rust_line_index = {} + + for file_path in rust_src_path.rglob("*.rs"): + lines = file_path.read_text().splitlines() + line_positions = _index_rust_code(lines) + for attr_name, (first, last) in line_positions.items(): + # Path relative to src/ directory + rel_path = file_path.relative_to(rust_src_path.parent) + rust_attribute_index[attr_name] = str(rel_path) + rust_line_index[attr_name] = (first, last) + + return rust_attribute_index, rust_line_index + + +def _get_rust_imports(module_path): + """ + Parse a Python module file to find attributes imported from biotite.rust. + + Parameters + ---------- + module_path : Path + Path to the Python module file. + + Returns + ------- + rust_imports : set of str + Names of attributes imported from biotite.rust. + """ + rust_imports = set() + + try: + tree = ast.parse(module_path.read_text()) + except SyntaxError: + return rust_imports + + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + if node.module and node.module.startswith("biotite.rust"): + for alias in node.names: + # Use the local name (asname if aliased, otherwise name) + local_name = alias.asname if alias.asname else alias.name + rust_imports.add(local_name) + + return rust_imports + + +def _index_attributes( + package_name, + src_path, + rust_attribute_index=None, + rust_line_index=None, +): """ Assign a Python module to each combination of (sub)package and attribute (e.g. function, class, etc.) in a given (sub)package. @@ -26,59 +173,60 @@ def _index_attributes(package_name, src_path): ---------- package_name : str Name of the (sub)package. - src_path : str + src_path : Path File path to `package_name`. + rust_attribute_index, rust_line_index : dict or None + Indices for Rust attributes. + If None (first call), they are computed from the Rust source files. - Parameters - ---------- - attribute_index : dict( tuple(str, str) -> (str, bool)) + Returns + ------- + attribute_index : dict( tuple(str, str) -> (str, Source)) Maps the combination of (sub)package name and attribute to - the name of a Python module and to a boolean value that - indicates, whether it is a Cython module. - cython_line_index : dict( tuple(str, str) -> tuple(int, int) ) ) + the name of a Python module and to the source type. + extension_line_index : dict( tuple(str, str) -> tuple(int, int) ) ) Maps the combination of (sub)package name and attribute to - the first and last line in a Cython module. - Does not contain entries for attributes that are not part of a - Cython module. + the first and last line in an extension module (Cython or Rust). + Does not contain entries for attributes that are not part of an + extension module. """ + if rust_attribute_index is None: + rust_attribute_index, rust_line_index = _index_rust_files( + src_path.parent / "rust" + ) + if not _is_package(src_path): # Directory is not a Python package/subpackage # -> Nothing to do return {}, {} attribute_index = {} - cython_line_index = {} - - # Identify all subdirectories... - directory_content = listdir(src_path) - dirs = [f for f in directory_content if isdir(join(src_path, f))] - # ... and index them recursively - for directory in dirs: - sub_attribute_index, sub_cython_line_index = _index_attributes( - f"{package_name}.{directory}", - join(src_path, directory), - ) - attribute_index.update(sub_attribute_index) - cython_line_index.update(sub_cython_line_index) + extension_line_index = {} + + # Identify all subdirectories and index them recursively + for subdir in src_path.iterdir(): + if subdir.is_dir(): + sub_attribute_index, sub_extension_line_index = _index_attributes( + f"{package_name}.{subdir.name}", + subdir, + rust_attribute_index, + rust_line_index, + ) + attribute_index.update(sub_attribute_index) + extension_line_index.update(sub_extension_line_index) # Import package package = import_module(package_name) # Import all modules in directory and index attributes source_files = [ - file_name - for file_name in directory_content - if file_name != "__init__.py" - and ( - # Standard Python modules - file_name.endswith(".py") - or - # Extension modules - file_name.endswith(".pyx") - ) + f + for f in src_path.iterdir() + if f.is_file() and f.name != "__init__.py" and f.suffix in (".py", ".pyx") ] + for source_file in source_files: - module_name = f"{package_name}.{splitext(source_file)[0]}" + module_name = f"{package_name}.{source_file.stem}" if module_name == "biotite.version": # Autogenerated module from hatch-vcs # It contains no '__all__' attribute on purpose @@ -97,16 +245,33 @@ def _index_attributes(package_name, src_path): if not all([hasattr(package, attr) for attr in module.__all__]): continue - is_cython = source_file.endswith(".pyx") + # Determine source type + is_cython = source_file.suffix == ".pyx" + rust_imports = set() if is_cython else _get_rust_imports(source_file) + for attribute in module.__all__: - attribute_index[(package_name, attribute)] = (module_name, is_cython) + if attribute in rust_imports and attribute in rust_attribute_index: + # Attribute is imported from Rust + source_type = Source.RUST + rust_file = rust_attribute_index[attribute] + attribute_index[(package_name, attribute)] = (rust_file, source_type) + if attribute in rust_line_index: + extension_line_index[(package_name, attribute)] = rust_line_index[ + attribute + ] + elif is_cython: + source_type = Source.CYTHON + attribute_index[(package_name, attribute)] = (module_name, source_type) + else: + source_type = Source.PYTHON + attribute_index[(package_name, attribute)] = (module_name, source_type) + if is_cython: - with open(join(src_path, source_file), "r") as cython_file: - lines = cython_file.read().splitlines() + lines = source_file.read_text().splitlines() for attribute, (first, last) in _index_cython_code(lines).items(): - cython_line_index[(package_name, attribute)] = (first, last) + extension_line_index[(package_name, attribute)] = (first, last) - return attribute_index, cython_line_index + return attribute_index, extension_line_index def _index_cython_code(code_lines): @@ -197,14 +362,12 @@ def _index_cython_code(code_lines): def _is_package(path): - content = listdir(path) - return "__init__.py" in content + return (path / "__init__.py").exists() -_attribute_index, _cython_line_index = _index_attributes( +_attribute_index, _extension_line_index = _index_attributes( "biotite", - # Directory to src/biotite - join(dirname(dirname(__file__)), "src", "biotite"), + Path(__file__).parent.parent / "src" / "biotite", ) @@ -218,36 +381,48 @@ def linkcode_resolve(domain, info): package_name = info["module"] attr_name = info["fullname"] try: - module_name, is_cython = _attribute_index[(package_name, attr_name)] + module_or_path, source_type = _attribute_index[(package_name, attr_name)] except KeyError: # The attribute is not defined within Biotite # It may be e.g. an inherited method from an external source return None - if is_cython: - if (package_name, attr_name) in _cython_line_index: - first, last = _cython_line_index[(package_name, attr_name)] - return base_url + f"{module_name.replace('.', '/')}.pyx#L{first}-L{last}" - else: - # In case the attribute is not found - # by the Cython code analyzer - return base_url + f"{module_name.replace('.', '/')}.pyx" + match source_type: + case Source.RUST: + if (package_name, attr_name) in _extension_line_index: + first, last = _extension_line_index[(package_name, attr_name)] + return base_url + f"{module_or_path}#L{first}-L{last}" + else: + return base_url + f"{module_or_path}" + + case Source.CYTHON: + module_name = module_or_path + if (package_name, attr_name) in _extension_line_index: + first, last = _extension_line_index[(package_name, attr_name)] + return ( + base_url + f"{module_name.replace('.', '/')}.pyx#L{first}-L{last}" + ) + else: + # In case the attribute is not found + # by the Cython code analyzer + return base_url + f"{module_name.replace('.', '/')}.pyx" - else: - module = import_module(module_name) + case Source.PYTHON: + module_name = module_or_path + module = import_module(module_name) - # Get the object defined by the attribute name, - # by traversing the 'attribute tree' to the leaf - obj = module - for attr_name_part in attr_name.split("."): - obj = getattr(obj, attr_name_part) + # Get the object defined by the attribute name, + # by traversing the 'attribute tree' to the leaf + obj = module + for attr_name_part in attr_name.split("."): + obj = getattr(obj, attr_name_part) - # Temporarily change the '__module__' attribute, which is set - # to the subpackage in Biotite, back to the actual module in - # order to fool Python's inspect module - obj.__module__ = module_name + # Temporarily change the '__module__' attribute, which is set + # to the subpackage in Biotite, back to the actual module in + # order to fool Python's inspect module + obj.__module__ = module_name - source_lines, first = inspect.getsourcelines(obj) - last = first + len(source_lines) - 1 + source_lines, first = inspect.getsourcelines(obj) + last = first + len(source_lines) - 1 - return base_url + f"{module_name.replace('.', '/')}.py#L{first}-L{last}" + return base_url + f"{module_name.replace('.', '/')}.py#L{first}-L{last}" From a99d0ebabef211a630336049782103f1a51726cc Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Thu, 12 Feb 2026 14:31:33 +0100 Subject: [PATCH 4/4] Fix test --- tests/interface/test_rdkit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/interface/test_rdkit.py b/tests/interface/test_rdkit.py index 9100322e7..78ab9eed2 100644 --- a/tests/interface/test_rdkit.py +++ b/tests/interface/test_rdkit.py @@ -37,6 +37,8 @@ def test_conversion_from_biotite(res_name): Run this on randomly selected molecules from the CCD. """ ref_atoms = info.residue(res_name, allow_missing_coord=True) + if np.any(ref_atoms.element == "X"): + pytest.skip("Molecule contains an atom with unknown element") mol = rdkit_interface.to_mol(ref_atoms) test_atoms = rdkit_interface.from_mol(mol, add_hydrogen=False)