diff --git a/CHANGES.rst b/CHANGES.rst index 3ef45c96..27da5e58 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,8 @@ SCICO Release Notes Version 0.0.8 (unreleased) ---------------------------- +• Substantial computational improvement in ``linop.xray.XRayTransform3D``, + which is now faster than ``linop.xray.astra.XRayTransform3D``. • Enable certain parameters of array creation functions to trigger ``BlockArray`` creation when they receive lists (currently ``device``). • New functional ``functional.BoxIndicator``. diff --git a/examples/scripts/ct_3d_tv_padmm.py b/examples/scripts/ct_3d_tv_padmm.py new file mode 100644 index 00000000..d67c212a --- /dev/null +++ b/examples/scripts/ct_3d_tv_padmm.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver) +====================================================================== + +This example demonstrates solution of a sparse-view, 3D CT +reconstruction problem with isotropic total variation (TV) +regularization + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} + \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ + +where $C$ is the X-ray transform (the CT forward projection operator), +$\mathbf{y}$ is the sinogram, $D$ is a 3D finite difference operator, +and $\mathbf{x}$ is the reconstructed image. + +This example uses the native scico 3d X-Ray projector, while the +[companion example](ct_astra_3d_tv_padmm.rst) uses the astra projector. +""" + +import numpy as np + +from mpl_toolkits.axes_grid1 import make_axes_locatable + +import scico.numpy as snp +from scico import functional, linop, loss, metric, plot +from scico.examples import create_tangle_phantom +from scico.linop.xray import XRayTransform3D +from scico.linop.xray.astra import angle_to_vector, convert_to_scico_geometry +from scico.optimize import ProximalADMM +from scico.util import device_info + +""" +Create a ground truth image and projector. +""" +Nx = 128 +Ny = 256 +Nz = 64 + +tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) + +n_projection = 10 # number of projections +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles +det_spacing = [1.0, 1.0] +det_count = (Nz, max(Nx, Ny)) +vectors = angle_to_vector(det_spacing, angles) + +# It would have been more straightforward to use the det_spacing and angles keywords +# in this case (since vectors is just computed directly from these two quantities), but +# the more general form is used here as a demonstration. +matrices = convert_to_scico_geometry(input_shape=tangle.shape, det_count=det_count, vectors=vectors) +C = XRayTransform3D(tangle.shape, matrices, det_count) # CT projection operator +y = C @ tangle # sinogram + + +r""" +Set up problem and solver. We want to minimize the functional + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} + \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,$$ + +where $C$ is the X-ray transform and $D$ is a finite difference +operator. This problem can be expressed as + + $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} - + \mathbf{z}_0 \|_2^2 + \lambda \| \mathbf{z}_1 \|_{2,1} \;\; + \text{such that} \;\; \mathbf{z}_0 = C \mathbf{x} \;\; \text{and} \;\; + \mathbf{z}_1 = D \mathbf{x} \;,$$ + +which can be written in the form of a standard ADMM problem + + $$\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; f(\mathbf{x}) + g(\mathbf{z}) + \;\; \text{such that} \;\; A \mathbf{x} + B \mathbf{z} = \mathbf{c}$$ + +with + + $$f = 0 \qquad g = g_0 + g_1$$ + $$g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad + g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}$$ + $$A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad + B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad + \mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.$$ + +This is a more complex splitting than that used in the +[companion example](ct_astra_3d_tv_admm.rst), but it allows the use of a +proximal ADMM solver in a way that avoids the need for the conjugate +gradient sub-iterations used by the ADMM solver in the +[companion example](ct_astra_3d_tv_admm.rst). +""" +𝛼 = 1e2 # improve problem conditioning by balancing C and D components of A +λ = 2e0 # ℓ2,1 norm regularization parameter +ρ = 5e-3 # ADMM penalty parameter +maxiter = 1000 # number of ADMM iterations + +f = functional.ZeroFunctional() +g0 = loss.SquaredL2Loss(y=y) +g1 = (λ / 𝛼) * functional.L21Norm() +g = functional.SeparableFunctional((g0, g1)) +D = linop.FiniteDifference(input_shape=tangle.shape, append=0) + +A = linop.VerticalStack((C, 𝛼 * D)) +mu, nu = ProximalADMM.estimate_parameters(A) + +solver = ProximalADMM( + f=f, + g=g, + A=A, + B=None, + rho=ρ, + mu=mu, + nu=nu, + maxiter=maxiter, + itstat_options={"display": True, "period": 50}, +) + +""" +Run the solver. +""" +print(f"Solving on {device_info()}\n") +tangle_recon = solver.solve() + +print( + "TV Restruction\nSNR: %.2f (dB), MAE: %.3f" + % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)) +) + + +""" +Show the recovered image. +""" +fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 6)) +plot.imview( + tangle[32], + title="Ground truth (central slice)", + cmap=plot.cm.Blues, + cbar=None, + fig=fig, + ax=ax[0], +) +plot.imview( + tangle_recon[32], + title="TV Reconstruction (central slice)\nSNR: %.2f (dB), MAE: %.3f" + % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)), + cmap=plot.cm.Blues, + fig=fig, + ax=ax[1], +) +divider = make_axes_locatable(ax[1]) +cax = divider.append_axes("right", size="5%", pad=0.2) +fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units") +fig.show() + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 551a68e9..a391bb08 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -15,6 +15,7 @@ Computed Tomography - ct_astra_noreg_pcg.py - ct_astra_3d_tv_admm.py - ct_astra_3d_tv_padmm.py + - ct_3d_tv_padmm.py - ct_tv_admm.py - ct_astra_tv_admm.py - ct_multi_tv_admm.py @@ -112,6 +113,7 @@ Total Variation - ct_astra_tv_admm.py - ct_astra_3d_tv_admm.py - ct_astra_3d_tv_padmm.py + - ct_3d_tv_padmm.py - ct_astra_weighted_tv_admm.py - ct_svmbir_tv_multi.py - deconv_circ_tv_admm.py @@ -210,6 +212,7 @@ Proximal ADMM ^^^^^^^^^^^^^ - ct_astra_3d_tv_padmm.py + - ct_3d_tv_padmm.py - deconv_tv_padmm.py - denoise_tv_multi.py - deconv_ppp_dncnn_padmm.py diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 99e2b0cd..ddf625ab 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2023-2024 by SCICO Developers +# Copyright (C) 2023-2026 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -19,7 +19,7 @@ import scico.numpy as snp from scico.numpy.util import is_scalar_equiv -from scico.typing import Shape +from scico.typing import DType, Shape from scipy.spatial.transform import Rotation from .._linop import LinearOperator @@ -328,14 +328,12 @@ class XRayTransform3D(LinearOperator): adjoint of the forward projector. It is written purely in JAX, allowing it to run on either CPU or GPU and minimizing host copies. - Warning: This class is experimental and may be up to ten times slower - than :class:`scico.linop.xray.astra.XRayTransform3D`. - For each view, the projection geometry is specified by an array with shape (2, 4) that specifies a :math:`2 \times 3` projection matrix and a :math:`2 \times 1` offset vector. Denoting the matrix - by :math:`\mathbf{M}` and the offset by :math:`\mathbf{t}`, a voxel at array - index `(i, j, k)` has its center projected to the detector coordinates + by :math:`\mathbf{M}` and the offset by :math:`\mathbf{t}`, a voxel + at array index `(i, j, k)` has its center projected to the detector + coordinates .. math:: \mathbf{M} \begin{bmatrix} @@ -354,23 +352,28 @@ def __init__( input_shape: Shape, matrices: ArrayLike, det_shape: Shape, + input_dtype: DType = np.float32, ): r""" Args: - input_shape: Shape of input image. - matrices: (num_views, 2, 4) array of homogeneous projection matrices. + input_shape: Input array shape. + matrices: (num_views, 2, 4) array of homogeneous projection + matrices. det_shape: Shape of detector. + input_dtype: Input array dtype. """ self.input_shape: Shape = input_shape self.matrices = jnp.asarray(matrices, dtype=np.float32) - self.det_shape = det_shape + self.det_shape = tuple(det_shape) # in case det_shape is a list self.output_shape = (len(matrices), *det_shape) super().__init__( input_shape=input_shape, output_shape=self.output_shape, eval_fn=self.project, adj_fn=self.back_project, + input_dtype=input_dtype, + output_dtype=input_dtype, ) def project(self, im: ArrayLike) -> snp.Array: @@ -382,32 +385,39 @@ def back_project(self, proj: ArrayLike) -> snp.Array: return XRayTransform3D._back_project(proj, self.matrices, self.input_shape) @staticmethod + @partial(jax.jit, static_argnames="det_shape") def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> snp.Array: r""" Args: im: Input image. - matrix: (num_views, 2, 4) array of homogeneous projection matrices. + matrix: (num_views, 2, 4) array of homogeneous projection + matrices. det_shape: Shape of detector. """ - MAX_SLICE_LEN = 10 - slice_offsets = list(range(0, im.shape[0], MAX_SLICE_LEN)) - - num_views = len(matrices) - proj = jnp.zeros((num_views,) + det_shape, dtype=im.dtype) - for view_ind, matrix in enumerate(matrices): - for slice_offset in slice_offsets: - proj = proj.at[view_ind].set( - XRayTransform3D._project_single( - im[slice_offset : slice_offset + MAX_SLICE_LEN], - matrix, - proj[view_ind], - slice_offset=slice_offset, - ) - ) - return proj + BATCH_SIZE = 8 + + # Apply gradient checkpointing to the underlying core operator + project_single = jax.remat(XRayTransform3D._project_single) + + # Define projection behavior for a single matrix over the full image + def project_single_matrix(matrix): + # Start with an empty detector plane baseline + init_plane = jnp.zeros(det_shape, dtype=im.dtype) + + # Call the rematerialized operator on the full image + return project_single( + im, + matrix, + init_plane, + slice_offset=0, # No manual loops: processed as a whole + ) + + # Automatically chunk and execute views sequentially/parallelized via JAX. + # If len(matrices) is not divisible by BATCH_SIZE, JAX natively handles the + # remainder. + return jax.lax.map(project_single_matrix, matrices, batch_size=BATCH_SIZE) @staticmethod - @partial(jax.jit, donate_argnames="proj") def _project_single( im: ArrayLike, matrix: ArrayLike, proj: ArrayLike, slice_offset: int = 0 ) -> snp.Array: @@ -428,33 +438,46 @@ def _project_single( return proj @staticmethod + @partial(jax.jit, static_argnames="input_shape") def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array: r""" Args: - proj: Input (set of) projection(s). + proj: Input projection data of shape (num_views, *det_shape). matrix: (num_views, 2, 4) array of homogeneous projection matrices. - input_shape: Shape of desired back projection. + input_shape: Shape of back projection. """ - MAX_SLICE_LEN = 10 - slice_offsets = list(range(0, input_shape[0], MAX_SLICE_LEN)) - - HTy = jnp.zeros(input_shape, dtype=proj.dtype) - for view_ind, matrix in enumerate(matrices): - for slice_offset in slice_offsets: - HTy = HTy.at[slice_offset : slice_offset + MAX_SLICE_LEN].set( - XRayTransform3D._back_project_single( - proj[view_ind], - matrix, - HTy[slice_offset : slice_offset + MAX_SLICE_LEN], - slice_offset=slice_offset, - ) - ) - HTy.block_until_ready() # prevent OOM + BATCH_SIZE = 8 - return HTy + # Wrap the single back-project function for gradient checkpointing + back_project_single = jax.remat(XRayTransform3D._back_project_single) + + # Process an individual view slice-by-slice natively via map mapping + def back_project_single_view(packed_inputs): + # Unpack the active iteration variables provided by lax.map + single_proj, single_matrix = packed_inputs + + # Initialize a full-sized target volume structure for this single projection + # contribution + init_volume = jnp.zeros(input_shape, dtype=proj.dtype) + + return back_project_single( + single_proj, + single_matrix, + init_volume, + slice_offset=0, # Let JAX optimize the internal execution structure + ) + + # Map across the zip-like structure of projections and matrices. + # lax.map accumulates a stacked array of individual volume reconstructions. + individual_volumes = jax.lax.map( + back_project_single_view, (proj, matrices), batch_size=BATCH_SIZE + ) + + # Collapse the mapped axis by summing the independent view contributions + # to finalize the reconstructed 3D output image array. + return jnp.sum(individual_volumes, axis=0) @staticmethod - @partial(jax.jit, donate_argnames="HTy") def _back_project_single( y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0 ) -> snp.Array: @@ -547,11 +570,14 @@ def matrices_from_euler_angles( Args: input_shape: Shape of input image. output_shape: Shape of output (detector). - str: Sequence of axes for rotation. Up to 3 characters belonging to the set {'X', 'Y', 'Z'} - for intrinsic rotations, or {'x', 'y', 'z'} for extrinsic rotations. Extrinsic and - intrinsic rotations cannot be mixed in one function call. + str: Sequence of axes for rotation. Up to 3 characters + belonging to the set {'X', 'Y', 'Z'} for intrinsic + rotations, or {'x', 'y', 'z'} for extrinsic rotations. + Extrinsic and intrinsic rotations cannot be mixed in one + function call. angles: (num_views, N), N = 1, 2, or 3 Euler angles. - degrees: If ``True``, angles are in degrees, otherwise radians. Default: ``True``, radians. + degrees: If ``True``, angles are in degrees, otherwise + radians. Default: ``True``, radians. voxel_spacing: (3,) array giving the spacing of image voxels. Default: `[1.0, 1.0, 1.0]`. Experimental. det_spacing: (2,) array giving the spacing of detector