Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3268,7 +3268,7 @@ def wait_for_everyone(self):
wait_for_everyone()

@on_main_process
def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = {}):
def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = None):
"""
Initializes a run for all trackers stored in `self.log_with`, potentially with starting configurations

Expand Down Expand Up @@ -3297,6 +3297,8 @@ def init_trackers(self, project_name: str, config: dict | None = None, init_kwar
... )
```
"""
if init_kwargs is None:
init_kwargs = {}
for tracker in self.log_with:
if issubclass(type(tracker), GeneralTracker):
# Custom trackers are already initialized
Expand Down Expand Up @@ -3351,7 +3353,7 @@ def get_tracker(self, name: str, unwrap: bool = False):
return GeneralTracker(_blank=True)

@on_main_process
def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}):
def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = None):
"""
Logs `values` to all stored trackers in `self.trackers` on the main process only.

Expand All @@ -3377,6 +3379,8 @@ def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {
>>> accelerator.log({"loss": 0.5, "accuracy": 0.9})
```
"""
if log_kwargs is None:
log_kwargs = {}
for tracker in self.trackers:
tracker.log(values, step=step, **log_kwargs.get(tracker.name, {}))

Expand Down
18 changes: 18 additions & 0 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,24 @@ def test_log(self):
}
assert data == truth

def test_init_trackers_none_kwargs(self):
with tempfile.TemporaryDirectory() as d:
tracker = MyCustomTracker(d)
accelerator = Accelerator(log_with=tracker)
# Passing init_kwargs=None must not raise AttributeError
accelerator.init_trackers("Some name", init_kwargs=None)
accelerator.end_training()

def test_log_none_kwargs(self):
with tempfile.TemporaryDirectory() as d:
tracker = MyCustomTracker(d)
accelerator = Accelerator(log_with=tracker)
accelerator.init_trackers("Some name")
values = {"total_loss": 0.1}
# Passing log_kwargs=None must not raise AttributeError
accelerator.log(values, step=0, log_kwargs=None)
accelerator.end_training()


@require_dvclive
@mock.patch("dvclive.live.get_dvc_repo", return_value=None)
Expand Down