Skip to content

feat(train): Add native MLflowLoggerCallback for Train v2 API#64234

Open
ArjunPakhan wants to merge 3 commits into
ray-project:masterfrom
ArjunPakhan:feat/train-v2-mlflow-callback
Open

feat(train): Add native MLflowLoggerCallback for Train v2 API#64234
ArjunPakhan wants to merge 3 commits into
ray-project:masterfrom
ArjunPakhan:feat/train-v2-mlflow-callback

Conversation

@ArjunPakhan

Copy link
Copy Markdown

Description

This PR introduces a native MLflowLoggerCallback for the new Ray Train v2 API framework.

The callback uses an explicit MlflowClient backend instance rather than relying on process-global state wrappers (mlflow.*). This decoupled architecture ensures thread-safe, process-agnostic metric tracking and database synchronization from remote TrainController actor processes.

A robust local integration test suite has also been added to verify metric recording stability during successful runs and proper lifecycle termination (FAILED status updates) during unexpected worker exceptions.

Related issues

Fixes #

Additional information

  • API Changes: Adds MLflowLoggerCallback to ray.train.v2.api.
  • Formatters: Codebase verified and cleaned locally via black and isort.
  • Testing: 100% locally verified using the following execution suite:
    python3 -m pytest venv/lib/python3.12/site-packages/ray/train/v2/tests/test_mlflow_callback.py -v

@ArjunPakhan ArjunPakhan requested review from a team as code owners June 20, 2026 16:51
Comment thread python/ray/serve/batching.py Outdated
Comment thread python/ray/train/v2/api/mlflow.py

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces dynamic configuration of maximum concurrent batches in Ray Serve's batching queue and adds a native MLflow logging callback (MLflowLoggerCallback) with associated unit tests for Ray Train v2. The review feedback highlights critical indentation errors in batching.py that would cause compilation or runtime failures, as well as an issue with directly modifying the private _value attribute of asyncio.Semaphore. For the MLflow callback, the reviewer recommends optimizing performance by using batch logging instead of sequential HTTP requests, correcting directory logging by using log_artifacts instead of log_artifact, and implementing a more robust cleanup mechanism via an explicit close method rather than relying solely on __del__.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread python/ray/serve/batching.py Outdated
Comment on lines +233 to +251
def set_max_concurrent_batches(self, new_max_concurrent_batches: int) -> None:
"""Safely updates queue's max_concurrent_batches and modifies semaphore limits."""
old_max = self.max_concurrent_batches
self.max_concurrent_batches = new_max_concurrent_batches

# Calculate the delta between old and new limits
delta = new_max_concurrent_batches - old_max

if delta > 0:
# If the limit increased, release the extra tokens onto the existing semaphore
for _ in range(delta):
self.semaphore.release()
elif delta < 0:
# If the limit decreased, manually lower the internal permit value
# asyncio.Semaphore tracks total capacity via its private _value attribute
self.semaphore._value += delta

# Fix Issue #3: Trigger the validation safety warning
self._warn_if_max_batch_size_exceeds_max_ongoing_requests()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The method set_max_concurrent_batches is defined with 0 indentation, making it a module-level function instead of a method of _BatchQueue. This will cause runtime errors when called. Additionally, modifying self.semaphore._value directly will raise an AttributeError on standard CPython because the C-optimized asyncio.Semaphore has a read-only _value attribute.

To fix both issues, indent the method correctly and use a safe, standard asyncio approach to dynamically adjust the semaphore capacity by managing background permit-holding tasks.

    def set_max_concurrent_batches(self, new_max_concurrent_batches: int) -> None:
        """Safely updates queue's max_concurrent_batches and modifies semaphore limits."""
        old_max = self.max_concurrent_batches
        self.max_concurrent_batches = new_max_concurrent_batches
        
        delta = new_max_concurrent_batches - old_max
        if not hasattr(self, "_holding_tasks"):
            self._holding_tasks = []

        if delta > 0:
            while delta > 0 and self._holding_tasks:
                task = self._holding_tasks.pop()
                task.cancel()
                delta -= 1
            for _ in range(delta):
                self.semaphore.release()
        elif delta < 0:
            async def hold_permit(sem):
                try:
                    await sem.acquire()
                    try:
                        await asyncio.Event().wait()
                    finally:
                        sem.release()
                except asyncio.CancelledError:
                    pass

            for _ in range(abs(delta)):
                task = asyncio.create_task(hold_permit(self.semaphore))
                self._holding_tasks.append(task)

        self._warn_if_max_batch_size_exceeds_max_ongoing_requests()

Comment thread python/ray/serve/batching.py Outdated
Comment on lines +1020 to +1024
wrapper.set_max_batch_size = lazy_batch_queue_wrapper.set_max_batch_size
wrapper.set_batch_wait_timeout_s = lazy_batch_queue_wrapper.set_batch_wait_timeout_s
# Fix Bugbot & Gemini Bot: Expose the new APIs to the outer function wrapper handle
wrapper.set_max_concurrent_batches = lazy_batch_queue_wrapper.set_max_concurrent_batches
wrapper.get_max_concurrent_batches = lazy_batch_queue_wrapper.get_max_concurrent_batches

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assignment wrapper.set_max_batch_size has 0 indentation, which will cause an IndentationError at import/compile time. Additionally, wrapper.set_max_batch_size and wrapper.set_batch_wait_timeout_s are already assigned with correct indentation on lines 1029-1032. We can safely remove the duplicate assignments here and only keep the new ones with correct indentation.

        wrapper.set_max_concurrent_batches = lazy_batch_queue_wrapper.set_max_concurrent_batches
        wrapper.get_max_concurrent_batches = lazy_batch_queue_wrapper.get_max_concurrent_batches

