From 055ad4f7952116c900693d8b7ddbaa853829a8da Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Sat, 14 Feb 2026 12:24:07 +0100 Subject: [PATCH] Implement `sasa()` in Rust --- src/biotite/structure/sasa.py | 171 +++++++++++++++++ src/biotite/structure/sasa.pyx | 322 --------------------------------- src/rust/lib.rs | 1 + src/rust/structure/celllist.rs | 219 +++++++++------------- src/rust/structure/mod.rs | 7 +- src/rust/structure/sasa.rs | 173 ++++++++++++++++++ src/rust/structure/util.rs | 61 +++++++ src/rust/util.rs | 13 ++ 8 files changed, 508 insertions(+), 459 deletions(-) create mode 100644 src/biotite/structure/sasa.py delete mode 100644 src/biotite/structure/sasa.pyx create mode 100644 src/rust/structure/sasa.rs create mode 100644 src/rust/structure/util.rs create mode 100644 src/rust/util.rs diff --git a/src/biotite/structure/sasa.py b/src/biotite/structure/sasa.py new file mode 100644 index 000000000..4b374b113 --- /dev/null +++ b/src/biotite/structure/sasa.py @@ -0,0 +1,171 @@ +# This source code is part of the Biotite package and is distributed +# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further +# information. + +""" +Use this module to calculate the Solvent Accessible Surface Area (SASA) of +a protein or single atoms. +""" + +__name__ = "biotite.structure" +__author__ = "Patrick Kunzmann" +__all__ = ["sasa"] + +import numpy as np +from biotite.rust.structure import sasa as rust_sasa +from biotite.structure.filter import ( + filter_heavy, + filter_monoatomic_ions, + filter_solvent, +) +from biotite.structure.info.radii import vdw_radius_protor, vdw_radius_single + + +def sasa( + array, + probe_radius=1.4, + atom_filter=None, + ignore_ions=True, + point_number=1000, + point_distr="Fibonacci", + vdw_radii="ProtOr", +): + """ + sasa(array, probe_radius=1.4, atom_filter=None, ignore_ions=True, + point_number=1000, point_distr="Fibonacci", vdw_radii="ProtOr") + + Calculate the Solvent Accessible Surface Area (SASA) of a protein. + + This function uses the Shrake-Rupley ("rolling probe") + algorithm :footcite:`Shrake1973`: + Every atom is occupied by a evenly distributed point mesh. The + points that can be reached by the "rolling probe", are surface + accessible. + + Parameters + ---------- + array : AtomArray + The protein model to calculate the SASA for. + probe_radius : float, optional + The VdW-radius of the solvent molecules. + atom_filter : ndarray, dtype=bool, optional + If this parameter is given, SASA is only calculated for the + filtered atoms. + ignore_ions : bool, optional + If true, all monoatomic ions are removed before SASA calculation. + point_number : int, optional + The number of points in the mesh occupying each atom for SASA + calculation. + The SASA calculation time is proportional to the amount of sphere points. + point_distr : str or function, optional + If a function is given, the function is used to calculate the + point distribution for the mesh. + The function must take `float` *n* as parameter and return a + *(n x 3)* :class:`ndarray` containing points on the surface of a unit sphere. + Alternatively a string can be given to choose a built-in + distribution: + + - **Fibonacci** - Distribute points using a golden section spiral. + + By default *Fibonacci* is used. + vdw_radii : str or ndarray, dtype=float, optional + Indicates the set of VdW radii to be used. If an `array`-length + :class:`ndarray` is given, each atom gets the radius at the + corresponding index. Radii given for atoms that are not used in + SASA calculation (e.g. solvent atoms) can have arbitrary values + (e.g. `NaN`). If instead a string is given, one of the + built-in sets is used: + + - **ProtOr** - A set, which does not require hydrogen atoms + in the model. Suitable for crystal structures. + :footcite:`Tsai1999` + - **Single** - A set, which uses a defined VdW radius for + every single atom, therefore hydrogen atoms are required + in the model (e.g. NMR elucidated structures). + Values for main group elements are taken from :footcite:`Mantina2009`, + and for relevant transition metals from the :footcite:`RDKit`. + + By default *ProtOr* is used. + + Returns + ------- + sasa : ndarray, dtype=bool, shape=(n,) + Atom-wise SASA. `NaN` for atoms where SASA has not been + calculated + (solvent atoms, hydrogen atoms (ProtOr), atoms not in `filter`). + + References + ---------- + + .. footbibliography:: + """ + if atom_filter is not None: + # Filter for all atoms to calculate SASA for + sasa_filter = np.array(atom_filter, dtype=bool) + else: + sasa_filter = np.ones(len(array), dtype=bool) + # Only include atoms within finite coordinates + sasa_filter &= np.isfinite(array.coord).all(axis=-1) + # Filter for all atoms that are considered for occlusion calculation + # sasa_filter is subfilter of occlusion_filter + occlusion_filter = np.ones(len(array), dtype=bool) + # Remove water residues, since it is the solvent + filter = ~filter_solvent(array) + sasa_filter = sasa_filter & filter + occlusion_filter = occlusion_filter & filter + if ignore_ions: + filter = ~filter_monoatomic_ions(array) + sasa_filter = sasa_filter & filter + occlusion_filter = occlusion_filter & filter + + if callable(point_distr): + sphere_points = point_distr(point_number) + elif point_distr == "Fibonacci": + sphere_points = _create_fibonacci_points(point_number) + else: + raise ValueError(f"'{point_distr}' is not a valid point distribution") + sphere_points = sphere_points.astype(np.float32, copy=False) + + if isinstance(vdw_radii, np.ndarray): + radii = vdw_radii.astype(np.float32) + if len(radii) != array.array_length(): + raise ValueError( + f"Amount VdW radii ({len(radii)}) and " + f"amount of atoms ({array.array_length()}) are not equal" + ) + elif vdw_radii == "ProtOr": + filter = filter_heavy(array) + sasa_filter = sasa_filter & filter + occlusion_filter = occlusion_filter & filter + radii = np.full(len(array), np.nan, dtype=np.float32) + for i in np.arange(len(radii))[occlusion_filter]: + rad = vdw_radius_protor(array.res_name[i], array.atom_name[i]) + # 1.8 is default radius + radii[i] = rad if rad is not None else 1.8 + elif vdw_radii == "Single": + radii = np.full(len(array), np.nan, dtype=np.float32) + for i in np.arange(len(radii))[occlusion_filter]: + rad = vdw_radius_single(array.element[i]) + # 1.5 is default radius + radii[i] = rad if rad is not None else 1.8 + else: + raise KeyError(f"'{vdw_radii}' is not a valid radii set") + # Increase atom radii by probe size ("rolling probe") + radii += probe_radius + + return rust_sasa(array.coord, radii, sphere_points, sasa_filter, occlusion_filter) + + +def _create_fibonacci_points(n): + """ + Get an array of approximately equidistant points on a unit sphere surface + using a golden section spiral. + """ + phi = (3 - np.sqrt(5)) * np.pi * np.arange(n) + z = np.linspace(1 - 1.0 / n, 1.0 / n - 1, n) + radius = np.sqrt(1 - z * z) + coords = np.zeros((n, 3)) + coords[:, 0] = radius * np.cos(phi) + coords[:, 1] = radius * np.sin(phi) + coords[:, 2] = z + return coords diff --git a/src/biotite/structure/sasa.pyx b/src/biotite/structure/sasa.pyx deleted file mode 100644 index a8561a6e3..000000000 --- a/src/biotite/structure/sasa.pyx +++ /dev/null @@ -1,322 +0,0 @@ -# This source code is part of the Biotite package and is distributed -# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further -# information. - -""" -Use this module to calculate the Solvent Accessible Surface Area (SASA) of -a protein or single atoms. -""" - -__name__ = "biotite.structure" -__author__ = "Patrick Kunzmann" -__all__ = ["sasa"] - -cimport cython -cimport numpy as np -from libc.stdlib cimport malloc, free - -import numpy as np -from biotite.rust.structure import CellList -from .filter import filter_solvent, filter_monoatomic_ions, filter_heavy -from .info.radii import vdw_radius_protor, vdw_radius_single - -ctypedef np.uint8_t np_bool -ctypedef np.int64_t int64 -ctypedef np.float32_t float32 - - -@cython.boundscheck(False) -@cython.wraparound(False) -def sasa(array, float probe_radius=1.4, np.ndarray atom_filter=None, - bint ignore_ions=True, int point_number=1000, - point_distr="Fibonacci", vdw_radii="ProtOr"): - """ - sasa(array, probe_radius=1.4, atom_filter=None, ignore_ions=True, - point_number=1000, point_distr="Fibonacci", vdw_radii="ProtOr") - - Calculate the Solvent Accessible Surface Area (SASA) of a protein. - - This function uses the Shrake-Rupley ("rolling probe") - algorithm :footcite:`Shrake1973`: - Every atom is occupied by a evenly distributed point mesh. The - points that can be reached by the "rolling probe", are surface - accessible. - - Parameters - ---------- - array : AtomArray - The protein model to calculate the SASA for. - probe_radius : float, optional - The VdW-radius of the solvent molecules. - atom_filter : ndarray, dtype=bool, optional - If this parameter is given, SASA is only calculated for the - filtered atoms. - ignore_ions : bool, optional - If true, all monoatomic ions are removed before SASA calculation. - point_number : int, optional - The number of points in the mesh occupying each atom for SASA - calculation. - The SASA calculation time is proportional to the amount of sphere points. - point_distr : str or function, optional - If a function is given, the function is used to calculate the - point distribution for the mesh (the function must take `float` - *n* as parameter and return a *(n x 3)* :class:`ndarray`). - Alternatively a string can be given to choose a built-in - distribution: - - - **Fibonacci** - Distribute points using a golden section - spiral. - - By default *Fibonacci* is used. - vdw_radii : str or ndarray, dtype=float, optional - Indicates the set of VdW radii to be used. If an `array`-length - :class:`ndarray` is given, each atom gets the radius at the - corresponding index. Radii given for atoms that are not used in - SASA calculation (e.g. solvent atoms) can have arbitrary values - (e.g. `NaN`). If instead a string is given, one of the - built-in sets is used: - - - **ProtOr** - A set, which does not require hydrogen atoms - in the model. Suitable for crystal structures. - :footcite:`Tsai1999` - - **Single** - A set, which uses a defined VdW radius for - every single atom, therefore hydrogen atoms are required - in the model (e.g. NMR elucidated structures). - Values for main group elements are taken from :footcite:`Mantina2009`, - and for relevant transition metals from the :footcite:`RDKit`. - - By default *ProtOr* is used. - - - Returns - ------- - sasa : ndarray, dtype=bool, shape=(n,) - Atom-wise SASA. `NaN` for atoms where SASA has not been - calculated - (solvent atoms, hydrogen atoms (ProtOr), atoms not in `filter`). - - References - ---------- - - .. footbibliography:: - - """ - cdef int i=0, j=0, k=0, adj_atom_i=0, rel_atom_i=0 - - cdef np.ndarray sasa_filter - cdef np.ndarray occl_filter - if atom_filter is not None: - # Filter for all atoms to calculate SASA for - sasa_filter = np.array(atom_filter, dtype=bool) - else: - sasa_filter = np.ones(len(array), dtype=bool) - # Filter for all atoms that are considered for occlusion calculation - # sasa_filter is subfilter of occlusion_filter - occl_filter = np.ones(len(array), dtype=bool) - # Remove water residues, since it is the solvent - filter = ~filter_solvent(array) - sasa_filter = sasa_filter & filter - occl_filter = occl_filter & filter - if ignore_ions: - filter = ~filter_monoatomic_ions(array) - sasa_filter = sasa_filter & filter - occl_filter = occl_filter & filter - - cdef np.ndarray sphere_points - if callable(point_distr): - sphere_points = point_distr(point_number) - elif point_distr == "Fibonacci": - sphere_points = _create_fibonacci_points(point_number) - else: - raise ValueError(f"'{point_distr}' is not a valid point distribution") - sphere_points = sphere_points.astype(np.float32) - - cdef np.ndarray radii - if isinstance(vdw_radii, np.ndarray): - radii = vdw_radii.astype(np.float32) - if len(radii) != array.array_length(): - raise ValueError( - f"Amount VdW radii ({len(radii)}) and " - f"amount of atoms ({array.array_length()}) are not equal" - ) - elif vdw_radii == "ProtOr": - filter = filter_heavy(array) - sasa_filter = sasa_filter & filter - occl_filter = occl_filter & filter - radii = np.full(len(array), np.nan, dtype=np.float32) - for i in np.arange(len(radii))[occl_filter]: - rad = vdw_radius_protor(array.res_name[i], array.atom_name[i]) - # 1.8 is default radius - radii[i] = rad if rad is not None else 1.8 - elif vdw_radii == "Single": - radii = np.full(len(array), np.nan, dtype=np.float32) - for i in np.arange(len(radii))[occl_filter]: - rad = vdw_radius_single(array.element[i]) - # 1.5 is default radius - radii[i] = rad if rad is not None else 1.8 - else: - raise KeyError(f"'{vdw_radii}' is not a valid radii set") - # Increase atom radii by probe size ("rolling probe") - radii += probe_radius - - # Memoryview for filter - # Problem with creating boolean memoryviews - # -> Type uint8 is used - cdef np_bool[:] sasa_filter_view = np.frombuffer(sasa_filter, - dtype=np.uint8) - - cdef np.ndarray occl_r = radii[occl_filter] - # Atom array containing occluding atoms - occl_array = array[occl_filter] - - # Memoryviews for coordinates of entire (main) array - # and for coordinates of occluding atom array - cdef float32[:,:] main_coord = array.coord.astype(np.float32, - copy=False) - cdef float32[:,:] occl_coord = occl_array.coord.astype(np.float32, - copy=False) - # Memoryviews for sphere points - cdef float32[:,:] sphere_coord = sphere_points - # Check if any of these arrays are empty to prevent segfault - if main_coord.shape[0] == 0 \ - or occl_coord.shape[0] == 0 \ - or sphere_coord.shape[0] == 0: - raise ValueError("Coordinates are empty") - # Memoryviews for radii of SASA and occluding atoms - # their squares and their sum of sqaures - cdef float32[:] atom_radii = radii - cdef float32[:] atom_radii_sq = radii * radii - cdef float32[:] occl_radii = occl_r - cdef float32[:] occl_radii_sq = occl_r * occl_r - # Memoryview for atomwise SASA - cdef float32[:] sasa = np.full(len(array), np.nan, dtype=np.float32) - - # Area of a sphere point on a unit sphere - cdef float32 area_per_point = 4.0 * np.pi / point_number - - # Define further statically typed variables - # that are needed for SASA calculation - cdef int n_accesible = 0 - cdef float32 radius = 0 - cdef float32 radius_sq = 0 - cdef float32 adj_radius = 0 - cdef float32 adj_radius_sq = 0 - cdef float32 dist_sq = 0 - cdef float32 point_x = 0 - cdef float32 point_y = 0 - cdef float32 point_z = 0 - cdef float32 atom_x = 0 - cdef float32 atom_y = 0 - cdef float32 atom_z = 0 - cdef float32 occl_x = 0 - cdef float32 occl_y = 0 - cdef float32 occl_z = 0 - cdef float32[:,:] relevant_occl_coord = None - - # Cell size is as large as the maximum distance, - # where two atom can intersect. - # Therefore intersecting atoms are always in the same or adjacent cell. - cell_list = CellList(occl_array, np.max(radii[occl_filter])*2) - cdef np.ndarray cell_indices - cdef int64[:,:] cell_indices_view - cdef int length - cdef int max_adj_list_length = 0 - cdef int array_length = array.array_length() - - cell_indices = cell_list.get_atoms_in_cells(array.coord) - cell_indices_view = cell_indices - max_adj_list_length = cell_indices.shape[0] - - # Later on, this array stores coordinates for actual - # occluding atoms for a certain atom to calculate the - # SASA for - # The first three indices of the second axis - # are x, y and z, the last one is the squared radius - # This list is as long as the maximal length of a list of - # adjacent atoms - relevant_occl_coord = np.zeros((max_adj_list_length, 4), - dtype=np.float32) - - # Actual SASA calculation - for i in range(array_length): - # First level: The atoms to calculate SASA for - if not sasa_filter_view[i]: - # SASA is not calculated for this atom - continue - n_accesible = point_number - atom_x = main_coord[i,0] - atom_y = main_coord[i,1] - atom_z = main_coord[i,2] - radius = atom_radii[i] - radius_sq = atom_radii_sq[i] - # Find occluding atoms from list of adjacent atoms - rel_atom_i = 0 - for j in range(max_adj_list_length): - # Remove all atoms, where the distance to the relevant atom - # is larger than the sum of the radii, - # since those atoms do not touch - # If distance is 0, it is the same atom, - # and the atom is removed from the list as well - adj_atom_i = cell_indices_view[i,j] - if adj_atom_i == -1: - # -1 means end of list - break - occl_x = occl_coord[adj_atom_i,0] - occl_y = occl_coord[adj_atom_i,1] - occl_z = occl_coord[adj_atom_i,2] - adj_radius = occl_radii[adj_atom_i] - adj_radius_sq = occl_radii_sq[adj_atom_i] - dist_sq = distance_sq(atom_x, atom_y, atom_z, - occl_x, occl_y, occl_z) - if dist_sq != 0 \ - and dist_sq < (adj_radius+radius) * (adj_radius+radius): - relevant_occl_coord[rel_atom_i,0] = occl_x - relevant_occl_coord[rel_atom_i,1] = occl_y - relevant_occl_coord[rel_atom_i,2] = occl_z - relevant_occl_coord[rel_atom_i,3] = adj_radius_sq - rel_atom_i += 1 - for j in range(sphere_coord.shape[0]): - # Second level: The sphere points for that atom - # Transform sphere point to sphere of current atom - point_x = sphere_coord[j,0] * radius + atom_x - point_y = sphere_coord[j,1] * radius + atom_y - point_z = sphere_coord[j,2] * radius + atom_z - for k in range(rel_atom_i): - # Third level: Compare point to occluding atoms - dist_sq = distance_sq(point_x, point_y, point_z, - relevant_occl_coord[k, 0], - relevant_occl_coord[k, 1], - relevant_occl_coord[k, 2]) - # Compare squared distance - # to squared radius of occluding atom - # (Radius is relevant_occl_coord[3]) - if dist_sq < relevant_occl_coord[k, 3]: - # Point is occluded - # -> Continue with next point - n_accesible -= 1 - break - sasa[i] = area_per_point * n_accesible * radius_sq - return np.asarray(sasa) - - -cdef inline float32 distance_sq(float32 x1, float32 y1, float32 z1, - float32 x2, float32 y2, float32 z2): - cdef float32 dx = x2 - x1 - cdef float32 dy = y2 - y1 - cdef float32 dz = z2 - z1 - return dx*dx + dy*dy + dz*dz - - -def _create_fibonacci_points(n): - """ - Get an array of approximately equidistant points on a sphere surface - using a golden section spiral. - """ - phi = (3 - np.sqrt(5)) * np.pi * np.arange(n) - z = np.linspace(1 - 1.0/n, 1.0/n - 1, n) - radius = np.sqrt(1 - z*z) - coords = np.zeros((n, 3)) - coords[:,0] = radius * np.cos(phi) - coords[:,1] = radius * np.sin(phi) - coords[:,2] = z - return coords \ No newline at end of file diff --git a/src/rust/lib.rs b/src/rust/lib.rs index 5c8517fb4..dfca1e0a6 100644 --- a/src/rust/lib.rs +++ b/src/rust/lib.rs @@ -2,6 +2,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; mod structure; +pub mod util; /// Add a submodule to a module and make it discoverable as package fn add_subpackage( diff --git a/src/rust/structure/celllist.rs b/src/rust/structure/celllist.rs index f5a97f6b7..cae4b0974 100644 --- a/src/rust/structure/celllist.rs +++ b/src/rust/structure/celllist.rs @@ -1,3 +1,5 @@ +use crate::structure::util::{distance_squared, extract_coord}; +use crate::util::check_signals_periodically; use numpy::ndarray::Array2; use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2}; use pyo3::exceptions; @@ -81,13 +83,13 @@ pub enum CellListResult { PAIRS, } -/// Internal enum for efficient support of both single and multiple radii +/// Enum for efficient support of both single and multiple radii /// given in :meth:`CellList.get_atoms` and :meth:`CellList.get_atoms_in_cells`. /// /// This avoids the overhead of always allocating a ``Vec`` when a single radius /// is provided, which is the common case. #[derive(Clone, Debug)] -enum Radius { +pub enum Radius { /// A single radius value applied to all coordinates. Single(T), /// Individual radius values for each coordinate. @@ -357,7 +359,7 @@ pub struct CellList { impl CellList { #[new] #[pyo3(signature = (atom_array, cell_size, periodic=false, r#box=None, selection=None))] - fn new<'py>( + fn from_python_objects<'py>( py: Python<'py>, atom_array: Bound<'py, PyAny>, cell_size: f32, @@ -438,49 +440,7 @@ impl CellList { orig_length = coord_object.call_method0("__len__")?.extract()?; } let coord: Vec<[f32; 3]> = extract_coord(coord_object.extract::>()?)?; - let coord_range: [[f32; 3]; 2] = calculate_coord_range(&coord)?; - - if let Some(ref sel) = selection { - if sel.len() != orig_length { - return Err(exceptions::PyIndexError::new_err(format!( - "Atom array has length {}, but selection has length {}", - orig_length, - sel.len() - ))); - } - } - - let mut positions: Vec<[usize; 3]> = Vec::with_capacity(coord.len()); - let mut elements: Vec = Vec::with_capacity(coord.len()); - for (i, coord) in coord.iter().enumerate() { - if let Some(ref selection) = selection { - if !selection[i % orig_length] { - continue; - } - } - positions - .push(as_usize(Self::get_cell_position(*coord, &coord_range, cell_size)).unwrap()); - elements.push(i); - } - // To get the number of cells in each dimension, use the position of the maximum coordinate - let dimensions: [usize; 3] = - Self::get_cell_position(coord_range[1], &coord_range, cell_size) - .iter() - .map(|x| (x + 1) as usize) - .collect::>() - .try_into() - .unwrap(); - let cells = CellGrid::from_elements(dimensions, positions, elements); - - Ok(CellList { - coord, - selection, - cells, - cell_size, - coord_range, - periodic_box, - orig_length, - }) + Self::new(coord, cell_size, periodic_box, selection, orig_length) } /// Reconstruct a :class:`CellList` from its serialized state. @@ -803,29 +763,56 @@ impl CellList { } impl CellList { - /// Convert a 3D coordinate to the corresponding cell position in the grid. - /// - /// Given a coordinate *[x, y, z]*, this computes the cell indices *[i, j, k]* - /// based on the coordinate range and cell size. - /// - /// Returns - /// ------- - /// The cell position *[i, j, k]*. Note that negative indices or indices - /// exceeding the grid dimensions can be returned if the coordinate is - /// outside the original coordinate range. The :class:`CellGrid` handles - /// out-of-bounds access gracefully by returning empty slices. - #[inline(always)] - fn get_cell_position( - coord: [f32; 3], - coord_range: &[[f32; 3]; 2], - cellsize: f32, - ) -> [isize; 3] { - [ - // Conversion to 'isize' automatically floors the result - ((coord[0] - coord_range[0][0]) / cellsize) as isize, - ((coord[1] - coord_range[0][1]) / cellsize) as isize, - ((coord[2] - coord_range[0][2]) / cellsize) as isize, - ] + pub fn new( + coord: Vec<[f32; 3]>, + cell_size: f32, + periodic_box: Option<[[f32; 3]; 3]>, + selection: Option>, + orig_length: usize, + ) -> PyResult { + let coord_range: [[f32; 3]; 2] = calculate_coord_range(&coord)?; + + if let Some(ref sel) = selection { + if sel.len() != orig_length { + return Err(exceptions::PyIndexError::new_err(format!( + "Atom array has length {}, but selection has length {}", + orig_length, + sel.len() + ))); + } + } + + let mut positions: Vec<[usize; 3]> = Vec::with_capacity(coord.len()); + let mut elements: Vec = Vec::with_capacity(coord.len()); + for (i, coord) in coord.iter().enumerate() { + if let Some(ref selection) = selection { + if !selection[i % orig_length] { + continue; + } + } + positions + .push(as_usize(Self::get_cell_position(*coord, &coord_range, cell_size)).unwrap()); + elements.push(i); + } + // To get the number of cells in each dimension, use the position of the maximum coordinate + let dimensions: [usize; 3] = + Self::get_cell_position(coord_range[1], &coord_range, cell_size) + .iter() + .map(|x| (x + 1) as usize) + .collect::>() + .try_into() + .unwrap(); + let cells = CellGrid::from_elements(dimensions, positions, elements); + + Ok(CellList { + coord, + selection, + cells, + cell_size, + coord_range, + periodic_box, + orig_length, + }) } /// Find atoms within a Euclidean distance from given coordinates. @@ -839,7 +826,7 @@ impl CellList { /// A vector of *[query_idx, atom_idx]* pairs where ``atom_idx`` may refer to /// periodic copies (``indices >= orig_length``). The caller is responsible for /// mapping these back to original indices using ``atom_idx % orig_length``. - fn get_atoms_from_slice( + pub fn get_atoms_from_slice( &self, py: Python<'_>, coord: &[[f32; 3]], @@ -900,7 +887,7 @@ impl CellList { /// the full coordinate array, which may include periodic copies when /// ``periodic=True``. These indices can exceed ``orig_length`` and should be /// mapped back using ``atom_idx % orig_length`` when needed. - fn get_atoms_in_cells_from_slice( + pub fn get_atoms_in_cells_from_slice( &self, py: Python<'_>, coord: &[[f32; 3]], @@ -942,16 +929,36 @@ impl CellList { } } } - // The size of `coord` is arbitrary - // -> the function may take a long time to complete - // Check for interrupts periodically (every 256 coords) - if coord_idx & 0xFF == 0 { - py.check_signals()?; - } + check_signals_periodically(py, coord_idx)?; } Ok(adjacent_atoms) } + /// Convert a 3D coordinate to the corresponding cell position in the grid. + /// + /// Given a coordinate *[x, y, z]*, this computes the cell indices *[i, j, k]* + /// based on the coordinate range and cell size. + /// + /// Returns + /// ------- + /// The cell position *[i, j, k]*. Note that negative indices or indices + /// exceeding the grid dimensions can be returned if the coordinate is + /// outside the original coordinate range. The :class:`CellGrid` handles + /// out-of-bounds access gracefully by returning empty slices. + #[inline(always)] + fn get_cell_position( + coord: [f32; 3], + coord_range: &[[f32; 3]; 2], + cellsize: f32, + ) -> [isize; 3] { + [ + // Conversion to 'isize' automatically floors the result + ((coord[0] - coord_range[0][0]) / cellsize) as isize, + ((coord[1] - coord_range[0][1]) / cellsize) as isize, + ((coord[2] - coord_range[0][2]) / cellsize) as isize, + ] + } + /// Convert Python coordinate input to a uniform Rust representation. /// /// This method handles the various input formats accepted by the Python API: @@ -1078,55 +1085,6 @@ fn calculate_coord_range(coord: &[[f32; 3]]) -> PyResult<[[f32; 3]; 2]> { Ok([min_coord, max_coord]) } -/// Extract coordinates from a 2D NumPy array into a ``Vec<[f32; 3]>``. -/// -/// Raises -/// ------ -/// PyValueError -/// If the array does not have shape *(n, 3)* or contains non-finite values (NaN or Inf). -fn extract_coord(coord_array: PyReadonlyArray2) -> PyResult> { - let coord_ndarray = coord_array.as_array(); - let shape = coord_ndarray.shape(); - if shape.len() != 2 { - return Err(exceptions::PyValueError::new_err( - "Coordinates must have shape (n,3)", - )); - } - if shape[1] != 3 { - return Err(exceptions::PyValueError::new_err( - "Coordinates must have form (x,y,z)", - )); - } - if !coord_ndarray.is_all_infinite() { - return Err(exceptions::PyValueError::new_err( - "Coordinates contain non-finite values", - )); - } - - // Try to access as contiguous slice for maximum efficiency - if let Some(slice) = coord_ndarray.as_slice() { - // Data is contiguous in memory, we can reinterpret directly - let n_coords = shape[0]; - let mut result: Vec<[f32; 3]> = Vec::with_capacity(n_coords); - for i in 0..n_coords { - let base = i * 3; - result.push([slice[base], slice[base + 1], slice[base + 2]]); - } - Ok(result) - } else { - // Fallback for non-contiguous arrays (e.g., transposed or sliced) - let mut result: Vec<[f32; 3]> = Vec::with_capacity(shape[0]); - for i in 0..shape[0] { - result.push([ - coord_ndarray[[i, 0]], - coord_ndarray[[i, 1]], - coord_ndarray[[i, 2]], - ]); - } - Ok(result) - } -} - /// Convert an ``[isize; 3]`` to ``[usize; 3]`` if all elements are non-negative. /// /// Returns ``None`` if any element is negative. @@ -1140,15 +1098,6 @@ fn as_usize(x: [isize; 3]) -> Option<[usize; 3]> { Some([x[0] as usize, x[1] as usize, x[2] as usize]) } -/// Compute the squared Euclidean distance between two 3D points. -/// -/// Using squared distance avoids the expensive square root operation, -/// which is sufficient for distance comparisons. -#[inline(always)] -fn distance_squared(a: [f32; 3], b: [f32; 3]) -> f32 { - (a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2) -} - /// Format query results as a mapping array (the default format). /// /// Creates a 2D array where each row corresponds to a query coordinate, diff --git a/src/rust/structure/mod.rs b/src/rust/structure/mod.rs index 9b6180efc..548340cb6 100644 --- a/src/rust/structure/mod.rs +++ b/src/rust/structure/mod.rs @@ -1,8 +1,10 @@ use crate::add_subpackage; use pyo3::prelude::*; -mod celllist; -mod io; +pub mod celllist; +pub mod io; +pub mod sasa; +pub mod util; use celllist::*; @@ -10,6 +12,7 @@ pub fn module<'py>(parent_module: &Bound<'py, PyModule>) -> PyResult()?; module.add_class::()?; + module.add_function(pyo3::wrap_pyfunction!(sasa::sasa, &module)?)?; add_subpackage(&module, &io::module(&module)?, "biotite.rust.structure.io")?; Ok(module) } diff --git a/src/rust/structure/sasa.rs b/src/rust/structure/sasa.rs new file mode 100644 index 000000000..468c86151 --- /dev/null +++ b/src/rust/structure/sasa.rs @@ -0,0 +1,173 @@ +use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2}; +use pyo3::exceptions; +use pyo3::prelude::*; +use std::f32::consts::PI; + +use crate::structure::celllist::{CellList, Radius}; +use crate::structure::util::{distance_squared, extract_coord}; +use crate::util::check_signals_periodically; + +/// Calculate the Solvent Accessible Surface Area (SASA) using the Shrake-Rupley algorithm. +/// +/// Parameters +/// ---------- +/// coord +/// Coordinates of all atoms. +/// radii +/// VdW radii + probe radius for each atom. +/// sphere_points +/// Points on the surface of a unit sphere. +/// sasa_filter +/// Boolean mask indicating which atoms to calculate SASA for. +/// occlusion_filter +/// Boolean mask indicating which atoms are considered for occlusion. +/// +/// Returns +/// ------- +/// sasa +/// Atom-wise SASA. NaN for atoms where SASA was not calculated. +#[pyfunction] +pub fn sasa<'py>( + py: Python<'py>, + coord: PyReadonlyArray2<'py, f32>, + radii: PyReadonlyArray1<'py, f32>, + sphere_points: PyReadonlyArray2<'py, f32>, + sasa_filter: PyReadonlyArray1<'py, bool>, + occlusion_filter: PyReadonlyArray1<'py, bool>, +) -> PyResult>> { + let coord = extract_coord(coord)?; + let radii = radii.as_slice()?; + let sphere_points = extract_coord(sphere_points)?; + let sasa_filter = sasa_filter.as_slice()?; + let occlusion_filter = occlusion_filter.as_slice()?; + + let n_atoms = coord.len(); + let n_sphere_points = sphere_points.len(); + // Area of a single point on a unit sphere + let area_per_point = 4.0 * PI / (n_sphere_points as f32); + + if radii.len() != n_atoms { + return Err(exceptions::PyIndexError::new_err(format!( + "{} radii were given for {} atoms", + radii.len(), + n_atoms + ))); + } + if sasa_filter.len() != n_atoms { + return Err(exceptions::PyValueError::new_err(format!( + "Mask has length {}, but {} atoms were provided", + sasa_filter.len(), + n_atoms + ))); + } + if occlusion_filter.len() != n_atoms { + return Err(exceptions::PyValueError::new_err(format!( + "Mask has length {}, but {} atoms were provided", + occlusion_filter.len(), + n_atoms + ))); + } + + // Filter coordinates and radii for occluding atoms + // Also track the original index for each occluding atom + let mut occlusion_coord: Vec<[f32; 3]> = Vec::new(); + let mut occlusion_radii: Vec = Vec::new(); + let mut occlusion_orig_idx: Vec = Vec::new(); + for i in 0..n_atoms { + if occlusion_filter[i] { + occlusion_coord.push(coord[i]); + occlusion_radii.push(radii[i]); + occlusion_orig_idx.push(i); + } + } + if occlusion_coord.is_empty() { + return Err(exceptions::PyValueError::new_err( + "No atoms are within the occlusion filter", + )); + } + // Pre-compute squared radii for occluding atoms + let occlusion_radii_sq: Vec = occlusion_radii.iter().map(|r| r * r).collect(); + + // Cell size is as large as the maximum distance where two atoms can intersect. + // Therefore intersecting atoms are always in the same or adjacent cell. + let max_occlusion_radius = occlusion_radii.iter().cloned().fold(0.0f32, f32::max); + let cell_size = max_occlusion_radius * 2.0; + let cell_list = CellList::new( + occlusion_coord.clone(), + cell_size, + None, + None, + occlusion_coord.len(), + )?; + + // Initialize result with NaN + let mut sasa_result: Vec = vec![f32::NAN; n_atoms]; + // Buffer for storing relevant occluding atoms, reused across iterations + // Each entry is (coord, radius_sq) + let mut relevant_occluders: Vec<([f32; 3], f32)> = Vec::new(); + for i in 0..n_atoms { + if !sasa_filter[i] { + continue; + } + + let atom_coord = coord[i]; + let atom_radius = radii[i]; + let atom_radius_sq = atom_radius * atom_radius; + + // Find adjacent atoms from cell list + // Query with a single coordinate + // Search radius must cover all potentially intersecting atoms: + // Two spheres intersect if distance < radius1 + radius2 + // So we need atoms within atom_radius + max_occlusion_radius + let adjacent_pairs = cell_list.get_atoms_from_slice( + py, + &[atom_coord], + &Radius::Single(atom_radius + max_occlusion_radius), + )?; + + // Filter adjacent atoms to only those that actually intersect with the current atom + relevant_occluders.clear(); + for [_, adjacent_idx] in &adjacent_pairs { + // Skip if this is the same atom we're computing SASA for + if occlusion_orig_idx[*adjacent_idx] == i { + continue; + } + + let occlusion_atom_coord = occlusion_coord[*adjacent_idx]; + let adjacent_radius = occlusion_radii[*adjacent_idx]; + let dist_sq = distance_squared(atom_coord, occlusion_atom_coord); + // Only include if spheres intersect + if dist_sq < (adjacent_radius + atom_radius).powi(2) { + relevant_occluders.push((occlusion_atom_coord, occlusion_radii_sq[*adjacent_idx])); + } + } + + // Count accessible sphere points + let mut n_accessible = n_sphere_points; + for sphere_point in &sphere_points { + // Transform unit sphere point to the atom's sphere + let point = [ + sphere_point[0] * atom_radius + atom_coord[0], + sphere_point[1] * atom_radius + atom_coord[1], + sphere_point[2] * atom_radius + atom_coord[2], + ]; + + // Check if this point is occluded by any neighboring atom + for (occlusion_coord, occlusion_radius_sq) in &relevant_occluders { + let dist_sq = distance_squared(point, *occlusion_coord); + if dist_sq < *occlusion_radius_sq { + // Point is inside the occluding atom's sphere + n_accessible -= 1; + break; + } + } + } + + // Calculate SASA for this atom + sasa_result[i] = area_per_point * (n_accessible as f32) * atom_radius_sq; + + check_signals_periodically(py, i)?; + } + + Ok(sasa_result.into_pyarray(py)) +} diff --git a/src/rust/structure/util.rs b/src/rust/structure/util.rs new file mode 100644 index 000000000..c6664812c --- /dev/null +++ b/src/rust/structure/util.rs @@ -0,0 +1,61 @@ +use numpy::PyReadonlyArray2; +use pyo3::exceptions; +use pyo3::prelude::*; + +/// Extract coordinates from a 2D NumPy array into a ``Vec<[f32; 3]>``. +/// +/// Raises +/// ------ +/// PyValueError +/// If the array does not have shape *(n, 3)* or contains non-finite values (NaN or Inf). +pub fn extract_coord(coord_array: PyReadonlyArray2) -> PyResult> { + let coord_ndarray = coord_array.as_array(); + let shape = coord_ndarray.shape(); + if shape.len() != 2 { + return Err(exceptions::PyValueError::new_err( + "Coordinates must have shape (n,3)", + )); + } + if shape[1] != 3 { + return Err(exceptions::PyValueError::new_err( + "Coordinates must have form (x,y,z)", + )); + } + if !coord_ndarray.is_all_infinite() { + return Err(exceptions::PyValueError::new_err( + "Coordinates contain non-finite values", + )); + } + + // Try to access as contiguous slice for maximum efficiency + if let Some(slice) = coord_ndarray.as_slice() { + // Data is contiguous in memory, we can reinterpret directly + let n_coords = shape[0]; + let mut result: Vec<[f32; 3]> = Vec::with_capacity(n_coords); + for i in 0..n_coords { + let base = i * 3; + result.push([slice[base], slice[base + 1], slice[base + 2]]); + } + Ok(result) + } else { + // Fallback for non-contiguous arrays (e.g., transposed or sliced) + let mut result: Vec<[f32; 3]> = Vec::with_capacity(shape[0]); + for i in 0..shape[0] { + result.push([ + coord_ndarray[[i, 0]], + coord_ndarray[[i, 1]], + coord_ndarray[[i, 2]], + ]); + } + Ok(result) + } +} + +/// Compute the squared Euclidean distance between two 3D points. +/// +/// Using squared distance avoids the expensive square root operation, +/// which is sufficient for distance comparisons. +#[inline(always)] +pub fn distance_squared(a: [f32; 3], b: [f32; 3]) -> f32 { + (a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2) +} diff --git a/src/rust/util.rs b/src/rust/util.rs new file mode 100644 index 000000000..26953f2ff --- /dev/null +++ b/src/rust/util.rs @@ -0,0 +1,13 @@ +use pyo3::prelude::*; + +/// Check for Python interrupts periodically to allow breaking long-running loops. +/// +/// This should be called inside loops that may take a long time. It checks for +/// signals every 256 iterations to avoid the overhead of checking on every iteration. +#[inline(always)] +pub fn check_signals_periodically(py: Python<'_>, iteration: usize) -> PyResult<()> { + if iteration & 0xFF == 0 { + py.check_signals()?; + } + Ok(()) +}