From 98095b95da123392db0578321fd8b362ffbd1d6b Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 11 Mar 2026 12:29:15 -0700 Subject: [PATCH] [monitoring] Pass config-managed key-value tags to monitoring listeners. JAX monitoring listeners currently receive only arguments that events pass. However, we want to be able to tag monitoring data along with durations and counters. For example, if users want to collect data across jobs or tests, we want to be able to record metadata for post-processing (e.g., job attributes, test parameters). NOTE: Added `jax.monitoring_tags` to the public API so it's available to users as a context manager. However, other monitoring APIs are undocumented, so added to the `UNDOCUMENTED_APIS` list. PiperOrigin-RevId: 882140161 --- .../orbax/checkpoint/_src/serialization/serialization_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py index ef66b5a6a..630890d5d 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py @@ -164,6 +164,7 @@ def add_monitoring_listener(self) -> list[tuple[str, dict[str, Any]]]: jax_events = [] jax.monitoring.clear_event_listeners() def monitoring_listener(event, **kwargs): + kwargs.pop('tags', None) jax_events.append((event, kwargs)) jax.monitoring.register_event_listener(monitoring_listener) return jax_events