From affc2f64730e059e0c9359d9204e67f819000490 Mon Sep 17 00:00:00 2001 From: Pratik Garg Date: Fri, 24 Apr 2026 19:12:21 -0700 Subject: [PATCH] Implement Prometheus metrics emission for Orbax. PiperOrigin-RevId: 905325203 --- checkpoint/orbax/__init__.py | 7 + checkpoint/orbax/checkpoint/__init__.py | 7 + .../checkpoint/_src/monitoring/monitoring.py | 208 ++++++++++++ .../_src/monitoring/monitoring_test.py | 303 ++++++++++++++++++ checkpoint/pyproject.toml | 1 + 5 files changed, 526 insertions(+) create mode 100644 checkpoint/orbax/checkpoint/_src/monitoring/monitoring.py create mode 100644 checkpoint/orbax/checkpoint/_src/monitoring/monitoring_test.py diff --git a/checkpoint/orbax/__init__.py b/checkpoint/orbax/__init__.py index 0a1693db1..bb1827939 100644 --- a/checkpoint/orbax/__init__.py +++ b/checkpoint/orbax/__init__.py @@ -18,6 +18,7 @@ import contextlib import functools +import os from orbax.checkpoint.experimental import v1 from orbax.checkpoint import arrays @@ -90,3 +91,9 @@ __version__ = version.__version__ del version + +# Autostart Prometheus metrics server if enabled via environment variable. +if os.environ.get('ENABLE_ORBAX_TELEMETRY', 'false').lower() == 'true': + from orbax.checkpoint._src.monitoring import monitoring + + monitoring.setup_telemetry() diff --git a/checkpoint/orbax/checkpoint/__init__.py b/checkpoint/orbax/checkpoint/__init__.py index 0a1693db1..bb1827939 100644 --- a/checkpoint/orbax/checkpoint/__init__.py +++ b/checkpoint/orbax/checkpoint/__init__.py @@ -18,6 +18,7 @@ import contextlib import functools +import os from orbax.checkpoint.experimental import v1 from orbax.checkpoint import arrays @@ -90,3 +91,9 @@ __version__ = version.__version__ del version + +# Autostart Prometheus metrics server if enabled via environment variable. +if os.environ.get('ENABLE_ORBAX_TELEMETRY', 'false').lower() == 'true': + from orbax.checkpoint._src.monitoring import monitoring + + monitoring.setup_telemetry() diff --git a/checkpoint/orbax/checkpoint/_src/monitoring/monitoring.py b/checkpoint/orbax/checkpoint/_src/monitoring/monitoring.py new file mode 100644 index 000000000..4326d0cad --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/monitoring/monitoring.py @@ -0,0 +1,208 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Orbax metrics telemetry.""" + +from __future__ import annotations + +# pylint: disable=g-import-not-at-top +# pylint: disable=invalid-name + +import importlib +import logging +import os +import tempfile +import threading + +from jax import monitoring + +if os.environ.get('ENABLE_ORBAX_TELEMETRY', 'false').lower() == 'true': + if 'PROMETHEUS_MULTIPROC_DIR' not in os.environ: + os.environ['PROMETHEUS_MULTIPROC_DIR'] = tempfile.mkdtemp( + prefix='orbax_metrics_' + ) + +try: + prometheus_client = importlib.import_module('prometheus_client') + + _prom_counter = prometheus_client.Counter # pytype: disable=attribute-error + _prom_histogram = prometheus_client.Histogram # pytype: disable=attribute-error +except (ImportError, AttributeError): + prometheus_client = None + _prom_counter = None + _prom_histogram = None + +_USE_PROMETHEUS = True + +_initialized = False +_metrics = {} # name -> metric object +_lock = threading.Lock() + +_PROMETHEUS_ALLOWED_METRIC_PREFIXES = ( + '/jax/orbax/write/', + '/jax/checkpoint/write/', + '/jax/orbax/read/', +) + + +def _is_allowed(metric_name: str) -> bool: + """Returns True if the metric is allowed for Prometheus export.""" + if not _USE_PROMETHEUS: + return False + for prefix in _PROMETHEUS_ALLOWED_METRIC_PREFIXES: + if metric_name.startswith(prefix): + return True + return False + + +def _record_event(metric_name: str, **kwargs): + """JAX monitoring handler for events to route to prometheus-client.""" + if not _initialized or not _is_allowed(metric_name) or not _prom_counter: + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + sorted_keys = sorted(kwargs.keys()) + labelnames = tuple(sorted_keys) + labelvalues = tuple(str(kwargs[k]) for k in sorted_keys) + + if metric_name_safe not in _metrics: + with _lock: + if metric_name_safe not in _metrics: + _metrics[metric_name_safe] = _prom_counter( + metric_name_safe, metric_name, labelnames=labelnames + ) + + metric = _metrics[metric_name_safe] + if _prom_counter and isinstance(metric, _prom_counter): + if labelnames: + metric.labels(*labelvalues).inc() + else: + metric.inc() + + +def _record_scalar(metric_name: str, value: float | int, **kwargs): + """JAX monitoring handler for scalars to route to prometheus-client.""" + if not _initialized or not _is_allowed(metric_name) or not _prom_histogram: + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + sorted_keys = sorted(kwargs.keys()) + labelnames = tuple(sorted_keys) + labelvalues = tuple(str(kwargs[k]) for k in sorted_keys) + + if metric_name_safe not in _metrics: + with _lock: + if metric_name_safe not in _metrics: + _metrics[metric_name_safe] = _prom_histogram( + metric_name_safe, metric_name, labelnames=labelnames + ) + + metric = _metrics[metric_name_safe] + if _prom_histogram and isinstance(metric, _prom_histogram): + if labelnames: + metric.labels(*labelvalues).observe(value) + else: + metric.observe(value) + + +def _record_duration(metric_name: str, duration: float | int, **kwargs): + """JAX monitoring handler for duration to route to prometheus-client.""" + if not _initialized or not _is_allowed(metric_name) or not _prom_histogram: + return + metric_name_safe = metric_name.strip('/').replace('/', '_') + sorted_keys = sorted(kwargs.keys()) + labelnames = tuple(sorted_keys) + labelvalues = tuple(str(kwargs[k]) for k in sorted_keys) + + if metric_name_safe not in _metrics: + with _lock: + if metric_name_safe not in _metrics: + _metrics[metric_name_safe] = _prom_histogram( + metric_name_safe, metric_name, labelnames=labelnames + ) + + metric = _metrics[metric_name_safe] + if _prom_histogram and isinstance(metric, _prom_histogram): + if labelnames: + metric.labels(*labelvalues).observe(duration) + else: + metric.observe(duration) + + +def initialize(port=9432): + """Initializes Orbax metric reporting.""" + global _initialized + if _initialized: + return + if not _USE_PROMETHEUS: + return + + if not prometheus_client: + logging.warning( + 'prometheus-client not found. Orbax metrics will not be reported.' + ) + return + + with _lock: + if _initialized: + return + + if port > 0: + try: + # If multiprocess directory is configured, use MultiProcessCollector + # to aggregate metrics from all worker processes. + multiprocess_started = False + if 'PROMETHEUS_MULTIPROC_DIR' in os.environ: + try: + multiprocess = importlib.import_module( + 'prometheus_client.multiprocess' + ) + registry = prometheus_client.CollectorRegistry() # pytype: disable=attribute-error + multiprocess.MultiProcessCollector(registry) # pytype: disable=attribute-error + prometheus_client.start_http_server(port, registry=registry) # pytype: disable=attribute-error + logging.info( + 'Prometheus multiprocess metrics server started on port %s.', + port, + ) + multiprocess_started = True + except (ImportError, AttributeError): + pass + + if not multiprocess_started: + # Standard single-process server + prometheus_client.start_http_server(port) # pytype: disable=attribute-error + logging.info('Prometheus metrics server started on port %s.', port) + except (OSError, ValueError) as e: + # Handle 'already in use' for Linux/macOS and Windows (10048). + if 'already in use' not in str(e) and '10048' not in str(e): + logging.warning('Failed to start Prometheus server: %s', e) + return + # If the server is already running (e.g. started by Grain), just + # register listeners. + logging.info('Prometheus server already active.') + + _initialized = True + monitoring.register_event_listener(_record_event) + monitoring.register_scalar_listener(_record_scalar) + monitoring.register_event_duration_secs_listener(_record_duration) + logging.info('Installed JAX monitoring listeners for Prometheus.') + + +def setup_telemetry(): + """Autostarts Prometheus metrics server if enabled via environment variable.""" + if os.environ.get('ENABLE_ORBAX_TELEMETRY', 'false').lower() == 'true': + import multiprocessing + + if multiprocessing.current_process().name == 'MainProcess': + initialize(port=9432) + else: + initialize(port=0) diff --git a/checkpoint/orbax/checkpoint/_src/monitoring/monitoring_test.py b/checkpoint/orbax/checkpoint/_src/monitoring/monitoring_test.py new file mode 100644 index 000000000..5acf1fe88 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/monitoring/monitoring_test.py @@ -0,0 +1,303 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Orbax Prometheus metrics telemetry.""" + +from __future__ import annotations + +import importlib +import os +import tempfile +from unittest import mock + +from absl.testing import absltest +from jax import monitoring +from orbax.checkpoint._src.monitoring import monitoring as orbax_monitoring +import prometheus_client + + +class PrometheusMetricsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.enter_context( + mock.patch.object(orbax_monitoring, '_initialized', False) + ) + self.enter_context(mock.patch.object(orbax_monitoring, '_metrics', {})) + + # Mock JAX monitoring registration to prevent accumulating listeners across + # tests. + self.enter_context(mock.patch.object(monitoring, 'register_event_listener')) + self.enter_context( + mock.patch.object(monitoring, 'register_scalar_listener') + ) + self.enter_context( + mock.patch.object(monitoring, 'register_event_duration_secs_listener') + ) + + # Clear registry for hermetic tests. + registry = prometheus_client.REGISTRY + # pylint: disable=protected-access + if hasattr(registry, '_collector_to_names'): + for collector in list(registry._collector_to_names): + registry.unregister(collector) + # pylint: enable=protected-access + + def test_initialize_prometheus_server_called_once(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + with mock.patch.object( + prometheus_client, 'start_http_server', autospec=True + ) as mock_start_http_server: + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=9432) + mock_start_http_server.assert_called_once_with(9432) + orbax_monitoring.initialize(port=9432) + mock_start_http_server.assert_called_once_with(9432) + + def test_initialize(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + monitoring.register_scalar_listener.assert_called_once() # pytype: disable=attribute-error + + def test_multiple_initializations(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring.initialize(port=0) + + def test_record_before_initialize(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring._record_scalar('/jax/orbax/write/test_scalar_early', 123) + orbax_monitoring._record_duration( + '/jax/orbax/write/test_duration_early', 0.5 + ) + orbax_monitoring._record_event('/jax/orbax/write/test_event_early') + self.assertIsNone( + prometheus_client.REGISTRY.get_sample_value( + 'jax_orbax_write_test_scalar_early_sum' + ) + ) + + def test_handler_scalar_metric(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_scalar('/jax/orbax/write/test_scalar', 123) + metric_name = 'jax_orbax_write_test_scalar' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_sum'), + 123.0, + ) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_count'), + 1.0, + ) + + def test_scalar_metric_updates(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_scalar('/jax/orbax/write/test_scalar', 123) + orbax_monitoring._record_scalar('/jax/orbax/write/test_scalar', 456) + metric_name = 'jax_orbax_write_test_scalar' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_sum'), + 579.0, + ) + + def test_handler_duration_metric(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_duration( + '/jax/orbax/write/test_duration_secs', 0.5 + ) + metric_name = 'jax_orbax_write_test_duration_secs' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_count'), + 1.0, + ) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_sum'), 0.5 + ) + + def test_duration_metric_updates(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_duration( + '/jax/orbax/write/test_duration_secs', 0.5 + ) + orbax_monitoring._record_duration( + '/jax/orbax/write/test_duration_secs', 1.5 + ) + metric_name = 'jax_orbax_write_test_duration_secs' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_count'), + 2.0, + ) + + def test_handler_event_metric(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_event('/jax/orbax/write/test_event') + metric_name = 'jax_orbax_write_test_event_total' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name), 1.0 + ) + + def test_event_metric_increments(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_event('/jax/orbax/write/test_event') + orbax_monitoring._record_event('/jax/orbax/write/test_event') + orbax_monitoring._record_event('/jax/orbax/write/test_event') + metric_name = 'jax_orbax_write_test_event_total' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name), 3.0 + ) + + def test_ignore_unrelated_metrics(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_scalar('/jax/compilation/time', 123) + metric_name = 'jax_compilation_time' + self.assertIsNone( + prometheus_client.REGISTRY.get_sample_value(metric_name) + ) + + def test_handler_second_prefix(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_scalar('/jax/checkpoint/write/test_scalar', 123) + metric_name = 'jax_checkpoint_write_test_scalar' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value(metric_name + '_sum'), + 123.0, + ) + + def test_labels(self): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(port=0) + orbax_monitoring._record_event( + '/jax/orbax/write/test_event_label', key1='val1' + ) + metric_name = 'jax_orbax_write_test_event_label_total' + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value( + metric_name, {'key1': 'val1'} + ), + 1.0, + ) + + def test_initialize_multiprocess(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with tempfile.TemporaryDirectory() as tmp_dir: + with mock.patch.dict(os.environ, {'PROMETHEUS_MULTIPROC_DIR': tmp_dir}): + with mock.patch.object( + prometheus_client, 'start_http_server' + ) as mock_start: + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring.initialize(9432) + mock_start.assert_called_once() + args, kwargs = mock_start.call_args + self.assertEqual(args[0], 9432) + self.assertIn('registry', kwargs) + + def test_setup_telemetry_main_process(self): + mock_process = mock.Mock() + mock_process.name = 'MainProcess' + with mock.patch.dict(os.environ, {'ENABLE_ORBAX_TELEMETRY': 'true'}): + with mock.patch( + 'multiprocessing.current_process', return_value=mock_process + ): + with mock.patch.object(orbax_monitoring, 'initialize') as mock_init: + orbax_monitoring.setup_telemetry() + mock_init.assert_called_once_with(port=9432) + + def test_setup_telemetry_worker_process(self): + mock_process = mock.Mock() + mock_process.name = 'Worker-1' + with mock.patch.dict(os.environ, {'ENABLE_ORBAX_TELEMETRY': 'true'}): + with mock.patch( + 'multiprocessing.current_process', return_value=mock_process + ): + with mock.patch.object(orbax_monitoring, 'initialize') as mock_init: + orbax_monitoring.setup_telemetry() + mock_init.assert_called_once_with(port=0) + + def test_initialize_port_already_in_use(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_client, + 'start_http_server', + side_effect=OSError('Address already in use'), + ): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + # Reset initialized state for test + orbax_monitoring._initialized = False # pylint: disable=protected-access + with self.assertLogs(level='INFO') as log: + orbax_monitoring.initialize(9432) + self.assertTrue(orbax_monitoring._initialized) # pylint: disable=protected-access + self.assertTrue( + any('Prometheus server already active' in m for m in log.output) + ) + + def test_initialize_other_oserror(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_client, + 'start_http_server', + side_effect=OSError('Some other error'), + ): + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + # Reset initialized state for test + orbax_monitoring._initialized = False # pylint: disable=protected-access + with self.assertLogs(level='WARNING') as log: + orbax_monitoring.initialize(9432) + self.assertFalse(orbax_monitoring._initialized) # pylint: disable=protected-access + self.assertTrue( + any('Failed to start Prometheus server' in m for m in log.output) + ) + + def test_initialize_multiprocess_import_error(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + original_import_module = importlib.import_module + + def mock_import_module(name, *args, **kwargs): + if name == 'prometheus_client.multiprocess': + raise ImportError('Mocked ImportError') + return original_import_module(name, *args, **kwargs) + + with tempfile.TemporaryDirectory() as tmp_dir: + with mock.patch.dict(os.environ, {'PROMETHEUS_MULTIPROC_DIR': tmp_dir}): + with mock.patch.object( + importlib, 'import_module', side_effect=mock_import_module + ): + with mock.patch.object( + prometheus_client, 'start_http_server' + ) as mock_start: + with mock.patch.object(orbax_monitoring, '_USE_PROMETHEUS', True): + orbax_monitoring._initialized = False # pylint: disable=protected-access + orbax_monitoring.initialize(9432) + mock_start.assert_called_once_with(9432) + self.assertTrue(orbax_monitoring._initialized) # pylint: disable=protected-access + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/pyproject.toml b/checkpoint/pyproject.toml index d8415291c..91d4b113d 100644 --- a/checkpoint/pyproject.toml +++ b/checkpoint/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ 'msgpack', 'jax >= 0.6.0', 'numpy', + 'prometheus-client', 'pyyaml', 'tensorstore >= 0.1.74', 'aiofiles',