From 16d712d4a75db9a75ddc73753e3aa7f22fb2a089 Mon Sep 17 00:00:00 2001 From: Pratik Garg Date: Thu, 23 Apr 2026 17:54:53 -0700 Subject: [PATCH] Implement Prometheus metrics emission for Orbax via JAX monitoring listeners. PiperOrigin-RevId: 904724322 --- checkpoint/orbax/__init__.py | 12 + checkpoint/orbax/checkpoint/__init__.py | 12 + checkpoint/orbax/checkpoint/monitoring.py | 148 +++++++++++++ .../orbax/checkpoint/monitoring_test.py | 205 ++++++++++++++++++ checkpoint/pyproject.toml | 1 + 5 files changed, 378 insertions(+) create mode 100644 checkpoint/orbax/checkpoint/monitoring.py create mode 100644 checkpoint/orbax/checkpoint/monitoring_test.py diff --git a/checkpoint/orbax/__init__.py b/checkpoint/orbax/__init__.py index 0a1693db1..ece9c8802 100644 --- a/checkpoint/orbax/__init__.py +++ b/checkpoint/orbax/__init__.py @@ -32,6 +32,8 @@ from orbax.checkpoint import msgpack_utils from orbax.checkpoint import options from orbax.checkpoint import path +import os +from orbax.checkpoint import monitoring as _orbax_monitoring from orbax.checkpoint import pathways from orbax.checkpoint import serialization from orbax.checkpoint import transform_utils @@ -90,3 +92,13 @@ __version__ = version.__version__ del version + +# Autostart Prometheus metrics server if not disabled. +# Default port is 8000. Use environment variable to override or set to 0 to +# disable. +try: + _prometheus_port = int(os.environ.get('ORBAX_PROMETHEUS_PORT', 8000)) +except ValueError: + _prometheus_port = 8000 +if _prometheus_port > 0: + _orbax_monitoring.initialize(port=_prometheus_port) diff --git a/checkpoint/orbax/checkpoint/__init__.py b/checkpoint/orbax/checkpoint/__init__.py index 0a1693db1..ece9c8802 100644 --- a/checkpoint/orbax/checkpoint/__init__.py +++ b/checkpoint/orbax/checkpoint/__init__.py @@ -32,6 +32,8 @@ from orbax.checkpoint import msgpack_utils from orbax.checkpoint import options from orbax.checkpoint import path +import os +from orbax.checkpoint import monitoring as _orbax_monitoring from orbax.checkpoint import pathways from orbax.checkpoint import serialization from orbax.checkpoint import transform_utils @@ -90,3 +92,13 @@ __version__ = version.__version__ del version + +# Autostart Prometheus metrics server if not disabled. +# Default port is 8000. Use environment variable to override or set to 0 to +# disable. +try: + _prometheus_port = int(os.environ.get('ORBAX_PROMETHEUS_PORT', 8000)) +except ValueError: + _prometheus_port = 8000 +if _prometheus_port > 0: + _orbax_monitoring.initialize(port=_prometheus_port) diff --git a/checkpoint/orbax/checkpoint/monitoring.py b/checkpoint/orbax/checkpoint/monitoring.py new file mode 100644 index 000000000..3b43fb228 --- /dev/null +++ b/checkpoint/orbax/checkpoint/monitoring.py @@ -0,0 +1,148 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Orbax metrics telemetry.""" + +# pylint: disable=g-import-not-at-top +# pylint: disable=invalid-name + +import logging +import os +import threading + +from jax import monitoring + +try: + import prometheus_client # pytype: disable=import-error + + _Counter = prometheus_client.Counter + _Gauge = prometheus_client.Gauge + _Histogram = prometheus_client.Histogram +except (ImportError, AttributeError): + prometheus_client = None + _Counter = None + _Gauge = None + _Histogram = None + +_USE_PROMETHEUS = True + +_initialized = False +_metrics = {} # name -> metric object +_lock = threading.Lock() + +_PROMETHEUS_ALLOWED_METRICS = { + '/jax/orbax/write/async/start', + '/jax/checkpoint/write/async/blocking_duration_sec', + '/jax/orbax/write/start', + '/jax/checkpoint/write/durations_sec', +} + + +def _is_allowed(metric_name: str) -> bool: + """Returns True if the metric is allowed for Prometheus export.""" + return metric_name in _PROMETHEUS_ALLOWED_METRICS + + +def _record_event(metric_name: str, **kwargs): + """JAX monitoring handler for events to route to prometheus-client.""" + del kwargs # Unused. + if not _initialized or not _is_allowed(metric_name) or not _Counter: + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + + if metric_name_safe not in _metrics: + with _lock: + if metric_name_safe not in _metrics: + _metrics[metric_name_safe] = _Counter(metric_name_safe, metric_name) + + metric = _metrics[metric_name_safe] + if _Counter and isinstance(metric, _Counter): + metric.inc() + + +def _record_scalar(metric_name: str, value: float | int, **kwargs): + """JAX monitoring handler for scalars to route to prometheus-client.""" + del kwargs # Unused. + if not _initialized or not _is_allowed(metric_name) or not _Gauge: + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + + if metric_name_safe not in _metrics: + with _lock: + if metric_name_safe not in _metrics: + _metrics[metric_name_safe] = _Gauge(metric_name_safe, metric_name) + + metric = _metrics[metric_name_safe] + if _Gauge and isinstance(metric, _Gauge): + metric.set(value) + + +def _record_duration(metric_name: str, duration: float | int, **kwargs): + """JAX monitoring handler for duration to route to prometheus-client.""" + del kwargs # Unused. + if not _initialized or not _is_allowed(metric_name) or not _Histogram: + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + + if metric_name_safe not in _metrics: + with _lock: + if metric_name_safe not in _metrics: + _metrics[metric_name_safe] = _Histogram(metric_name_safe, metric_name) + + metric = _metrics[metric_name_safe] + if _Histogram and isinstance(metric, _Histogram): + metric.observe(duration) + + +def initialize(port=8000): + """Initializes Orbax metric reporting.""" + global _initialized + if _initialized: + return + if not _USE_PROMETHEUS: + return + if os.environ.get('DISABLE_ORBAX_TELEMETRY', 'false').lower() == 'true': + logging.info('Orbax telemetry is deactivated via environment variable.') + return + + if not prometheus_client: + logging.warning( + 'prometheus-client not found. Orbax metrics will not be reported.' + ) + return + + with _lock: + if _initialized: + return + try: + if port > 0: + prometheus_client.start_http_server(port) + logging.info('Prometheus metrics server started on port %s.', port) + _initialized = True + monitoring.register_event_listener(_record_event) + monitoring.register_scalar_listener(_record_scalar) + monitoring.register_event_duration_secs_listener(_record_duration) + logging.info('Installed JAX monitoring listeners for Prometheus.') + except (OSError, ValueError) as e: + # Handle 'already in use' for Linux/macOS and Windows (10048). + if 'already in use' in str(e) or '10048' in str(e): + # If the server is already running (e.g. started by Grain), just + # register listeners. + _initialized = True + monitoring.register_event_listener(_record_event) + monitoring.register_scalar_listener(_record_scalar) + monitoring.register_event_duration_secs_listener(_record_duration) + logging.info('Prometheus server already active. Listeners installed.') + else: + logging.warning('Failed to initialize Prometheus metrics: %s', e) diff --git a/checkpoint/orbax/checkpoint/monitoring_test.py b/checkpoint/orbax/checkpoint/monitoring_test.py new file mode 100644 index 000000000..56c1694b8 --- /dev/null +++ b/checkpoint/orbax/checkpoint/monitoring_test.py @@ -0,0 +1,205 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Orbax Prometheus metrics telemetry.""" + +from unittest import mock + +from absl.testing import absltest +from jax import monitoring +from orbax.checkpoint import monitoring as orbax_monitoring +import prometheus_client + + +class PrometheusMetricsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # Use mock.patch to reset internal state without production hooks. + # This is standard Google practice for keeping production code clean. + self.enter_context( + mock.patch.object(orbax_monitoring, '_initialized', False) + ) + self.enter_context(mock.patch.object(orbax_monitoring, '_metrics', {})) + monitoring.clear_event_listeners() + + # Clear registry for hermetic tests. + registry = prometheus_client.REGISTRY + # pylint: disable=protected-access + if hasattr(registry, '_collector_to_names'): + for collector in list(registry._collector_to_names): + registry.unregister(collector) + # pylint: enable=protected-access + + def test_initialize(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + # Verify behavior (e.g. listeners registered) + monitoring.record_scalar('/jax/orbax/test_init', 1) + + def test_multiple_initializations(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + # The second initialization shouldn't throw any errors or wipe states. + orbax_monitoring.initialize(port=0) + + def test_record_before_initialize(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + monitoring.record_scalar('/jax/orbax/test_scalar_early', 123) + monitoring.record_event_duration_secs( + '/jax/orbax/test_duration_early', 0.5 + ) + monitoring.record_event('/jax/orbax/test_event_early') + # No metrics should be registered in prometheus if not initialized. + self.assertIsNone( + prometheus_client.REGISTRY.get_sample_value( + 'jax_orbax_test_scalar_early' + ) + ) + + def test_handler_scalar_metric(self): + with mock.patch.object( + orbax_monitoring, '_USE_PROMETHEUS', True + ), mock.patch.object( + orbax_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/jax/orbax/test_scalar'}, + ): + orbax_monitoring.initialize(port=0) + monitoring.record_scalar('/jax/orbax/test_scalar', 123) + metric_name = 'jax_orbax_test_scalar' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name), 123.0 + ) + + def test_scalar_metric_updates(self): + with mock.patch.object( + orbax_monitoring, '_USE_PROMETHEUS', True + ), mock.patch.object( + orbax_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/jax/orbax/test_scalar'}, + ): + orbax_monitoring.initialize(port=0) + monitoring.record_scalar('/jax/orbax/test_scalar', 123) + monitoring.record_scalar('/jax/orbax/test_scalar', 456) + metric_name = 'jax_orbax_test_scalar' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name), 456.0 + ) + + def test_handler_duration_metric(self): + with mock.patch.object( + orbax_monitoring, '_USE_PROMETHEUS', True + ), mock.patch.object( + orbax_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/jax/orbax/test_duration_secs'}, + ): + orbax_monitoring.initialize(port=0) + monitoring.record_event_duration_secs( + '/jax/orbax/test_duration_secs', 0.5 + ) + metric_name = 'jax_orbax_test_duration_secs' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_count'), + 1.0, + ) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_sum'), 0.5 + ) + + def test_duration_metric_updates(self): + with mock.patch.object( + orbax_monitoring, '_USE_PROMETHEUS', True + ), mock.patch.object( + orbax_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/jax/orbax/test_duration_secs'}, + ): + orbax_monitoring.initialize(port=0) + monitoring.record_event_duration_secs( + '/jax/orbax/test_duration_secs', 0.5 + ) + monitoring.record_event_duration_secs( + '/jax/orbax/test_duration_secs', 1.5 + ) + metric_name = 'jax_orbax_test_duration_secs' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_count'), + 2.0, + ) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_sum'), 2.0 + ) + + def test_handler_event_metric(self): + with mock.patch.object( + orbax_monitoring, '_USE_PROMETHEUS', True + ), mock.patch.object( + orbax_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/jax/orbax/test_event'}, + ): + orbax_monitoring.initialize(port=0) + monitoring.record_event('/jax/orbax/test_event') + metric_name = 'jax_orbax_test_event_total' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name), 1.0 + ) + + def test_event_metric_increments(self): + with mock.patch.object( + orbax_monitoring, '_USE_PROMETHEUS', True + ), mock.patch.object( + orbax_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/jax/orbax/test_event'}, + ): + orbax_monitoring.initialize(port=0) + monitoring.record_event('/jax/orbax/test_event') + monitoring.record_event('/jax/orbax/test_event') + monitoring.record_event('/jax/orbax/test_event') + metric_name = 'jax_orbax_test_event_total' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name), 3.0 + ) + + def test_ignore_unrelated_metrics(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + monitoring.record_scalar('/jax/compilation/time', 123) + metric_name = 'jax_compilation_time' + self.assertIsNone( + prometheus_client.REGISTRY.get_sample_value(metric_name) + ) + + def test_handler_second_prefix(self): + with mock.patch.object( + orbax_monitoring, '_USE_PROMETHEUS', True + ), mock.patch.object( + orbax_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/jax/checkpoint/test_scalar'}, + ): + orbax_monitoring.initialize(port=0) + monitoring.record_scalar('/jax/checkpoint/test_scalar', 123) + metric_name = 'jax_checkpoint_test_scalar' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name), 123.0 + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/pyproject.toml b/checkpoint/pyproject.toml index 8ff58e925..90a1ed4b0 100644 --- a/checkpoint/pyproject.toml +++ b/checkpoint/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ 'msgpack', 'jax >= 0.6.0', 'numpy', + 'prometheus-client', 'pyyaml', 'tensorstore >= 0.1.74', 'aiofiles',