Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import contextlib
import functools
import os

from orbax.checkpoint.experimental import v1
from orbax.checkpoint import arrays
Expand Down Expand Up @@ -90,3 +91,9 @@
__version__ = version.__version__
del version


# Autostart Prometheus metrics server if enabled via environment variable.
if os.environ.get('ENABLE_ORBAX_TELEMETRY', 'false').lower() == 'true':
from orbax.checkpoint._src.monitoring import monitoring

monitoring.setup_telemetry()
7 changes: 7 additions & 0 deletions checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import contextlib
import functools
import os

from orbax.checkpoint.experimental import v1
from orbax.checkpoint import arrays
Expand Down Expand Up @@ -90,3 +91,9 @@
__version__ = version.__version__
del version


# Autostart Prometheus metrics server if enabled via environment variable.
if os.environ.get('ENABLE_ORBAX_TELEMETRY', 'false').lower() == 'true':
from orbax.checkpoint._src.monitoring import monitoring

monitoring.setup_telemetry()
208 changes: 208 additions & 0 deletions checkpoint/orbax/checkpoint/_src/monitoring/monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Orbax metrics telemetry."""

from __future__ import annotations

# pylint: disable=g-import-not-at-top
# pylint: disable=invalid-name

import importlib
import logging
import os
import tempfile
import threading

from jax import monitoring

if os.environ.get('ENABLE_ORBAX_TELEMETRY', 'false').lower() == 'true':
if 'PROMETHEUS_MULTIPROC_DIR' not in os.environ:
os.environ['PROMETHEUS_MULTIPROC_DIR'] = tempfile.mkdtemp(
prefix='orbax_metrics_'
)

try:
prometheus_client = importlib.import_module('prometheus_client')

_prom_counter = prometheus_client.Counter # pytype: disable=attribute-error
_prom_histogram = prometheus_client.Histogram # pytype: disable=attribute-error
except (ImportError, AttributeError):
prometheus_client = None
_prom_counter = None
_prom_histogram = None

_USE_PROMETHEUS = True

_initialized = False
_metrics = {} # name -> metric object
_lock = threading.Lock()

_PROMETHEUS_ALLOWED_METRIC_PREFIXES = (
'/jax/orbax/write/',
'/jax/checkpoint/write/',
'/jax/orbax/read/',
)


def _is_allowed(metric_name: str) -> bool:
"""Returns True if the metric is allowed for Prometheus export."""
if not _USE_PROMETHEUS:
return False
for prefix in _PROMETHEUS_ALLOWED_METRIC_PREFIXES:
if metric_name.startswith(prefix):
return True
return False


def _record_event(metric_name: str, **kwargs):
"""JAX monitoring handler for events to route to prometheus-client."""
if not _initialized or not _is_allowed(metric_name) or not _prom_counter:
return
metric_name_safe = metric_name.strip('/').replace('/', '_')
sorted_keys = sorted(kwargs.keys())
labelnames = tuple(sorted_keys)
labelvalues = tuple(str(kwargs[k]) for k in sorted_keys)

if metric_name_safe not in _metrics:
with _lock:
if metric_name_safe not in _metrics:
_metrics[metric_name_safe] = _prom_counter(
metric_name_safe, metric_name, labelnames=labelnames
)

metric = _metrics[metric_name_safe]
if _prom_counter and isinstance(metric, _prom_counter):
if labelnames:
metric.labels(*labelvalues).inc()
else:
metric.inc()


def _record_scalar(metric_name: str, value: float | int, **kwargs):
"""JAX monitoring handler for scalars to route to prometheus-client."""
if not _initialized or not _is_allowed(metric_name) or not _prom_histogram:
return
metric_name_safe = metric_name.strip('/').replace('/', '_')
sorted_keys = sorted(kwargs.keys())
labelnames = tuple(sorted_keys)
labelvalues = tuple(str(kwargs[k]) for k in sorted_keys)

if metric_name_safe not in _metrics:
with _lock:
if metric_name_safe not in _metrics:
_metrics[metric_name_safe] = _prom_histogram(
metric_name_safe, metric_name, labelnames=labelnames
)

metric = _metrics[metric_name_safe]
if _prom_histogram and isinstance(metric, _prom_histogram):
if labelnames:
metric.labels(*labelvalues).observe(value)
else:
metric.observe(value)


def _record_duration(metric_name: str, duration: float | int, **kwargs):
"""JAX monitoring handler for duration to route to prometheus-client."""
if not _initialized or not _is_allowed(metric_name) or not _prom_histogram:
return
metric_name_safe = metric_name.strip('/').replace('/', '_')
sorted_keys = sorted(kwargs.keys())
labelnames = tuple(sorted_keys)
labelvalues = tuple(str(kwargs[k]) for k in sorted_keys)

if metric_name_safe not in _metrics:
with _lock:
if metric_name_safe not in _metrics:
_metrics[metric_name_safe] = _prom_histogram(
metric_name_safe, metric_name, labelnames=labelnames
)

metric = _metrics[metric_name_safe]
if _prom_histogram and isinstance(metric, _prom_histogram):
if labelnames:
metric.labels(*labelvalues).observe(duration)
else:
metric.observe(duration)


def initialize(port=9432):
"""Initializes Orbax metric reporting."""
global _initialized
if _initialized:
return
if not _USE_PROMETHEUS:
return

if not prometheus_client:
logging.warning(
'prometheus-client not found. Orbax metrics will not be reported.'
)
return

with _lock:
if _initialized:
return

if port > 0:
try:
# If multiprocess directory is configured, use MultiProcessCollector
# to aggregate metrics from all worker processes.
multiprocess_started = False
if 'PROMETHEUS_MULTIPROC_DIR' in os.environ:
try:
multiprocess = importlib.import_module(
'prometheus_client.multiprocess'
)
registry = prometheus_client.CollectorRegistry() # pytype: disable=attribute-error
multiprocess.MultiProcessCollector(registry) # pytype: disable=attribute-error
prometheus_client.start_http_server(port, registry=registry) # pytype: disable=attribute-error
logging.info(
'Prometheus multiprocess metrics server started on port %s.',
port,
)
multiprocess_started = True
except (ImportError, AttributeError):
pass

if not multiprocess_started:
# Standard single-process server
prometheus_client.start_http_server(port) # pytype: disable=attribute-error
logging.info('Prometheus metrics server started on port %s.', port)
except (OSError, ValueError) as e:
# Handle 'already in use' for Linux/macOS and Windows (10048).
if 'already in use' not in str(e) and '10048' not in str(e):
logging.warning('Failed to start Prometheus server: %s', e)
return
# If the server is already running (e.g. started by Grain), just
# register listeners.
logging.info('Prometheus server already active.')

_initialized = True
monitoring.register_event_listener(_record_event)
monitoring.register_scalar_listener(_record_scalar)
monitoring.register_event_duration_secs_listener(_record_duration)
logging.info('Installed JAX monitoring listeners for Prometheus.')


def setup_telemetry():
"""Autostarts Prometheus metrics server if enabled via environment variable."""
if os.environ.get('ENABLE_ORBAX_TELEMETRY', 'false').lower() == 'true':
import multiprocessing

if multiprocessing.current_process().name == 'MainProcess':
initialize(port=9432)
else:
initialize(port=0)
Loading
Loading