diff --git a/docs/source/user_guide/rasterize.ipynb b/docs/source/user_guide/rasterize.ipynb
index d9361fbc8..4a0f90082 100644
--- a/docs/source/user_guide/rasterize.ipynb
+++ b/docs/source/user_guide/rasterize.ipynb
@@ -3,7 +3,7 @@
{
"cell_type": "markdown",
"id": "jb1yj58wq4q",
- "source": "## Rasterize\n\n`xrspatial.rasterize` converts vector geometries (polygons, lines, points) into a 2D `xr.DataArray`. No GDAL dependency required.\n\nThis guide covers:\n- [Basic rasterization](#Basic-rasterization) -- polygons, lines, and points\n- [Merge modes](#Merge-modes) -- controlling how overlapping geometries combine\n- [Custom merge functions](#Custom-merge-functions) -- user-defined numba-jitted merge logic\n- [Dask parallel rasterization](#Dask-parallel-rasterization) -- tile-based output chunking",
+ "source": "## Rasterize\n\n`xrspatial.rasterize` converts vector geometries (polygons, lines, points) into a 2D `xr.DataArray`. No GDAL dependency required.\n\nThis guide covers:\n- [Basic rasterization](#Basic-rasterization) -- polygons, lines, and points\n- [Categorical columns](#Categorical-columns) -- burn string/categorical labels, readable in QGIS\n- [Merge modes](#Merge-modes) -- controlling how overlapping geometries combine\n- [Custom merge functions](#Custom-merge-functions) -- user-defined numba-jitted merge logic\n- [Dask parallel rasterization](#Dask-parallel-rasterization) -- tile-based output chunking",
"metadata": {}
}
],
diff --git a/examples/user_guide/28_Rasterize.ipynb b/examples/user_guide/28_Rasterize.ipynb
index ea04d9fe4..97567689b 100644
--- a/examples/user_guide/28_Rasterize.ipynb
+++ b/examples/user_guide/28_Rasterize.ipynb
@@ -18,20 +18,21 @@
"### What you'll build\n",
"\n",
"1. Rasterize land-use zones with the `.xrs` accessor\n",
- "2. Handle overlapping polygons and interior holes\n",
- "3. Burn lines and points into a raster\n",
- "4. Compare merge modes for overlapping features\n",
- "5. Write a custom numba merge function\n",
- "6. Use multi-column properties for density mapping\n",
- "7. Run Dask parallel rasterization\n",
- "8. Combine rasterization with zonal statistics\n",
- "9. Compare default vs. `all_touched` rasterization\n",
- "10. Use the standalone `rasterize()` function with geometry pairs\n",
+ "2. Burn a string or categorical column into a labeled raster\n",
+ "3. Handle overlapping polygons and interior holes\n",
+ "4. Burn lines and points into a raster\n",
+ "5. Compare merge modes for overlapping features\n",
+ "6. Write a custom numba merge function\n",
+ "7. Use multi-column properties for density mapping\n",
+ "8. Run Dask parallel rasterization\n",
+ "9. Combine rasterization with zonal statistics\n",
+ "10. Compare default vs. `all_touched` rasterization\n",
+ "11. Use the standalone `rasterize()` function with geometry pairs\n",
"\n",
"\n",
"\n",
"**Jump to a section:**\n",
- "[Basic rasterization](#Basic-rasterization) | [Overlapping polygons](#Overlapping-polygons) | [Lines and points](#Lines-and-points) | [Merge modes](#Merge-modes) | [Custom merge](#Custom-merge) | [Multi-column properties](#Multi-column-properties) | [Dask parallel](#Dask-parallel) | [Zonal statistics](#Zonal-statistics) | [All touched](#All-touched) | [Standalone function](#Standalone-function)"
+ "[Basic rasterization](#Basic-rasterization) | [Categorical columns](#Categorical-columns) | [Overlapping polygons](#Overlapping-polygons) | [Lines and points](#Lines-and-points) | [Merge modes](#Merge-modes) | [Custom merge](#Custom-merge) | [Multi-column properties](#Multi-column-properties) | [Dask parallel](#Dask-parallel) | [Zonal statistics](#Zonal-statistics) | [All touched](#All-touched) | [Standalone function](#Standalone-function)"
]
},
{
@@ -120,6 +121,58 @@
"plt.tight_layout()"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "cat3482md",
+ "metadata": {},
+ "source": [
+ "## Categorical columns\n",
+ "\n",
+ "Pass a string or categorical column straight to `column=` and `rasterize` label-encodes it for you. Each distinct label becomes an integer code, the output is an `int32` band with a `-1` nodata value, and the label map rides along in `result.attrs['category_names']` (the list index is the pixel code). A matching `attrs['category_colors']` holds one RGBA per class.\n",
+ "\n",
+ "The plot below burns the land-use `label` column directly, with no manual encoding step."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "cat3482code",
+ "metadata": {},
+ "source": [
+ "# Burn the string 'label' column directly -- no manual encoding\n",
+ "landcover = template.xrs.rasterize(gdf, column='label')\n",
+ "\n",
+ "names = landcover.attrs['category_names']\n",
+ "print('category_names:', names)\n",
+ "print('dtype:', landcover.dtype, '| nodata:', landcover.attrs['nodata'])\n",
+ "\n",
+ "# Colors the encoder assigned, scaled to 0-1 for matplotlib\n",
+ "colors = [tuple(c / 255 for c in rgba)\n",
+ " for rgba in landcover.attrs['category_colors']]\n",
+ "cmap = ListedColormap(colors)\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(10, 4))\n",
+ "landcover.where(landcover >= 0).plot.imshow(\n",
+ " ax=ax, cmap=cmap, add_colorbar=False, vmin=0, vmax=len(names) - 1)\n",
+ "ax.legend(handles=[Patch(facecolor=colors[i], label=name)\n",
+ " for i, name in enumerate(names)],\n",
+ " loc='upper right', fontsize=10, framealpha=0.9)\n",
+ "ax.set_title('Rasterized by string label')\n",
+ "ax.set_axis_off()\n",
+ "plt.tight_layout()"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cat3482alert",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "Labels travel to QGIS. Writing this result with to_geotiff(landcover, 'landcover.tif') also writes a landcover.tif.aux.xml sidecar holding the category names and colors. GDAL reads the sidecar, so the file opens in QGIS showing the class names instead of bare codes. Keep the sidecar next to the .tif when you move the file. open_geotiff restores category_names / category_colors back onto attrs.\n",
+ "
"
+ ]
+ },
{
"cell_type": "markdown",
"id": "wtyzoml2vc",
@@ -596,4 +649,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
+}
\ No newline at end of file
diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py
index 5764875be..41e7c7fca 100644
--- a/xrspatial/geotiff/__init__.py
+++ b/xrspatial/geotiff/__init__.py
@@ -910,6 +910,17 @@ def open_geotiff(source: str | BinaryIO, *,
source = _coerce_path(source)
+ def _attach_category_attrs(da):
+ # Categorical labels/colors live in a PAM ``.aux.xml`` sidecar
+ # (GDAL ignores them embedded in the TIFF). Merge them back onto the
+ # result so a rasterize -> to_geotiff -> open_geotiff round-trip
+ # preserves ``category_names`` / ``category_colors``. Local string
+ # sources only; a missing/malformed sidecar yields {} and is a no-op.
+ if isinstance(source, str):
+ from ._pam import read_pam_sidecar
+ da.attrs.update(read_pam_sidecar(source))
+ return da
+
# Resolve the rioxarray-compatible renames. ``masked`` / ``default_name``
# are the canonical names; ``mask_nodata`` / ``name`` are deprecated
# aliases kept for back-compat. Mirrors the sentinel-based deprecation in
@@ -1115,7 +1126,8 @@ def open_geotiff(source: str | BinaryIO, *,
vrt_kwargs = {}
if missing_sources_passed:
vrt_kwargs['missing_sources'] = missing_sources
- return _read_vrt(source, dtype=dtype, window=window, band=band,
+ return _attach_category_attrs(_read_vrt(
+ source, dtype=dtype, window=window, band=band,
name=default_name, chunks=chunks, gpu=gpu,
max_pixels=max_pixels,
allow_rotated=allow_rotated,
@@ -1126,7 +1138,7 @@ def open_geotiff(source: str | BinaryIO, *,
allow_internal_only_jpeg=allow_internal_only_jpeg,
band_nodata=band_nodata,
mask_nodata=masked,
- **vrt_kwargs)
+ **vrt_kwargs))
# File-like buffer rejections for ``gpu=True`` / ``chunks=...`` already
# fired inside ``_validate_dispatch_kwargs`` above; the non-VRT branches
@@ -1138,7 +1150,8 @@ def open_geotiff(source: str | BinaryIO, *,
gpu_kwargs = {}
if on_gpu_failure is not _ON_GPU_FAILURE_SENTINEL:
gpu_kwargs['on_gpu_failure'] = on_gpu_failure
- return _read_geotiff_gpu(source, dtype=dtype,
+ return _attach_category_attrs(_read_geotiff_gpu(
+ source, dtype=dtype,
overview_level=overview_level,
window=window, band=band,
name=default_name, chunks=chunks,
@@ -1153,11 +1166,12 @@ def open_geotiff(source: str | BinaryIO, *,
allow_internal_only_jpeg),
mask_nodata=masked,
mask_and_scale=unpack,
- **gpu_kwargs)
+ **gpu_kwargs))
# Dask path (CPU)
if chunks is not None:
- return _read_geotiff_dask(source, dtype=dtype, chunks=chunks,
+ return _attach_category_attrs(_read_geotiff_dask(
+ source, dtype=dtype, chunks=chunks,
overview_level=overview_level,
window=window, band=band,
max_pixels=max_pixels, name=default_name,
@@ -1171,7 +1185,7 @@ def open_geotiff(source: str | BinaryIO, *,
allow_internal_only_jpeg),
mask_nodata=masked,
mask_and_scale=unpack,
- parse_coordinates=parse_coordinates)
+ parse_coordinates=parse_coordinates))
kwargs = {}
if max_pixels is not None:
@@ -1212,7 +1226,7 @@ def open_geotiff(source: str | BinaryIO, *,
getattr(geo_info, '_mask_nodata', nodata)
if nodata is not None else None
)
- return _finalize_eager_read(
+ return _attach_category_attrs(_finalize_eager_read(
arr,
geo_info=geo_info,
nodata=nodata,
@@ -1226,7 +1240,7 @@ def open_geotiff(source: str | BinaryIO, *,
mask_and_scale=unpack,
parse_coordinates=parse_coordinates,
band=band,
- )
+ ))
def plot_geotiff(da: xr.DataArray, **kwargs):
diff --git a/xrspatial/geotiff/_attrs.py b/xrspatial/geotiff/_attrs.py
index 4f20ddc16..6a843665e 100644
--- a/xrspatial/geotiff/_attrs.py
+++ b/xrspatial/geotiff/_attrs.py
@@ -116,6 +116,12 @@
- ``extra_samples``: TIFF ExtraSamples tag.
- ``colormap``: raw uint16 RGB triples from the TIFF ColorMap tag (320),
attached to single-band paletted images.
+- ``category_names``: ordered list of class label strings (index == pixel
+ value) for a categorical raster. Written to / read from a PAM
+ ``.aux.xml`` sidecar (```` plus a thematic
+ ````); see :mod:`xrspatial.geotiff._pam`.
+- ``category_colors``: list of ``(r, g, b, a)`` int tuples (0-255), one per
+ category, emitted as the RAT's Red/Green/Blue/Alpha columns.
Removed in contract v2 (issue #2016):
diff --git a/xrspatial/geotiff/_pam.py b/xrspatial/geotiff/_pam.py
new file mode 100644
index 000000000..aa7255bca
--- /dev/null
+++ b/xrspatial/geotiff/_pam.py
@@ -0,0 +1,195 @@
+"""PAM (``.aux.xml``) sidecar helpers for categorical rasters.
+
+GDAL stores a band's category names and Raster Attribute Table (RAT) for a
+GeoTIFF in a ``.aux.xml`` PAM sidecar, not in the TIFF itself: an
+embedded RAT in the GDAL_METADATA tag is silently ignored on read. QGIS
+reads the sidecar to label and color discrete classes, so writing one is
+what makes a categorical ``rasterize`` result show class names instead of
+bare integers.
+
+This module builds that sidecar from ``attrs['category_names']`` /
+``attrs['category_colors']`` and parses it back. The XML shape matches what
+``gdalinfo`` round-trips: a single ```` carrying a
+```` list plus a thematic ````.
+"""
+from __future__ import annotations
+
+import os
+from xml.sax.saxutils import escape
+
+from ._safe_xml import safe_fromstring
+
+# GDAL RAT field-usage codes (GDALRATFieldUsage).
+_USAGE_MINMAX = 5
+_USAGE_NAME = 2
+_USAGE_RED = 6
+_USAGE_GREEN = 7
+_USAGE_BLUE = 8
+_USAGE_ALPHA = 9
+# GDAL RAT field types (GDALRATFieldType): Integer=0, Real=1, String=2.
+_TYPE_INT = 0
+_TYPE_STRING = 2
+
+
+def sidecar_path(path: str) -> str:
+ """Return the PAM sidecar path GDAL expects for *path* (``.aux.xml``)."""
+ return path + '.aux.xml'
+
+
+def build_pam_xml(category_names, category_colors=None):
+ """Build a PAM ``.aux.xml`` document for a categorical band.
+
+ Parameters
+ ----------
+ category_names : sequence of str
+ Class labels; list index is the pixel value.
+ category_colors : sequence of (r, g, b, a), optional
+ One RGBA tuple (components 0-255) per category. When given, the RAT
+ gains Red/Green/Blue/Alpha columns so QGIS colors each class.
+
+ Returns
+ -------
+ str
+ The ```` XML document.
+ """
+ names = [str(n) for n in category_names]
+ have_colors = category_colors is not None and len(category_colors) == len(names)
+
+ field_defs = [
+ ('Value', _TYPE_INT, _USAGE_MINMAX),
+ ('Class', _TYPE_STRING, _USAGE_NAME),
+ ]
+ if have_colors:
+ field_defs += [
+ ('Red', _TYPE_INT, _USAGE_RED),
+ ('Green', _TYPE_INT, _USAGE_GREEN),
+ ('Blue', _TYPE_INT, _USAGE_BLUE),
+ ('Alpha', _TYPE_INT, _USAGE_ALPHA),
+ ]
+
+ lines = ['', ' ']
+
+ lines.append(' ')
+ for name in names:
+ lines.append(f' {escape(name)}')
+ lines.append(' ')
+
+ lines.append(' ')
+ for i, (fname, ftype, usage) in enumerate(field_defs):
+ lines.append(f' ')
+ lines.append(f' {fname}')
+ lines.append(f' {ftype}')
+ lines.append(f' {usage}')
+ lines.append(' ')
+ for value, name in enumerate(names):
+ cells = [str(value), escape(name)]
+ if have_colors:
+ r, g, b, a = category_colors[value]
+ cells += [str(int(r)), str(int(g)), str(int(b)), str(int(a))]
+ row = ''.join(f'{c}' for c in cells)
+ lines.append(f' {row}
')
+ lines.append(' ')
+
+ lines.append(' ')
+ lines.append('')
+ return '\n'.join(lines) + '\n'
+
+
+def write_pam_sidecar(path, category_names, category_colors=None):
+ """Write the PAM sidecar for *path* and return the sidecar path."""
+ xml = build_pam_xml(category_names, category_colors)
+ out = sidecar_path(path)
+ with open(out, 'w', encoding='utf-8') as fh:
+ fh.write(xml)
+ return out
+
+
+def read_pam_sidecar(path):
+ """Read ``category_names`` / ``category_colors`` from *path*'s sidecar.
+
+ Returns a dict with whatever it could recover (``category_names`` and,
+ when the RAT carries color columns, ``category_colors``). Returns an
+ empty dict when no sidecar exists or it cannot be parsed -- a missing or
+ malformed sidecar is non-fatal auxiliary metadata, not a read error.
+ """
+ aux = sidecar_path(path)
+ if not os.path.exists(aux):
+ return {}
+ try:
+ with open(aux, 'r', encoding='utf-8') as fh:
+ root = safe_fromstring(fh.read())
+
+ band = root.find('.//PAMRasterBand')
+ if band is None:
+ return {}
+
+ names = None
+ colors = None
+ # Only a thematic RAT with a Name column describes categories. GDAL
+ # writes an athematic histogram/statistics RAT next to many ordinary
+ # rasters; it must not masquerade as category names. Prefer the RAT
+ # (it carries colors); fall back to the element,
+ # which GDAL only writes for real categories.
+ rat = band.find('GDALRasterAttributeTable')
+ if rat is not None and rat.get('tableType') == 'thematic':
+ names, colors = _parse_rat(rat)
+ if names is None:
+ cat_el = band.find('CategoryNames')
+ if cat_el is not None:
+ names = [c.text or '' for c in cat_el.findall('Category')]
+
+ out = {}
+ if names is not None:
+ out['category_names'] = names
+ if colors is not None:
+ out['category_colors'] = colors
+ return out
+ except (OSError, ValueError, TypeError):
+ # A missing, malformed, or foreign sidecar is non-fatal auxiliary
+ # metadata, not a read error -- never let it break open_geotiff.
+ return {}
+
+
+def _parse_rat(rat):
+ """Return (names, colors) from a thematic ````.
+
+ Returns ``(None, None)`` when the table has no Name column, i.e. it does
+ not describe categories. ``names`` is ordered by the row's Value column;
+ ``colors`` is the RGBA list when the RAT defines color columns, else
+ ``None``.
+ """
+ usage_to_col = {}
+ for fd in rat.findall('FieldDefn'):
+ usage_el = fd.find('Usage')
+ if fd.get('index') is None or usage_el is None \
+ or usage_el.text is None:
+ continue
+ usage_to_col[int(usage_el.text)] = int(fd.get('index'))
+
+ name_col = usage_to_col.get(_USAGE_NAME)
+ if name_col is None:
+ return None, None
+ value_col = usage_to_col.get(_USAGE_MINMAX)
+ rgba_cols = [usage_to_col.get(u) for u in
+ (_USAGE_RED, _USAGE_GREEN, _USAGE_BLUE, _USAGE_ALPHA)]
+ have_colors = all(c is not None for c in rgba_cols)
+
+ rows = []
+ for row in rat.findall('Row'):
+ fields = [f.text or '' for f in row.findall('F')]
+ # Tolerate a Real-typed value cell ("0.0") as well as an integer.
+ value = int(float(fields[value_col])) if value_col is not None \
+ else int(row.get('index'))
+ name = fields[name_col]
+ color = None
+ if have_colors:
+ color = tuple(int(fields[c]) for c in rgba_cols)
+ rows.append((value, name, color))
+
+ if not rows:
+ return None, None
+
+ rows.sort(key=lambda r: r[0])
+ names = [name for _, name, _ in rows]
+ colors = [color for _, _, color in rows] if have_colors else None
+ return names, colors
diff --git a/xrspatial/geotiff/_writers/eager.py b/xrspatial/geotiff/_writers/eager.py
index 67755292a..1c3c3a5e6 100644
--- a/xrspatial/geotiff/_writers/eager.py
+++ b/xrspatial/geotiff/_writers/eager.py
@@ -435,6 +435,22 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
path = _coerce_path(path)
+ # Categorical rasters carry their value->label map in attrs. GDAL/QGIS
+ # only read category names and colors from a PAM ``.aux.xml``
+ # sidecar (an embedded RAT is ignored), so capture the labels now and
+ # emit the sidecar next to the file on the way out. File-like
+ # destinations have no path to anchor a sidecar, so skip them.
+ _cat_names = None
+ _cat_colors = None
+ if isinstance(path, str) and isinstance(data, xr.DataArray):
+ _cat_names = data.attrs.get('category_names')
+ _cat_colors = data.attrs.get('category_colors')
+
+ def _write_category_sidecar():
+ if _cat_names:
+ from .._pam import write_pam_sidecar
+ write_pam_sidecar(path, _cat_names, _cat_colors)
+
# Reject bool / np.bool_ nodata up front. ``bool`` is a subclass of
# ``int`` in Python, so a typo like ``nodata=True`` slips past every
# downstream ``isinstance(nodata, (int, float))`` guard. The geotag
@@ -812,6 +828,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
allow_unparseable_crs=allow_unparseable_crs,
allow_internal_only_jpeg=allow_internal_only_jpeg,
drop_rotation=drop_rotation)
+ _write_category_sidecar()
return path
# Dispatch to _write_geotiff_gpu when GPU was selected (explicit
@@ -860,6 +877,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
allow_unparseable_crs=allow_unparseable_crs,
drop_rotation=drop_rotation,
)
+ _write_category_sidecar()
return path
except ImportError as e:
# ``_write_geotiff_gpu`` raises ImportError when cupy itself
@@ -1042,6 +1060,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
allow_internal_only_jpeg=allow_internal_only_jpeg,
allow_unparseable_crs=allow_unparseable_crs,
)
+ _write_category_sidecar()
return path
# Eager compute (numpy, CuPy, or dask+COG)
@@ -1137,6 +1156,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
allow_internal_only_jpeg=allow_internal_only_jpeg,
allow_unparseable_crs=allow_unparseable_crs,
)
+ _write_category_sidecar()
return path
diff --git a/xrspatial/rasterize.py b/xrspatial/rasterize.py
index a681c42d7..d001246fb 100644
--- a/xrspatial/rasterize.py
+++ b/xrspatial/rasterize.py
@@ -3744,8 +3744,52 @@ def _run_dask_cupy(geometries, props_array, bounds, height, width, fill,
# Input parsing
# ---------------------------------------------------------------------------
+def _is_categorical_like(series):
+ """True when a GeoDataFrame column should be label-encoded.
+
+ Catches pandas ``CategoricalDtype`` and any non-numeric dtype
+ (object / string). Bool stays numeric (it burns as 0/1), matching
+ the historical ``.astype(np.float64)`` behaviour.
+ """
+ import pandas as pd
+ return (isinstance(series.dtype, pd.CategoricalDtype)
+ or not pd.api.types.is_numeric_dtype(series.dtype))
+
+
+def _encode_categorical(series):
+ """Label-encode a column to integer codes plus an ordered name list.
+
+ An existing ``CategoricalDtype`` keeps its declared category order;
+ a plain string/object column is converted with ``astype('category')``,
+ whose categories come out lexically sorted. Missing values keep the
+ pandas ``-1`` code, which the caller maps to the fill value.
+ """
+ import pandas as pd
+ cat = series if isinstance(series.dtype, pd.CategoricalDtype) \
+ else series.astype('category')
+ codes = cat.cat.codes.to_numpy()
+ names = [str(c) for c in cat.cat.categories]
+ return codes, names
+
+
+def _categorical_colors(n):
+ """Generate ``n`` distinct opaque RGBA colors from an HSV spread.
+
+ Evenly spaced hues at fixed saturation/value give visually separable
+ classes without pulling in matplotlib. Returns a list of
+ ``(r, g, b, a)`` int tuples with components in 0-255.
+ """
+ import colorsys
+ colors = []
+ for i in range(n):
+ h = i / n if n else 0.0
+ r, g, b = colorsys.hsv_to_rgb(h, 0.65, 0.95)
+ colors.append((round(r * 255), round(g * 255), round(b * 255), 255))
+ return colors
+
+
def _parse_input(geometries, column=None, columns=None):
- """Normalise input to (geometry_list, props_array, bounds, crs).
+ """Normalise input to (geometry_list, props_array, bounds, crs, cats).
Returns
-------
@@ -3755,6 +3799,9 @@ def _parse_input(geometries, column=None, columns=None):
crs : the GeoDataFrame's ``.crs`` (any pyproj-parseable value) or
``None``. Only a GeoDataFrame exposes a CRS; the
``(geometry, value)`` iterable path always returns ``None``.
+ category_names : ordered list of label strings when ``column`` named a
+ string/categorical field, else ``None``. When set, ``props_array``
+ holds the integer codes (``-1`` for missing).
"""
# Handle dask-geopandas by materializing eagerly. Geometry data is
# typically much smaller than the output raster, so this is fine.
@@ -3775,6 +3822,7 @@ def _parse_input(geometries, column=None, columns=None):
# guard fires instead of producing a raster with nan coords.
if any(not np.isfinite(v) for v in total_bounds):
total_bounds = None
+ category_names = None
if columns is not None:
props_array = geometries[columns].values.astype(np.float64)
else:
@@ -3786,9 +3834,15 @@ def _parse_input(geometries, column=None, columns=None):
"GeoDataFrame has no numeric columns to burn. "
"Pass a 'column' name explicitly.")
column = numeric_cols[0]
- props_array = geometries[column].values.astype(
- np.float64).reshape(-1, 1)
- return geom_list, props_array, total_bounds, geometries.crs
+ col = geometries[column]
+ if _is_categorical_like(col):
+ codes, category_names = _encode_categorical(col)
+ props_array = codes.astype(np.float64).reshape(-1, 1)
+ else:
+ props_array = col.values.astype(
+ np.float64).reshape(-1, 1)
+ return (geom_list, props_array, total_bounds,
+ geometries.crs, category_names)
except ImportError:
pass
@@ -3804,13 +3858,13 @@ def _parse_input(geometries, column=None, columns=None):
if not geom_list:
props_array = np.empty((0, 1), dtype=np.float64)
- return geom_list, props_array, None, None
+ return geom_list, props_array, None, None, None
props_array = np.array(value_list, dtype=np.float64).reshape(-1, 1)
# Bounds computation is deferred: return None here and let the
# caller compute bboxes only when bounds are actually needed.
- return geom_list, props_array, None, None
+ return geom_list, props_array, None, None, None
def _check_uniform_axis(axis_name, coords, expected_step):
@@ -4135,6 +4189,18 @@ def rasterize(
Name of the GeoDataFrame column whose values are burned into
the raster. Ignored when ``geometries`` is a list of pairs.
Mutually exclusive with ``columns``.
+
+ A string or categorical column is label-encoded: each distinct
+ label gets an integer code ``0..N-1`` and the result is an
+ ``int32`` band with a ``-1`` nodata sentinel (unless ``dtype`` /
+ ``fill`` are passed explicitly). Plain string/object columns are
+ ordered lexically; an existing pandas ``Categorical`` keeps its
+ declared order. The label map is stored on the result as
+ ``attrs['category_names']`` (index == pixel code) plus an
+ auto-generated ``attrs['category_colors']`` (one RGBA per class).
+ ``to_geotiff`` writes these to a PAM ``.aux.xml`` sidecar
+ so GDAL/QGIS display the class names, and ``open_geotiff`` reads
+ them back.
columns : list of str, optional
Names of multiple GeoDataFrame columns to pass as a properties
array to the merge function. Mutually exclusive with ``column``.
@@ -4396,8 +4462,30 @@ def rasterize(
like_x_descending = grid.x_descending
# Parse input geometries
- geom_list, props_array, inferred_bounds, geom_crs = _parse_input(
- geometries, column=column, columns=columns)
+ geom_list, props_array, inferred_bounds, geom_crs, category_names = \
+ _parse_input(geometries, column=column, columns=columns)
+
+ # Categorical column: burn integer codes onto an integer band with a
+ # -1 nodata sentinel (pandas' own missing code) so the result renders
+ # as discrete classes in QGIS rather than a float ramp. Explicit
+ # ``dtype`` / ``fill`` always win. ``like`` only supplies the grid
+ # here, not the dtype -- a float template must not silently demote the
+ # codes back to a float ramp.
+ category_colors = None
+ if category_names is not None:
+ if dtype is None:
+ dtype = np.int32
+ try:
+ fill_is_nan = np.isnan(float(fill))
+ except (TypeError, ValueError):
+ fill_is_nan = False
+ if fill_is_nan:
+ fill = -1
+ # Map the pandas missing code (-1) to the resolved fill so
+ # geometries with no category become nodata.
+ if fill != -1:
+ props_array = np.where(props_array == -1, fill, props_array)
+ category_colors = _categorical_colors(len(category_names))
# Guard against silently burning geometries onto a template in a
# different CRS. The output inherits the template CRS (attrs /
@@ -4771,6 +4859,13 @@ def rasterize(
out_attrs['_FillValue'] = fill
out_attrs['nodatavals'] = (fill,)
+ # Carry the label map so to_geotiff can emit a PAM sidecar and QGIS
+ # shows class names. Index == pixel code; colors are one RGBA per
+ # category. These are pass-through attrs (xrspatial/geotiff/_attrs.py).
+ if category_names is not None:
+ out_attrs['category_names'] = category_names
+ out_attrs['category_colors'] = category_colors
+
# Emit the geometry CRS when the output would otherwise carry none.
# The grid is laid out in the geometry's coordinate system (bounds
# come from the geometry coords), so a CRS-carrying GeoDataFrame
diff --git a/xrspatial/tests/test_rasterize_categorical_3482.py b/xrspatial/tests/test_rasterize_categorical_3482.py
new file mode 100644
index 000000000..2aa1e26bf
--- /dev/null
+++ b/xrspatial/tests/test_rasterize_categorical_3482.py
@@ -0,0 +1,273 @@
+"""Tests for categorical/string column support in rasterize (issue #3482).
+
+A string or categorical ``column`` is label-encoded to integer codes on an
+int32 band with a -1 nodata sentinel, and the value->label map is carried in
+``attrs['category_names']`` / ``attrs['category_colors']``. ``to_geotiff``
+emits a PAM ``.aux.xml`` sidecar so GDAL/QGIS show the class names, and
+``open_geotiff`` reads it back for a full round-trip.
+"""
+import os
+import shutil
+import subprocess
+
+import numpy as np
+import pytest
+
+try:
+ from shapely.geometry import box
+ has_shapely = True
+except ImportError:
+ has_shapely = False
+
+try:
+ import geopandas as gpd
+ has_geopandas = True
+except ImportError:
+ has_geopandas = False
+
+try:
+ import pandas as pd
+ has_pandas = True
+except ImportError:
+ has_pandas = False
+
+try:
+ import cupy # noqa: F401
+ from numba import cuda
+ has_cuda = cuda.is_available()
+except Exception:
+ has_cuda = False
+
+if has_shapely:
+ from xrspatial.rasterize import rasterize
+
+pytestmark = [
+ pytest.mark.skipif(not has_shapely, reason="shapely not installed"),
+ pytest.mark.skipif(not has_geopandas, reason="geopandas not installed"),
+]
+
+
+@pytest.fixture
+def landcover_gdf():
+ """Three non-overlapping squares keyed by a string land-cover column."""
+ return gpd.GeoDataFrame(
+ {'landcover': ['water', 'forest', 'urban']},
+ geometry=[box(0, 0, 5, 5), box(5, 0, 10, 5), box(0, 5, 5, 10)],
+ crs='EPSG:4326',
+ )
+
+
+# ---------------------------------------------------------------------------
+# Encoding
+# ---------------------------------------------------------------------------
+
+class TestCategoricalEncoding:
+ def test_string_column_encodes_to_int32(self, landcover_gdf):
+ result = rasterize(landcover_gdf, column='landcover',
+ width=10, height=10)
+ assert result.dtype == np.int32
+ # Codes are 0..N-1 plus the -1 nodata for untouched cells.
+ assert set(np.unique(result.values)) <= {-1, 0, 1, 2}
+ assert result.attrs['nodata'] == -1
+
+ def test_category_names_sorted_for_plain_strings(self, landcover_gdf):
+ result = rasterize(landcover_gdf, column='landcover',
+ width=10, height=10)
+ # astype('category') sorts object categories lexically.
+ assert result.attrs['category_names'] == ['forest', 'urban', 'water']
+
+ def test_codes_match_category_order(self, landcover_gdf):
+ result = rasterize(landcover_gdf, column='landcover',
+ width=10, height=10)
+ names = result.attrs['category_names']
+ # 'forest' geometry is box(5,0,10,5): right half, bottom rows.
+ forest_code = names.index('forest')
+ vals = result.values
+ # Bottom-right quadrant should be all the forest code.
+ assert np.all(vals[5:, 5:] == forest_code)
+
+ @pytest.mark.skipif(not has_pandas, reason="pandas not installed")
+ def test_ordered_categorical_preserves_order(self):
+ cat = pd.Categorical(['b', 'a', 'c'],
+ categories=['c', 'b', 'a'], ordered=True)
+ gdf = gpd.GeoDataFrame(
+ {'k': cat},
+ geometry=[box(0, 0, 2, 2), box(2, 0, 4, 2), box(4, 0, 6, 2)],
+ )
+ result = rasterize(gdf, column='k', width=6, height=2)
+ assert result.attrs['category_names'] == ['c', 'b', 'a']
+
+ @pytest.mark.skipif(not has_pandas, reason="pandas not installed")
+ def test_missing_category_becomes_nodata(self):
+ gdf = gpd.GeoDataFrame(
+ {'k': ['x', None]},
+ geometry=[box(0, 0, 2, 2), box(2, 0, 4, 2)],
+ )
+ result = rasterize(gdf, column='k', width=4, height=2)
+ assert result.attrs['category_names'] == ['x']
+ # The None geometry must not paint a real code; only -1 and 0 appear.
+ assert set(np.unique(result.values)) <= {-1, 0}
+
+ def test_explicit_fill_remaps_missing(self):
+ gdf = gpd.GeoDataFrame(
+ {'k': ['x', None]},
+ geometry=[box(0, 0, 2, 2), box(2, 0, 4, 2)],
+ )
+ result = rasterize(gdf, column='k', width=4, height=2,
+ fill=99, dtype=np.int32)
+ assert result.attrs['nodata'] == 99
+ assert -1 not in np.unique(result.values)
+
+ def test_colors_one_per_category(self, landcover_gdf):
+ result = rasterize(landcover_gdf, column='landcover',
+ width=10, height=10)
+ colors = result.attrs['category_colors']
+ assert len(colors) == len(result.attrs['category_names'])
+ for r, g, b, a in colors:
+ assert all(0 <= c <= 255 for c in (r, g, b, a))
+ assert a == 255
+
+ def test_numeric_column_unchanged(self):
+ gdf = gpd.GeoDataFrame({'v': [5.0]}, geometry=[box(0, 0, 4, 4)])
+ result = rasterize(gdf, column='v', width=4, height=4)
+ assert result.dtype == np.float64
+ assert 'category_names' not in result.attrs
+
+ def test_explicit_dtype_respected(self, landcover_gdf):
+ result = rasterize(landcover_gdf, column='landcover',
+ width=10, height=10, fill=0, dtype=np.uint8)
+ assert result.dtype == np.uint8
+ assert result.attrs['category_names'] == ['forest', 'urban', 'water']
+
+
+# ---------------------------------------------------------------------------
+# GeoTIFF round-trip
+# ---------------------------------------------------------------------------
+
+class TestGeoTIFFRoundTrip:
+ def test_sidecar_written_and_read_back(self, landcover_gdf, tmp_path):
+ pytest.importorskip('tifffile')
+ from xrspatial.geotiff import to_geotiff, open_geotiff
+
+ result = rasterize(landcover_gdf, column='landcover',
+ width=20, height=20)
+ path = str(tmp_path / 'landcover_3482.tif')
+ to_geotiff(result, path)
+
+ assert os.path.exists(path + '.aux.xml')
+
+ back = open_geotiff(path)
+ assert back.attrs['category_names'] == ['forest', 'urban', 'water']
+ assert (list(back.attrs['category_colors'])
+ == list(result.attrs['category_colors']))
+
+ def test_no_sidecar_for_numeric(self, tmp_path):
+ pytest.importorskip('tifffile')
+ from xrspatial.geotiff import to_geotiff
+
+ gdf = gpd.GeoDataFrame({'v': [5.0]}, geometry=[box(0, 0, 4, 4)],
+ crs='EPSG:4326')
+ result = rasterize(gdf, column='v', width=8, height=8)
+ path = str(tmp_path / 'numeric_3482.tif')
+ to_geotiff(result, path)
+ assert not os.path.exists(path + '.aux.xml')
+
+ @pytest.mark.skipif(shutil.which('gdalinfo') is None,
+ reason="gdalinfo not available")
+ def test_gdalinfo_shows_categories(self, landcover_gdf, tmp_path):
+ """GDAL (and therefore QGIS) reads the labels from the sidecar."""
+ pytest.importorskip('tifffile')
+ from xrspatial.geotiff import to_geotiff
+
+ result = rasterize(landcover_gdf, column='landcover',
+ width=20, height=20)
+ path = str(tmp_path / 'gdalinfo_3482.tif')
+ to_geotiff(result, path)
+
+ out = subprocess.run(['gdalinfo', path], capture_output=True,
+ text=True, check=True).stdout
+ assert 'Categories:' in out
+ for name in ('forest', 'urban', 'water'):
+ assert name in out
+
+
+# ---------------------------------------------------------------------------
+# PAM XML helpers
+# ---------------------------------------------------------------------------
+
+class TestPamHelpers:
+ def test_build_and_parse_round_trip(self):
+ from xrspatial.geotiff._pam import build_pam_xml, _parse_rat
+ from xrspatial.geotiff._safe_xml import safe_fromstring
+
+ names = ['a', 'b', 'c']
+ colors = [(10, 20, 30, 255), (40, 50, 60, 255), (70, 80, 90, 255)]
+ xml = build_pam_xml(names, colors)
+ root = safe_fromstring(xml)
+ rat = root.find('.//GDALRasterAttributeTable')
+ parsed_names, parsed_colors = _parse_rat(rat)
+ assert parsed_names == names
+ assert parsed_colors == colors
+
+ def test_special_characters_escaped(self):
+ from xrspatial.geotiff._pam import build_pam_xml
+ from xrspatial.geotiff._safe_xml import safe_fromstring
+
+ # Ampersand / angle brackets must survive XML serialization.
+ xml = build_pam_xml(['a & b', 'c '])
+ root = safe_fromstring(xml) # must not raise
+ cats = [c.text for c in root.findall('.//Category')]
+ assert cats == ['a & b', 'c ']
+
+ def test_read_missing_sidecar_returns_empty(self, tmp_path):
+ from xrspatial.geotiff._pam import read_pam_sidecar
+ assert read_pam_sidecar(str(tmp_path / 'nope.tif')) == {}
+
+ def test_athematic_stats_sidecar_ignored(self, tmp_path):
+ """A GDAL statistics/histogram sidecar must not yield categories."""
+ from xrspatial.geotiff._pam import read_pam_sidecar
+ path = str(tmp_path / 'stats.tif')
+ with open(path + '.aux.xml', 'w') as fh:
+ fh.write(
+ ''
+ '0'
+ ''
+ 'Value1'
+ '5'
+ 'Count0'
+ '1'
+ '0.05
'
+ '1.04
'
+ ''
+ '')
+ assert read_pam_sidecar(path) == {}
+
+ def test_malformed_sidecar_returns_empty(self, tmp_path):
+ from xrspatial.geotiff._pam import read_pam_sidecar
+ path = str(tmp_path / 'bad.tif')
+ with open(path + '.aux.xml', 'w') as fh:
+ fh.write(''
+ ''
+ 'Class2'
+ '2'
+ 'ok
'
+ 'badextra
'
+ 'GARBAGE'
+ '')
+ # Must never raise; worst case returns {}.
+ assert isinstance(read_pam_sidecar(path), dict)
+
+
+# ---------------------------------------------------------------------------
+# GPU backend parity
+# ---------------------------------------------------------------------------
+
+@pytest.mark.skipif(not has_cuda, reason="CUDA / CuPy not available")
+class TestCuPyParity:
+ def test_gpu_codes_match_cpu(self, landcover_gdf):
+ cpu = rasterize(landcover_gdf, column='landcover',
+ width=16, height=16)
+ gpu = rasterize(landcover_gdf, column='landcover',
+ width=16, height=16, gpu=True)
+ assert gpu.attrs['category_names'] == cpu.attrs['category_names']
+ np.testing.assert_array_equal(gpu.data.get(), cpu.values)