Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/accelerate/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations

import functools
import logging
import os

Expand Down Expand Up @@ -78,16 +77,22 @@ def log(self, level, msg, *args, **kwargs):
self.logger.log(level, msg, *args, **kwargs)
state.wait_for_everyone()

@functools.lru_cache(None)
def warning_once(self, *args, **kwargs):
def warning_once(self, msg, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
Like `warning()`, but emits each unique message only once per adapter.

Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
switch to another type of cache that includes the caller frame information in the hashing function.
The cache is keyed on the message text, so passing the standard
`extra={...}` kwarg (or any other unhashable value) doesn't crash
the way it did with the previous `lru_cache` decorator.
"""
self.warning(*args, **kwargs)
cache = getattr(self, "_warning_once_cache", None)
if cache is None:
cache = set()
self._warning_once_cache = cache
if msg in cache:
return
cache.add(msg)
self.warning(msg, *args, **kwargs)


def get_logger(name: str, log_level: str | None = None):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ def test_log_stack(caplog):
assert rec.message == expected_message


@pytest.mark.usefixtures("accelerator")
def test_warning_once_handles_unhashable_kwargs(caplog):
# The previous lru_cache-based implementation raised
# `TypeError: unhashable type: 'dict'` whenever the caller passed a
# standard logging kwarg like `extra={...}`, because lru_cache hashes
# every keyword argument.
logger = get_logger(__name__)
with caplog.at_level(logging.WARNING):
logger.warning_once("only once", extra={"id": 1})
logger.warning_once("only once", extra={"id": 2})
logger.warning_once("a different message", extra={"id": 3})

messages = [r.message for r in caplog.records]
assert sum(m.endswith("only once") for m in messages) == 1
assert sum(m.endswith("a different message") for m in messages) == 1


@pytest.mark.usefixtures("accelerator")
def test_custom_stacklevel(caplog):
wrapped_logger = get_logger(__name__)
Expand Down