diff --git a/src/accelerate/logging.py b/src/accelerate/logging.py index 9132f8cbf4b..58cf7ca4eeb 100644 --- a/src/accelerate/logging.py +++ b/src/accelerate/logging.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import functools import logging import os @@ -78,16 +77,22 @@ def log(self, level, msg, *args, **kwargs): self.logger.log(level, msg, *args, **kwargs) state.wait_for_everyone() - @functools.lru_cache(None) - def warning_once(self, *args, **kwargs): + def warning_once(self, msg, *args, **kwargs): """ - This method is identical to `logger.warning()`, but will emit the warning with the same message only once + Like `warning()`, but emits each unique message only once per adapter. - Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the - cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to - switch to another type of cache that includes the caller frame information in the hashing function. + The cache is keyed on the message text, so passing the standard + `extra={...}` kwarg (or any other unhashable value) doesn't crash + the way it did with the previous `lru_cache` decorator. """ - self.warning(*args, **kwargs) + cache = getattr(self, "_warning_once_cache", None) + if cache is None: + cache = set() + self._warning_once_cache = cache + if msg in cache: + return + cache.add(msg) + self.warning(msg, *args, **kwargs) def get_logger(name: str, log_level: str | None = None): diff --git a/tests/test_logging.py b/tests/test_logging.py index fd788fcb1a2..48ef88d0a10 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -70,6 +70,23 @@ def test_log_stack(caplog): assert rec.message == expected_message +@pytest.mark.usefixtures("accelerator") +def test_warning_once_handles_unhashable_kwargs(caplog): + # The previous lru_cache-based implementation raised + # `TypeError: unhashable type: 'dict'` whenever the caller passed a + # standard logging kwarg like `extra={...}`, because lru_cache hashes + # every keyword argument. + logger = get_logger(__name__) + with caplog.at_level(logging.WARNING): + logger.warning_once("only once", extra={"id": 1}) + logger.warning_once("only once", extra={"id": 2}) + logger.warning_once("a different message", extra={"id": 3}) + + messages = [r.message for r in caplog.records] + assert sum(m.endswith("only once") for m in messages) == 1 + assert sum(m.endswith("a different message") for m in messages) == 1 + + @pytest.mark.usefixtures("accelerator") def test_custom_stacklevel(caplog): wrapped_logger = get_logger(__name__)