Skip to content

feat: GPU acceleration module and scripts#99

Open
FIrgolitsch wants to merge 5 commits intopr-a-build-toolingfrom
pr-f-gpu-acceleration
Open

feat: GPU acceleration module and scripts#99
FIrgolitsch wants to merge 5 commits intopr-a-build-toolingfrom
pr-f-gpu-acceleration

Conversation

@FIrgolitsch
Copy link
Copy Markdown
Contributor

@FIrgolitsch FIrgolitsch commented Apr 1, 2026

PR #99 — GPU Acceleration (unified CPU/GPU scripts)

Adds the linumpy/gpu/ module (CuPy + JAX primitives) and wires the existing CPU scripts in the repo to pick up an optional GPU path via --use_gpu. Stand-alone *_gpu.py script variants are intentionally removed in this PR (replaced by the unified CPU scripts in PR #97).

New library module linumpy/gpu/

  • array_ops.py — cupy/jax array primitives and CPU fallbacks
  • corrections.py — GPU illumination / flat-field corrections
  • cuda_env.py — CUDA environment detection (CUDA version, CuPy availability, JAX plugin status) + device selection; used by every GPU-capable script to negotiate the backend
  • fft_ops.py — GPU FFT (phase-correlation building blocks)
  • image_quality.py — GPU variance / sharpness metrics
  • interpolation.py — GPU affine/linear resampling
  • morphology.py — GPU morphological ops (erode/dilate/tophat)
  • normalization.py — GPU intensity normalization (percentile / rolling-ball) paired with linumpy/preproc/normalization.py from PR feat: utility modules, preprocessing improvements, galvo correction #97
  • registration.py — GPU phase-correlation + refinement; callable from linumpy/stitching/registration.py

Scripts

  • scripts/linum_benchmark_gpu.py — benchmark harness comparing CPU vs GPU on the key hot paths (normalization, registration, resampling)
  • scripts/linum_gpu_info.py — report available CUDA / CuPy / JAX devices and plugin status

Removed (unified into CPU scripts via --use_gpu)

  • linum_aip_gpu.py, linum_assess_slice_quality_gpu.py, linum_create_mosaic_grid_3d_gpu.py, linum_estimate_transform_gpu.py, linum_fix_illumination_3d_gpu.py, linum_generate_mosaic_aips_gpu.py, linum_normalize_intensities_per_slice_gpu.py, linum_resample_mosaic_grid_gpu.py

Shell helpers

  • shell_scripts/fix_jax_cuda_plugin.sh — resets a broken JAX CUDA plugin install; autodetects CUDA 12 vs 13 and picks the matching jax-cuda*-plugin wheel

Commits

  1. feat: GPU acceleration module and scripts — initial linumpy/gpu/ drop
  2. feat(gpu): unify CPU/GPU scripts, add normalization module — removes the *_gpu.py script variants, adds gpu/normalization.py, hardens cuda_env.py

Dependencies

@FIrgolitsch FIrgolitsch force-pushed the pr-f-gpu-acceleration branch 2 times, most recently from 73f2ced to ece07fe Compare April 17, 2026 22:15
@FIrgolitsch FIrgolitsch changed the base branch from main to pr-a-build-tooling April 17, 2026 22:22
@FIrgolitsch FIrgolitsch force-pushed the pr-f-gpu-acceleration branch from ece07fe to 5aae60d Compare April 23, 2026 19:43
@FIrgolitsch FIrgolitsch force-pushed the pr-f-gpu-acceleration branch from 5aae60d to aa42b00 Compare April 23, 2026 21:09
@FIrgolitsch FIrgolitsch force-pushed the pr-f-gpu-acceleration branch from aa42b00 to fdaffcf Compare April 23, 2026 21:23
- Delete linumpy/gpu/cuda_env.py entirely
- Remove cuda_env imports from gpu/__init__.py
- Remove ensure_cuda_env/preload_cuda_libraries from linum_fix_illumination_3d.py
  and drop the now-obsolete XLA_FLAGS setup (was JAX-specific)
- Remove ensure_cuda_env/preload_cuda_libraries from linum_benchmark_gpu.py
- Update BaSiCPy docstring to reflect PyTorch backend
@FIrgolitsch FIrgolitsch force-pushed the pr-f-gpu-acceleration branch from fdaffcf to 5efa698 Compare April 23, 2026 21:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant