diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py index ef66b5a6a..630890d5d 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py @@ -164,6 +164,7 @@ def add_monitoring_listener(self) -> list[tuple[str, dict[str, Any]]]: jax_events = [] jax.monitoring.clear_event_listeners() def monitoring_listener(event, **kwargs): + kwargs.pop('tags', None) jax_events.append((event, kwargs)) jax.monitoring.register_event_listener(monitoring_listener) return jax_events