diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 898b98a..491ced2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.14.14 hooks: # Run the linter - id: ruff @@ -11,9 +11,15 @@ repos: # Run the formatter - id: ruff-format - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.1 hooks: - id: codespell args: [--skip, "pyproject.toml,docs/_build/*,*.egg-info"] additional_dependencies: - tomli + - repo: https://github.com/numpy/numpydoc + rev: v1.10.0 + hooks: + - id: numpydoc-validation + name: numpydoc validate + description: Run numpydoc validation across trx package diff --git a/.spin/__init__.py b/.spin/__init__.py index 85155f8..fa7a7eb 100644 --- a/.spin/__init__.py +++ b/.spin/__init__.py @@ -1 +1 @@ -# Spin commands package +"""Spin commands package.""" diff --git a/.spin/cmds.py b/.spin/cmds.py index 8b7873e..5269933 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -14,7 +14,22 @@ def run(cmd, check=True, capture=True): - """Run a shell command.""" + """Run a shell command. + + Parameters + ---------- + cmd : list of str + Command and arguments to execute. + check : bool, optional + If True, check the return code and report errors. + capture : bool, optional + If True, capture stdout and stderr. + + Returns + ------- + str or int or None + Captured stdout string, return code, or None on error. + """ result = subprocess.run(cmd, capture_output=capture, text=True, check=False) if check and result.returncode != 0: if capture: @@ -24,7 +39,13 @@ def run(cmd, check=True, capture=True): def get_remotes(): - """Get dict of remote names to URLs.""" + """Get dict of remote names to URLs. + + Returns + ------- + dict + Mapping of remote names to their fetch URLs. + """ output = run(["git", "remote", "-v"]) if not output: return {} @@ -110,11 +131,14 @@ def test(pattern, verbose, pytest_args): Additional arguments are passed directly to pytest. - Examples: - spin test # Run all tests - spin test -m memmap # Run tests matching 'memmap' - spin test -v # Verbose output - spin test -- -x --tb=short # Pass args to pytest + Parameters + ---------- + pattern : str or None + Only run tests matching this pattern (passed to pytest -k). + verbose : bool + If True, enable verbose output. + pytest_args : tuple + Additional arguments passed directly to pytest. """ cmd = ["pytest", "trx/tests"] @@ -138,9 +162,10 @@ def test(pattern, verbose, pytest_args): def lint(fix): """Run linting checks using ruff and codespell. - Examples: - spin lint # Run ruff and codespell checks - spin lint --fix # Run ruff and auto-fix issues + Parameters + ---------- + fix : bool + If True, automatically fix issues where possible. """ click.echo("Running ruff linter...") cmd = ["ruff", "check", "."] @@ -191,10 +216,12 @@ def lint(fix): def docs(clean, open_browser): """Build documentation using Sphinx. - Examples: - spin docs # Build docs - spin docs --clean # Clean and rebuild - spin docs --open # Build and open in browser + Parameters + ---------- + clean : bool + If True, clean build directory before building. + open_browser : bool + If True, open documentation in browser after building. """ import os diff --git a/docs/source/conf.py b/docs/source/conf.py index a05264e..27eaff0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,3 +1,4 @@ +"""Sphinx configuration for trx-python documentation.""" # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full diff --git a/pyproject.toml b/pyproject.toml index bd099eb..18de8f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,3 +102,38 @@ package = "trx" "Test" = [".spin/cmds.py:test", ".spin/cmds.py:lint"] "Docs" = [".spin/cmds.py:docs"] "Clean" = [".spin/cmds.py:clean"] + +[tool.numpydoc_validation] +checks = [ + "all", # report on all checks, except the below + # These we we ignore: + "GL01", # Docstring should start in the line immediately after the quotes + "GL02", # Closing quotes on own line (doesn't work on Python 3.13 anyway) + "EX01", + "EX02", # examples failed (we test them separately) + "ES01", # no extended summary + "SA01", # no see also + "YD01", # no yields section + "SA04", # no description in See Also + "PR04", # Parameter "shape (n_channels" has no type + "RT02", # The first line of the Returns section should +] +# remember to use single quotes for regex in TOML +exclude = [ # don't report on objects that match any of these regex + '\.undocumented_method$', + '\.__repr__$', + '\.__str__$', + '\.__len__$', + '\.__getitem__$', + '\.__deepcopy__$', +] +exclude_files = [ # don't process filepaths that match these regex + '^trx/tests/.*', + '^module/gui.*', + '^examples/.*', +] +override_SS05 = [ # override SS05 to allow docstrings starting with these words + '^Process ', + '^Assess ', + '^Access ', +] diff --git a/tools/update_switcher.py b/tools/update_switcher.py index 4a66988..04a04e2 100644 --- a/tools/update_switcher.py +++ b/tools/update_switcher.py @@ -15,7 +15,18 @@ def load_switcher(path): - """Load existing switcher.json or return empty list.""" + """Load existing switcher.json or return empty list. + + Parameters + ---------- + path : str or Path + Path to the switcher.json file. + + Returns + ------- + list + List of version entries from the switcher file. + """ try: with open(path, "r") as f: return json.load(f) @@ -24,14 +35,33 @@ def load_switcher(path): def save_switcher(path, versions): - """Save switcher.json with proper formatting.""" + """Save switcher.json with proper formatting. + + Parameters + ---------- + path : str or Path + Path to the switcher.json file. + versions : list + List of version entries to write. + """ with open(path, "w") as f: json.dump(versions, f, indent=4) f.write("\n") def ensure_dev_entry(versions): - """Ensure dev entry exists in versions list.""" + """Ensure dev entry exists in versions list. + + Parameters + ---------- + versions : list + List of version entries. + + Returns + ------- + list + Updated versions list with dev entry. + """ dev_exists = any(v.get("version") == "dev" for v in versions) if not dev_exists: versions.insert(0, {"name": "dev", "version": "dev", "url": f"{BASE_URL}/dev/"}) @@ -39,7 +69,18 @@ def ensure_dev_entry(versions): def ensure_stable_entry(versions): - """Ensure stable entry exists with preferred flag.""" + """Ensure stable entry exists with preferred flag. + + Parameters + ---------- + versions : list + List of version entries. + + Returns + ------- + list + Updated versions list with stable entry. + """ stable_idx = next( (i for i, v in enumerate(versions) if v.get("version") == "stable"), None ) @@ -98,7 +139,13 @@ def add_version(versions, version): def main(): - """Main entry point.""" + """Run the switcher update workflow. + + Returns + ------- + int + Exit code (0 for success). + """ parser = argparse.ArgumentParser( description="Update switcher.json for documentation version switching" ) diff --git a/trx/__init__.py b/trx/__init__.py index e69de29..047e068 100644 --- a/trx/__init__.py +++ b/trx/__init__.py @@ -0,0 +1 @@ +"""TRX file format for brain tractography data.""" diff --git a/trx/cli.py b/trx/cli.py index 69f6429..92202a6 100644 --- a/trx/cli.py +++ b/trx/cli.py @@ -109,6 +109,28 @@ def concatenate_tractograms( If the data_per_point or data_per_streamline is not the same for all tractograms, the data must be deleted first using the appropriate flags. + + Parameters + ---------- + in_tractograms : list of Path + Input tractogram files (.trk, .tck, .vtk, .fib, .dpy, .trx). + out_tractogram : Path + Output filename for the concatenated tractogram. + delete_dpv : bool, optional + Delete ``data_per_vertex`` if metadata differ across inputs. + delete_dps : bool, optional + Delete ``data_per_streamline`` if metadata differ across inputs. + delete_groups : bool, optional + Delete groups when metadata differ across inputs. + reference : Path or None, optional + Reference anatomy for tck/vtk/fib/dpy inputs. + force : bool, optional + Overwrite output if it already exists. + + Returns + ------- + None + Writes the concatenated tractogram to ``out_tractogram``. """ _check_overwrite(out_tractogram, force) @@ -185,6 +207,26 @@ def convert( Supports conversion of .tck, .trk, .fib, .vtk, .trx and .dpy files. TCK files always need a reference NIFTI file for conversion. + + Parameters + ---------- + in_tractogram : Path + Input tractogram file. + out_tractogram : Path + Output tractogram path. + reference : Path or None, optional + Reference anatomy required for some input formats. + positions_dtype : str, optional + Datatype for positions in TRX output. + offsets_dtype : str, optional + Datatype for offsets in TRX output. + force : bool, optional + Overwrite output if it already exists. + + Returns + ------- + None + Writes the converted tractogram to disk. """ _check_overwrite(out_tractogram, force) @@ -238,13 +280,27 @@ def convert_dsi( typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: - """Fix DSI-Studio TRK files for compatibility. + """Convert a DSI-Studio TRK file to TRX or TRK and fix space metadata. - This script fixes DSI-Studio TRK files (unknown space/convention) to make - them compatible with TrackVis, MI-Brain, and Dipy Horizon. + Parameters + ---------- + in_dsi_tractogram : Path + Input DSI-Studio tractogram (.trk or .trk.gz). + in_dsi_fa : Path + FA volume used as reference (.nii.gz). + out_tractogram : Path + Output tractogram path (.trx or .trk). + remove_invalid : bool, optional + Remove streamlines outside the bounding box. Defaults to False. + keep_invalid : bool, optional + Keep streamlines outside the bounding box. Defaults to False. + force : bool, optional + Overwrite output if it already exists. - [bold yellow]WARNING:[/bold yellow] This script is experimental. DSI-Studio evolves - quickly and results may vary depending on the data and DSI-Studio version. + Returns + ------- + None + Writes the converted tractogram to disk. """ _check_overwrite(out_tractogram, force) @@ -367,15 +423,48 @@ def generate( typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: - """Generate TRX file from raw data files. + """Generate a TRX file from raw data files. Create a TRX file from CSV, TXT, or NPY files by specifying positions, offsets, data_per_vertex, data_per_streamlines, groups, and data_per_group. - Each --dpv, --dps, --groups option requires FILE,DTYPE format. - Each --dpg option requires GROUP,FILE,DTYPE format. + Parameters + ---------- + reference : Path + Reference anatomy (.nii or .nii.gz). + out_tractogram : Path + Output tractogram (.trk, .tck, .vtk, .fib, .dpy, .trx). + positions : Path or None, optional + Binary file with streamline coordinates (Nx3 .npy). + offsets : Path or None, optional + Binary file with streamline offsets (.npy). + positions_csv : Path or None, optional + CSV file with flattened streamline coordinates. + space : str, optional + Coordinate space. Non-default requires Dipy. + origin : str, optional + Coordinate origin. Non-default requires Dipy. + positions_dtype : str, optional + Datatype for positions. + offsets_dtype : str, optional + Datatype for offsets. + dpv : list of str or None, optional + Data per vertex entries as FILE,DTYPE pairs. + dps : list of str or None, optional + Data per streamline entries as FILE,DTYPE pairs. + groups : list of str or None, optional + Group entries as FILE,DTYPE pairs. + dpg : list of str or None, optional + Data per group entries as GROUP,FILE,DTYPE triplets. + verify_invalid : bool, optional + Verify positions are inside bounding box (requires Dipy). + force : bool, optional + Overwrite output if it already exists. - Valid DTYPEs: (u)int8, (u)int16, (u)int32, (u)int64, float16, float32, float64, bool + Returns + ------- + None + Writes the generated tractogram to disk. """ _check_overwrite(out_tractogram, force) @@ -522,7 +611,31 @@ def manipulate_dtype( Change the data types of positions, offsets, data_per_vertex, data_per_streamline, groups, and data_per_group arrays. - Valid DTYPEs: (u)int8, (u)int16, (u)int32, (u)int64, float16, float32, float64, bool + Parameters + ---------- + in_tractogram : Path + Input TRX file. + out_tractogram : Path + Output TRX file. + positions_dtype : str or None, optional + Target dtype for positions (float16, float32, float64). + offsets_dtype : str or None, optional + Target dtype for offsets (uint32, uint64). + dpv : list of str or None, optional + Data per vertex dtype overrides as NAME,DTYPE pairs. + dps : list of str or None, optional + Data per streamline dtype overrides as NAME,DTYPE pairs. + groups : list of str or None, optional + Group dtype overrides as NAME,DTYPE pairs. + dpg : list of str or None, optional + Data per group dtype overrides as GROUP,NAME,DTYPE triplets. + force : bool, optional + Overwrite output if it already exists. + + Returns + ------- + None + Writes the dtype-converted TRX file. """ _check_overwrite(out_tractogram, force) @@ -584,12 +697,21 @@ def compare( ), ] = None, ) -> None: - """Simple comparison of tractograms by subtracting coordinates. + """Compare two tractograms and report basic differences. - Does not account for shuffling of streamlines. Simple A-B operations. + Parameters + ---------- + in_tractogram1 : Path + First tractogram file. + in_tractogram2 : Path + Second tractogram file. + reference : Path or None, optional + Reference anatomy for formats requiring it. - Differences below 1e-3 are expected for affines with large rotation/scaling. - Differences below 1e-6 are expected for isotropic data with small rotation. + Returns + ------- + None + Prints comparison summary to stdout. """ ref = str(reference) if reference else None tractogram_simple_compare([str(in_tractogram1), str(in_tractogram2)], ref) @@ -637,13 +759,27 @@ def validate( typer.Option("--force", "-f", help="Force overwriting of output files."), ] = False, ) -> None: - """Validate TRX file and remove invalid streamlines. + """Validate a tractogram and optionally clean invalid/duplicate streamlines. - Removes streamlines that are out of the volume bounding box (in voxel space, - no negative coordinates or coordinates above volume dimensions). + Parameters + ---------- + in_tractogram : Path + Input tractogram (.trk, .tck, .vtk, .fib, .dpy, .trx). + out_tractogram : Path or None, optional + Optional output tractogram with invalid streamlines removed. + remove_identical : bool, optional + Remove duplicate streamlines based on hashing precision. + precision : int, optional + Number of decimals when hashing streamline points. + reference : Path or None, optional + Reference anatomy for formats requiring it. + force : bool, optional + Overwrite output if it already exists. - Also removes streamlines with single or no points. - Use --remove-identical to remove duplicate streamlines based on precision. + Returns + ------- + None + Prints validation summary and optionally writes cleaned output. """ if out_tractogram: _check_overwrite(out_tractogram, force) @@ -684,8 +820,15 @@ def verify_header( ) -> None: """Compare spatial attributes of input files. - Compares all input files against the first one for compatibility of - spatial attributes: affine, dimensions, voxel sizes, and voxel order. + Parameters + ---------- + in_files : list of Path + Files to compare (.trk, .trx, .nii, .nii.gz). + + Returns + ------- + None + Prints compatibility results to stdout. """ verify_header_compatibility([str(f) for f in in_files]) @@ -710,8 +853,19 @@ def visualize( ) -> None: """Display tractogram and density map with bounding box. - Shows the tractogram and its density map (computed from Dipy) in - rasmm, voxmm, and vox space with its bounding box. + Parameters + ---------- + in_tractogram : Path + Input tractogram (.trk, .tck, .vtk, .fib, .dpy, .trx). + reference : Path + Reference anatomy (.nii or .nii.gz). + remove_invalid : bool, optional + Remove invalid streamlines to avoid density map crashes. + + Returns + ------- + None + Opens visualization windows when fury is available. """ tractogram_visualize_overlap( str(in_tractogram), diff --git a/trx/fetcher.py b/trx/fetcher.py index 4aa3d1b..488427d 100644 --- a/trx/fetcher.py +++ b/trx/fetcher.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""Test data management for downloading and verifying test assets.""" import hashlib import logging @@ -19,7 +20,13 @@ def get_home(): - """Set a user-writeable file-system location to put files""" + """Return a user-writeable file-system location to put files. + + Returns + ------- + str + Path to the TRX home directory. + """ if "TRX_HOME" in os.environ: trx_home = os.environ["TRX_HOME"] else: @@ -28,11 +35,16 @@ def get_home(): def get_testing_files_dict(): - """Get dictionary linking zip file to their GitHub release URL & checksums. + """Return dictionary linking zip file to their GitHub release URL and checksums. Assets are hosted under the v0.1.0 release of tee-ar-ex/trx-test-data. If URLs change, check TEST_DATA_API_URL to discover the latest asset locations. + + Returns + ------- + dict + Mapping of filenames to (url, md5, sha256) tuples. """ return { "DSI.zip": ( @@ -59,7 +71,18 @@ def get_testing_files_dict(): def md5sum(filename): - """Compute one md5 checksum for a file""" + """Compute the MD5 checksum of a file. + + Parameters + ---------- + filename : str + Path to file to hash. + + Returns + ------- + str + Hexadecimal MD5 digest. + """ h = hashlib.md5() with open(filename, "rb") as f: for chunk in iter(lambda: f.read(128 * h.block_size), b""): @@ -68,7 +91,18 @@ def md5sum(filename): def sha256sum(filename): - """Compute one sha256 checksum for a file""" + """Compute the SHA256 checksum of a file. + + Parameters + ---------- + filename : str + Path to file to hash. + + Returns + ------- + str + Hexadecimal SHA256 digest. + """ h = hashlib.sha256() with open(filename, "rb") as f: for chunk in iter(lambda: f.read(128 * h.block_size), b""): @@ -77,15 +111,18 @@ def sha256sum(filename): def fetch_data(files_dict, keys=None): # noqa: C901 - """Downloads files to folder and checks their md5 checksums + """Download files to folder and check their md5 checksums. Parameters ---------- - files_dict : dictionary + files_dict : dict For each file in `files_dict` the value should be (url, md5). The file will be downloaded from url, if the file does not already exist or if the file exists but the md5 checksum does not match. - Zip files are automatically unzipped and its content* are md5 checked. + Zip files are automatically unzipped and its contents are md5 checked. + keys : list of str or str or None, optional + Subset of keys from ``files_dict`` to download. When None, all + keys are downloaded. Raises ------ diff --git a/trx/io.py b/trx/io.py index 86c8156..102d3f3 100644 --- a/trx/io.py +++ b/trx/io.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""Unified I/O interface for tractogram file formats.""" import logging import os @@ -16,6 +17,19 @@ def get_trx_tmp_dir(): + """Return a temporary directory honoring the ``TRX_TMPDIR`` setting. + + When the ``TRX_TMPDIR`` environment variable is set to ``"use_working_dir"`` + the current working directory is used. Otherwise, the value of + ``TRX_TMPDIR`` is used directly. If the variable is not set, the system + temporary directory is used. + + Returns + ------- + tempfile.TemporaryDirectory + Context-managed temporary directory placed according to the environment + configuration. + """ if os.getenv("TRX_TMPDIR") is not None: if os.getenv("TRX_TMPDIR") == "use_working_dir": trx_tmp_dir = os.getcwd() @@ -33,6 +47,29 @@ def get_trx_tmp_dir(): def load_sft_with_reference(filepath, reference=None, bbox_check=True): + """Load a tractogram as a StatefulTractogram with an explicit reference. + + Parameters + ---------- + filepath : str + Path to the tractogram file (.trk, .tck, .fib, .vtk, .dpy). + reference : str or nibabel.Nifti1Image, optional + Reference image used for formats without embedded affine information. + Pass ``"same"`` to reuse the header embedded in .trk files. + bbox_check : bool, optional + If True, validate that streamlines lie within the reference bounding + box. Defaults to True. + + Returns + ------- + StatefulTractogram or None + Loaded tractogram. Returns ``None`` when ``dipy`` is unavailable. + + Raises + ------ + IOError + If the file format is unsupported or a required reference is missing. + """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " @@ -64,6 +101,20 @@ def load_sft_with_reference(filepath, reference=None, bbox_check=True): def load(tractogram_filename, reference): + """Load a tractogram from disk and return a TRX or StatefulTractogram. + + Parameters + ---------- + tractogram_filename : str + Path to the input tractogram. TRX directories are supported. + reference : str or nibabel.Nifti1Image + Reference image used for formats without embedded affine information. + + Returns + ------- + TrxFile or StatefulTractogram + TRX file handle for ``.trx`` inputs, otherwise a StatefulTractogram. + """ import trx.trx_file_memmap as tmm in_ext = split_name_with_gz(tractogram_filename)[1] @@ -78,6 +129,26 @@ def load(tractogram_filename, reference): def save(tractogram_obj, tractogram_filename, bbox_valid_check=False): + """Save a tractogram object to disk. + + Parameters + ---------- + tractogram_obj : TrxFile or StatefulTractogram + Tractogram to persist. Non-TRX inputs are converted to StatefulTractogram + before saving to non-TRX formats. + tractogram_filename : str + Destination file name. ``.trx`` will be saved using the TRX writer; all + other extensions are handled by ``dipy.save_tractogram``. + bbox_valid_check : bool, optional + If True, validate that streamlines lie within the reference bounding + box when saving non-TRX formats. Defaults to False. + + Returns + ------- + None + The function writes to disk and returns ``None``. Returns ``None`` + immediately when ``dipy`` is unavailable. + """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " diff --git a/trx/streamlines_ops.py b/trx/streamlines_ops.py index 007aaba..bf7a87a 100644 --- a/trx/streamlines_ops.py +++ b/trx/streamlines_ops.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""Set operations on streamlines with precision-based matching.""" from functools import reduce import itertools @@ -10,39 +11,78 @@ def intersection(left, right): - """Intersection of two streamlines dict (see hash_streamlines)""" + """Return the intersection of two streamline hash dictionaries. + + Parameters + ---------- + left : dict + Hash dictionary returned by :func:`hash_streamlines`. + right : dict + Hash dictionary returned by :func:`hash_streamlines`. + + Returns + ------- + dict + Dictionary containing only keys present in both inputs. + """ return {k: v for k, v in left.items() if k in right} def difference(left, right): - """Difference of two streamlines dict (see hash_streamlines)""" + """Return the difference of two streamline hash dictionaries. + + Parameters + ---------- + left : dict + Hash dictionary returned by :func:`hash_streamlines`. + right : dict + Hash dictionary returned by :func:`hash_streamlines`. + + Returns + ------- + dict + Dictionary containing keys present in ``left`` but not in ``right``. + """ return {k: v for k, v in left.items() if k not in right} def union(left, right): - """Union of two streamlines dict (see hash_streamlines)""" + """Return the union of two streamline hash dictionaries. + + Parameters + ---------- + left : dict + Hash dictionary returned by :func:`hash_streamlines`. + right : dict + Hash dictionary returned by :func:`hash_streamlines`. + + Returns + ------- + dict + Dictionary containing all keys from both inputs. Values from ``left`` + overwrite those from ``right`` when keys overlap. + """ result = right.copy() result.update(left) return result def get_streamline_key(streamline, precision=None): - """Produces a key using a hash from a streamline using a few points only and - the desired precision + """Produce a hash key from a streamline using a few points. Parameters ---------- - streamlines: ndarray - A single streamline (N,3) - precision: int, optional + streamline : ndarray + A single streamline (N, 3). + precision : int, optional The number of decimals to keep when hashing the points of the streamlines. Allows a soft comparison of streamlines. If None, no rounding is performed. Returns ------- - Value of the hash of the first/last MIN_NB_POINTS points of the streamline. - + bytes + Hash of the first/last MIN_NB_POINTS points of the streamline. """ # Use just a few data points as hash key. I could use all the data of @@ -62,34 +102,34 @@ def get_streamline_key(streamline, precision=None): def hash_streamlines(streamlines, start_index=0, precision=None): - """Produces a dict from streamlines + """Produce a dict from streamlines. - Produces a dict from streamlines by using the points as keys and the + Produce a dict from streamlines by using the points as keys and the indices of the streamlines as values. Parameters ---------- - streamlines: list of ndarray + streamlines : list of ndarray The list of streamlines used to produce the dict. - start_index: int, optional + start_index : int, optional The index of the first streamline. 0 by default. - precision: int, optional + precision : int, optional The number of decimals to keep when hashing the points of the streamlines. Allows a soft comparison of streamlines. If None, no rounding is performed. Returns ------- - A dict where the keys are streamline points and the values are indices - starting at start_index. - + dict + A dict where the keys are streamline points and the values are + indices starting at start_index. """ keys = [get_streamline_key(s, precision) for s in streamlines] return {k: i for i, k in enumerate(keys, start_index)} def perform_streamlines_operation(operation, streamlines, precision=0): - """Performs an operation on a list of list of streamlines + """Perform an operation on a list of list of streamlines. Given a list of list of streamlines, this function applies the operation to the first two lists of streamlines. The result in then used recursively @@ -101,24 +141,23 @@ def perform_streamlines_operation(operation, streamlines, precision=0): Parameters ---------- - operation: callable - A callable that takes two streamlines dicts as inputs and preduces a + operation : callable + A callable that takes two streamlines dicts as inputs and produces a new streamline dict. - streamlines: list of list of streamlines + streamlines : list of list of streamlines The streamlines used in the operation. - precision: int, optional + precision : int, optional The number of decimals to keep when hashing the points of the streamlines. Allows a soft comparison of streamlines. If None, no rounding is performed. Returns ------- - streamlines: list of `nib.streamline.ArraySequence` + streamlines : list of `nib.streamline.ArraySequence` The streamlines obtained after performing the operation on all the input streamlines. - indices: np.ndarray + indices : np.ndarray The indices of the streamlines that are used in the output. - """ # Hash the streamlines using the desired precision. diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index 0c0e7cb..2bfd0bc 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""Core TrxFile class with memory-mapped data access.""" from copy import deepcopy import json @@ -43,7 +44,7 @@ def _get_dtype_little_endian(dtype: Union[np.dtype, str, type]) -> np.dtype: Parameters ---------- dtype : np.dtype, str, or type - Input dtype specification (e.g., np.float32, 'float32', '>f4') + Input dtype specification (e.g., np.float32, 'float32', '>f4'). Returns ------- @@ -68,7 +69,7 @@ def _ensure_little_endian(arr: np.ndarray) -> np.ndarray: Parameters ---------- arr : np.ndarray - Input array + Input array. Returns ------- @@ -108,6 +109,18 @@ def _append_last_offsets(nib_offsets: np.ndarray, nb_vertices: int) -> np.ndarra """ def is_sorted(a): + """Return True if array is sorted non-decreasing. + + Parameters + ---------- + a : np.ndarray + 1D array of numeric offsets. + + Returns + ------- + bool + True when ``a`` is monotonically non-decreasing. + """ return np.all(a[:-1] <= a[1:]) if not is_sorted(nib_offsets): @@ -461,12 +474,34 @@ def load_from_directory(directory: str) -> Type["TrxFile"]: def _filter_empty_trx_files(trx_list: List["TrxFile"]) -> List["TrxFile"]: - """Remove empty TrxFiles from the list.""" + """Remove empty TrxFiles from the list. + + Parameters + ---------- + trx_list : list of TrxFile class instances + Collection of tractograms to filter. + + Returns + ------- + list of TrxFile class instances + Only entries containing at least one streamline. + """ return [curr_trx for curr_trx in trx_list if curr_trx.header["NB_STREAMLINES"] > 0] def _get_all_data_keys(trx_list: List["TrxFile"]) -> Tuple[set, set]: - """Get all dps and dpv keys from the TrxFile list.""" + """Get all dps and dpv keys from the TrxFile list. + + Parameters + ---------- + trx_list : list of TrxFile class instances + Collection of tractograms. + + Returns + ------- + tuple of set + Sets of `data_per_streamline` keys and `data_per_vertex` keys. + """ all_dps = [] all_dpv = [] for curr_trx in trx_list: @@ -476,7 +511,18 @@ def _get_all_data_keys(trx_list: List["TrxFile"]) -> Tuple[set, set]: def _check_space_attributes(trx_list: List["TrxFile"]) -> None: - """Verify that space attributes are consistent across TrxFiles.""" + """Verify that space attributes are consistent across TrxFiles. + + Parameters + ---------- + trx_list : list of TrxFile + Tractograms to compare for affine and dimension consistency. + + Raises + ------ + ValueError + If voxel-to-RASMM matrices or dimensions differ. + """ ref_trx = trx_list[0] for curr_trx in trx_list[1:]: if not np.allclose( @@ -490,7 +536,24 @@ def _check_space_attributes(trx_list: List["TrxFile"]) -> None: def _verify_dpv_coherence( trx_list: List["TrxFile"], all_dpv: set, ref_trx: "TrxFile", delete_dpv: bool ) -> None: - """Verify dpv coherence across TrxFiles.""" + """Verify dpv coherence across TrxFiles. + + Parameters + ---------- + trx_list : list of TrxFile class instances + Tractograms being concatenated. + all_dpv : set + Union of `data_per_vertex` keys across tractograms. + ref_trx : TrxFile class instance + Reference tractogram for dtype/key checks. + delete_dpv : bool + Drop mismatched dpv keys instead of raising when True. + + Raises + ------ + ValueError + If dpv keys or dtypes differ and `delete_dpv` is False. + """ for curr_trx in trx_list: for key in all_dpv: if ( @@ -516,7 +579,24 @@ def _verify_dpv_coherence( def _verify_dps_coherence( trx_list: List["TrxFile"], all_dps: set, ref_trx: "TrxFile", delete_dps: bool ) -> None: - """Verify dps coherence across TrxFiles.""" + """Verify dps coherence across TrxFiles. + + Parameters + ---------- + trx_list : list of TrxFile class instances + Tractograms being concatenated. + all_dps : set + Union of data_per_streamline keys across tractograms. + ref_trx : TrxFile class instance + Reference tractogram for dtype/key checks. + delete_dps : bool + Drop mismatched dps keys instead of raising when True. + + Raises + ------ + ValueError + If dps keys or dtypes differ and `delete_dps` is False. + """ for curr_trx in trx_list: for key in all_dps: if ( @@ -540,7 +620,18 @@ def _verify_dps_coherence( def _compute_groups_info(trx_list: List["TrxFile"]) -> Tuple[dict, dict]: - """Compute group length and dtype information.""" + """Compute group length and dtype information. + + Parameters + ---------- + trx_list : list of TrxFile class instances + Tractograms being concatenated. + + Returns + ------- + tuple of dict + (group lengths, group dtypes) keyed by group name. + """ all_groups_len = {} all_groups_dtype = {} @@ -569,7 +660,26 @@ def _create_new_trx_for_concatenation( delete_dpv: bool, delete_groups: bool, ) -> "TrxFile": - """Create a new TrxFile for concatenation.""" + """Create a new TrxFile for concatenation. + + Parameters + ---------- + trx_list : list of TrxFile class instances + Input tractograms to concatenate. + ref_trx : TrxFile class instance + Reference tractogram for header/dtype template. + delete_dps : bool + Drop `data_per_streamline` keys not shared. + delete_dpv : bool + Drop `data_per_vertex` keys not shared. + delete_groups : bool + Drop groups when metadata differ. + + Returns + ------- + TrxFile + Empty TRX ready to receive concatenated data. + """ nb_vertices = 0 nb_streamlines = 0 for curr_trx in trx_list: @@ -597,7 +707,21 @@ def _setup_groups_for_concatenation( all_groups_dtype: dict, delete_groups: bool, ) -> None: - """Setup groups in the new TrxFile for concatenation.""" + """Setup groups in the new TrxFile for concatenation. + + Parameters + ---------- + new_trx : TrxFile class instance + Destination tractogram. + trx_list : list of TrxFile class instances + Source tractograms. + all_groups_len : dict + Mapping of group name to total length. + all_groups_dtype : dict + Mapping of group name to dtype. + delete_groups : bool + If True, skip creating group arrays. + """ if delete_groups: return @@ -757,7 +881,20 @@ def zip_from_folder( class TrxFile: - """Core class of the TrxFile""" + """Core class of the TrxFile. + + Parameters + ---------- + nb_vertices : int, optional + The number of vertices to use in the new TrxFile. + nb_streamlines : int, optional + The number of streamlines in the new TrxFile. + init_as : TrxFile class instance, optional + A TrxFile to use as reference. + + reference : str, dict, Nifti1Image, TrkFile, or Nifti1Header, optional + A Nifti or Trk file/obj to use as reference. + """ header: dict streamlines: Type[ArraySequence] @@ -905,6 +1042,18 @@ def __getitem__(self, key) -> Any: return self.select(key, keep_group=False) def __deepcopy__(self) -> Type["TrxFile"]: + """Return a deep copy of the TrxFile. + + Parameters + ---------- + self + TrxFile class instance. + + Returns + ------- + TrxFile class instance + Deep-copied instance. + """ return self.deepcopy() def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 @@ -1527,6 +1676,20 @@ def get_dtype_dict(self): return dtype_dict def append(self, obj, extra_buffer: int = 0) -> None: + """Append another tractogram-like object to this TRX. + + Parameters + ---------- + obj : TrxFile or Tractogram or StatefulTractogram class instance + Object whose streamlines and associated data will be appended. + extra_buffer : int, optional + Additional preallocation buffer for streamlines (in count). + + Returns + ------- + None + Mutates the current TrxFile in-place. + """ curr_dtype_dict = self.get_dtype_dict() if dipy_available: from dipy.io.stateful_tractogram import StatefulTractogram @@ -1762,7 +1925,21 @@ def from_lazy_tractogram( @staticmethod def from_sft(sft, dtype_dict=None): - """Generate a valid TrxFile from a StatefulTractogram""" + """Generate a TrxFile from a StatefulTractogram. + + Parameters + ---------- + sft : StatefulTractogram class instance + Input tractogram. + dtype_dict : dict or None, optional + Mapping of target dtypes for positions, offsets, dpv, and dps. When + None, uses ``sft.dtype_dict`` or sensible defaults. + + Returns + ------- + TrxFile + TRX representation of the StatefulTractogram. + """ if dtype_dict is None: dtype_dict = {} @@ -1854,7 +2031,22 @@ def from_tractogram( reference, dtype_dict=None, ): - """Generate a valid TrxFile from a Nibabel Tractogram""" + """Generate a TrxFile from a nibabel Tractogram. + + Parameters + ---------- + tractogram : nibabel.streamlines.Tractogram class instance + Input tractogram to convert. + reference : object + Reference anatomy used to populate header fields. + dtype_dict : dict or None, optional + Mapping of target dtypes for positions, offsets, dpv, and dps. + + Returns + ------- + TrxFile class instance + TRX representation of the tractogram. + """ if dtype_dict is None: dtype_dict = { "positions": np.float32, @@ -1929,7 +2121,18 @@ def from_tractogram( return trx def to_tractogram(self, resize=False): - """Convert a TrxFile to a nibabel Tractogram (in RAM)""" + """Convert this TrxFile to a nibabel Tractogram. + + Parameters + ---------- + resize : bool, optional + If True, resize to actual data length before conversion. + + Returns + ------- + nibabel.streamlines.Tractogram class instance + Tractogram containing streamlines and metadata. + """ if resize: self.resize() @@ -1977,7 +2180,18 @@ def to_memory(self, resize: bool = False) -> Type["TrxFile"]: return trx_obj def to_sft(self, resize=False): - """Convert a TrxFile to a valid StatefulTractogram (in RAM)""" + """Convert this TrxFile to a StatefulTractogram. + + Parameters + ---------- + resize : bool, optional + If True, resize to actual data length before conversion. + + Returns + ------- + StatefulTractogram class instance or None + StatefulTractogram object, or None if dipy is unavailable. + """ try: from dipy.io.stateful_tractogram import Space, StatefulTractogram except ImportError: @@ -2009,7 +2223,13 @@ def to_sft(self, resize=False): return sft def close(self) -> None: - """Cleanup on-disk temporary folder and initialize an empty TrxFile""" + """Cleanup on-disk temporary folder and memmaps. + + Returns + ------- + None + Releases file handles and removes temporary storage. + """ if self._uncompressed_folder_handle is not None: close_or_delete_mmap(self.streamlines) diff --git a/trx/utils.py b/trx/utils.py index 02f0869..885f988 100644 --- a/trx/utils.py +++ b/trx/utils.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""Utility functions for reference handling, coordinate flips, and file operations.""" import logging import os @@ -37,19 +38,21 @@ def close_or_delete_mmap(obj): def split_name_with_gz(filename): - """ - Returns the clean basename and extension of a file. - Means that this correctly manages the ".nii.gz" extensions. + """Return the clean basename and extension of a file. + + Correctly manages the ".nii.gz" extensions. Parameters ---------- - filename: str - The filename to clean + filename : str + The filename to clean. Returns ------- - base, ext : tuple(str, str) - Clean basename and the full extension + base : str + Clean basename. + ext : str + The full extension. """ base, ext = os.path.splitext(filename) @@ -65,20 +68,23 @@ def split_name_with_gz(filename): def get_reference_info_wrapper(reference): # noqa: C901 - """Will compare the spatial attribute of 2 references. + """Extract spatial attributes from a reference object. Parameters ---------- - reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or - trk.header (dict), TrxFile or trx.header (dict) + reference : str or dict or Nifti1Image or TrkFile or Nifti1Header or TrxFile Reference that provides the spatial attribute. + Returns ------- - output : tuple - - affine ndarray (4,4), np.float32, transformation of VOX to RASMM - - dimensions ndarray (3,), int16, volume shape for each axis - - voxel_sizes ndarray (3,), float32, size of voxel for each axis - - voxel_order, string, Typically 'RAS' or 'LPS' + affine : ndarray (4, 4) + Transformation of VOX to RASMM, np.float32. + dimensions : ndarray (3,) + Volume shape for each axis, int16. + voxel_sizes : ndarray (3,) + Size of voxel for each axis, float32. + voxel_order : str + Typically 'RAS' or 'LPS'. """ from trx import trx_file_memmap @@ -158,7 +164,7 @@ def get_reference_info_wrapper(reference): # noqa: C901 def is_header_compatible(reference_1, reference_2): - """Will compare the spatial attribute of 2 references. + """Compare the spatial attributes of 2 references. Parameters ---------- @@ -168,10 +174,11 @@ def is_header_compatible(reference_1, reference_2): reference_2 : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict) Reference that provides the spatial attribute. + Returns ------- - output : bool - Does all the spatial attribute match + bool + Whether all the spatial attributes match. """ affine_1, dimensions_1, voxel_sizes_1, voxel_order_1 = get_reference_info_wrapper( @@ -202,17 +209,19 @@ def is_header_compatible(reference_1, reference_2): def get_axis_shift_vector(flip_axes): - """ + """Return a shift vector for the given axes. + Parameters ---------- flip_axes : list of str String containing the axis to flip. - Possible values are 'x', 'y', 'z' + Possible values are 'x', 'y', 'z'. + Returns ------- - flip_vector : np.ndarray (3,) - Vector containing the axis to flip. - Possible values are -1, 1 + shift_vector : np.ndarray (3,) + Vector containing the axis to shift. + Possible values are -1, 0. """ shift_vector = np.zeros(3) if "x" in flip_axes: @@ -226,17 +235,19 @@ def get_axis_shift_vector(flip_axes): def get_axis_flip_vector(flip_axes): - """ + """Return a flip vector for the given axes. + Parameters ---------- flip_axes : list of str String containing the axis to flip. - Possible values are 'x', 'y', 'z' + Possible values are 'x', 'y', 'z'. + Returns ------- flip_vector : np.ndarray (3,) Vector containing the axis to flip. - Possible values are -1, 1 + Possible values are -1, 1. """ flip_vector = np.ones(3) if "x" in flip_axes: @@ -250,18 +261,20 @@ def get_axis_flip_vector(flip_axes): def get_shift_vector(sft): - """ + """Return the shift vector for flipping a tractogram. + When flipping a tractogram the shift vector is used to change the origin of the grid from the corner to the center of the grid. Parameters ---------- sft : StatefulTractogram - StatefulTractogram object + StatefulTractogram object. + Returns ------- shift_vector : ndarray - Shift vector to apply to the streamlines + Shift vector to apply to the streamlines. """ dims = sft.space_attributes[1] shift_vector = -1.0 * (np.array(dims) / 2.0) @@ -270,21 +283,23 @@ def get_shift_vector(sft): def flip_sft(sft, flip_axes): - """Flip the streamlines in the StatefulTractogram according to the - flip_axes. Uses the spatial information to flip according to the center + """Flip the streamlines in a StatefulTractogram. + + Use the spatial information to flip according to the center of the grid. Parameters ---------- sft : StatefulTractogram - StatefulTractogram to flip + StatefulTractogram to flip. flip_axes : list of str Axes to flip. - Possible values are 'x', 'y', 'z' + Possible values are 'x', 'y', 'z'. + Returns ------- sft : StatefulTractogram - StatefulTractogram with flipped axes + StatefulTractogram with flipped axes. """ if not dipy_available: logging.error( @@ -321,6 +336,7 @@ def load_matrix_in_any_format(filepath): ---------- filepath : str Path to the matrix file. + Returns ------- matrix : numpy.ndarray @@ -346,10 +362,13 @@ def get_reverse_enum(space_str, origin_str): String representing the space. origin_str : str String representing the origin. + Returns ------- - output : str - Space and Origin as Enums. + space : Space + Space enum value. + origin : Origin + Origin enum value. """ if not dipy_available: logging.error( @@ -411,6 +430,23 @@ def convert_data_dict_to_tractogram(data): def append_generator_to_dict(gen, data): + """Append items yielded by a tractogram generator into data dict. + + Parameters + ---------- + gen : TractogramItem class instance or np.ndarray + Item produced by a tractogram generator. Structured entries include + per-point and per-streamline metadata. + data : dict + Accumulator containing ``strs`` (positions), ``dpv`` and ``dps`` + dictionaries that will be extended in-place. + + Returns + ------- + None + The function mutates ``data`` and returns ``None``. + """ + if isinstance(gen, TractogramItem): data["strs"].append(gen.streamline.tolist()) for key in gen.data_for_points: @@ -426,8 +462,7 @@ def append_generator_to_dict(gen, data): def verify_trx_dtype(trx, dict_dtype): # noqa: C901 - """Verify if the dtype of the data in the trx is the same as the one in - the dict. + """Verify that data dtypes in the trx match the given dict. Parameters ---------- @@ -435,9 +470,10 @@ def verify_trx_dtype(trx, dict_dtype): # noqa: C901 Tractogram to verify. dict_dtype : dict Dictionary containing all elements dtype to verify. + Returns ------- - output : bool + bool True if the dtype is the same, False otherwise. """ identical = True diff --git a/trx/viz.py b/trx/viz.py index c995a84..03faf68 100644 --- a/trx/viz.py +++ b/trx/viz.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""Optional 3D visualization using FURY/VTK.""" import itertools import logging @@ -19,6 +20,26 @@ def display( volume, volume_affine=None, streamlines=None, title="FURY", display_bounds=True ): + """Display a volume with optional streamlines using fury. + + Parameters + ---------- + volume : np.ndarray + 3D volume to display. + volume_affine : np.ndarray or None, optional + Affine matrix for the volume; None assumes identity. + streamlines : sequence or None, optional + Streamlines to render as lines. + title : str, optional + Window title. + display_bounds : bool, optional + If True, draw bounding box and coordinate annotations. + + Returns + ------- + None + Opens an interactive visualization window when fury is available. + """ if not fury_available: logging.error( "Fury library is missing, visualization functions are not available." diff --git a/trx/workflows.py b/trx/workflows.py index c4c1579..2efbd1f 100644 --- a/trx/workflows.py +++ b/trx/workflows.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""High-level processing workflows for tractogram operations.""" from copy import deepcopy import csv @@ -41,6 +42,26 @@ def convert_dsi_studio( remove_invalid=True, keep_invalid=False, ): + """Convert a DSI-Studio TRK file to TRX, fixing space metadata. + + Parameters + ---------- + in_dsi_tractogram : str + Input DSI-Studio TRK path (optionally .trk.gz). + in_dsi_fa : str + FA image (.nii.gz) used as reference anatomy. + out_tractogram : str + Destination tractogram path; ``.trx`` will be written using TRX writer. + remove_invalid : bool, optional + Remove streamlines falling outside the bounding box. Defaults to True. + keep_invalid : bool, optional + Keep invalid streamlines even if outside bounding box. Defaults to False. + + Returns + ------- + None + Writes the converted tractogram to disk. + """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return None @@ -94,6 +115,26 @@ def convert_tractogram( # noqa: C901 pos_dtype="float32", offsets_dtype="uint32", ): + """Convert tractograms between formats with dtype control. + + Parameters + ---------- + in_tractogram : str + Input tractogram path. + out_tractogram : str + Output tractogram path. + reference : str + Reference anatomy required for formats without header affine. + pos_dtype : str, optional + Datatype for positions in TRX output. + offsets_dtype : str, optional + Datatype for offsets in TRX output. + + Returns + ------- + None + Writes the converted tractogram to disk. + """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return None @@ -134,6 +175,20 @@ def convert_tractogram( # noqa: C901 def tractogram_simple_compare(in_tractograms, reference): + """Compare tractograms against a reference and return a summary diff. + + Parameters + ---------- + in_tractograms : list of str + Paths to tractograms to compare. + reference : str + Reference tractogram path. + + Returns + ------- + dict + Dictionary capturing differences across tractograms. + """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return @@ -186,6 +241,18 @@ def tractogram_simple_compare(in_tractograms, reference): def verify_header_compatibility(in_files): + """Verify that multiple tractogram headers are mutually compatible. + + Parameters + ---------- + in_files : list of str + Paths to tractogram or NIfTI files to compare. + + Returns + ------- + None + Prints compatibility results to stdout. + """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return @@ -207,6 +274,22 @@ def verify_header_compatibility(in_files): def tractogram_visualize_overlap(in_tractogram, reference, remove_invalid=True): + """Visualize overlap between tractogram density maps in different spaces. + + Parameters + ---------- + in_tractogram : str + Input tractogram path. + reference : str + Reference anatomy (.nii or .nii.gz). + remove_invalid : bool, optional + Remove streamlines outside bounding box before visualization. + + Returns + ------- + None + Opens interactive windows when fury is available. + """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return None @@ -272,6 +355,26 @@ def validate_tractogram( remove_identical_streamlines=True, precision=1, ): + """Validate a tractogram and optionally remove invalid/duplicate streamlines. + + Parameters + ---------- + in_tractogram : str + Input tractogram path. + reference : str + Reference anatomy for formats requiring it. + out_tractogram : str or None + Optional output path to save the cleaned tractogram. + remove_identical_streamlines : bool, optional + Remove duplicate streamlines based on hashing precision. + precision : int, optional + Number of decimals when hashing streamline points. + + Returns + ------- + None + Prints warnings and optionally writes a cleaned tractogram. + """ if not dipy_available: logging.error("Dipy library is missing, scripts are not available.") return None @@ -349,7 +452,18 @@ def validate_tractogram( def _load_streamlines_from_csv(positions_csv): - """Load streamlines from CSV file.""" + """Load streamlines from a CSV file. + + Parameters + ---------- + positions_csv : str + Path to CSV containing flattened coordinates. + + Returns + ------- + nibabel.streamlines.ArraySequence class instance + Streamlines reconstructed from the CSV rows. + """ with open(positions_csv, newline="") as f: reader = csv.reader(f) data = list(reader) @@ -358,7 +472,20 @@ def _load_streamlines_from_csv(positions_csv): def _load_streamlines_from_arrays(positions, offsets): - """Load streamlines from position and offset arrays.""" + """Load streamlines from position and offset arrays. + + Parameters + ---------- + positions : str + Path to positions array (.npy or text) shaped (N, 3). + offsets : str + Path to offsets array marking streamline boundaries. + + Returns + ------- + tuple + (ArraySequence, np.ndarray) of streamlines and offsets. + """ positions = load_matrix_in_any_format(positions) offsets = load_matrix_in_any_format(offsets) lengths = tmm._compute_lengths(offsets) @@ -372,7 +499,28 @@ def _load_streamlines_from_arrays(positions, offsets): def _apply_spatial_transforms( streamlines, reference, space_str, origin_str, verify_invalid, offsets ): - """Apply spatial transforms and verify streamlines.""" + """Apply spatial transforms and optionally remove invalid streamlines. + + Parameters + ---------- + streamlines : ArraySequence class instance + Streamlines to transform. + reference : str + Reference anatomy used for space/origin. + space_str : str + Desired space (e.g., \"rasmm\"). + origin_str : str + Desired origin (e.g., \"nifti\"). + verify_invalid : bool + Remove streamlines outside bounding box when True. + offsets : np.ndarray + Offsets array to preserve after transforms. + + Returns + ------- + ArraySequence class instance or None + Transformed streamlines, or None if dipy is unavailable. + """ if not dipy_available: logging.error( "Dipy library is missing, advanced options " @@ -398,7 +546,17 @@ def _apply_spatial_transforms( def _write_header(tmp_dir_name, reference, streamlines): - """Write header file.""" + """Write TRX header file to a temporary directory. + + Parameters + ---------- + tmp_dir_name : str + Temporary directory where header.json is written. + reference : str + Reference anatomy used to derive affine and dimensions. + streamlines : ArraySequence class instance + Streamlines whose counts populate the header. + """ affine, dimensions, _, _ = get_reference_info_wrapper(reference) header = { "DIMENSIONS": dimensions.tolist(), @@ -415,7 +573,19 @@ def _write_header(tmp_dir_name, reference, streamlines): def _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, offsets_dtype): - """Write streamline position and offset data.""" + """Write streamline position and offset data. + + Parameters + ---------- + tmp_dir_name : str + Temporary directory to store binary arrays. + streamlines : ArraySequence class instance + Streamlines to serialize. + positions_dtype : str + Datatype for positions array. + offsets_dtype : str + Datatype for offsets array. + """ curr_filename = os.path.join(tmp_dir_name, "positions.3.{}".format(positions_dtype)) positions = streamlines._data.astype(positions_dtype) tmm._ensure_little_endian(positions).tofile(curr_filename) @@ -426,12 +596,40 @@ def _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, offsets_d def _normalize_dtype(dtype_str): - """Normalize dtype string format.""" + """Normalize dtype string format for file naming. + + Parameters + ---------- + dtype_str : str + Input dtype string (e.g., \"bool\", \"float32\"). + + Returns + ------- + str + Normalized dtype string where ``bool`` is mapped to ``bit``. + """ return "bit" if dtype_str == "bool" else dtype_str def _write_data_array(tmp_dir_name, subdir_name, args, is_dpg=False): - """Write data array to file.""" + """Write a data array (dpv/dps/group/dpg) to disk. + + Parameters + ---------- + tmp_dir_name : str + Base temporary directory. + subdir_name : str + Subdirectory name (dpv, dps, groups, dpg). + args : tuple + Tuple describing the array path and dtype (and group when dpg). + is_dpg : bool, optional + True when writing data_per_group arrays. + + Returns + ------- + None + Writes the array to disk. + """ if is_dpg: os.makedirs(os.path.join(tmp_dir_name, "dpg", args[0]), exist_ok=True) curr_arr = load_matrix_in_any_format(args[1]).astype(args[2]) @@ -481,7 +679,44 @@ def generate_trx_from_scratch( # noqa: C901 groups=None, dpg=None, ): - """Generate TRX file from scratch using various input formats.""" + """Generate TRX file from scratch using various input formats. + + Parameters + ---------- + reference : str + Reference anatomy used to set affine and dimensions. + out_tractogram : str + Output TRX filename. + positions_csv : str or bool, optional + CSV file containing streamline coordinates; False to disable. + positions : str or bool, optional + Binary positions array file; False to disable. + offsets : str or bool, optional + Offsets array file; False to disable. + positions_dtype : str, optional + Datatype for positions. + offsets_dtype : str, optional + Datatype for offsets. + space_str : str, optional + Desired space for generated streamlines. + origin_str : str, optional + Desired origin for generated streamlines. + verify_invalid : bool, optional + Remove invalid streamlines when True. + dpv : list or None, optional + Data per vertex definitions. + dps : list or None, optional + Data per streamline definitions. + groups : list or None, optional + Group definitions. + dpg : list or None, optional + Data per group definitions. + + Returns + ------- + None + Writes the generated TRX file to disk. + """ if dpv is None: dpv = [] if dps is None: @@ -536,6 +771,22 @@ def generate_trx_from_scratch( # noqa: C901 def manipulate_trx_datatype(in_filename, out_filename, dict_dtype): # noqa: C901 + """Change dtype of positions, offsets, dpv, dps, dpg, and groups in a TRX. + + Parameters + ---------- + in_filename : str + Input TRX file path. + out_filename : str + Output TRX file path. + dict_dtype : dict + Mapping describing target dtypes for each data category. + + Returns + ------- + None + Writes the converted TRX to ``out_filename``. + """ trx = tmm.load(in_filename) # For each key in dict_dtype, we create a new memmap with the new dtype