Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def start(self):
Lazy initialization of the tracker inside Accelerator to avoid initializing PartialState before
InitProcessGroupKwargs.
"""
pass

def store_init_configuration(self, values: dict):
"""
Expand All @@ -156,9 +155,8 @@ def store_init_configuration(self, values: dict):
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
`str`, `float`, `int`, or `None`.
"""
pass

def log(self, values: dict, step: Optional[int], **kwargs):
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with
special behavior for the `step parameter.
Expand All @@ -169,14 +167,12 @@ def log(self, values: dict, step: Optional[int], **kwargs):
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
"""
pass

def finish(self):
"""
Should run any finalizing functions within the tracking API. If the API should not have one, just don't
overwrite that method.
"""
pass


class TensorBoardTracker(GeneralTracker):
Expand Down Expand Up @@ -269,7 +265,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs):
logger.debug("Successfully logged to TensorBoard")

@on_main_process
def log_images(self, values: dict, step: Optional[int], **kwargs):
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `images` to the current run.

Expand Down Expand Up @@ -637,7 +633,7 @@ def store_init_configuration(self, values: dict):
self.writer["hparams"] = values

@on_main_process
def log(self, values: dict, step: Optional[int], **kwargs):
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.

Expand Down Expand Up @@ -813,7 +809,7 @@ def store_init_configuration(self, values: dict):
logger.debug("Stored initial configuration hyperparameters to MLflow")

@on_main_process
def log(self, values: dict, step: Optional[int]):
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.

Expand All @@ -822,6 +818,8 @@ def log(self, values: dict, step: Optional[int]):
Values to be logged as key-value pairs.
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs:
Additional key word arguments passed along to the `mlflow.log_metrics` method.
"""
metrics = {}
for k, v in values.items():
Expand All @@ -834,7 +832,7 @@ def log(self, values: dict, step: Optional[int]):
)
import mlflow

mlflow.log_metrics(metrics, step=step)
mlflow.log_metrics(metrics, step=step, **kwargs)
logger.debug("Successfully logged to mlflow")

@on_main_process
Expand Down
38 changes: 35 additions & 3 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
if is_tensorboard_available():
import struct

import tensorboard.compat.proto.event_pb2 as event_pb2
from tensorboard.compat.proto import event_pb2

if is_dvclive_available():
from dvclive.plots.metric import Metric
Expand Down Expand Up @@ -322,6 +322,38 @@ def test_log_artifacts(self):
],
)

def test_log_without_step(self):
"""`MLflowTracker.log` should treat `step` as optional, matching the docstring."""
tracker = MLflowTracker(experiment_name="test_exp", logging_dir=self.tmpdir.name)
accelerator = Accelerator(log_with=tracker)
accelerator.init_trackers(project_name="test_exp")
# Should not raise; previously raised TypeError because `step` had no default.
tracker.log({"loss": 0.1})
accelerator.end_training()

def test_log_accepts_extra_kwargs(self):
"""`MLflowTracker.log` should accept extra kwargs forwarded by `Accelerator.log(log_kwargs=...)`.

Previously the signature omitted `**kwargs`, so any per-tracker kwarg (e.g. `synchronous`)
passed via `log_kwargs={"mlflow": {...}}` raised TypeError. Beyond not raising, the kwarg
must reach `mlflow.log_metrics` so it actually changes mlflow's behavior.
"""
tracker = MLflowTracker(experiment_name="test_exp", logging_dir=self.tmpdir.name)
accelerator = Accelerator(log_with=tracker)
accelerator.init_trackers(project_name="test_exp")
# `synchronous` is a real mlflow.log_metrics kwarg; the tracker must forward it.
with mock.patch("mlflow.log_metrics") as mock_log_metrics:
accelerator.log({"loss": 0.1}, step=1, log_kwargs={"mlflow": {"synchronous": True}})
accelerator.end_training()
mock_log_metrics.assert_called_once()
call_kwargs = mock_log_metrics.call_args.kwargs
assert call_kwargs.get("synchronous") is True, (
f"expected synchronous=True forwarded to mlflow.log_metrics, got kwargs={call_kwargs}"
)
assert call_kwargs.get("step") == 1, (
f"expected step=1 forwarded to mlflow.log_metrics, got kwargs={call_kwargs}"
)


@require_comet_ml
class CometMLTest(unittest.TestCase):
Expand Down Expand Up @@ -678,7 +710,7 @@ def store_init_configuration(self, values: dict):
logger.info("Call init")
self.writer.writerow(values)

def log(self, values: dict, step: Optional[int]):
def log(self, values: dict, step: Optional[int] = None):
logger.info("Call log")
self.writer.writerow(values)

Expand Down Expand Up @@ -770,7 +802,7 @@ def test_log(self, mock_repo):
assert latest.pop("step") == 3
assert latest == values
scalars = os.path.join(live.plots_dir, Metric.subfolder)
for val in values.keys():
for val in values:
val_path = os.path.join(scalars, f"{val}.tsv")
steps = [int(row["step"]) for row in logs[val_path]]
assert steps == [0, 1, 3]
Expand Down