diff --git a/sectionate/gridutils.py b/sectionate/gridutils.py index 3ad3e6f..19cccfc 100644 --- a/sectionate/gridutils.py +++ b/sectionate/gridutils.py @@ -25,17 +25,21 @@ def get_geo_corners(grid): vorticity coordinates at 'outer' and 'right' positions, respectively.") coords = grid._ds.coords - return { + + geo_coord_dict = { axis: [ coords[c] for c in coords if ( - (geoc in c) and + (geoc in c.lower()) and (dims["X"] in coords[c].dims) and (dims["Y"] in coords[c].dims) ) - ][0] + ] for axis, geoc in zip(["X", "Y"], ["lon", "lat"]) } + if any([len(v) == 0 for (k,v) in geo_coord_dict.items()]): + raise ValueError("""grid._ds must contain two-dimensional ("X", "Y") coordinates including the strings "lon" and "lat", consistent with grid.coords.""") + return {k:v[0] for (k,v) in geo_coord_dict.items()} def coord_dict(grid): """ diff --git a/sectionate/tests/test_utils.py b/sectionate/tests/test_utils.py index ae3db07..6c83e36 100644 --- a/sectionate/tests/test_utils.py +++ b/sectionate/tests/test_utils.py @@ -1,4 +1,42 @@ +import pytest + def test_load_section(): from sectionate.utils import get_all_section_names, load_section section_names = get_all_section_names() - load_section(section_names[0]) \ No newline at end of file + load_section(section_names[0]) + +def test_get_geo_corners(): + import numpy as np + import xarray as xr + import xgcm + coords = { + "xh": np.arange(0, 10), + "yh": np.arange(0, 10), + "xq": np.arange(0, 11), + "yq": np.arange(0, 11), + } + + # Without coordinates + ds = xr.Dataset(coords=coords) + grid = xgcm.Grid( + ds = ds, + coords={ + "X":{"center":"xh", "outer":"xq"}, + "Y":{"center":"yh", "outer":"yq"} + }, + boundary={"X":"periodic", "Y":"extend"}, + autoparse_metadata=False + ) + + from sectionate.gridutils import get_geo_corners + + # Fail when (lon, lat) coordinates are missing + with pytest.raises(ValueError): + get_geo_corners(grid) + + # Pass when coordinates are there + grid._ds = grid._ds.assign_coords({ + "lon_c": xr.DataArray(xr.broadcast(grid._ds.xq, grid._ds.yq)[0]), + "lat_c": xr.DataArray(xr.broadcast(grid._ds.xq, grid._ds.yq)[1]) + }) + get_geo_corners(grid) \ No newline at end of file