diff --git a/docs/installation.rst b/docs/installation.rst
index 92144e526..1c956addd 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -13,7 +13,7 @@ This package has the following dependencies:
* `Numpy `__ 1.20 or later
-* `Astropy `__ 5.0 or later
+* `Astropy `__ 6.0 or later
* `Scipy `__ 1.5 or later
diff --git a/pyproject.toml b/pyproject.toml
index 7e3bd49a6..1f8295ac4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,7 +13,7 @@ urls = {Homepage = "https://reproject.readthedocs.io"}
requires-python = ">=3.11"
dependencies = [
"numpy>=1.23",
- "astropy>=5.0",
+ "astropy>=6.0",
"astropy-healpix>=1.0",
"scipy>=1.9",
"dask[array]>=2024.4.1",
diff --git a/reproject/_common.py b/reproject/_common.py
index b3e0a137a..bc0e45ef0 100644
--- a/reproject/_common.py
+++ b/reproject/_common.py
@@ -61,6 +61,7 @@ def _reproject_dispatcher(
shape_out,
wcs_out,
block_size=None,
+ non_reprojected_dims=None,
array_out=None,
return_footprint=True,
output_footprint=None,
@@ -94,6 +95,15 @@ def _reproject_dispatcher(
the block size automatically determined. If ``block_size`` is not
specified or set to `None`, the reprojection will not be carried out in
blocks.
+ non_reprojected_dims : tuple, optional
+ Leading dimensions of the data that should not be reprojected but for
+ which a one-to-one mapping between input and output pixels is assumed,
+ given as a tuple of sequential integers starting from zero (e.g.
+ ``(0,)`` or ``(0, 1)``). If `None` (the default), any leading dimensions
+ for which the WCS has fewer dimensions than the data are treated this
+ way. Reprojecting fewer dimensions than the WCS currently requires a
+ ``block_size`` that matches the output shape along the reprojected
+ dimensions.
array_out : `~numpy.ndarray`, optional
An array in which to store the reprojected data. This can be any numpy
array including a memory map, which may be helpful when dealing with
@@ -144,6 +154,50 @@ def _reproject_dispatcher(
if reproject_func_kwargs is None:
reproject_func_kwargs = {}
+ # For now, we are quite restrictive in what non_reprojected_dims can
+ # be, but it is designed so that if we wanted we could support more use
+ # cases in future. For now, it has to be a tuple where each element is
+ # sequential from zero, e.g. (0,) or (0, 1) or (0, 1, 2)
+
+ if non_reprojected_dims is None:
+ n_dim_reproject = min(wcs_in.low_level_wcs.pixel_n_dim, wcs_out.low_level_wcs.pixel_n_dim)
+ else:
+ if non_reprojected_dims != tuple(range(len(non_reprojected_dims))):
+ raise ValueError(
+ "non_reprojected_dims should be a tuple with values increasing sequentially from zero"
+ )
+ # If either WCS already has fewer dimensions than the data, the missing
+ # dimensions are implicitly non-reprojected, so the shortfall has to be
+ # consistent with the number of non_reprojected_dims requested.
+ for label, wcs in (("input", wcs_in), ("output", wcs_out)):
+ n_dim_missing = len(shape_out) - wcs.low_level_wcs.pixel_n_dim
+ if n_dim_missing > 0 and n_dim_missing != len(non_reprojected_dims):
+ raise ValueError(
+ f"The {label} WCS has {wcs.low_level_wcs.pixel_n_dim} pixel dimensions "
+ f"which is fewer than the {len(shape_out)} data dimensions, but the "
+ f"difference ({n_dim_missing}) does not match the number of "
+ f"non_reprojected_dims ({len(non_reprojected_dims)})"
+ )
+ n_dim_reproject = len(shape_out) - len(non_reprojected_dims)
+ if n_dim_reproject < 1:
+ raise ValueError(
+ "non_reprojected_dims should leave at least one dimension to be " "reprojected"
+ )
+
+ # If we are reprojecting fewer dimensions than the input or output WCS has,
+ # the WCS needs to be sliced down to the reprojected dimensions for each
+ # non-reprojected slice. This is currently only done when parallelizing over
+ # the non-reprojected (broadcasted) dimensions, so any other code path would
+ # silently reproject the dimensions that should have been left untouched.
+ # This is gated on non_reprojected_dims being set since that is the only way
+ # to opt into reprojecting fewer dimensions than the WCS; a plain mismatch
+ # between the input and output WCS dimensionality is instead a validation
+ # error raised by the underlying reprojection function.
+ wcs_slicing_required = non_reprojected_dims is not None and (
+ n_dim_reproject < wcs_in.low_level_wcs.pixel_n_dim
+ or n_dim_reproject < wcs_out.low_level_wcs.pixel_n_dim
+ )
+
# We set up a global temporary directory since this will be used e.g. to
# store memory mapped Numpy arrays and zarr arrays.
@@ -168,7 +222,7 @@ def _reproject_dispatcher(
# If neither parallel nor blocked reprojection are requested, we simply
# call the underlying core reproject function with the full arrays.
- if block_size is None and parallel is False:
+ if block_size is None and parallel is False and not wcs_slicing_required:
# If a dask array was passed as input, we first convert this to a
# Numpy memory mapped array
@@ -207,9 +261,49 @@ def _reproject_dispatcher(
# shape_out will be the full size of the output array as this is updated
# in parse_output_projection, even if shape_out was originally passed in as
# the shape of a single image.
- broadcasting = wcs_in.low_level_wcs.pixel_n_dim < len(shape_out)
+ broadcasting = n_dim_reproject < len(shape_out)
+
+ logger.info(
+ f"Broadcasting is {'' if broadcasting else 'not '}being used, "
+ f"reprojecting last {n_dim_reproject} axes"
+ )
- logger.info(f"Broadcasting is {'' if broadcasting else 'not '}being used")
+ # The output shape must match the input shape along any non-reprojected
+ # (broadcasted) dimensions.
+
+ shape_in = array_in.shape
+ shape_out = tuple(shape_out)
+
+ if shape_out[:-n_dim_reproject] != shape_in[:-n_dim_reproject]:
+ raise ValueError("Input shape should match output shape for non-reprojected dimensions")
+
+ # If an explicit block size was passed, normalize it to have the same
+ # number of elements as shape_out, expanding it if it only covers the
+ # reprojected dimensions and replacing any -1 values by the full size
+ # along the corresponding dimension. If block_size is None or 'auto',
+ # the chunking is determined automatically further below.
+
+ if block_size is not None and block_size != "auto":
+
+ if len(block_size) > len(shape_out):
+ raise ValueError(
+ f"block_size {block_size} cannot have more elements "
+ f"than the dimensionality of the output ({len(shape_out)})"
+ )
+
+ if len(block_size) != n_dim_reproject and len(block_size) != len(shape_out):
+ raise ValueError(
+ f"block_size {block_size} should have either "
+ f"{n_dim_reproject} or {len(shape_out)} elements"
+ )
+
+ if len(block_size) == n_dim_reproject:
+ block_size = (-1,) * (len(shape_out) - n_dim_reproject) + tuple(block_size)
+
+ block_size = tuple(
+ block_size[i] if block_size[i] != -1 else shape_out[i]
+ for i in range(len(block_size))
+ )
# Check block size and determine whether block size indicates we should
# parallelize over broadcasted dimension. The logic is as follows: if
@@ -222,42 +316,39 @@ def _reproject_dispatcher(
# missing dimensions.
broadcasted_parallelization = False
if broadcasting and block_size is not None and block_size != "auto":
- if len(block_size) == len(shape_out):
- if (
- block_size[-wcs_in.low_level_wcs.pixel_n_dim :]
- == shape_out[-wcs_in.low_level_wcs.pixel_n_dim :]
- ):
- broadcasted_parallelization = True
- block_size = (
- block_size[: -wcs_in.low_level_wcs.pixel_n_dim]
- + (-1,) * wcs_in.low_level_wcs.pixel_n_dim
- )
- else:
- for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
- if block_size[i] != -1 and block_size[i] != shape_out[i]:
- raise ValueError(
- "block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions"
- )
- elif len(block_size) < len(shape_out):
- block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
- else:
+ if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]:
+ # TODO: maybe error if block_size was given in full and is wrong
+ broadcasted_parallelization = True
+ block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[
+ -n_dim_reproject:
+ ]
+ elif block_size[:-n_dim_reproject] != shape_out[:-n_dim_reproject]:
raise ValueError(
- f"block_size {len(block_size)} cannot have more elements "
- f"than the dimensionality of the output ({len(shape_out)})"
+ "block shape should either match output data shape along "
+ "reprojected dimensions or non-reprojected dimensions"
)
- # TODO: check for shape_out not matching shape_in along broadcasted dimensions
-
logger.info(
f"{'P' if broadcasted_parallelization else 'Not p'}arallelizing along "
f"broadcasted dimension ({block_size=}, {shape_out=})"
)
+ # TODO: support block_size="auto" (and the default of None) together
+ # with non_reprojected_dims so that this does not have to raise; "auto"
+ # currently falls through to the generic auto-chunking path further
+ # below, which cannot parallelize over the non-reprojected dimensions.
+ if wcs_slicing_required and not broadcasted_parallelization:
+ raise NotImplementedError(
+ "Reprojecting fewer dimensions than the input or output WCS "
+ "(for example using non_reprojected_dims) currently requires "
+ "passing a block_size whose entries along the reprojected "
+ "dimensions match the output shape (optionally with parallel=True "
+ "to compute the blocks concurrently)"
+ )
+
if output_footprint is None and return_footprint:
output_footprint = np.zeros(shape_out, dtype=float)
- shape_in = array_in.shape
-
def reproject_single_block(a, array_or_path, block_info=None):
if (
@@ -271,6 +362,8 @@ def reproject_single_block(a, array_or_path, block_info=None):
if isinstance(array_or_path, str) and array_or_path == "from-dict":
array_or_path = dask_arrays["array"]
+ shape_out = block_info[None]["chunk-shape"][1:]
+
# The WCS class from astropy is not thread-safe, see e.g.
# https://github.com/astropy/astropy/issues/16244
# https://github.com/astropy/astropy/issues/16245
@@ -282,16 +375,41 @@ def reproject_single_block(a, array_or_path, block_info=None):
wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in
wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out
- slices = [
- slice(*x) for x in block_info[None]["array-location"][-wcs_out_cp.pixel_n_dim :]
- ]
+ slices_in = []
+ slices_out = []
+ for idx in range(len(shape_out)):
+ interval = block_info[None]["array-location"][idx + 1]
+ if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject:
+ if interval[1] - interval[0] != 1:
+ raise RuntimeError(
+ f"Expected a chunk of width 1 along dimension {idx} "
+ f"(got {interval[1] - interval[0]})"
+ )
+ slices_in.append(interval[0])
+ slices_out.append(interval[0])
+ else:
+ slices_in.append(slice(None))
+ slices_out.append(slice(*block_info[None]["array-location"][idx + 1]))
+
+ slices_in = slices_in[-wcs_in.low_level_wcs.pixel_n_dim :]
+ slices_out = slices_out[-wcs_out.low_level_wcs.pixel_n_dim :]
+
+ if broadcasted_parallelization:
+ if isinstance(wcs_in_cp, BaseHighLevelWCS):
+ low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp.low_level_wcs, slices=slices_in)
+ else:
+ low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp, slices=slices_in)
+
+ wcs_in_sub = HighLevelWCSWrapper(low_level_wcs_in)
+ else:
+ wcs_in_sub = wcs_in_cp
- if isinstance(wcs_out, BaseHighLevelWCS):
- low_level_wcs = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices)
+ if isinstance(wcs_out_cp, BaseHighLevelWCS):
+ low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices_out)
else:
- low_level_wcs = SlicedLowLevelWCS(wcs_out_cp, slices=slices)
+ low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp, slices=slices_out)
- wcs_out_sub = HighLevelWCSWrapper(low_level_wcs)
+ wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out)
if isinstance(array_or_path, tuple):
array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r")
@@ -303,11 +421,9 @@ def reproject_single_block(a, array_or_path, block_info=None):
if array_or_path is None:
raise RuntimeError("array_or_path is not set")
- shape_out = block_info[None]["chunk-shape"][1:]
-
array, footprint = reproject_func(
array_in,
- wcs_in_cp,
+ wcs_in_sub,
wcs_out_sub,
shape_out=shape_out,
array_out=np.zeros(shape_out),
@@ -319,17 +435,17 @@ def reproject_single_block(a, array_or_path, block_info=None):
if broadcasted_parallelization:
array_out_dask = da.empty(shape_out, chunks=block_size)
+
+ # The input is reprojected in full for each output block, so it must
+ # not be chunked along the reprojected dimensions (which can have a
+ # different size from the output); only the broadcasted dimensions are
+ # chunked, matching array_out_dask block for block.
+ input_chunks = (1,) * (array_in.ndim - n_dim_reproject) + (-1,) * n_dim_reproject
if isinstance(array_in, da.core.Array):
- if array_in.chunksize != block_size:
- logger.info(
- f"Rechunking input dask array as chunks ({array_in.chunksize}) "
- "do not match block size ({block_size})"
- )
- array_in = array_in.rechunk(block_size)
+ array_in = array_in.rechunk(input_chunks)
else:
-
array_in = da.asarray(
- ArrayWrapper(array_in), name=str(uuid.uuid4()), chunks=block_size
+ ArrayWrapper(array_in), name=str(uuid.uuid4()), chunks=input_chunks
)
result = da.map_blocks(
@@ -380,8 +496,8 @@ def reproject_single_block(a, array_or_path, block_info=None):
array_out_dask = da.empty(shape_out, chunks=block_size)
else:
if broadcasting:
- chunks = (-1,) * (len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim)
- chunks += ("auto",) * wcs_in.low_level_wcs.pixel_n_dim
+ chunks = (-1,) * (len(shape_out) - n_dim_reproject)
+ chunks += ("auto",) * n_dim_reproject
rechunk_kwargs = {"chunks": chunks}
else:
rechunk_kwargs = {}
diff --git a/reproject/adaptive/_high_level.py b/reproject/adaptive/_high_level.py
index 9b3347502..2fede6f8e 100644
--- a/reproject/adaptive/_high_level.py
+++ b/reproject/adaptive/_high_level.py
@@ -30,6 +30,7 @@ def reproject_adaptive(
output_footprint=None,
return_footprint=True,
block_size=None,
+ non_reprojected_dims=None,
parallel=False,
return_type=None,
dask_method=None,
@@ -205,6 +206,16 @@ def reproject_adaptive(
the block size automatically determined. If ``block_size`` is not
specified or set to `None`, the reprojection will not be carried out in
blocks.
+ non_reprojected_dims : tuple, optional
+ Leading dimensions of the data that should not be reprojected but for
+ which a one-to-one mapping between input and output pixels is assumed.
+ This makes it possible to broadcast a reprojection over these dimensions
+ even when the input and output WCS have the same number of dimensions as
+ the data. The dimensions must be the leading ones, given as a tuple of
+ sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``).
+ This currently requires passing a ``block_size`` whose entries along
+ the reprojected dimensions match ``shape_out`` (optionally combined
+ with ``parallel`` to compute the blocks concurrently).
parallel : bool or int or str, optional
If `True`, the reprojection is carried out in parallel, and if a
positive integer, this specifies the number of threads to use.
@@ -253,6 +264,7 @@ def reproject_adaptive(
array_out=output_array,
parallel=parallel,
block_size=block_size,
+ non_reprojected_dims=non_reprojected_dims,
return_footprint=return_footprint,
output_footprint=output_footprint,
reproject_func_kwargs=dict(
diff --git a/reproject/conftest.py b/reproject/conftest.py
index 56f44dd7f..46c3bab19 100644
--- a/reproject/conftest.py
+++ b/reproject/conftest.py
@@ -55,6 +55,11 @@
def pytest_configure(config):
+
+ from astropy.utils.iers import conf
+
+ conf.auto_download = False
+
if ASTROPY_HEADER:
config.option.astropy_header = True
diff --git a/reproject/interpolation/_core.py b/reproject/interpolation/_core.py
index 1497827eb..28a93995b 100644
--- a/reproject/interpolation/_core.py
+++ b/reproject/interpolation/_core.py
@@ -11,7 +11,10 @@
def _validate_wcs(wcs_in, wcs_out, shape_in, shape_out):
if wcs_in.low_level_wcs.pixel_n_dim != wcs_out.low_level_wcs.pixel_n_dim:
- raise ValueError("Number of dimensions in input and output WCS should match")
+ raise ValueError(
+ "Number of dimensions in input and output WCS should match "
+ f"(got {wcs_in.low_level_wcs.pixel_n_dim} and {wcs_out.low_level_wcs.pixel_n_dim})"
+ )
elif len(shape_out) < wcs_out.low_level_wcs.pixel_n_dim:
raise ValueError("Too few dimensions in shape_out")
elif len(shape_in) < wcs_in.low_level_wcs.pixel_n_dim:
diff --git a/reproject/interpolation/_high_level.py b/reproject/interpolation/_high_level.py
index 7d293ab0e..876b63605 100644
--- a/reproject/interpolation/_high_level.py
+++ b/reproject/interpolation/_high_level.py
@@ -25,6 +25,7 @@ def reproject_interp(
output_footprint=None,
return_footprint=True,
block_size=None,
+ non_reprojected_dims=None,
parallel=False,
return_type=None,
dask_method=None,
@@ -101,6 +102,16 @@ def reproject_interp(
the block size automatically determined. If ``block_size`` is not
specified or set to `None`, the reprojection will not be carried out in
blocks.
+ non_reprojected_dims : tuple, optional
+ Leading dimensions of the data that should not be reprojected but for
+ which a one-to-one mapping between input and output pixels is assumed.
+ This makes it possible to broadcast a reprojection over these dimensions
+ even when the input and output WCS have the same number of dimensions as
+ the data. The dimensions must be the leading ones, given as a tuple of
+ sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``).
+ This currently requires passing a ``block_size`` whose entries along
+ the reprojected dimensions match ``shape_out`` (optionally combined
+ with ``parallel`` to compute the blocks concurrently).
parallel : bool or int or str, optional
If `True`, the reprojection is carried out in parallel, and if a
positive integer, this specifies the number of threads to use.
@@ -150,6 +161,7 @@ def reproject_interp(
array_out=output_array,
parallel=parallel,
block_size=block_size,
+ non_reprojected_dims=non_reprojected_dims,
return_footprint=return_footprint,
output_footprint=output_footprint,
reproject_func_kwargs=dict(
diff --git a/reproject/tests/test_non_reprojected_dims.py b/reproject/tests/test_non_reprojected_dims.py
new file mode 100644
index 000000000..81d17f15d
--- /dev/null
+++ b/reproject/tests/test_non_reprojected_dims.py
@@ -0,0 +1,189 @@
+# Licensed under a 3-clause BSD style license - see LICENSE.rst
+
+import numpy as np
+import pytest
+from astropy.wcs import WCS
+from numpy.testing import assert_allclose
+
+from reproject import reproject_adaptive, reproject_interp
+
+# Reprojection functions that support non_reprojected_dims. reproject_exact can
+# be added here once it gains support.
+REPROJECT_FUNCTIONS = [reproject_interp, reproject_adaptive]
+
+
+@pytest.fixture(params=REPROJECT_FUNCTIONS, ids=lambda func: func.__name__)
+def reproject_function(request):
+ return request.param
+
+
+def _spectral_cube_wcs(crval_dec, crval_freq):
+ wcs = WCS(naxis=3)
+ wcs.wcs.ctype = ["RA---TAN", "DEC--TAN", "FREQ"]
+ wcs.wcs.crpix = [10, 10, 1]
+ wcs.wcs.crval = [40.0, crval_dec, crval_freq]
+ wcs.wcs.cdelt = [-0.01, 0.01, 1e6]
+ return wcs
+
+
+def test_non_reprojected_dims(reproject_function):
+ # Reproject a cube where the input and output WCS have the same number of
+ # dimensions as the data, treating the leading (spectral) axis as a
+ # non-reprojected dimension. The result should match reprojecting each
+ # spectral slice independently with the corresponding 2D WCS, and in
+ # particular should not be affected by the (deliberately different) spectral
+ # part of the WCS.
+
+ data = np.arange(4 * 20 * 20, dtype=float).reshape((4, 20, 20))
+ wcs_in = _spectral_cube_wcs(0.0, 1e9)
+ wcs_out = _spectral_cube_wcs(0.02, 1e9 + 2e6)
+ shape_out = (4, 20, 20)
+
+ reference = np.empty_like(data)
+ for islice in range(data.shape[0]):
+ reference[islice], _ = reproject_function(
+ (data[islice], wcs_in.celestial), wcs_out.celestial, shape_out=(20, 20)
+ )
+
+ array_out, _ = reproject_function(
+ (data, wcs_in),
+ wcs_out,
+ shape_out=shape_out,
+ non_reprojected_dims=(0,),
+ parallel=True,
+ block_size=(20, 20),
+ )
+
+ assert_allclose(array_out, reference, equal_nan=True)
+
+
+def test_non_reprojected_dims_invalid_order(reproject_function):
+ data = np.ones((4, 20, 20))
+ wcs = _spectral_cube_wcs(0.0, 1e9)
+ with pytest.raises(ValueError, match="increasing sequentially from zero"):
+ reproject_function((data, wcs), wcs, shape_out=(4, 20, 20), non_reprojected_dims=(1,))
+
+
+def test_non_reprojected_dims_inconsistent_with_wcs(reproject_function):
+ # The WCS already has fewer dimensions than the data, but the shortfall does
+ # not match the number of non_reprojected_dims requested.
+ data = np.ones((3, 4, 20, 20))
+ wcs = WCS(naxis=2)
+ wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
+ with pytest.raises(ValueError, match="does not match the number of non_reprojected_dims"):
+ reproject_function(
+ (data, wcs),
+ wcs,
+ shape_out=(3, 4, 20, 20),
+ non_reprojected_dims=(0,),
+ parallel=True,
+ block_size=(20, 20),
+ )
+
+
+@pytest.mark.parametrize(
+ "kwargs", [{}, {"parallel": True}, {"parallel": True, "block_size": (4, 10, 10)}]
+)
+def test_non_reprojected_dims_unsupported_mode(reproject_function, kwargs):
+ # non_reprojected_dims with a full-dimensional WCS is only supported when
+ # parallelizing over the non-reprojected dimensions; other modes should
+ # raise rather than silently reprojecting the non-reprojected axis.
+ data = np.ones((4, 20, 20))
+ wcs_in = _spectral_cube_wcs(0.0, 1e9)
+ wcs_out = _spectral_cube_wcs(0.02, 1e9 + 2e6)
+ with pytest.raises(NotImplementedError, match="non_reprojected_dims"):
+ reproject_function(
+ (data, wcs_in), wcs_out, shape_out=(4, 20, 20), non_reprojected_dims=(0,), **kwargs
+ )
+
+
+def _drifting_cube_wcs(drift):
+ # 3D WCS over (time, y, x) where the celestial axes are coupled to the time
+ # pixel axis via the PC matrix, so the celestial coordinates drift along the
+ # time axis (while the time axis itself stays independent of the celestial
+ # axes). A drift of zero gives celestial coordinates that are constant in
+ # time.
+ wcs = WCS(naxis=3)
+ wcs.wcs.ctype = ["RA---TAN", "DEC--TAN", "TIME"]
+ wcs.wcs.crpix = [15, 15, 1]
+ wcs.wcs.crval = [40.0, 0.0, 0.0]
+ wcs.wcs.cdelt = [-0.01, 0.01, 1.0]
+ wcs.wcs.pc = [[1.0, 0.0, drift], [0.0, 1.0, drift], [0.0, 0.0, 1.0]]
+ return wcs
+
+
+def test_non_reprojected_dims_time_varying_wcs(reproject_function):
+ # Motivating use case: a cube whose celestial coordinates drift along a
+ # non-reprojected (time) axis, reprojected to a cube where they do not. Each
+ # time slice must be reprojected using its own (drifted) celestial WCS, which
+ # should match reprojecting each slice independently with the WCS sliced at
+ # that time.
+ n_time = 5
+ shape_out = (n_time, 30, 30)
+ wcs_in = _drifting_cube_wcs(drift=0.6)
+ wcs_out = _drifting_cube_wcs(drift=0.0)
+
+ data = np.random.default_rng(0).random((n_time, 30, 30))
+
+ array_out, _ = reproject_function(
+ (data, wcs_in),
+ wcs_out,
+ shape_out=shape_out,
+ non_reprojected_dims=(0,),
+ parallel=True,
+ block_size=(30, 30),
+ )
+
+ reference = np.empty_like(data)
+ for itime in range(n_time):
+ reference[itime], _ = reproject_function(
+ (data[itime], wcs_in[itime]), wcs_out[itime], shape_out=(30, 30)
+ )
+
+ assert_allclose(array_out, reference, equal_nan=True)
+
+ # Make sure the drift is actually exercised (otherwise the test would pass
+ # trivially even if a single WCS were reused for all slices).
+ assert not np.allclose(np.nan_to_num(reference[0]), np.nan_to_num(reference[-1]))
+
+
+@pytest.mark.filterwarnings("ignore::erfa.ErfaWarning")
+def test_non_reprojected_dims_matches_full_reproject():
+ # The full N-D reproject transforms the TIME axis through world coordinates,
+ # which emits an incidental ERFA "dubious year" warning for this synthetic
+ # epoch; that is unrelated to what we are checking here.
+ # Cross-check the non_reprojected_dims fast path against a full N-D reproject
+ # (with no non_reprojected_dims), which is a completely independent code path.
+ # Because the time axis maps one-to-one between the input and output WCS, the
+ # two must agree. This only applies to reproject_interp, since
+ # reproject_adaptive does not support a full N-D reproject of a cube with a
+ # coupled WCS (it is celestial-2D only).
+ n_time = 5
+ shape_out = (n_time, 30, 30)
+ wcs_in = _drifting_cube_wcs(drift=0.6)
+ wcs_out = _drifting_cube_wcs(drift=0.0)
+
+ data = np.random.default_rng(0).random((n_time, 30, 30))
+
+ array_out, _ = reproject_interp(
+ (data, wcs_in),
+ wcs_out,
+ shape_out=shape_out,
+ non_reprojected_dims=(0,),
+ parallel=True,
+ block_size=(30, 30),
+ )
+
+ reference_full, _ = reproject_interp((data, wcs_in), wcs_out, shape_out=shape_out)
+
+ assert_allclose(array_out, reference_full, equal_nan=True, atol=1e-8)
+
+
+def test_non_reprojected_dims_all_dimensions(reproject_function):
+ # Marking every dimension as non-reprojected leaves nothing to reproject and
+ # should raise a clear error rather than failing obscurely further down.
+ data = np.ones((20, 20))
+ wcs = WCS(naxis=2)
+ wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
+ with pytest.raises(ValueError, match="at least one dimension"):
+ reproject_function((data, wcs), wcs, shape_out=(20, 20), non_reprojected_dims=(0, 1))
diff --git a/tox.ini b/tox.ini
index cc02c2352..acca7fe5a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -21,7 +21,7 @@ deps =
numpy121: numpy==1.21.*
oldestdeps: numpy==1.23.*
- oldestdeps: astropy==5.0.*
+ oldestdeps: astropy==6.0.*
oldestdeps: astropy-healpix==1.0.*
oldestdeps: scipy==1.9.*
oldestdeps: dask==2024.4.*