Comment thread python/ray/train/v2/api/mlflow.py Outdated
Comment on lines +67 to +69
for k, v in rank_0_metrics.items():
if isinstance(v, (int, float)):
self.client.log_metric(self._run_id, k, v, step=step)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Logging metrics one-by-one in a loop using self.client.log_metric makes $O(N)$ sequential HTTP requests to the MLflow tracking server per step. This can introduce severe performance bottlenecks and slow down the training loop. Use self.client.log_batch to log all metrics in a single HTTP request.

        metrics_to_log = []
        timestamp = int(time.time() * 1000)
        for k, v in rank_0_metrics.items():
            if isinstance(v, (int, float)):
                metrics_to_log.append(Metric(key=k, value=v, timestamp=timestamp, step=step))

        if metrics_to_log:
            self.client.log_batch(self._run_id, metrics=metrics_to_log)

Comment thread python/ray/train/v2/api/mlflow.py Outdated
Comment on lines +74 to +78
self.client.log_artifact(
self._run_id,
checkpoint_dir,
artifact_path=f"checkpoints/step_{step}",
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.client.log_artifact is designed to log a single file. To log a directory (such as the checkpoint directory), use self.client.log_artifacts instead. Using log_artifact on a directory can fail or behave unexpectedly depending on the MLflow version and storage backend.

Suggested change
self.client.log_artifact(
self._run_id,
checkpoint_dir,
artifact_path=f"checkpoints/step_{step}",
)
self.client.log_artifacts(
self._run_id,
checkpoint_dir,
artifact_path=f"checkpoints/step_{step}",
)

Comment on lines +1 to +8
import logging
from typing import Any, Dict, List, Optional

from mlflow.tracking import MlflowClient
from ray.train import Checkpoint
from ray.train.v2._internal.execution.context import TrainRunContext
from ray.train.v2.api.callback import UserCallback
from ray.util.annotations import DeveloperAPI

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add imports for time and Metric to support the optimized batch metric logging.

Suggested change
import logging
from typing import Any, Dict, List, Optional
from mlflow.tracking import MlflowClient
from ray.train import Checkpoint
from ray.train.v2._internal.execution.context import TrainRunContext
from ray.train.v2.api.callback import UserCallback
from ray.util.annotations import DeveloperAPI
import logging
import time
from typing import Any, Dict, List, Optional
from mlflow.entities import Metric
from mlflow.tracking import MlflowClient
from ray.train import Checkpoint
from ray.train.v2._internal.execution.context import TrainRunContext
from ray.train.v2.api.callback import UserCallback
from ray.util.annotations import DeveloperAPI

Comment on lines +89 to +95
def __del__(self):
"""Ensure the run status is closed out as FINISHED when the training workflow concludes."""
if hasattr(self, "_run_id") and self._run_id:
try:
self.client.set_terminated(self._run_id, status="FINISHED")
except Exception:
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Relying solely on __del__ to close out the MLflow run is unreliable because Python does not guarantee __del__ execution, and during interpreter shutdown, global modules/attributes (like self.client) may already be cleaned up, leading to runtime errors.

Add an explicit close() method to safely terminate the run, and make __del__ robust by checking if the client and run ID are still available.

Suggested change
def __del__(self):
"""Ensure the run status is closed out as FINISHED when the training workflow concludes."""
if hasattr(self, "_run_id") and self._run_id:
try:
self.client.set_terminated(self._run_id, status="FINISHED")
except Exception:
pass
def close(self):
"""Explicitly terminate the active MLflow run."""
if getattr(self, "_run_id", None):
try:
self.client.set_terminated(self._run_id, status="FINISHED")
except Exception as e:
logger.warning(f"Failed to terminate MLflow run: {e}")
self._run_id = None
def __del__(self):
"""Ensure the run status is closed out as FINISHED when the training workflow concludes."""
if getattr(self, "_run_id", None) and getattr(self, "client", None):
try:
self.close()
except Exception:
pass

Comment thread python/ray/serve/batching.py
try:
self.close()
except Exception:
pass

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Success leaves MLflow RUNNING

Medium Severity

On a successful train run, the callback never marks the MLflow run FINISHED except via close() or __del__. Train v2’s UserCallbackHandler only calls after_report and after_exception, so completed jobs often stay RUNNING in MLflow unless finalization runs from destructor timing.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e8290a5. Configure here.

@ArjunPakhan ArjunPakhan force-pushed the feat/train-v2-mlflow-callback branch 2 times, most recently from 475bb98 to 8073318 Compare June 20, 2026 17:47
@ArjunPakhan ArjunPakhan force-pushed the feat/train-v2-mlflow-callback branch from 8073318 to 3dfd91e Compare June 20, 2026 17:47

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

There are 2 total unresolved issues (including 1 from previous review).

Fix All in Cursor

Reviewed by Cursor Bugbot for commit 3dfd91e. Configure here.

if isinstance(v, (int, float)):
metrics_to_log.append(
Metric(key=k, value=v, timestamp=timestamp, step=step)
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLflow skips numeric metrics

Medium Severity

after_report only logs values passing isinstance(v, (int, float)), so common training metrics (numpy scalars, torch tensors converted to metric dicts, etc.) are silently omitted from log_batch even though MLflow accepts them as floats.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 3dfd91e. Configure here.

@ray-gardener ray-gardener Bot added train Ray Train Related Issue community-contribution Contributed by the community labels Jun 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant