diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e2a496b9..dd636a85 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,7 +78,7 @@ jobs: - name: install dependencies run: | - python -m pip install ".[test]" + python -m pip install . --group test - name: install astropy-xarray run: python -m pip install --no-deps . diff --git a/CHANGELOG.md b/CHANGELOG.md index 74bc0a66..8ba831e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ # What's New +## **Unreleased** + +- Added supported for {py:attr}`astropy.coordinates.SkyCoord.obstime`. +- Added supported for {py:attr}`astropy.coordinates.SkyCoord.equinox`. + ## 0.2.1 (18 Apr 2026) - Added support for `astropy==7.2.0` sky coordinates. diff --git a/astropy_xarray/coordinates/frame.py b/astropy_xarray/coordinates/frame.py index b1e438f2..76726784 100644 --- a/astropy_xarray/coordinates/frame.py +++ b/astropy_xarray/coordinates/frame.py @@ -176,11 +176,6 @@ def load_frame(frame_dict: dict, with_data: bool = False) -> BaseCoordinateFrame frame = HCRS( obstime=load_optional_object(Time, frame_dict["obstime"]), **kwargs ) - case "itrs": - frame = ITRS( - location=load_optional_earthlocation(frame_dict["location"]), - **kwargs, - ) case "altaz": frame = AltAz( obstime=load_optional_object(Time, frame_dict["obstime"]), @@ -274,7 +269,7 @@ def load_representation( frame_name: str | None, data: dict[str, np.ndarray], ) -> BaseRepresentation: - RepresentationClass = representation.REPRESENTATION_CLASSES.get(representation_type) + RepresentationClass = representation.REPRESENTATION_CLASSES[representation_type] DifferentialClass = representation.DIFFERENTIAL_CLASSES.get(differential_type) if frame_name is None: diff --git a/astropy_xarray/coordinates/sky_coord.py b/astropy_xarray/coordinates/sky_coord.py index 2369e3cf..b1afee46 100644 --- a/astropy_xarray/coordinates/sky_coord.py +++ b/astropy_xarray/coordinates/sky_coord.py @@ -5,8 +5,10 @@ import numpy as np import xarray as xr from astropy.coordinates import SkyCoord +from astropy.time import Time from astropy.utils import ShapedLikeNDArray +from astropy_xarray.coordinates.core import dump_time, load_optional_object from astropy_xarray.coordinates.frame import dump_frame, load_frame, load_representation _ArrayLike = list | np.ndarray | ShapedLikeNDArray @@ -102,11 +104,15 @@ def skycoord_to_dataset( quantified dataset. """ return xr.Dataset( - coords=coords if coords is not None else None, + coords=coords, data_vars=_skycoord_to_dataarrays(skycoord, coords), - attrs=dict( - frame=dump_frame(skycoord.frame), - ), + attrs={"frame": dump_frame(skycoord.frame)} + | { + attr: dump_time(value) + for attr in ("obstime", "equinox") + if (value := getattr(skycoord, attr)) is not None + and not hasattr(skycoord.frame, attr) + }, ) @@ -128,5 +134,15 @@ def dataset_to_skycoord(ds: xr.Dataset) -> SkyCoord: {k: v.data for k, v in dsq.data_vars.items()}, ) frame.representation_type = ds.attrs["frame"]["representation_type"] - frame.differential_type = ds.attrs["frame"]["differential_type"] - return SkyCoord(frame) + frame.differential_type = ds.attrs["frame"].get("differential_type") + + kwargs = { + attr: value + for attr in ("obstime", "equinox") + if not hasattr(frame, attr) + and (value := load_optional_object(Time, ds.attrs.get(attr))) is not None + } + return SkyCoord( + frame, + **kwargs, + ) diff --git a/docs/examples/skycoord.ipynb b/docs/examples/skycoord.ipynb index e049ad7a..26a9b900 100644 --- a/docs/examples/skycoord.ipynb +++ b/docs/examples/skycoord.ipynb @@ -40,6 +40,8 @@ " representation_type=\"unitspherical\",\n", " differential_type=\"unitsphericalcoslat\",\n", " ),\n", + " obstime=\"J2000\",\n", + " equinox=\"J2000\",\n", ")\n", "display(sky_direction)\n", "skycoord_to_dataset(\n", @@ -77,6 +79,7 @@ " representation_type=\"spherical\",\n", " differential_type=\"unitsphericalcoslat\",\n", " ),\n", + " obstime=\"J2000\",\n", ")\n", "display(sky_position)\n", "display(\n", @@ -110,7 +113,7 @@ "# frame specific info stored as dataset attribute\n", "assert load_frame(ds.attrs[\"frame\"]) == sky_position.replicate_without_data()\n", "\n", - "# dataset\n", + "# restored skycoord matches orginal\n", "result = dataset_to_skycoord(ds)\n", "display(result)\n", "np.testing.assert_array_equal(result.ra, sky_position.ra)\n", diff --git a/pyproject.toml b/pyproject.toml index 10816ed3..fe0fdb23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ requires-python = ">=3.10" dependencies = [ "numpy>=1.23,<2.3.0", "xarray>=2022.06.0,<=2025.4.0", - "astropy>=6.1.0", + "astropy>=6.1.0,<8.0.0", "pandas>=2.3.0,<3.0.0", ] dynamic = ["version"] @@ -38,7 +38,7 @@ documentation = "https://astropy-xarray.readthedocs.io/en/stable" repository = "https://github.com/calgray/astropy-xarray" issues = "https://github.com/calgray/astropy-xarray/issues" -[project.optional-dependencies] +[dependency-groups] test = [ "pytest>=8.0", "pytest-cov", diff --git a/tests/test_sky_coord.py b/tests/test_sky_coord.py index 35bbb362..be748d74 100644 --- a/tests/test_sky_coord.py +++ b/tests/test_sky_coord.py @@ -49,6 +49,7 @@ dataset_to_skycoord, skycoord_to_dataset, ) +from astropy_xarray.coordinates.core import dump_time from astropy_xarray.coordinates.sky_coord import ( _skycoord_differential_component_names, _skycoord_representation_component_names, @@ -346,3 +347,100 @@ def test_skycoord_roundtrip( assert actual.is_equivalent_frame(expected) assert ds.coords.equals(xr.Coordinates(dict(coords))) np.testing.assert_array_equal(actual, expected) + + +def test_skycoord_obstime_empty(): + empty = SkyCoord(ra=[[0.1], [0.2]] * u.deg, dec=[[0.5], [0.7]] * u.deg) + assert empty.obstime is None + assert "obstime" not in skycoord_to_dataset(empty).attrs + ds = skycoord_to_dataset(empty) + assert "obstime" not in ds.attrs + assert dataset_to_skycoord(ds).obstime is None + + +def test_skycoord_obstime_optional(): + value = SkyCoord( + ra=[[0.1], [0.2]] * u.deg, dec=[[0.5], [0.7]] * u.deg, obstime="J2000" + ) + assert value.obstime == Time("J2000") + ds = skycoord_to_dataset(value) + assert ds.attrs["obstime"] == dump_time(Time("J2000")) + assert dataset_to_skycoord(ds).obstime == Time("J2000") + + +def test_skycoord_obstime_frame_default(): + value = SkyCoord( + ra=[[0.1], [0.2]] * u.deg, + dec=[[0.5], [0.7]] * u.deg, + frame=FK4(), + ) + assert value.obstime == Time("B1950") + ds = skycoord_to_dataset(value) + assert "obstime" not in ds.attrs + assert dataset_to_skycoord(ds).obstime == Time("B1950") + + +def test_skycoord_obstime_frame_override(): + _ = SkyCoord( + ra=[[0.1], [0.2]] * u.deg, + dec=[[0.5], [0.7]] * u.deg, + frame=FK5(), + obstime="J2000", + ) + with pytest.raises(ValueError): + _ = SkyCoord( + ra=[[0.1], [0.2]] * u.deg, + dec=[[0.5], [0.7]] * u.deg, + frame=FK4(), + obstime="J2000", + ) + + +def test_skycoord_equinox_empty(): + empty = SkyCoord(ra=[[0.1], [0.2]] * u.deg, dec=[[0.5], [0.7]] * u.deg) + assert empty.equinox is None + assert "equinox" not in skycoord_to_dataset(empty).attrs + ds = skycoord_to_dataset(empty) + assert "equinox" not in ds.attrs + assert dataset_to_skycoord(ds).equinox is None + + +def test_skycoord_equinox_optional(): + value = SkyCoord( + ra=[[0.1], [0.2]] * u.deg, dec=[[0.5], [0.7]] * u.deg, equinox="J2000" + ) + assert value.equinox == Time("J2000") + assert skycoord_to_dataset(value).attrs["equinox"] == dump_time(Time("J2000")) + ds = skycoord_to_dataset(value) + assert ds.attrs["equinox"] == dump_time(Time("J2000")) + assert dataset_to_skycoord(ds).equinox == Time("J2000") + + +def test_skycoord_equinox_frame_default(): + value = SkyCoord( + ra=[[0.1], [0.2]] * u.deg, + dec=[[0.5], [0.7]] * u.deg, + frame=FK4(), + ) + assert value.equinox == Time("B1950") + assert "equinox" not in skycoord_to_dataset(value).attrs + ds = skycoord_to_dataset(value) + assert "equinox" not in ds.attrs + assert dataset_to_skycoord(ds).equinox == Time("B1950") + + +def test_skycoord_equinox_frame_override(): + with pytest.raises(ValueError): + _ = SkyCoord( + ra=[[0.1], [0.2]] * u.deg, + dec=[[0.5], [0.7]] * u.deg, + frame=FK5(), + equinox="J2000", + ) + with pytest.raises(ValueError): + _ = SkyCoord( + ra=[[0.1], [0.2]] * u.deg, + dec=[[0.5], [0.7]] * u.deg, + frame=FK4(), + equinox="J2000", + )