From 7b4ca3cedbc09629349263fefa9fb11921879ef6 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 16 Jun 2026 16:19:49 -0600 Subject: [PATCH 01/13] Avoid linop __call__ on initialization by specifying input and output dtypes --- scico/linop/xray/_xray.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 99e2b0cd5..09aa2b50a 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2023-2024 by SCICO Developers +# Copyright (C) 2023-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -19,7 +19,7 @@ import scico.numpy as snp from scico.numpy.util import is_scalar_equiv -from scico.typing import Shape +from scico.typing import DType, Shape from scipy.spatial.transform import Rotation from .._linop import LinearOperator @@ -334,8 +334,9 @@ class XRayTransform3D(LinearOperator): For each view, the projection geometry is specified by an array with shape (2, 4) that specifies a :math:`2 \times 3` projection matrix and a :math:`2 \times 1` offset vector. Denoting the matrix - by :math:`\mathbf{M}` and the offset by :math:`\mathbf{t}`, a voxel at array - index `(i, j, k)` has its center projected to the detector coordinates + by :math:`\mathbf{M}` and the offset by :math:`\mathbf{t}`, a voxel + at array index `(i, j, k)` has its center projected to the detector + coordinates .. math:: \mathbf{M} \begin{bmatrix} @@ -354,12 +355,15 @@ def __init__( input_shape: Shape, matrices: ArrayLike, det_shape: Shape, + input_dtype: DType = np.float32, ): r""" Args: - input_shape: Shape of input image. - matrices: (num_views, 2, 4) array of homogeneous projection matrices. + input_shape: Input array shape. + matrices: (num_views, 2, 4) array of homogeneous projection + matrices. det_shape: Shape of detector. + input_dtype: Input array dtype. """ self.input_shape: Shape = input_shape @@ -371,6 +375,8 @@ def __init__( output_shape=self.output_shape, eval_fn=self.project, adj_fn=self.back_project, + input_dtype=input_dtype, + output_dtype=input_dtype, ) def project(self, im: ArrayLike) -> snp.Array: @@ -386,7 +392,8 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: r""" Args: im: Input image. - matrix: (num_views, 2, 4) array of homogeneous projection matrices. + matrix: (num_views, 2, 4) array of homogeneous projection + matrices. det_shape: Shape of detector. """ MAX_SLICE_LEN = 10 @@ -547,11 +554,14 @@ def matrices_from_euler_angles( Args: input_shape: Shape of input image. output_shape: Shape of output (detector). - str: Sequence of axes for rotation. Up to 3 characters belonging to the set {'X', 'Y', 'Z'} - for intrinsic rotations, or {'x', 'y', 'z'} for extrinsic rotations. Extrinsic and - intrinsic rotations cannot be mixed in one function call. + str: Sequence of axes for rotation. Up to 3 characters + belonging to the set {'X', 'Y', 'Z'} for intrinsic + rotations, or {'x', 'y', 'z'} for extrinsic rotations. + Extrinsic and intrinsic rotations cannot be mixed in one + function call. angles: (num_views, N), N = 1, 2, or 3 Euler angles. - degrees: If ``True``, angles are in degrees, otherwise radians. Default: ``True``, radians. + degrees: If ``True``, angles are in degrees, otherwise + radians. Default: ``True``, radians. voxel_spacing: (3,) array giving the spacing of image voxels. Default: `[1.0, 1.0, 1.0]`. Experimental. det_spacing: (2,) array giving the spacing of detector From a4dd05c6e229b58ded0442e5fe9f6f282342b0ee Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jun 2026 17:19:14 -0600 Subject: [PATCH 02/13] Attempt to speed up 3D projection and adjoint and make them jit-compatible --- scico/linop/xray/_xray.py | 74 +++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 09aa2b50a..88ae065a9 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -388,6 +388,7 @@ def back_project(self, proj: ArrayLike) -> snp.Array: return XRayTransform3D._back_project(proj, self.matrices, self.input_shape) @staticmethod + @partial(jax.jit, static_argnames="det_shape") def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: r""" Args: @@ -399,22 +400,25 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: MAX_SLICE_LEN = 10 slice_offsets = list(range(0, im.shape[0], MAX_SLICE_LEN)) - num_views = len(matrices) - proj = jnp.zeros((num_views,) + det_shape, dtype=im.dtype) - for view_ind, matrix in enumerate(matrices): + blank_view_plane = jnp.zeros(det_shape, dtype=im.dtype) + + def project_single_matrix(matrix, init_plane): + proj_plane = init_plane for slice_offset in slice_offsets: - proj = proj.at[view_ind].set( - XRayTransform3D._project_single( - im[slice_offset : slice_offset + MAX_SLICE_LEN], - matrix, - proj[view_ind], - slice_offset=slice_offset, - ) + proj_plane = XRayTransform3D._project_single( + im[slice_offset : slice_offset + MAX_SLICE_LEN], + matrix, + proj_plane, + slice_offset=slice_offset, ) - return proj + return proj_plane + + mapped_project = jax.vmap(project_single_matrix, in_axes=(0, None)) + + return mapped_project(matrices, blank_view_plane) @staticmethod - @partial(jax.jit, donate_argnames="proj") + # @partial(jax.jit, donate_argnames="proj") def _project_single( im: ArrayLike, matrix: ArrayLike, proj: ArrayLike, slice_offset: int = 0 ) -> snp.Array: @@ -435,33 +439,49 @@ def _project_single( return proj @staticmethod + @partial(jax.jit, static_argnames="input_shape") def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array: r""" Args: - proj: Input (set of) projection(s). + proj: Input projection data of shape (num_views, *det_shape). matrix: (num_views, 2, 4) array of homogeneous projection matrices. - input_shape: Shape of desired back projection. + input_shape: Shape of back projection. """ MAX_SLICE_LEN = 10 slice_offsets = list(range(0, input_shape[0], MAX_SLICE_LEN)) - HTy = jnp.zeros(input_shape, dtype=proj.dtype) - for view_ind, matrix in enumerate(matrices): - for slice_offset in slice_offsets: - HTy = HTy.at[slice_offset : slice_offset + MAX_SLICE_LEN].set( - XRayTransform3D._back_project_single( - proj[view_ind], - matrix, - HTy[slice_offset : slice_offset + MAX_SLICE_LEN], - slice_offset=slice_offset, - ) + # 1. Provide an accumulator buffer matching the inner target slice size + # This acts as the tracer buffer, preventing 'donated buffer' mismatches. + blank_im_slice_buffer = jnp.zeros((MAX_SLICE_LEN,) + input_shape[1:], dtype=proj.dtype) + + # 2. Extract specific image slices across ALL views simultaneously + def back_project_single_slice(slice_offset): + # We start with our empty, statically shaped sub-slice buffer + im_slice = blank_im_slice_buffer + + # Sequentially accumulate contributions from all projection views into this slice + for view_ind, matrix in enumerate(matrices): + im_slice = XRayTransform3D._back_project_single( + proj[view_ind], + matrix, + im_slice, + slice_offset=slice_offset, ) - HTy.block_until_ready() # prevent OOM + return im_slice - return HTy + # 3. Use vmap to calculate all slice blocks in parallel + mapped_back_project = jax.vmap(back_project_single_slice) + all_slices = mapped_back_project(jnp.array(slice_offsets)) + + # 4. Reshape the stacked vmap output back into a continuous 3D image array + # vmap gives (num_slices, MAX_SLICE_LEN, Y, Z), we flatten the first two axes + im = all_slices.reshape((-1,) + input_shape[1:]) + + # Trim any padding if the image size isn't perfectly divisible by MAX_SLICE_LEN + return im[: input_shape[0]] @staticmethod - @partial(jax.jit, donate_argnames="HTy") + # @partial(jax.jit, donate_argnames="HTy") def _back_project_single( y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0 ) -> snp.Array: From 66ab63f188d3ee0fd9d610baca8ec2355c229f62 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jun 2026 17:28:56 -0600 Subject: [PATCH 03/13] Clean up --- scico/linop/xray/_xray.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 88ae065a9..8a500fd75 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -400,7 +400,7 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: MAX_SLICE_LEN = 10 slice_offsets = list(range(0, im.shape[0], MAX_SLICE_LEN)) - blank_view_plane = jnp.zeros(det_shape, dtype=im.dtype) + view_plane = jnp.zeros(det_shape, dtype=im.dtype) def project_single_matrix(matrix, init_plane): proj_plane = init_plane @@ -415,10 +415,9 @@ def project_single_matrix(matrix, init_plane): mapped_project = jax.vmap(project_single_matrix, in_axes=(0, None)) - return mapped_project(matrices, blank_view_plane) + return mapped_project(matrices, view_plane) @staticmethod - # @partial(jax.jit, donate_argnames="proj") def _project_single( im: ArrayLike, matrix: ArrayLike, proj: ArrayLike, slice_offset: int = 0 ) -> snp.Array: @@ -450,14 +449,14 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> s MAX_SLICE_LEN = 10 slice_offsets = list(range(0, input_shape[0], MAX_SLICE_LEN)) - # 1. Provide an accumulator buffer matching the inner target slice size + # Provide an accumulator buffer matching the inner target slice size # This acts as the tracer buffer, preventing 'donated buffer' mismatches. - blank_im_slice_buffer = jnp.zeros((MAX_SLICE_LEN,) + input_shape[1:], dtype=proj.dtype) + im_slice_buffer = jnp.zeros((MAX_SLICE_LEN,) + input_shape[1:], dtype=proj.dtype) - # 2. Extract specific image slices across ALL views simultaneously + # Extract specific image slices across ALL views simultaneously def back_project_single_slice(slice_offset): # We start with our empty, statically shaped sub-slice buffer - im_slice = blank_im_slice_buffer + im_slice = im_slice_buffer # Sequentially accumulate contributions from all projection views into this slice for view_ind, matrix in enumerate(matrices): @@ -469,19 +468,18 @@ def back_project_single_slice(slice_offset): ) return im_slice - # 3. Use vmap to calculate all slice blocks in parallel + # Use vmap to calculate all slice blocks in parallel. mapped_back_project = jax.vmap(back_project_single_slice) all_slices = mapped_back_project(jnp.array(slice_offsets)) - # 4. Reshape the stacked vmap output back into a continuous 3D image array - # vmap gives (num_slices, MAX_SLICE_LEN, Y, Z), we flatten the first two axes + # Reshape the stacked vmap output back into a continuous 3D image array + # vmap gives (num_slices, MAX_SLICE_LEN, Y, Z), we flatten the first two axes. im = all_slices.reshape((-1,) + input_shape[1:]) - # Trim any padding if the image size isn't perfectly divisible by MAX_SLICE_LEN + # Trim any padding if the image size isn't perfectly divisible by MAX_SLICE_LEN. return im[: input_shape[0]] @staticmethod - # @partial(jax.jit, donate_argnames="HTy") def _back_project_single( y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0 ) -> snp.Array: From db5ad17ab6e73e2fda252c49428b903f62605b24 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jun 2026 17:32:02 -0600 Subject: [PATCH 04/13] Clean up --- scico/linop/xray/_xray.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 8a500fd75..417aa8bc5 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -400,8 +400,12 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: MAX_SLICE_LEN = 10 slice_offsets = list(range(0, im.shape[0], MAX_SLICE_LEN)) + # Provide an accumulator buffer matching the inner target detector size + # to avoid tracer buffer mismatch warnings during vmap tracing. view_plane = jnp.zeros(det_shape, dtype=im.dtype) + # Compute complete projections across all slice blocks for a single view + # matrix. def project_single_matrix(matrix, init_plane): proj_plane = init_plane for slice_offset in slice_offsets: @@ -413,6 +417,7 @@ def project_single_matrix(matrix, init_plane): ) return proj_plane + # Use vmap to process all projection views in parallel mapped_project = jax.vmap(project_single_matrix, in_axes=(0, None)) return mapped_project(matrices, view_plane) From cf8e12929949cdfd539c35a1b79eae2275e82159 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jun 2026 17:32:15 -0600 Subject: [PATCH 05/13] Clean up --- scico/linop/xray/_xray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 417aa8bc5..06fff0dfa 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -458,7 +458,7 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> s # This acts as the tracer buffer, preventing 'donated buffer' mismatches. im_slice_buffer = jnp.zeros((MAX_SLICE_LEN,) + input_shape[1:], dtype=proj.dtype) - # Extract specific image slices across ALL views simultaneously + # Extract specific image slices across all views simultaneously def back_project_single_slice(slice_offset): # We start with our empty, statically shaped sub-slice buffer im_slice = im_slice_buffer From 7df37ffd3362389afeef769decee95a509e0b6d4 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jun 2026 18:02:05 -0600 Subject: [PATCH 06/13] Remove manual slicing/batching from implementation --- scico/linop/xray/_xray.py | 107 ++++++++++++++++++-------------------- 1 file changed, 51 insertions(+), 56 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 06fff0dfa..dc5002c7d 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -397,30 +397,28 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: matrices. det_shape: Shape of detector. """ - MAX_SLICE_LEN = 10 - slice_offsets = list(range(0, im.shape[0], MAX_SLICE_LEN)) - - # Provide an accumulator buffer matching the inner target detector size - # to avoid tracer buffer mismatch warnings during vmap tracing. - view_plane = jnp.zeros(det_shape, dtype=im.dtype) - - # Compute complete projections across all slice blocks for a single view - # matrix. - def project_single_matrix(matrix, init_plane): - proj_plane = init_plane - for slice_offset in slice_offsets: - proj_plane = XRayTransform3D._project_single( - im[slice_offset : slice_offset + MAX_SLICE_LEN], - matrix, - proj_plane, - slice_offset=slice_offset, - ) - return proj_plane - - # Use vmap to process all projection views in parallel - mapped_project = jax.vmap(project_single_matrix, in_axes=(0, None)) - - return mapped_project(matrices, view_plane) + BATCH_SIZE = 10 + + # Apply gradient checkpointing to the underlying core operator + project_single = jax.remat(XRayTransform3D._project_single) + + # Define projection behavior for a single matrix over the full image + def project_single_matrix(matrix): + # Start with an empty detector plane baseline + init_plane = jnp.zeros(det_shape, dtype=im.dtype) + + # Call the rematerialized operator on the full image + return XRayTransform3D._project_single( + im, + matrix, + init_plane, + slice_offset=0, # No manual loops: processed as a whole + ) + + # Automatically chunk and execute views sequentially/parallelized via JAX. + # If len(matrices) is not divisible by BATCH_SIZE, JAX natively handles the + # remainder. + return jax.lax.map(project_single_matrix, matrices, batch_size=BATCH_SIZE) @staticmethod def _project_single( @@ -451,38 +449,35 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> s matrix: (num_views, 2, 4) array of homogeneous projection matrices. input_shape: Shape of back projection. """ - MAX_SLICE_LEN = 10 - slice_offsets = list(range(0, input_shape[0], MAX_SLICE_LEN)) - - # Provide an accumulator buffer matching the inner target slice size - # This acts as the tracer buffer, preventing 'donated buffer' mismatches. - im_slice_buffer = jnp.zeros((MAX_SLICE_LEN,) + input_shape[1:], dtype=proj.dtype) - - # Extract specific image slices across all views simultaneously - def back_project_single_slice(slice_offset): - # We start with our empty, statically shaped sub-slice buffer - im_slice = im_slice_buffer - - # Sequentially accumulate contributions from all projection views into this slice - for view_ind, matrix in enumerate(matrices): - im_slice = XRayTransform3D._back_project_single( - proj[view_ind], - matrix, - im_slice, - slice_offset=slice_offset, - ) - return im_slice - - # Use vmap to calculate all slice blocks in parallel. - mapped_back_project = jax.vmap(back_project_single_slice) - all_slices = mapped_back_project(jnp.array(slice_offsets)) - - # Reshape the stacked vmap output back into a continuous 3D image array - # vmap gives (num_slices, MAX_SLICE_LEN, Y, Z), we flatten the first two axes. - im = all_slices.reshape((-1,) + input_shape[1:]) - - # Trim any padding if the image size isn't perfectly divisible by MAX_SLICE_LEN. - return im[: input_shape[0]] + BATCH_SIZE = 10 + + # Wrap the single back-project function for gradient checkpointing + back_project_single_core = jax.remat(XRayTransform3D._back_project_single) + + # Process an individual view slice-by-slice natively via map mapping + def back_project_single_view(packed_inputs): + # Unpack the active iteration variables provided by lax.map + single_proj, single_matrix = packed_inputs + + # Initialize a full-sized target volume structure for this single projection contribution + init_volume = jnp.zeros(input_shape, dtype=proj.dtype) + + return back_project_single_core( + single_proj, + single_matrix, + init_volume, + slice_offset=0, # Let JAX optimize the internal execution structure + ) + + # Map across the zip-like structure of projections and matrices. + # lax.map accumulates a stacked array of individual volume reconstructions. + individual_volumes = jax.lax.map( + back_project_single_view, (proj, matrices), batch_size=BATCH_SIZE + ) + + # Collapse the mapped axis by summing the independent view contributions + # to finalize the reconstructed 3D output image array. + return jnp.sum(individual_volumes, axis=0) @staticmethod def _back_project_single( From 6e1391448955718827c7ed75d4bb809103b2fa2d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jun 2026 18:08:21 -0600 Subject: [PATCH 07/13] Fix some errors --- scico/linop/xray/_xray.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index dc5002c7d..9bb410739 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -408,7 +408,7 @@ def project_single_matrix(matrix): init_plane = jnp.zeros(det_shape, dtype=im.dtype) # Call the rematerialized operator on the full image - return XRayTransform3D._project_single( + return project_single( im, matrix, init_plane, @@ -452,17 +452,18 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> s BATCH_SIZE = 10 # Wrap the single back-project function for gradient checkpointing - back_project_single_core = jax.remat(XRayTransform3D._back_project_single) + back_project_single = jax.remat(XRayTransform3D._back_project_single) # Process an individual view slice-by-slice natively via map mapping def back_project_single_view(packed_inputs): # Unpack the active iteration variables provided by lax.map single_proj, single_matrix = packed_inputs - # Initialize a full-sized target volume structure for this single projection contribution + # Initialize a full-sized target volume structure for this single projection + # contribution init_volume = jnp.zeros(input_shape, dtype=proj.dtype) - return back_project_single_core( + return back_project_single( single_proj, single_matrix, init_volume, From 67ee13d727341d9fb73f7bcd340f31b5b36fb9e3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jun 2026 18:22:21 -0600 Subject: [PATCH 08/13] Change batch size --- scico/linop/xray/_xray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 9bb410739..c1c044c26 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -397,7 +397,7 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: matrices. det_shape: Shape of detector. """ - BATCH_SIZE = 10 + BATCH_SIZE = 8 # Apply gradient checkpointing to the underlying core operator project_single = jax.remat(XRayTransform3D._project_single) @@ -449,7 +449,7 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> s matrix: (num_views, 2, 4) array of homogeneous projection matrices. input_shape: Shape of back projection. """ - BATCH_SIZE = 10 + BATCH_SIZE = 8 # Wrap the single back-project function for gradient checkpointing back_project_single = jax.remat(XRayTransform3D._back_project_single) From 0a72cb440172caaa50087f77ff375839e5e74560 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 17 Jun 2026 18:42:54 -0600 Subject: [PATCH 09/13] Improve initializer parameter handling --- scico/linop/xray/_xray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index c1c044c26..631c463f9 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -368,7 +368,7 @@ def __init__( self.input_shape: Shape = input_shape self.matrices = jnp.asarray(matrices, dtype=np.float32) - self.det_shape = det_shape + self.det_shape = tuple(det_shape) # in case det_shape is a list self.output_shape = (len(matrices), *det_shape) super().__init__( input_shape=input_shape, From 68cb2eb904cd39225f751b582730a8540955ca83 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jun 2026 07:07:55 -0600 Subject: [PATCH 10/13] New example script --- examples/scripts/ct_3d_tv_padmm.py | 158 +++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 examples/scripts/ct_3d_tv_padmm.py diff --git a/examples/scripts/ct_3d_tv_padmm.py b/examples/scripts/ct_3d_tv_padmm.py new file mode 100644 index 000000000..d67c212ad --- /dev/null +++ b/examples/scripts/ct_3d_tv_padmm.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver) +====================================================================== + +This example demonstrates solution of a sparse-view, 3D CT +reconstruction problem with isotropic total variation (TV) +regularization + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} + \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ + +where $C$ is the X-ray transform (the CT forward projection operator), +$\mathbf{y}$ is the sinogram, $D$ is a 3D finite difference operator, +and $\mathbf{x}$ is the reconstructed image. + +This example uses the native scico 3d X-Ray projector, while the +[companion example](ct_astra_3d_tv_padmm.rst) uses the astra projector. +""" + +import numpy as np + +from mpl_toolkits.axes_grid1 import make_axes_locatable + +import scico.numpy as snp +from scico import functional, linop, loss, metric, plot +from scico.examples import create_tangle_phantom +from scico.linop.xray import XRayTransform3D +from scico.linop.xray.astra import angle_to_vector, convert_to_scico_geometry +from scico.optimize import ProximalADMM +from scico.util import device_info + +""" +Create a ground truth image and projector. +""" +Nx = 128 +Ny = 256 +Nz = 64 + +tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) + +n_projection = 10 # number of projections +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles +det_spacing = [1.0, 1.0] +det_count = (Nz, max(Nx, Ny)) +vectors = angle_to_vector(det_spacing, angles) + +# It would have been more straightforward to use the det_spacing and angles keywords +# in this case (since vectors is just computed directly from these two quantities), but +# the more general form is used here as a demonstration. +matrices = convert_to_scico_geometry(input_shape=tangle.shape, det_count=det_count, vectors=vectors) +C = XRayTransform3D(tangle.shape, matrices, det_count) # CT projection operator +y = C @ tangle # sinogram + + +r""" +Set up problem and solver. We want to minimize the functional + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} + \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ + +where $C$ is the X-ray transform and $D$ is a finite difference +operator. This problem can be expressed as + + $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} - + \mathbf{z}_0 \|_2^2 + \lambda \| \mathbf{z}_1 \|_{2,1} \;\; + \text{such that} \;\; \mathbf{z}_0 = C \mathbf{x} \;\; \text{and} \;\; + \mathbf{z}_1 = D \mathbf{x} \;,$$ + +which can be written in the form of a standard ADMM problem + + $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; f(\mathbf{x}) + g(\mathbf{z}) + \;\; \text{such that} \;\; A \mathbf{x} + B \mathbf{z} = \mathbf{c}$$ + +with + + $$f = 0 \qquad g = g_0 + g_1$$ + $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad + g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}$$ + $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad + B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad + \mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$ + +This is a more complex splitting than that used in the +[companion example](ct_astra_3d_tv_admm.rst), but it allows the use of a +proximal ADMM solver in a way that avoids the need for the conjugate +gradient sub-iterations used by the ADMM solver in the +[companion example](ct_astra_3d_tv_admm.rst). +""" +𝛼 = 1e2 # improve problem conditioning by balancing C and D components of A +λ = 2e0 # ℓ2,1 norm regularization parameter +ρ = 5e-3 # ADMM penalty parameter +maxiter = 1000 # number of ADMM iterations + +f = functional.ZeroFunctional() +g0 = loss.SquaredL2Loss(y=y) +g1 = (λ / 𝛼) * functional.L21Norm() +g = functional.SeparableFunctional((g0, g1)) +D = linop.FiniteDifference(input_shape=tangle.shape, append=0) + +A = linop.VerticalStack((C, 𝛼 * D)) +mu, nu = ProximalADMM.estimate_parameters(A) + +solver = ProximalADMM( + f=f, + g=g, + A=A, + B=None, + rho=ρ, + mu=mu, + nu=nu, + maxiter=maxiter, + itstat_options={"display": True, "period": 50}, +) + +""" +Run the solver. +""" +print(f"Solving on {device_info()}\n") +tangle_recon = solver.solve() + +print( + "TV Restruction\nSNR: %.2f (dB), MAE: %.3f" + % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)) +) + + +""" +Show the recovered image. +""" +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 6)) +plot.imview( + tangle[32], + title="Ground truth (central slice)", + cmap=plot.cm.Blues, + cbar=None, + fig=fig, + ax=ax[0], +) +plot.imview( + tangle_recon[32], + title="TV Reconstruction (central slice)\nSNR: %.2f (dB), MAE: %.3f" + % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)), + cmap=plot.cm.Blues, + fig=fig, + ax=ax[1], +) +divider = make_axes_locatable(ax[1]) +cax = divider.append_axes("right", size="5%", pad=0.2) +fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units") +fig.show() + +input("\nWaiting for input to close figures and exit") From e27eacb238f1c5346c6eb4e7588236e9bb98c5cc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jun 2026 07:20:22 -0600 Subject: [PATCH 11/13] Update change summary --- CHANGES.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 9f7b0070e..c2236c7df 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,9 @@ SCICO Release Notes Version 0.0.8 (unreleased) ---------------------------- + +• Substantial computational improvement in ``linop.xray.XRayTransform3D``, + which is now faster than ``linop.xray.astra.XRayTransform3D``. • Enable certain parameters of array creation functions to trigger ``BlockArray`` creation when they receive lists (currently ``device``). • New functional ``functional.BoxIndicator``. From 5ed508545b43451409138bde3448a2c440391424 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jun 2026 07:26:05 -0600 Subject: [PATCH 12/13] Remove experimental warning --- scico/linop/xray/_xray.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 631c463f9..ddf625ab7 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -328,9 +328,6 @@ class XRayTransform3D(LinearOperator): adjoint of the forward projector. It is written purely in JAX, allowing it to run on either CPU or GPU and minimizing host copies. - Warning: This class is experimental and may be up to ten times slower - than :class:`scico.linop.xray.astra.XRayTransform3D`. - For each view, the projection geometry is specified by an array with shape (2, 4) that specifies a :math:`2 \times 3` projection matrix and a :math:`2 \times 1` offset vector. Denoting the matrix From ac2390cc4cd6d1b44b38a0ee714e6cb0d0653caf Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 18 Jun 2026 07:26:20 -0600 Subject: [PATCH 13/13] Update example index --- examples/scripts/index.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 551a68e9a..a391bb08c 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -15,6 +15,7 @@ Computed Tomography - ct_astra_noreg_pcg.py - ct_astra_3d_tv_admm.py - ct_astra_3d_tv_padmm.py + - ct_3d_tv_padmm.py - ct_tv_admm.py - ct_astra_tv_admm.py - ct_multi_tv_admm.py @@ -112,6 +113,7 @@ Total Variation - ct_astra_tv_admm.py - ct_astra_3d_tv_admm.py - ct_astra_3d_tv_padmm.py + - ct_3d_tv_padmm.py - ct_astra_weighted_tv_admm.py - ct_svmbir_tv_multi.py - deconv_circ_tv_admm.py @@ -210,6 +212,7 @@ Proximal ADMM ^^^^^^^^^^^^^ - ct_astra_3d_tv_padmm.py + - ct_3d_tv_padmm.py - deconv_tv_padmm.py - denoise_tv_multi.py - deconv_ppp_dncnn_padmm.py