From 9c5bfc9609e07b35808b26d3b18ab66639c46033 Mon Sep 17 00:00:00 2001 From: Anthony Meza <64243783+anthony-meza@users.noreply.github.com> Date: Tue, 24 Feb 2026 10:02:51 +0000 Subject: [PATCH 1/4] adding "allow_rechunk" option this commit incorporates an optional "allow_rechunk" option for collect_budgets and budget_fill_dict --- .gitignore | 2 +- xbudget/collect.py | 68 ++++++++++++++++++++++++++++++++-------------- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index d313308..59cec8f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ __pycache__ .ipynb_checkpoints -/data/data +/data/* diff --git a/xbudget/collect.py b/xbudget/collect.py index 5c8feb4..f3eb512 100644 --- a/xbudget/collect.py +++ b/xbudget/collect.py @@ -158,7 +158,7 @@ def _deep_search(b, new_b={}, k_last=None): _deep_search(v, new_b=new_b, k_last=k) return new_b -def collect_budgets(ds, xbudget_dict): +def collect_budgets(ds, xbudget_dict, allow_rechunk = True): """Fills xbudget dictionary with all tracer content tendencies Parameters @@ -183,9 +183,9 @@ def collect_budgets(ds, xbudget_dict): for eq, v in xbudget_dict.items(): for side in ["lhs", "rhs"]: if side in v: - budget_fill_dict(ds, v[side], f"{eq}_{side}") + budget_fill_dict(ds, v[side], f"{eq}_{side}", allow_rechunk = allow_rechunk) -def budget_fill_dict(data, xbudget_dict, namepath): +def budget_fill_dict(data, xbudget_dict, namepath, allow_rechunk = True): """Recursively fill xbudget dictionary Parameters @@ -216,7 +216,7 @@ def budget_fill_dict(data, xbudget_dict, namepath): op_list = [] for k_term, v_term in v.items(): if isinstance(v_term, dict): # recursive call to get this variable - v_term_recursive = budget_fill_dict(data, v_term, f"{namepath}_{k}_{k_term}") + v_term_recursive = budget_fill_dict(data, v_term, f"{namepath}_{k}_{k_term}", allow_rechunk = allow_rechunk) if v_term_recursive is not None: op_list.append(v_term_recursive) elif v_term.get("var") is not None and v_term.get("var") not in ds: @@ -287,6 +287,7 @@ def budget_fill_dict(data, xbudget_dict, namepath): if var_pref is None: var_pref = var.copy() + if k == "difference": if grid is not None: staggered_axes = { @@ -294,26 +295,51 @@ def budget_fill_dict(data, xbudget_dict, namepath): for pos,c in ax.coords.items() if pos!="center" } - v_term = [v_term for k_term,v_term in v.items() if k_term!="var"][0] - if v_term not in ds: - warnings.warn(f"Variable {v_term} is missing from the dataset `ds`, so it is being skipped. To suppress this warning, remove {v_term} from the `xbudget_dict`.") - continue - candidate_axes = [axn for (axn,c) in staggered_axes.items() if c in ds[v_term].dims] - if len(candidate_axes) == 1: - axis = candidate_axes[0] - else: - raise ValueError("Flux difference inconsistent with finite volume discretization.") - var = grid.diff(ds[v_term].fillna(0.), axis) - var_name = f"{namepath}_difference" - var = var.rename(var_name) - var_provenance = v_term - var.attrs["provenance"] = var_provenance - ds[var_name] = var - if var_pref is None: - var_pref = var.copy() + v_term = [v_term for k_term,v_term in v.items() if k_term!="var"][0] + if v_term not in ds: + warnings.warn(f"Variable {v_term} is missing from the dataset `ds`, so it is being skipped. To suppress this warning, remove {v_term} from the `xbudget_dict`.") + continue + + candidate_axes = [axn for (axn,c) in staggered_axes.items() if c in ds[v_term].dims] + if len(candidate_axes) == 1: + axis = candidate_axes[0] + else: + raise ValueError("Flux difference inconsistent with finite volume discretization.") + + if allow_rechunk: #NEW CODE + try: #extract original chunks when possible + #not using ds[v_term] since it may not have the non-staggered dimension chunks. + original_chunks = dict(ds.chunksizes) + except Exception: + warnings.warn("Dataset chunks are inconsistent; using unify_chunks()", UserWarning) + original_chunks = dict(ds.unify_chunks().chunksizes) + + # Find the staggered dimension for the given axis in the DataArray + axis_dim = [d for d in ds[v_term].dims if d in grid.axes[axis].coords.values()] + if len(axis_dim) != 1: + raise ValueError(f"Expected to find one dimension for axis '{axis}' in variable '{v_term}', but found {len(dims_for_axis)}: {dims_for_axis}") + axis_dim = axis_dim[0] + + # Temporarily rechunk to put the difference dim in a single chunk, all other chunks are auto. + temporary_chunks = {axis_dim: -1, **{d: "auto" for d in ds[v_term].dims if d != axis_dim}} + var = grid.diff(ds[v_term].chunk(temporary_chunks).fillna(0.0), axis=axis) + # Attempt original chunking for preserved dimensions + var = var.chunk({d: original_chunks.get(d, var.chunksizes[d]) for d in var.dims}) + else: #OLD CODE + var = grid.diff(ds[v_term].fillna(0.), axis) + + var_name = f"{namepath}_difference" + var = var.rename(var_name) + var_provenance = v_term + var.attrs["provenance"] = var_provenance + ds[var_name] = var + if var_pref is None: + var_pref = var.copy() else: raise ValueError("Input `ds` must be `xgcm.Grid` instance if using `difference` operations.") + + return var_pref def get_vars(xbudget_dict, terms): From 61b806dcfa70be92a2e5987367cba11cc4c25e4d Mon Sep 17 00:00:00 2001 From: Anthony Meza <64243783+anthony-meza@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:41:07 +0000 Subject: [PATCH 2/4] adding test "test_budget_fill_dict_allow_rechunk" that tests the "allow_rechunk" option --- xbudget/tests/test_utilities.py | 68 ++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/xbudget/tests/test_utilities.py b/xbudget/tests/test_utilities.py index 1ba5388..1d39902 100644 --- a/xbudget/tests/test_utilities.py +++ b/xbudget/tests/test_utilities.py @@ -2,6 +2,8 @@ import numpy as np import xarray as xr import copy +import xgcm +import dask.array as da from xbudget.collect import ( aggregate, disaggregate, @@ -437,4 +439,68 @@ def test_budget_fill_dict_numeric_values(self): result = budget_fill_dict(ds, xbudget_dict, "heat_rhs") assert result is not None - assert np.allclose(ds["heat_rhs_product"].values, 2.0) \ No newline at end of file + assert np.allclose(ds["heat_rhs_product"].values, 2.0) + + def test_budget_fill_dict_allow_rechunk(self): + """Test the allow_rechunk option for the difference operation.""" + # Create a dataset with non-uniform chunks on the staggered grid, + # which would cause issues for xgcm.grid.diff + flux_data = da.from_array(np.random.rand(5, 3), chunks=((2, 2, 1), 3)) + ds_chunked = xr.Dataset( + { + "var": xr.DataArray( + flux_data, + dims=("x_g", "y_c"), + ) + }, + coords={ + "x_g": np.arange(5), + "x_c": np.arange(4) + 0.5, + "y_c": np.arange(3), + }, + ) + + grid_params = { + "coords": {"X": {"center": "x_c", "left": "x_g"}}, + "periodic": False, + "autoparse_metadata": False, + } + + xbudget_dict = { + "var": None, + "difference": {"var_diff": "var", "var": None}, + } + + # 1. Test that allow_rechunk=False raises an error when passing a chunked + # dataset through budget_fill_dict + with pytest.raises(ValueError): + grid_fail = xgcm.Grid(ds_chunked.copy(deep=True), **grid_params) + budget_fill_dict( + grid_fail, + copy.deepcopy(xbudget_dict), + "tendency_rhs", + allow_rechunk=False, + ) + + # 2. Test that shows allow_rechunk=True works + grid_success = xgcm.Grid(ds_chunked.copy(deep=True), **grid_params) + budget_fill_dict( + grid_success, + copy.deepcopy(xbudget_dict), + "tendency_rhs", + allow_rechunk=True, + ) + tendency_rechunked = grid_success._ds["tendency_rhs_difference"] + + # 3. Compare with a correct result from an unchunked array + grid_unchunked = xgcm.Grid(ds_chunked.chunk(-1), **grid_params) + budget_fill_dict( + grid_unchunked, + copy.deepcopy(xbudget_dict), + "tendency_rhs", + allow_rechunk=False, + ) + tendency_correct = grid_unchunked._ds["tendency_rhs_difference"] + + # The numerical results should be identical + xr.testing.assert_allclose(tendency_rechunked, tendency_correct) From 45eac450a485fde2d34632c2b2f31ad99e2d59db Mon Sep 17 00:00:00 2001 From: Henri Drake Date: Tue, 24 Feb 2026 08:05:20 -0800 Subject: [PATCH 3/4] Clean linting and manually add chunks to small example dataset By explicitly adding chunks to the example MOM6 dataset, we should be able to better catch chunk-related problems in the future. --- .../MOM6_budget_examples_mass_heat_salt.ipynb | 666 ++++++++++++++---- examples/load_example_model_grid.py | 1 + xbudget/collect.py | 46 +- 3 files changed, 550 insertions(+), 163 deletions(-) diff --git a/examples/MOM6_budget_examples_mass_heat_salt.ipynb b/examples/MOM6_budget_examples_mass_heat_salt.ipynb index 9316595..8a58220 100644 --- a/examples/MOM6_budget_examples_mass_heat_salt.ipynb +++ b/examples/MOM6_budget_examples_mass_heat_salt.ipynb @@ -1047,111 +1047,131 @@ " stroke-width: 0.8px;\n", "}\n", "
<xarray.DataArray 'heat_lhs_sum_advection' (time: 1, z_l: 35, yh: 180, xh: 240)> Size: 12MB\n",
-       "array([[[[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]],\n",
-       "\n",
-       "        [[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]],\n",
-       "\n",
-       "        [[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "...\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]],\n",
-       "\n",
-       "        [[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]],\n",
-       "\n",
-       "        [[0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         ...,\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.],\n",
-       "         [0., 0., 0., ..., 0., 0., 0.]]]], shape=(1, 35, 180, 240))\n",
+       "dask.array<add, shape=(1, 35, 180, 240), dtype=float64, chunksize=(1, 35, 100, 100), chunktype=numpy.ndarray>\n",
        "Coordinates:\n",
        "  * time       (time) object 8B 2000-07-01 00:00:00\n",
        "  * z_l        (z_l) float64 280B 2.5 10.0 20.0 32.5 ... 5.5e+03 6e+03 6.5e+03\n",
        "  * yh         (yh) int64 1kB 0 1 2 3 4 5 6 7 ... 173 174 175 176 177 178 179\n",
        "  * xh         (xh) int64 2kB 0 1 2 3 4 5 6 7 ... 233 234 235 236 237 238 239\n",
-       "    geolon     (yh, xh) float64 346kB ...\n",
-       "    lon        (yh, xh) float64 346kB ...\n",
-       "    geolat     (yh, xh) float64 346kB ...\n",
-       "    lat        (yh, xh) float64 346kB ...\n",
-       "    deptho     (yh, xh) float32 173kB ...\n",
-       "    wet        (yh, xh) float32 173kB ...\n",
-       "    areacello  (yh, xh) float64 346kB ...\n",
+       "    geolon     (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    lon        (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    geolat     (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    lat        (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    deptho     (yh, xh) float32 173kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    wet        (yh, xh) float32 173kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
+       "    areacello  (yh, xh) float64 346kB dask.array<chunksize=(100, 100), meta=np.ndarray>\n",
        "Attributes:\n",
        "    cell_measures:  volume: volcello area: areacello\n",
        "    time_avg_info:  average_T1,average_T2,average_DT\n",
        "    standard_name:  cell_area\n",
        "    note:           We ignore land cells in partially wet cells when coarseni...\n",
-       "    provenance:     heat_lhs_sum_advection_sum
  • cell_measures :
    volume: volcello area: areacello
    time_avg_info :
    average_T1,average_T2,average_DT
    standard_name :
    cell_area
    note :
    We ignore land cells in partially wet cells when coarsening, so that tracer content can be accurately reconstructed by multiplying coarsened area-averaged tendencies by it. Fully wet (`wet==1.0`) and fully dry (`wet==0.0`) cells should be unaffected, and will just represent the total cell area. For the partially wet cells, total cell area can be derived from the ocean area by divding `areacello` by `wet`.
    provenance :
    heat_lhs_sum_advection_sum
  • " ], "text/plain": [ " Size: 12MB\n", - "array([[[[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - "...\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]]]], shape=(1, 35, 180, 240))\n", + "dask.array\n", "Coordinates:\n", " * time (time) object 8B 2000-07-01 00:00:00\n", " * z_l (z_l) float64 280B 2.5 10.0 20.0 32.5 ... 5.5e+03 6e+03 6.5e+03\n", " * yh (yh) int64 1kB 0 1 2 3 4 5 6 7 ... 173 174 175 176 177 178 179\n", " * xh (xh) int64 2kB 0 1 2 3 4 5 6 7 ... 233 234 235 236 237 238 239\n", - " geolon (yh, xh) float64 346kB ...\n", - " lon (yh, xh) float64 346kB ...\n", - " geolat (yh, xh) float64 346kB ...\n", - " lat (yh, xh) float64 346kB ...\n", - " deptho (yh, xh) float32 173kB ...\n", - " wet (yh, xh) float32 173kB ...\n", - " areacello (yh, xh) float64 346kB ...\n", + " geolon (yh, xh) float64 346kB dask.array\n", + " lon (yh, xh) float64 346kB dask.array\n", + " geolat (yh, xh) float64 346kB dask.array\n", + " lat (yh, xh) float64 346kB dask.array\n", + " deptho (yh, xh) float32 173kB dask.array\n", + " wet (yh, xh) float32 173kB dask.array\n", + " areacello (yh, xh) float64 346kB dask.array\n", "Attributes:\n", " cell_measures: volume: volcello area: areacello\n", " time_avg_info: average_T1,average_T2,average_DT\n", diff --git a/examples/load_example_model_grid.py b/examples/load_example_model_grid.py index c5eb9ac..9fd917f 100644 --- a/examples/load_example_model_grid.py +++ b/examples/load_example_model_grid.py @@ -25,6 +25,7 @@ def load_MOM6_example_grid(file_name): "z_l":xr.DataArray([3000], dims=("z_l",)), "z_i":xr.DataArray([0,6000], dims=("z_i",)) }) + ds = ds.chunk({"xh":100, "yh":100, "xq":100, "yq":100, "time":1}) # Chunk up the data to make it more like a user's typical dataset return construct_grid(ds) def load_MOM6_coarsened_diagnostics(): diff --git a/xbudget/collect.py b/xbudget/collect.py index f3eb512..2b9e7d6 100644 --- a/xbudget/collect.py +++ b/xbudget/collect.py @@ -304,29 +304,29 @@ def budget_fill_dict(data, xbudget_dict, namepath, allow_rechunk = True): if len(candidate_axes) == 1: axis = candidate_axes[0] else: - raise ValueError("Flux difference inconsistent with finite volume discretization.") - - if allow_rechunk: #NEW CODE - try: #extract original chunks when possible - #not using ds[v_term] since it may not have the non-staggered dimension chunks. - original_chunks = dict(ds.chunksizes) - except Exception: - warnings.warn("Dataset chunks are inconsistent; using unify_chunks()", UserWarning) - original_chunks = dict(ds.unify_chunks().chunksizes) - - # Find the staggered dimension for the given axis in the DataArray - axis_dim = [d for d in ds[v_term].dims if d in grid.axes[axis].coords.values()] - if len(axis_dim) != 1: - raise ValueError(f"Expected to find one dimension for axis '{axis}' in variable '{v_term}', but found {len(dims_for_axis)}: {dims_for_axis}") - axis_dim = axis_dim[0] - - # Temporarily rechunk to put the difference dim in a single chunk, all other chunks are auto. - temporary_chunks = {axis_dim: -1, **{d: "auto" for d in ds[v_term].dims if d != axis_dim}} - var = grid.diff(ds[v_term].chunk(temporary_chunks).fillna(0.0), axis=axis) - # Attempt original chunking for preserved dimensions - var = var.chunk({d: original_chunks.get(d, var.chunksizes[d]) for d in var.dims}) - else: #OLD CODE - var = grid.diff(ds[v_term].fillna(0.), axis) + raise ValueError("Finite difference inconsistent with finite volume discretization.") + + if allow_rechunk: + try: #extract original chunks when possible + #not using ds[v_term] since it may not have the non-staggered dimension chunks. + original_chunks = dict(ds.chunksizes) + except Exception: + warnings.warn("Dataset chunks are inconsistent; using unify_chunks()", UserWarning) + original_chunks = dict(ds.unify_chunks().chunksizes) + + # Find the staggered dimension for the given axis in the DataArray + axis_dim = [d for d in ds[v_term].dims if d in grid.axes[axis].coords.values()] + if len(axis_dim) != 1: + raise ValueError(f"Expected to find one dimension for axis '{axis}' in variable '{v_term}', but found {len(axis_dim)}: {axis_dim}") + axis_dim = axis_dim[0] + + # Temporarily rechunk to put the difference dim in a single chunk, all other chunks are auto. + temporary_chunks = {axis_dim: -1, **{d: "auto" for d in ds[v_term].dims if d != axis_dim}} + var = grid.diff(ds[v_term].chunk(temporary_chunks).fillna(0.0), axis=axis) + # Attempt original chunking for preserved dimensions + var = var.chunk({d: original_chunks.get(d, var.chunksizes[d]) for d in var.dims}) + else: + var = grid.diff(ds[v_term].fillna(0.), axis) var_name = f"{namepath}_difference" var = var.rename(var_name) From d7fb7f1074e17f414ca1bdbb4d5dd4d2c2507e66 Mon Sep 17 00:00:00 2001 From: Henri Drake Date: Tue, 24 Feb 2026 08:12:07 -0800 Subject: [PATCH 4/4] Add `allow_chunk` to docstrings --- xbudget/collect.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xbudget/collect.py b/xbudget/collect.py index 2b9e7d6..509c3ee 100644 --- a/xbudget/collect.py +++ b/xbudget/collect.py @@ -179,6 +179,10 @@ def collect_budgets(ds, xbudget_dict, allow_rechunk = True): } } } + allow_rechunk : bool (default: True) + Whether to temporarily rechunk when taking differences along a dimension, + e.g. to compute flux divergences on `center` from fluxes on `outer` or + tendencies on `center` from snapshots on `outer`. """ for eq, v in xbudget_dict.items(): for side in ["lhs", "rhs"]: @@ -193,6 +197,10 @@ def budget_fill_dict(data, xbudget_dict, namepath, allow_rechunk = True): data : xgcm.grid or xr.Dataset xbudget_dict : dictionary in xbudget-compatible format containing variable in namepath namepath : name of variable in dataset (data._ds or data) + allow_rechunk : bool (default: True) + Whether to temporarily rechunk when taking differences along a dimension, + e.g. to compute flux divergences on `center` from fluxes on `outer` or + tendencies on `center` from snapshots on `outer`. """ if type(data)==xgcm.grid.Grid: grid = data