diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index b9b199bb3a5..e6868fb2975 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -795,16 +795,16 @@ def store_init_configuration(self, values: dict): """ import mlflow - for name, value in list(values.items()): + values_list = [] + for name, value in values.items(): # internally, all values are converted to str in MLflow if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH: logger.warning_once( f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s' f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute." ) - del values[name] - - values_list = list(values.items()) + else: + values_list.append((name, value)) # MLflow cannot log more than 100 values in one go, so we have to split it for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH): diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 4fee94a61f8..618f27619bb 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -322,6 +322,32 @@ def test_log_artifacts(self): ], ) + def test_store_init_configuration_does_not_mutate_input(self): + import mlflow + + """`store_init_configuration` must not mutate the caller's dict when dropping over-long values. + + The same `config` object is shared across all trackers in `Accelerator.init_trackers`, so deleting + keys in place removed them from the user's dict and from every other tracker too. + """ + too_long = "x" * (mlflow.utils.validation.MAX_PARAM_VAL_LENGTH + 1) + values = {"learning_rate": 0.001, "too_long": too_long} + tracker = MLflowTracker(experiment_name="test_exp", logging_dir=self.tmpdir.name) + accelerator = Accelerator(log_with=tracker) + accelerator.init_trackers(project_name="test_exp") + tracker.store_init_configuration(values) + + run_id = tracker.active_run.info.run_id + accelerator.end_training() + + # The caller's dict must be left untouched (previously `too_long` was deleted in place). + self.assertEqual(values, {"learning_rate": 0.001, "too_long": too_long}) + + # The over-long value is still dropped from what is actually logged to MLflow. + params = mlflow.get_run(run_id).data.params + self.assertIn("learning_rate", params) + self.assertNotIn("too_long", params) + @require_comet_ml class CometMLTest(unittest.TestCase):