Skip to content

Add GPU acceleration module (CuPy/JAX) with CPU fallback#90

Closed
FIrgolitsch wants to merge 1 commit intomainfrom
pr-6-gpu-acceleration
Closed

Add GPU acceleration module (CuPy/JAX) with CPU fallback#90
FIrgolitsch wants to merge 1 commit intomainfrom
pr-6-gpu-acceleration

Conversation

@FIrgolitsch
Copy link
Copy Markdown
Contributor

Summary

Introduces a linumpy/gpu/ package that provides GPU-accelerated implementations of the most computationally expensive pipeline operations. All functions fall back to CPU (NumPy/SciPy) automatically when no GPU is available, so the package remains usable on CPU-only machines.

New linumpy/gpu/ Modules

Module Description
__init__.py CuPy/JAX detection, get_array_module(), device selection
array_ops.py Projections, resizing, tiling
corrections.py Galvo shift correction
cuda_env.py CUDA environment detection and GPU selection
fft_ops.py FFT and phase correlation
image_quality.py Sharpness, contrast, SNR, SSIM metrics
interpolation.py Affine transform and resampling
morphology.py Morphological operations
registration.py Phase-correlation registration

New GPU Scripts

Drop-in GPU replacements for CPU counterparts:

  • linum_aip_gpu.py
  • linum_create_masks_gpu.py
  • linum_create_mosaic_grid_3d_gpu.py
  • linum_estimate_transform_gpu.py
  • linum_fix_illumination_3d_gpu.py
  • linum_generate_mosaic_aips_gpu.py
  • linum_normalize_intensities_per_slice_gpu.py
  • linum_resample_mosaic_grid_gpu.py

Utilities

  • linum_gpu_info.py — list available GPUs and select a device
  • linum_benchmark_gpu.py — CPU vs GPU performance comparison for all pipeline stages
  • shell_scripts/fix_jax_cuda_plugin.sh — script to patch JAX/CUDA plugin compatibility

Dependencies

Adds CuPy, JAX, and BaSiCPy as optional extras in requirements.txt and setup.py.

Depends on PR #85 (thread config module).


Merge order: Merge after PR #85.

New linumpy/gpu/ package with automatic CPU fallback: array_ops.py,
corrections.py, cuda_env.py, fft_ops.py, image_quality.py,
interpolation.py, morphology.py, registration.py. GPU-accelerated
counterparts added for all major pipeline scripts (AIP, mosaic grid,
masks, illumination correction, resampling, transform estimation).
Adds linum_gpu_info.py for device selection and linum_benchmark_gpu.py
for CPU/GPU performance comparison. Updates requirements.txt and setup.py
with optional CuPy/JAX/BaSiCPy dependencies.
@FIrgolitsch
Copy link
Copy Markdown
Contributor Author

Closing in favour of #99 (recreated with squashed commits as part of the PR split plan refresh).

@FIrgolitsch FIrgolitsch closed this Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant