From 6f4246f29ab8e85fb8acb3645bff17595ff7c668 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Tue, 12 May 2026 01:32:46 -0700 Subject: [PATCH 1/2] fix(tracking): make `step` optional and accept **kwargs in MLflowTracker.log Several tracker `log` / `log_images` methods declared `step: Optional[int]` without a default, even though their docstrings said the parameter was optional. Calling `tracker.log(values)` therefore raised `TypeError: log() missing 1 required positional argument: 'step'` on `GeneralTracker`, `TensorBoardTracker.log_images`, `AimTracker.log`, and `MLflowTracker.log`. `MLflowTracker.log` also did not accept `**kwargs`, so passing per-tracker arguments via `Accelerator.log(log_kwargs={"mlflow": {...}})` raised `TypeError: log() got an unexpected keyword argument`. The forwarded kwargs are now passed through to `mlflow.log_metrics`. Added two regression tests on `MLflowTracker` covering both paths. --- src/accelerate/tracking.py | 16 +++++++--------- tests/test_tracking.py | 28 +++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index b9b199bb3a5..95d910cd9c6 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -144,7 +144,6 @@ def start(self): Lazy initialization of the tracker inside Accelerator to avoid initializing PartialState before InitProcessGroupKwargs. """ - pass def store_init_configuration(self, values: dict): """ @@ -156,9 +155,8 @@ def store_init_configuration(self, values: dict): Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`, `str`, `float`, `int`, or `None`. """ - pass - def log(self, values: dict, step: Optional[int], **kwargs): + def log(self, values: dict, step: Optional[int] = None, **kwargs): """ Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with special behavior for the `step parameter. @@ -169,14 +167,12 @@ def log(self, values: dict, step: Optional[int], **kwargs): step (`int`, *optional*): The run step. If included, the log will be affiliated with this step. """ - pass def finish(self): """ Should run any finalizing functions within the tracking API. If the API should not have one, just don't overwrite that method. """ - pass class TensorBoardTracker(GeneralTracker): @@ -269,7 +265,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): logger.debug("Successfully logged to TensorBoard") @on_main_process - def log_images(self, values: dict, step: Optional[int], **kwargs): + def log_images(self, values: dict, step: Optional[int] = None, **kwargs): """ Logs `images` to the current run. @@ -637,7 +633,7 @@ def store_init_configuration(self, values: dict): self.writer["hparams"] = values @on_main_process - def log(self, values: dict, step: Optional[int], **kwargs): + def log(self, values: dict, step: Optional[int] = None, **kwargs): """ Logs `values` to the current run. @@ -813,7 +809,7 @@ def store_init_configuration(self, values: dict): logger.debug("Stored initial configuration hyperparameters to MLflow") @on_main_process - def log(self, values: dict, step: Optional[int]): + def log(self, values: dict, step: Optional[int] = None, **kwargs): """ Logs `values` to the current run. @@ -822,6 +818,8 @@ def log(self, values: dict, step: Optional[int]): Values to be logged as key-value pairs. step (`int`, *optional*): The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `mlflow.log_metrics` method. """ metrics = {} for k, v in values.items(): @@ -834,7 +832,7 @@ def log(self, values: dict, step: Optional[int]): ) import mlflow - mlflow.log_metrics(metrics, step=step) + mlflow.log_metrics(metrics, step=step, **kwargs) logger.debug("Successfully logged to mlflow") @on_main_process diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 4fee94a61f8..43065805e9e 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -75,7 +75,7 @@ if is_tensorboard_available(): import struct - import tensorboard.compat.proto.event_pb2 as event_pb2 + from tensorboard.compat.proto import event_pb2 if is_dvclive_available(): from dvclive.plots.metric import Metric @@ -322,6 +322,28 @@ def test_log_artifacts(self): ], ) + def test_log_without_step(self): + """`MLflowTracker.log` should treat `step` as optional, matching the docstring.""" + tracker = MLflowTracker(experiment_name="test_exp", logging_dir=self.tmpdir.name) + accelerator = Accelerator(log_with=tracker) + accelerator.init_trackers(project_name="test_exp") + # Should not raise; previously raised TypeError because `step` had no default. + tracker.log({"loss": 0.1}) + accelerator.end_training() + + def test_log_accepts_extra_kwargs(self): + """`MLflowTracker.log` should accept extra kwargs forwarded by `Accelerator.log(log_kwargs=...)`. + + Previously the signature omitted `**kwargs`, so any per-tracker kwarg (e.g. `synchronous`) + passed via `log_kwargs={"mlflow": {...}}` raised TypeError. + """ + tracker = MLflowTracker(experiment_name="test_exp", logging_dir=self.tmpdir.name) + accelerator = Accelerator(log_with=tracker) + accelerator.init_trackers(project_name="test_exp") + # `synchronous` is a real mlflow.log_metrics kwarg; the tracker must forward it. + accelerator.log({"loss": 0.1}, step=1, log_kwargs={"mlflow": {"synchronous": True}}) + accelerator.end_training() + @require_comet_ml class CometMLTest(unittest.TestCase): @@ -678,7 +700,7 @@ def store_init_configuration(self, values: dict): logger.info("Call init") self.writer.writerow(values) - def log(self, values: dict, step: Optional[int]): + def log(self, values: dict, step: Optional[int] = None): logger.info("Call log") self.writer.writerow(values) @@ -770,7 +792,7 @@ def test_log(self, mock_repo): assert latest.pop("step") == 3 assert latest == values scalars = os.path.join(live.plots_dir, Metric.subfolder) - for val in values.keys(): + for val in values: val_path = os.path.join(scalars, f"{val}.tsv") steps = [int(row["step"]) for row in logs[val_path]] assert steps == [0, 1, 3] From 7436eca33d9c29d64f83e38c88ea3bf2ef3d8ce0 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Tue, 12 May 2026 02:01:19 -0700 Subject: [PATCH 2/2] test: assert MLflowTracker.log forwards extra kwargs to log_metrics --- tests/test_tracking.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 43065805e9e..389befe0d1a 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -335,14 +335,24 @@ def test_log_accepts_extra_kwargs(self): """`MLflowTracker.log` should accept extra kwargs forwarded by `Accelerator.log(log_kwargs=...)`. Previously the signature omitted `**kwargs`, so any per-tracker kwarg (e.g. `synchronous`) - passed via `log_kwargs={"mlflow": {...}}` raised TypeError. + passed via `log_kwargs={"mlflow": {...}}` raised TypeError. Beyond not raising, the kwarg + must reach `mlflow.log_metrics` so it actually changes mlflow's behavior. """ tracker = MLflowTracker(experiment_name="test_exp", logging_dir=self.tmpdir.name) accelerator = Accelerator(log_with=tracker) accelerator.init_trackers(project_name="test_exp") # `synchronous` is a real mlflow.log_metrics kwarg; the tracker must forward it. - accelerator.log({"loss": 0.1}, step=1, log_kwargs={"mlflow": {"synchronous": True}}) + with mock.patch("mlflow.log_metrics") as mock_log_metrics: + accelerator.log({"loss": 0.1}, step=1, log_kwargs={"mlflow": {"synchronous": True}}) accelerator.end_training() + mock_log_metrics.assert_called_once() + call_kwargs = mock_log_metrics.call_args.kwargs + assert call_kwargs.get("synchronous") is True, ( + f"expected synchronous=True forwarded to mlflow.log_metrics, got kwargs={call_kwargs}" + ) + assert call_kwargs.get("step") == 1, ( + f"expected step=1 forwarded to mlflow.log_metrics, got kwargs={call_kwargs}" + ) @require_comet_ml