Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:

- name: install dependencies
run: |
python -m pip install ".[test]"
python -m pip install . --group test

- name: install astropy-xarray
run: python -m pip install --no-deps .
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

# What's New

## **Unreleased**

- Added supported for {py:attr}`astropy.coordinates.SkyCoord.obstime`.
- Added supported for {py:attr}`astropy.coordinates.SkyCoord.equinox`.

## 0.2.1 (18 Apr 2026)

- Added support for `astropy==7.2.0` sky coordinates.
Expand Down
7 changes: 1 addition & 6 deletions astropy_xarray/coordinates/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,6 @@ def load_frame(frame_dict: dict, with_data: bool = False) -> BaseCoordinateFrame
frame = HCRS(
obstime=load_optional_object(Time, frame_dict["obstime"]), **kwargs
)
case "itrs":
frame = ITRS(
location=load_optional_earthlocation(frame_dict["location"]),
**kwargs,
)
case "altaz":
frame = AltAz(
obstime=load_optional_object(Time, frame_dict["obstime"]),
Expand Down Expand Up @@ -274,7 +269,7 @@ def load_representation(
frame_name: str | None,
data: dict[str, np.ndarray],
) -> BaseRepresentation:
RepresentationClass = representation.REPRESENTATION_CLASSES.get(representation_type)
RepresentationClass = representation.REPRESENTATION_CLASSES[representation_type]
DifferentialClass = representation.DIFFERENTIAL_CLASSES.get(differential_type)

if frame_name is None:
Expand Down
28 changes: 22 additions & 6 deletions astropy_xarray/coordinates/sky_coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import numpy as np
import xarray as xr
from astropy.coordinates import SkyCoord
from astropy.time import Time
from astropy.utils import ShapedLikeNDArray

from astropy_xarray.coordinates.core import dump_time, load_optional_object
from astropy_xarray.coordinates.frame import dump_frame, load_frame, load_representation

_ArrayLike = list | np.ndarray | ShapedLikeNDArray
Expand Down Expand Up @@ -102,11 +104,15 @@ def skycoord_to_dataset(
quantified dataset.
"""
return xr.Dataset(
coords=coords if coords is not None else None,
coords=coords,
data_vars=_skycoord_to_dataarrays(skycoord, coords),
attrs=dict(
frame=dump_frame(skycoord.frame),
),
attrs={"frame": dump_frame(skycoord.frame)}
| {
attr: dump_time(value)
for attr in ("obstime", "equinox")
if (value := getattr(skycoord, attr)) is not None
and not hasattr(skycoord.frame, attr)
},
)


Expand All @@ -128,5 +134,15 @@ def dataset_to_skycoord(ds: xr.Dataset) -> SkyCoord:
{k: v.data for k, v in dsq.data_vars.items()},
)
frame.representation_type = ds.attrs["frame"]["representation_type"]
frame.differential_type = ds.attrs["frame"]["differential_type"]
return SkyCoord(frame)
frame.differential_type = ds.attrs["frame"].get("differential_type")

kwargs = {
attr: value
for attr in ("obstime", "equinox")
if not hasattr(frame, attr)
and (value := load_optional_object(Time, ds.attrs.get(attr))) is not None
}
return SkyCoord(
frame,
**kwargs,
)
5 changes: 4 additions & 1 deletion docs/examples/skycoord.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
" representation_type=\"unitspherical\",\n",
" differential_type=\"unitsphericalcoslat\",\n",
" ),\n",
" obstime=\"J2000\",\n",
" equinox=\"J2000\",\n",
")\n",
"display(sky_direction)\n",
"skycoord_to_dataset(\n",
Expand Down Expand Up @@ -77,6 +79,7 @@
" representation_type=\"spherical\",\n",
" differential_type=\"unitsphericalcoslat\",\n",
" ),\n",
" obstime=\"J2000\",\n",
")\n",
"display(sky_position)\n",
"display(\n",
Expand Down Expand Up @@ -110,7 +113,7 @@
"# frame specific info stored as dataset attribute\n",
"assert load_frame(ds.attrs[\"frame\"]) == sky_position.replicate_without_data()\n",
"\n",
"# dataset\n",
"# restored skycoord matches orginal\n",
"result = dataset_to_skycoord(ds)\n",
"display(result)\n",
"np.testing.assert_array_equal(result.ra, sky_position.ra)\n",
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ requires-python = ">=3.10"
dependencies = [
"numpy>=1.23,<2.3.0",
"xarray>=2022.06.0,<=2025.4.0",
"astropy>=6.1.0",
"astropy>=6.1.0,<8.0.0",
"pandas>=2.3.0,<3.0.0",
]
dynamic = ["version"]
Expand All @@ -38,7 +38,7 @@ documentation = "https://astropy-xarray.readthedocs.io/en/stable"
repository = "https://github.com/calgray/astropy-xarray"
issues = "https://github.com/calgray/astropy-xarray/issues"

[project.optional-dependencies]
[dependency-groups]
test = [
"pytest>=8.0",
"pytest-cov",
Expand Down
98 changes: 98 additions & 0 deletions tests/test_sky_coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
dataset_to_skycoord,
skycoord_to_dataset,
)
from astropy_xarray.coordinates.core import dump_time
from astropy_xarray.coordinates.sky_coord import (
_skycoord_differential_component_names,
_skycoord_representation_component_names,
Expand Down Expand Up @@ -346,3 +347,100 @@ def test_skycoord_roundtrip(
assert actual.is_equivalent_frame(expected)
assert ds.coords.equals(xr.Coordinates(dict(coords)))
np.testing.assert_array_equal(actual, expected)


def test_skycoord_obstime_empty():
empty = SkyCoord(ra=[[0.1], [0.2]] * u.deg, dec=[[0.5], [0.7]] * u.deg)
assert empty.obstime is None
assert "obstime" not in skycoord_to_dataset(empty).attrs
ds = skycoord_to_dataset(empty)
assert "obstime" not in ds.attrs
assert dataset_to_skycoord(ds).obstime is None


def test_skycoord_obstime_optional():
value = SkyCoord(
ra=[[0.1], [0.2]] * u.deg, dec=[[0.5], [0.7]] * u.deg, obstime="J2000"
)
assert value.obstime == Time("J2000")
ds = skycoord_to_dataset(value)
assert ds.attrs["obstime"] == dump_time(Time("J2000"))
assert dataset_to_skycoord(ds).obstime == Time("J2000")


def test_skycoord_obstime_frame_default():
value = SkyCoord(
ra=[[0.1], [0.2]] * u.deg,
dec=[[0.5], [0.7]] * u.deg,
frame=FK4(),
)
assert value.obstime == Time("B1950")
ds = skycoord_to_dataset(value)
assert "obstime" not in ds.attrs
assert dataset_to_skycoord(ds).obstime == Time("B1950")


def test_skycoord_obstime_frame_override():
_ = SkyCoord(
ra=[[0.1], [0.2]] * u.deg,
dec=[[0.5], [0.7]] * u.deg,
frame=FK5(),
obstime="J2000",
)
with pytest.raises(ValueError):
_ = SkyCoord(
ra=[[0.1], [0.2]] * u.deg,
dec=[[0.5], [0.7]] * u.deg,
frame=FK4(),
obstime="J2000",
)


def test_skycoord_equinox_empty():
empty = SkyCoord(ra=[[0.1], [0.2]] * u.deg, dec=[[0.5], [0.7]] * u.deg)
assert empty.equinox is None
assert "equinox" not in skycoord_to_dataset(empty).attrs
ds = skycoord_to_dataset(empty)
assert "equinox" not in ds.attrs
assert dataset_to_skycoord(ds).equinox is None


def test_skycoord_equinox_optional():
value = SkyCoord(
ra=[[0.1], [0.2]] * u.deg, dec=[[0.5], [0.7]] * u.deg, equinox="J2000"
)
assert value.equinox == Time("J2000")
assert skycoord_to_dataset(value).attrs["equinox"] == dump_time(Time("J2000"))
ds = skycoord_to_dataset(value)
assert ds.attrs["equinox"] == dump_time(Time("J2000"))
assert dataset_to_skycoord(ds).equinox == Time("J2000")


def test_skycoord_equinox_frame_default():
value = SkyCoord(
ra=[[0.1], [0.2]] * u.deg,
dec=[[0.5], [0.7]] * u.deg,
frame=FK4(),
)
assert value.equinox == Time("B1950")
assert "equinox" not in skycoord_to_dataset(value).attrs
ds = skycoord_to_dataset(value)
assert "equinox" not in ds.attrs
assert dataset_to_skycoord(ds).equinox == Time("B1950")


def test_skycoord_equinox_frame_override():
with pytest.raises(ValueError):
_ = SkyCoord(
ra=[[0.1], [0.2]] * u.deg,
dec=[[0.5], [0.7]] * u.deg,
frame=FK5(),
equinox="J2000",
)
with pytest.raises(ValueError):
_ = SkyCoord(
ra=[[0.1], [0.2]] * u.deg,
dec=[[0.5], [0.7]] * u.deg,
frame=FK4(),
equinox="J2000",
)
Loading