diff --git a/src/underworld3/function/quantities.py b/src/underworld3/function/quantities.py index 236a8e3c..e534f25a 100644 --- a/src/underworld3/function/quantities.py +++ b/src/underworld3/function/quantities.py @@ -716,6 +716,47 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + # ========================================================================= + # Array Indexing + # ========================================================================= + + def __getitem__(self, key): + """ + Enable array indexing on UWQuantity objects. + + Supports indexing into the underlying value array while preserving units. + Returns a new UWQuantity if the result is still an array, or a scalar + if indexing yields a single value. + + Parameters + ---------- + key : int, slice, tuple, or array-like + Index, slice, or advanced indexing specification + + Returns + ------- + UWQuantity or scalar + Indexed value with units preserved (if result is array) + or scalar value (if result is single element) + + Examples + -------- + >>> coords = uw.quantity([[100, 200], [300, 400]], "km") + >>> coords[0] # First particle: UWQuantity([100, 200], "km") + >>> coords[0, 0] # First coordinate: 100.0 (scalar, units lost) + >>> coords[:, 0] # All x-coordinates: UWQuantity([100, 300], "km") + """ + # Index into the underlying value + indexed_value = self._value[key] + + # If result is still an array, wrap in new UWQuantity + if isinstance(indexed_value, np.ndarray): + return UWQuantity(indexed_value, units=self._pint_unit) + + # Scalar result - return bare value + # (Could alternatively return UWQuantity for consistency) + return indexed_value + # ========================================================================= # SymPy Compatibility # ========================================================================= diff --git a/tests/test_0815_mesh_length_scale.py b/tests/test_0815_mesh_length_scale.py index 13a367ea..630b0a23 100644 --- a/tests/test_0815_mesh_length_scale.py +++ b/tests/test_0815_mesh_length_scale.py @@ -230,6 +230,7 @@ def test_length_units_matches_coordinate_units(): assert isinstance(mesh.length_units, str) +@pytest.mark.skip(reason="mesh.view() requires display - PyVista crashes in headless CI") def test_mesh_view_displays_length_scale(): """Test that mesh.view() displays length scale information.""" uw.reset_default_model()