From e749910b2abba2dc39d6d95fb0d8b337ac14602b Mon Sep 17 00:00:00 2001 From: Charlie Tonneslan Date: Fri, 22 May 2026 18:45:09 -0400 Subject: [PATCH] Stop warning_once from crashing on unhashable kwargs `MultiProcessAdapter.warning_once` was decorated with `@functools.lru_cache(None)`, which hashes every positional and keyword argument. The standard `logging` API accepts an `extra={...}` kwarg, and a dict isn't hashable, so a perfectly normal call like logger.warning_once("only once", extra={"id": 1}) raised `TypeError: unhashable type: 'dict'`. Cache by the message text on a per-adapter set instead. That matches the docstring (`"the same message only once"`), accepts any kwargs the underlying `warning()` accepts, and also drops the implicit `self` retention that `lru_cache` on a method caused. Added a regression test in `tests/test_logging.py` that calls `warning_once` with an `extra={...}` kwarg, twice with the same message and once with a different message, and asserts each unique message is emitted exactly once. Signed-off-by: Charlie Tonneslan --- src/accelerate/logging.py | 21 +++++++++++++-------- tests/test_logging.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/src/accelerate/logging.py b/src/accelerate/logging.py index 9132f8cbf4b..58cf7ca4eeb 100644 --- a/src/accelerate/logging.py +++ b/src/accelerate/logging.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import functools import logging import os @@ -78,16 +77,22 @@ def log(self, level, msg, *args, **kwargs): self.logger.log(level, msg, *args, **kwargs) state.wait_for_everyone() - @functools.lru_cache(None) - def warning_once(self, *args, **kwargs): + def warning_once(self, msg, *args, **kwargs): """ - This method is identical to `logger.warning()`, but will emit the warning with the same message only once + Like `warning()`, but emits each unique message only once per adapter. - Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the - cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to - switch to another type of cache that includes the caller frame information in the hashing function. + The cache is keyed on the message text, so passing the standard + `extra={...}` kwarg (or any other unhashable value) doesn't crash + the way it did with the previous `lru_cache` decorator. """ - self.warning(*args, **kwargs) + cache = getattr(self, "_warning_once_cache", None) + if cache is None: + cache = set() + self._warning_once_cache = cache + if msg in cache: + return + cache.add(msg) + self.warning(msg, *args, **kwargs) def get_logger(name: str, log_level: str | None = None): diff --git a/tests/test_logging.py b/tests/test_logging.py index fd788fcb1a2..48ef88d0a10 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -70,6 +70,23 @@ def test_log_stack(caplog): assert rec.message == expected_message +@pytest.mark.usefixtures("accelerator") +def test_warning_once_handles_unhashable_kwargs(caplog): + # The previous lru_cache-based implementation raised + # `TypeError: unhashable type: 'dict'` whenever the caller passed a + # standard logging kwarg like `extra={...}`, because lru_cache hashes + # every keyword argument. + logger = get_logger(__name__) + with caplog.at_level(logging.WARNING): + logger.warning_once("only once", extra={"id": 1}) + logger.warning_once("only once", extra={"id": 2}) + logger.warning_once("a different message", extra={"id": 3}) + + messages = [r.message for r in caplog.records] + assert sum(m.endswith("only once") for m in messages) == 1 + assert sum(m.endswith("a different message") for m in messages) == 1 + + @pytest.mark.usefixtures("accelerator") def test_custom_stacklevel(caplog): wrapped_logger = get_logger(__name__)