diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index b9b199bb3a5..95d910cd9c6 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -144,7 +144,6 @@ def start(self): Lazy initialization of the tracker inside Accelerator to avoid initializing PartialState before InitProcessGroupKwargs. """ - pass def store_init_configuration(self, values: dict): """ @@ -156,9 +155,8 @@ def store_init_configuration(self, values: dict): Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`, `str`, `float`, `int`, or `None`. """ - pass - def log(self, values: dict, step: Optional[int], **kwargs): + def log(self, values: dict, step: Optional[int] = None, **kwargs): """ Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with special behavior for the `step parameter. @@ -169,14 +167,12 @@ def log(self, values: dict, step: Optional[int], **kwargs): step (`int`, *optional*): The run step. If included, the log will be affiliated with this step. """ - pass def finish(self): """ Should run any finalizing functions within the tracking API. If the API should not have one, just don't overwrite that method. """ - pass class TensorBoardTracker(GeneralTracker): @@ -269,7 +265,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): logger.debug("Successfully logged to TensorBoard") @on_main_process - def log_images(self, values: dict, step: Optional[int], **kwargs): + def log_images(self, values: dict, step: Optional[int] = None, **kwargs): """ Logs `images` to the current run. @@ -637,7 +633,7 @@ def store_init_configuration(self, values: dict): self.writer["hparams"] = values @on_main_process - def log(self, values: dict, step: Optional[int], **kwargs): + def log(self, values: dict, step: Optional[int] = None, **kwargs): """ Logs `values` to the current run. @@ -813,7 +809,7 @@ def store_init_configuration(self, values: dict): logger.debug("Stored initial configuration hyperparameters to MLflow") @on_main_process - def log(self, values: dict, step: Optional[int]): + def log(self, values: dict, step: Optional[int] = None, **kwargs): """ Logs `values` to the current run. @@ -822,6 +818,8 @@ def log(self, values: dict, step: Optional[int]): Values to be logged as key-value pairs. step (`int`, *optional*): The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `mlflow.log_metrics` method. """ metrics = {} for k, v in values.items(): @@ -834,7 +832,7 @@ def log(self, values: dict, step: Optional[int]): ) import mlflow - mlflow.log_metrics(metrics, step=step) + mlflow.log_metrics(metrics, step=step, **kwargs) logger.debug("Successfully logged to mlflow") @on_main_process diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 4fee94a61f8..389befe0d1a 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -75,7 +75,7 @@ if is_tensorboard_available(): import struct - import tensorboard.compat.proto.event_pb2 as event_pb2 + from tensorboard.compat.proto import event_pb2 if is_dvclive_available(): from dvclive.plots.metric import Metric @@ -322,6 +322,38 @@ def test_log_artifacts(self): ], ) + def test_log_without_step(self): + """`MLflowTracker.log` should treat `step` as optional, matching the docstring.""" + tracker = MLflowTracker(experiment_name="test_exp", logging_dir=self.tmpdir.name) + accelerator = Accelerator(log_with=tracker) + accelerator.init_trackers(project_name="test_exp") + # Should not raise; previously raised TypeError because `step` had no default. + tracker.log({"loss": 0.1}) + accelerator.end_training() + + def test_log_accepts_extra_kwargs(self): + """`MLflowTracker.log` should accept extra kwargs forwarded by `Accelerator.log(log_kwargs=...)`. + + Previously the signature omitted `**kwargs`, so any per-tracker kwarg (e.g. `synchronous`) + passed via `log_kwargs={"mlflow": {...}}` raised TypeError. Beyond not raising, the kwarg + must reach `mlflow.log_metrics` so it actually changes mlflow's behavior. + """ + tracker = MLflowTracker(experiment_name="test_exp", logging_dir=self.tmpdir.name) + accelerator = Accelerator(log_with=tracker) + accelerator.init_trackers(project_name="test_exp") + # `synchronous` is a real mlflow.log_metrics kwarg; the tracker must forward it. + with mock.patch("mlflow.log_metrics") as mock_log_metrics: + accelerator.log({"loss": 0.1}, step=1, log_kwargs={"mlflow": {"synchronous": True}}) + accelerator.end_training() + mock_log_metrics.assert_called_once() + call_kwargs = mock_log_metrics.call_args.kwargs + assert call_kwargs.get("synchronous") is True, ( + f"expected synchronous=True forwarded to mlflow.log_metrics, got kwargs={call_kwargs}" + ) + assert call_kwargs.get("step") == 1, ( + f"expected step=1 forwarded to mlflow.log_metrics, got kwargs={call_kwargs}" + ) + @require_comet_ml class CometMLTest(unittest.TestCase): @@ -678,7 +710,7 @@ def store_init_configuration(self, values: dict): logger.info("Call init") self.writer.writerow(values) - def log(self, values: dict, step: Optional[int]): + def log(self, values: dict, step: Optional[int] = None): logger.info("Call log") self.writer.writerow(values) @@ -770,7 +802,7 @@ def test_log(self, mock_repo): assert latest.pop("step") == 3 assert latest == values scalars = os.path.join(live.plots_dir, Metric.subfolder) - for val in values.keys(): + for val in values: val_path = os.path.join(scalars, f"{val}.tsv") steps = [int(row["step"]) for row in logs[val_path]] assert steps == [0, 1, 3]