feat(train): Add native MLflowLoggerCallback for Train v2 API#64234
feat(train): Add native MLflowLoggerCallback for Train v2 API#64234ArjunPakhan wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces dynamic configuration of maximum concurrent batches in Ray Serve's batching queue and adds a native MLflow logging callback (MLflowLoggerCallback) with associated unit tests for Ray Train v2. The review feedback highlights critical indentation errors in batching.py that would cause compilation or runtime failures, as well as an issue with directly modifying the private _value attribute of asyncio.Semaphore. For the MLflow callback, the reviewer recommends optimizing performance by using batch logging instead of sequential HTTP requests, correcting directory logging by using log_artifacts instead of log_artifact, and implementing a more robust cleanup mechanism via an explicit close method rather than relying solely on __del__.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def set_max_concurrent_batches(self, new_max_concurrent_batches: int) -> None: | ||
| """Safely updates queue's max_concurrent_batches and modifies semaphore limits.""" | ||
| old_max = self.max_concurrent_batches | ||
| self.max_concurrent_batches = new_max_concurrent_batches | ||
|
|
||
| # Calculate the delta between old and new limits | ||
| delta = new_max_concurrent_batches - old_max | ||
|
|
||
| if delta > 0: | ||
| # If the limit increased, release the extra tokens onto the existing semaphore | ||
| for _ in range(delta): | ||
| self.semaphore.release() | ||
| elif delta < 0: | ||
| # If the limit decreased, manually lower the internal permit value | ||
| # asyncio.Semaphore tracks total capacity via its private _value attribute | ||
| self.semaphore._value += delta | ||
|
|
||
| # Fix Issue #3: Trigger the validation safety warning | ||
| self._warn_if_max_batch_size_exceeds_max_ongoing_requests() |
There was a problem hiding this comment.
The method set_max_concurrent_batches is defined with 0 indentation, making it a module-level function instead of a method of _BatchQueue. This will cause runtime errors when called. Additionally, modifying self.semaphore._value directly will raise an AttributeError on standard CPython because the C-optimized asyncio.Semaphore has a read-only _value attribute.
To fix both issues, indent the method correctly and use a safe, standard asyncio approach to dynamically adjust the semaphore capacity by managing background permit-holding tasks.
def set_max_concurrent_batches(self, new_max_concurrent_batches: int) -> None:
"""Safely updates queue's max_concurrent_batches and modifies semaphore limits."""
old_max = self.max_concurrent_batches
self.max_concurrent_batches = new_max_concurrent_batches
delta = new_max_concurrent_batches - old_max
if not hasattr(self, "_holding_tasks"):
self._holding_tasks = []
if delta > 0:
while delta > 0 and self._holding_tasks:
task = self._holding_tasks.pop()
task.cancel()
delta -= 1
for _ in range(delta):
self.semaphore.release()
elif delta < 0:
async def hold_permit(sem):
try:
await sem.acquire()
try:
await asyncio.Event().wait()
finally:
sem.release()
except asyncio.CancelledError:
pass
for _ in range(abs(delta)):
task = asyncio.create_task(hold_permit(self.semaphore))
self._holding_tasks.append(task)
self._warn_if_max_batch_size_exceeds_max_ongoing_requests()| wrapper.set_max_batch_size = lazy_batch_queue_wrapper.set_max_batch_size | ||
| wrapper.set_batch_wait_timeout_s = lazy_batch_queue_wrapper.set_batch_wait_timeout_s | ||
| # Fix Bugbot & Gemini Bot: Expose the new APIs to the outer function wrapper handle | ||
| wrapper.set_max_concurrent_batches = lazy_batch_queue_wrapper.set_max_concurrent_batches | ||
| wrapper.get_max_concurrent_batches = lazy_batch_queue_wrapper.get_max_concurrent_batches |
There was a problem hiding this comment.
The assignment wrapper.set_max_batch_size has 0 indentation, which will cause an IndentationError at import/compile time. Additionally, wrapper.set_max_batch_size and wrapper.set_batch_wait_timeout_s are already assigned with correct indentation on lines 1029-1032. We can safely remove the duplicate assignments here and only keep the new ones with correct indentation.
wrapper.set_max_concurrent_batches = lazy_batch_queue_wrapper.set_max_concurrent_batches
wrapper.get_max_concurrent_batches = lazy_batch_queue_wrapper.get_max_concurrent_batches| for k, v in rank_0_metrics.items(): | ||
| if isinstance(v, (int, float)): | ||
| self.client.log_metric(self._run_id, k, v, step=step) |
There was a problem hiding this comment.
Logging metrics one-by-one in a loop using self.client.log_metric makes self.client.log_batch to log all metrics in a single HTTP request.
metrics_to_log = []
timestamp = int(time.time() * 1000)
for k, v in rank_0_metrics.items():
if isinstance(v, (int, float)):
metrics_to_log.append(Metric(key=k, value=v, timestamp=timestamp, step=step))
if metrics_to_log:
self.client.log_batch(self._run_id, metrics=metrics_to_log)| self.client.log_artifact( | ||
| self._run_id, | ||
| checkpoint_dir, | ||
| artifact_path=f"checkpoints/step_{step}", | ||
| ) |
There was a problem hiding this comment.
self.client.log_artifact is designed to log a single file. To log a directory (such as the checkpoint directory), use self.client.log_artifacts instead. Using log_artifact on a directory can fail or behave unexpectedly depending on the MLflow version and storage backend.
| self.client.log_artifact( | |
| self._run_id, | |
| checkpoint_dir, | |
| artifact_path=f"checkpoints/step_{step}", | |
| ) | |
| self.client.log_artifacts( | |
| self._run_id, | |
| checkpoint_dir, | |
| artifact_path=f"checkpoints/step_{step}", | |
| ) |
| import logging | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| from mlflow.tracking import MlflowClient | ||
| from ray.train import Checkpoint | ||
| from ray.train.v2._internal.execution.context import TrainRunContext | ||
| from ray.train.v2.api.callback import UserCallback | ||
| from ray.util.annotations import DeveloperAPI |
There was a problem hiding this comment.
Add imports for time and Metric to support the optimized batch metric logging.
| import logging | |
| from typing import Any, Dict, List, Optional | |
| from mlflow.tracking import MlflowClient | |
| from ray.train import Checkpoint | |
| from ray.train.v2._internal.execution.context import TrainRunContext | |
| from ray.train.v2.api.callback import UserCallback | |
| from ray.util.annotations import DeveloperAPI | |
| import logging | |
| import time | |
| from typing import Any, Dict, List, Optional | |
| from mlflow.entities import Metric | |
| from mlflow.tracking import MlflowClient | |
| from ray.train import Checkpoint | |
| from ray.train.v2._internal.execution.context import TrainRunContext | |
| from ray.train.v2.api.callback import UserCallback | |
| from ray.util.annotations import DeveloperAPI |
| def __del__(self): | ||
| """Ensure the run status is closed out as FINISHED when the training workflow concludes.""" | ||
| if hasattr(self, "_run_id") and self._run_id: | ||
| try: | ||
| self.client.set_terminated(self._run_id, status="FINISHED") | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Relying solely on __del__ to close out the MLflow run is unreliable because Python does not guarantee __del__ execution, and during interpreter shutdown, global modules/attributes (like self.client) may already be cleaned up, leading to runtime errors.
Add an explicit close() method to safely terminate the run, and make __del__ robust by checking if the client and run ID are still available.
| def __del__(self): | |
| """Ensure the run status is closed out as FINISHED when the training workflow concludes.""" | |
| if hasattr(self, "_run_id") and self._run_id: | |
| try: | |
| self.client.set_terminated(self._run_id, status="FINISHED") | |
| except Exception: | |
| pass | |
| def close(self): | |
| """Explicitly terminate the active MLflow run.""" | |
| if getattr(self, "_run_id", None): | |
| try: | |
| self.client.set_terminated(self._run_id, status="FINISHED") | |
| except Exception as e: | |
| logger.warning(f"Failed to terminate MLflow run: {e}") | |
| self._run_id = None | |
| def __del__(self): | |
| """Ensure the run status is closed out as FINISHED when the training workflow concludes.""" | |
| if getattr(self, "_run_id", None) and getattr(self, "client", None): | |
| try: | |
| self.close() | |
| except Exception: | |
| pass |
| try: | ||
| self.close() | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Success leaves MLflow RUNNING
Medium Severity
On a successful train run, the callback never marks the MLflow run FINISHED except via close() or __del__. Train v2’s UserCallbackHandler only calls after_report and after_exception, so completed jobs often stay RUNNING in MLflow unless finalization runs from destructor timing.
Reviewed by Cursor Bugbot for commit e8290a5. Configure here.
475bb98 to
8073318
Compare
8073318 to
3dfd91e
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
There are 2 total unresolved issues (including 1 from previous review).
Reviewed by Cursor Bugbot for commit 3dfd91e. Configure here.
| if isinstance(v, (int, float)): | ||
| metrics_to_log.append( | ||
| Metric(key=k, value=v, timestamp=timestamp, step=step) | ||
| ) |
There was a problem hiding this comment.
MLflow skips numeric metrics
Medium Severity
after_report only logs values passing isinstance(v, (int, float)), so common training metrics (numpy scalars, torch tensors converted to metric dicts, etc.) are silently omitted from log_batch even though MLflow accepts them as floats.
Reviewed by Cursor Bugbot for commit 3dfd91e. Configure here.


Description
This PR introduces a native
MLflowLoggerCallbackfor the new Ray Train v2 API framework.The callback uses an explicit
MlflowClientbackend instance rather than relying on process-global state wrappers (mlflow.*). This decoupled architecture ensures thread-safe, process-agnostic metric tracking and database synchronization from remoteTrainControlleractor processes.A robust local integration test suite has also been added to verify metric recording stability during successful runs and proper lifecycle termination (
FAILEDstatus updates) during unexpected worker exceptions.Related issues
Fixes #
Additional information
MLflowLoggerCallbacktoray.train.v2.api.blackandisort.