diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 9bf7fc96eb8..613d895961c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -3268,7 +3268,7 @@ def wait_for_everyone(self): wait_for_everyone() @on_main_process - def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = {}): + def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = None): """ Initializes a run for all trackers stored in `self.log_with`, potentially with starting configurations @@ -3297,6 +3297,8 @@ def init_trackers(self, project_name: str, config: dict | None = None, init_kwar ... ) ``` """ + if init_kwargs is None: + init_kwargs = {} for tracker in self.log_with: if issubclass(type(tracker), GeneralTracker): # Custom trackers are already initialized @@ -3351,7 +3353,7 @@ def get_tracker(self, name: str, unwrap: bool = False): return GeneralTracker(_blank=True) @on_main_process - def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}): + def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = None): """ Logs `values` to all stored trackers in `self.trackers` on the main process only. @@ -3377,6 +3379,8 @@ def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = { >>> accelerator.log({"loss": 0.5, "accuracy": 0.9}) ``` """ + if log_kwargs is None: + log_kwargs = {} for tracker in self.trackers: tracker.log(values, step=step, **log_kwargs.get(tracker.name, {})) diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 4fee94a61f8..004799cf3b4 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -730,6 +730,24 @@ def test_log(self): } assert data == truth + def test_init_trackers_none_kwargs(self): + with tempfile.TemporaryDirectory() as d: + tracker = MyCustomTracker(d) + accelerator = Accelerator(log_with=tracker) + # Passing init_kwargs=None must not raise AttributeError + accelerator.init_trackers("Some name", init_kwargs=None) + accelerator.end_training() + + def test_log_none_kwargs(self): + with tempfile.TemporaryDirectory() as d: + tracker = MyCustomTracker(d) + accelerator = Accelerator(log_with=tracker) + accelerator.init_trackers("Some name") + values = {"total_loss": 0.1} + # Passing log_kwargs=None must not raise AttributeError + accelerator.log(values, step=0, log_kwargs=None) + accelerator.end_training() + @require_dvclive @mock.patch("dvclive.live.get_dvc_repo", return_value=None)