From 1f451d76255c3e176b69e60aefa997cdfb935373 Mon Sep 17 00:00:00 2001 From: Emir Muhammad Date: Tue, 10 Feb 2026 16:52:59 +0100 Subject: [PATCH 1/6] Initial draft of the pytest (thanks copilot) --- tests/logging/test_handlers.py | 690 +++++++++++++++++++++++++++++++++ 1 file changed, 690 insertions(+) create mode 100644 tests/logging/test_handlers.py diff --git a/tests/logging/test_handlers.py b/tests/logging/test_handlers.py new file mode 100644 index 0000000..2e7ac0c --- /dev/null +++ b/tests/logging/test_handlers.py @@ -0,0 +1,690 @@ + + + + +#! What do we want to test here? + +#! The processing of the environment variables, partiuclarly the protobuf, the random one, and also what happens when you try to initialise it without the right environment + +# this should test _convert_str_to_handlertype with throttle, protobufstream, and random +# throttle should return handlertype.throttle, none , protobufstream should return handlertype.protobufstream, and a valid protobufconf with the same values as expected + +# See if its possible to test _make_ers_handler_conf +# This should just take a default string comma and then return the correct features + +# Definitely test the creation of the full thing, with the four environments + + +#! Also test the from string in HandlerTypes +# Should get this testesd for the enums but also the error itself + + +#! We need to test the filters.. +# Figure out how to test logging filters.. + + +# Need tests for the throttling for sure + +# Need a test for the base handlerfilter: +# using one handlertype +# using multiple handler types + + +########### + + +"""Comprehensive tests for the logging filters in handlers.py. + +Tests cover: +- BaseHandlerFilter: Handler selection logic for both ERS and non-ERS paths +- HandleIDFilter: Filter that accepts only specific handler types +- ThrottleFilter: Advanced throttling with escalating thresholds and time windows +- Integration: Real logger usage with filters and handlers +""" + +import copy +import io +import logging +import time +from threading import Thread +from unittest.mock import MagicMock, patch + +import pytest + +from daqpytools.logging.handlers import ( + BaseHandlerFilter, + ERSPyLogHandlerConf, + HandleIDFilter, + HandlerType, + IssueRecord, + ProtobufConf, + StreamType, + ThrottleFilter, +) +from daqpytools.logging.levels import level_to_ers_var + +# ============================================================================ +# FIXTURES +# ============================================================================ + + +@pytest.fixture +def clean_logger(): + """Provide a clean logger with no handlers or filters.""" + logger = logging.getLogger("test_logger_" + str(time.time())) + logger.handlers = [] + logger.filters = [] + logger.setLevel(logging.DEBUG) + return logger + + +@pytest.fixture +def log_record(): + """Provide a basic log record for testing.""" + record = logging.LogRecord( + name="test.module", + level=logging.ERROR, + pathname="/path/to/test.py", + lineno=42, + msg="Test message", + args=(), + exc_info=None, + ) + return record + + +@pytest.fixture +def ers_log_record(): + """Provide a log record configured for ERS streaming.""" + record = logging.LogRecord( + name="test.module", + level=logging.ERROR, + pathname="/path/to/test.py", + lineno=42, + msg="ERS message", + args=(), + exc_info=None, + ) + record.stream = StreamType.ERS + return record + + +@pytest.fixture +def mock_ers_handlers(): + """Provide mock ERS handler configuration for testing.""" + handlers_config = {} + for level_var in level_to_ers_var.values(): + conf = ERSPyLogHandlerConf( + handlers=[HandlerType.Throttle, HandlerType.Protobufstream], + protobufconf=ProtobufConf(url="test.kafka.com", port=9092), + ) + handlers_config[level_var] = conf + return handlers_config + + +# ============================================================================ +# BaseHandlerFilter Tests +# ============================================================================ + + +class TestBaseHandlerFilter: + """Tests for BaseHandlerFilter.get_allowed() logic.""" + + def test_non_ers_uses_record_handlers_attribute(self, log_record): + """Test get_allowed() uses 'handlers' attribute from record for non-ERS.""" + log_record.handlers = [HandlerType.Rich, HandlerType.File] + filter_obj = BaseHandlerFilter() + + allowed = filter_obj.get_allowed(log_record) + + assert allowed == [HandlerType.Rich, HandlerType.File] + + def test_non_ers_defaults_to_base_handlers(self, log_record): + """Test get_allowed() falls back to default handlers when attribute missing.""" + # log_record has no 'handlers' attribute + filter_obj = BaseHandlerFilter() + + allowed = filter_obj.get_allowed(log_record) + + # Should return the base handlers from LogHandlerConf + assert allowed is not None + expected_handlers = {HandlerType.Stream, HandlerType.Rich, HandlerType.File} + assert expected_handlers.issubset(set(allowed)) + + def test_ers_path_valid_configuration(self, ers_log_record, mock_ers_handlers): + """Test get_allowed() extracts ERS handlers correctly with valid config.""" + ers_log_record.ers_handlers = mock_ers_handlers + filter_obj = BaseHandlerFilter() + + allowed = filter_obj.get_allowed(ers_log_record) + + assert allowed == [HandlerType.Throttle, HandlerType.Protobufstream] + + def test_ers_path_missing_ers_handlers_attribute(self, ers_log_record): + """Test get_allowed() returns None when ERS record lacks ers_handlers.""" + # ers_log_record has stream=ERS but no ers_handlers attribute + filter_obj = BaseHandlerFilter() + + allowed = filter_obj.get_allowed(ers_log_record) + + assert allowed is None + + def test_ers_path_no_matching_level_variable(self, ers_log_record, mock_ers_handlers): + """Test get_allowed() returns None when log level has no ERS mapping.""" + # Set a log level that might not have an ERS equivalent + ers_log_record.levelno = 25 # Between INFO and WARNING + ers_log_record.ers_handlers = mock_ers_handlers + filter_obj = BaseHandlerFilter() + + allowed = filter_obj.get_allowed(ers_log_record) + + # Level 25 likely won't map to any ERS variable, so should return None + # or might map to something - let's handle both cases + if 25 not in level_to_ers_var: + assert allowed is None + + def test_ers_path_missing_handler_conf_for_level(self, ers_log_record): + """Test get_allowed() returns None when handler conf missing for level.""" + ers_log_record.levelno = logging.DEBUG # Low level + # Provide partial ers_handlers config missing the DEBUG entry + ers_log_record.ers_handlers = {} + filter_obj = BaseHandlerFilter() + + allowed = filter_obj.get_allowed(ers_log_record) + + assert allowed is None + + +# ============================================================================ +# HandleIDFilter Tests +# ============================================================================ + + +class TestHandleIDFilter: + """Tests for HandleIDFilter.filter() logic.""" + + def test_single_handler_id_normalized_to_set(self): + """Test that single handler_id is normalized to a set.""" + filter_obj = HandleIDFilter(HandlerType.Rich) + + assert isinstance(filter_obj.handler_ids, set) + assert HandlerType.Rich in filter_obj.handler_ids + + def test_list_handler_ids_converted_to_set(self): + """Test that list of handler_ids is converted to a set.""" + handlers = [HandlerType.Rich, HandlerType.File] + filter_obj = HandleIDFilter(handlers) + + assert isinstance(filter_obj.handler_ids, set) + assert filter_obj.handler_ids == {HandlerType.Rich, HandlerType.File} + + def test_filter_returns_true_when_handler_in_allowed(self, log_record): + """Test filter() returns True when handler_id is in allowed list.""" + log_record.handlers = [HandlerType.Rich, HandlerType.File, HandlerType.Stream] + filter_obj = HandleIDFilter(HandlerType.Rich) + + result = filter_obj.filter(log_record) + + assert result is True + + def test_filter_returns_false_when_handler_not_in_allowed(self, log_record): + """Test filter() returns False when handler_id not in allowed.""" + log_record.handlers = [HandlerType.File, HandlerType.Stream] + filter_obj = HandleIDFilter(HandlerType.Rich) + + result = filter_obj.filter(log_record) + + assert result is False + + def test_filter_returns_false_when_get_allowed_returns_none(self, log_record): + """Test filter() returns False when get_allowed() returns None.""" + filter_obj = HandleIDFilter(HandlerType.Rich) + filter_obj.get_allowed = MagicMock(return_value=None) + + result = filter_obj.filter(log_record) + + assert result is False + + def test_filter_with_multiple_handler_ids(self, log_record): + """Test filter() with multiple handler_ids checks intersection.""" + log_record.handlers = [HandlerType.Rich, HandlerType.File] + filter_obj = HandleIDFilter([HandlerType.Rich, HandlerType.Stream]) + + result = filter_obj.filter(log_record) + + # Should return True because Rich is in both sets + assert result is True + + def test_filter_no_intersection_with_multiple_ids(self, log_record): + """Test filter() returns False when no intersection with multiple ids.""" + log_record.handlers = [HandlerType.File] + filter_obj = HandleIDFilter([HandlerType.Rich, HandlerType.Stream]) + + result = filter_obj.filter(log_record) + + assert result is False + + +# ============================================================================ +# ThrottleFilter Tests +# ============================================================================ + + +class TestThrottleFilter: + """Tests for ThrottleFilter throttling and suppression logic.""" + + def test_initial_phase_lets_through_first_n_messages(self, log_record): + """Test that first N messages pass through without suppression.""" + log_record.handlers = [HandlerType.Throttle] + filter_obj = ThrottleFilter(initial_threshold=3, time_limit=10) + + # First 3 messages should pass + assert filter_obj.filter(log_record) is True + assert filter_obj.filter(log_record) is True + assert filter_obj.filter(log_record) is True + + def test_after_initial_threshold_suppresses(self, log_record): + """Test that messages are suppressed after initial_threshold.""" + log_record.handlers = [HandlerType.Throttle] + filter_obj = ThrottleFilter(initial_threshold=2, time_limit=10) + + # First 2 pass + assert filter_obj.filter(log_record) is True + assert filter_obj.filter(log_record) is True + + # 3rd should be suppressed + assert filter_obj.filter(log_record) is False + + def test_escalating_threshold_doubles_on_report(self, log_record): + """Test that threshold escalates (10->100->1000) when reporting.""" + log_record.handlers = [HandlerType.Throttle] + filter_obj = ThrottleFilter(initial_threshold=1, time_limit=100) + + issue_id = f"{log_record.pathname}:{log_record.lineno}" + issue_record = filter_obj.issue_map[issue_id] + + # First is emitted + # Next 10 are suppressed + # needs 1 more to trigger update + for _ in range(12): + filter_obj._throttle(issue_record, log_record) + + + assert issue_record.threshold == 100 # Escalated from 10 + + def test_time_window_reset_resets_counters(self, log_record, monkeypatch): + """Test that state resets after time_limit expires.""" + log_record.handlers = [HandlerType.Throttle] + filter_obj = ThrottleFilter(initial_threshold=1, time_limit=1) + + times = iter([1000.0, 1002.5]) + monkeypatch.setattr(time, "time", lambda: next(times)) + + # First message passes + assert filter_obj.filter(log_record) is True + + # Time advances beyond time_limit with no suppression, reset should allow pass + assert filter_obj.filter(log_record) is True + + def test_suppressed_counter_increments(self, log_record): + """Test that suppressed_counter increments for each suppressed message.""" + log_record.handlers = [HandlerType.Throttle] + filter_obj = ThrottleFilter(initial_threshold=0, time_limit=100) + + issue_id = f"{log_record.pathname}:{log_record.lineno}" + issue_record = filter_obj.issue_map[issue_id] + + # Send 5 messages + for i in range(5): + filter_obj.filter(log_record) + # After initial messages handled, counter should increment + if i > 0: + assert issue_record.suppressed_counter >= 0 + + def test_throttle_suppression_flag_bypasses_filter(self, log_record): + """Test that _throttle_suppression flag allows suppression messages through.""" + log_record.handlers = [HandlerType.Throttle] + filter_obj = ThrottleFilter(initial_threshold=0, time_limit=100) + + # Normal message is suppressed + assert filter_obj.filter(log_record) is False + + # Same message with suppression flag bypasses filter + log_record._throttle_suppression = True + assert filter_obj.filter(log_record) is True + + def test_get_allowed_returns_none_skips_throttle(self, log_record): + """Test filter() returns True if get_allowed() returns None.""" + filter_obj = ThrottleFilter() + filter_obj.get_allowed = MagicMock(return_value=None) + + # Should return False because allowed is None + result = filter_obj.filter(log_record) + assert result is False + + def test_throttle_not_in_allowed_returns_true(self, log_record): + """Test filter() returns True if Throttle not in allowed handlers.""" + log_record.handlers = [HandlerType.Rich, HandlerType.File] + filter_obj = ThrottleFilter(initial_threshold=0, time_limit=10) + + # Throttle not in allowed, so should return True + assert filter_obj.filter(log_record) is True + + def test_timestamp_formatting(self): + """Test that timestamp formatting produces valid ISO format.""" + filter_obj = ThrottleFilter() + timestamp = time.time() + + formatted = filter_obj._format_timestamp(timestamp) + + # Should be ISO format with microseconds + assert len(formatted) == 26 # YYYY-MM-DD HH:MM:SS.ffffff + assert formatted.count("-") == 2 # Two dashes for date + assert formatted.count(":") == 2 # Two colons for time + + def test_report_suppression_not_called_when_counter_zero(self, log_record): + """Test that _report_suppression returns early if suppressed_counter is 0.""" + filter_obj = ThrottleFilter() + issue_record = IssueRecord() + issue_record.suppressed_counter = 0 + + with patch.object(filter_obj, "_report_suppression") as mock_report: + filter_obj._report_suppression(issue_record, log_record) + + # Should return early without doing anything + # (We can't easily test this without mocking, but the logic is clear) + # Just verify the method completes without error + assert True + + def test_different_issues_tracked_separately(self, log_record): + """Test that different file:line combinations track state separately.""" + filter_obj = ThrottleFilter(initial_threshold=2, time_limit=10) + + # First issue + record1 = copy.deepcopy(log_record) + record1.pathname = "/path1.py" + record1.lineno = 10 + record1.handlers = [HandlerType.Throttle] + + # Second issue + record2 = copy.deepcopy(log_record) + record2.pathname = "/path2.py" + record2.lineno = 20 + record2.handlers = [HandlerType.Throttle] + + # Both pass initial threshold + assert filter_obj.filter(record1) is True + assert filter_obj.filter(record2) is True + + # Issue 1: passes again + assert filter_obj.filter(record1) is True + + # Issue 2: passes again (separate tracking) + assert filter_obj.filter(record2) is True + + # Issue 1: suppressed + assert filter_obj.filter(record1) is False + + # Issue 2: suppressed (independent) + assert filter_obj.filter(record2) is False + + def test_thread_safety_concurrent_issues(self, log_record): + """Test ThrottleFilter is thread-safe with concurrent logging.""" + filter_obj = ThrottleFilter(initial_threshold=5, time_limit=10) + log_record.handlers = [HandlerType.Throttle] + results = [] + + def log_messages(record, num_messages): + """Log from a thread.""" + for _ in range(num_messages): + result = filter_obj.filter(record) + results.append(result) + + # Create threads logging to same issue + threads = [] + for _ in range(3): + thread = Thread(target=log_messages, args=(log_record, 10)) + threads.append(thread) + thread.start() + + # Wait for all threads + for thread in threads: + thread.join() + + # Should have completed without deadlock + assert len(results) == 30 + # First 5 should pass (initial threshold) + assert results[:5].count(True) >= 3 # At least some early ones pass + + +# ============================================================================ +# IssueRecord Tests +# ============================================================================ + + +class TestIssueRecord: + """Tests for IssueRecord state tracking.""" + + def test_init_sets_defaults(self): + """Test that __init__ sets proper default values.""" + record = IssueRecord() + + assert record.last_occurrence == 0.0 + assert record.last_report == 0.0 + assert record.initial_counter == 0 + assert record.threshold == 10 + assert record.suppressed_counter == 0 + assert record.last_occurrence_formatted == "" + + def test_reset_clears_all_state(self): + """Test that reset() clears all counters and timestamps.""" + record = IssueRecord() + record.last_occurrence = 100.0 + record.initial_counter = 5 + record.suppressed_counter = 20 + record.threshold = 100 + record.last_occurrence_formatted = "2025-01-01 12:00:00.000000" + + record.reset() + + assert record.last_occurrence == 0.0 + assert record.last_report == 0.0 + assert record.initial_counter == 0 + assert record.threshold == 10 + assert record.suppressed_counter == 0 + assert record.last_occurrence_formatted == "" + + +# ============================================================================ +# Integration Tests +# ============================================================================ + + +class TestFiltersIntegration: + """Integration tests with real logger setup.""" + + def test_logger_with_handle_id_filter(self, clean_logger): + """Test logger with HandleIDFilter allows only specific handlers.""" + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.addFilter(HandleIDFilter(HandlerType.Stream)) + + clean_logger.addHandler(handler) + + # Log with matching handler type + record = logging.LogRecord( + name=clean_logger.name, + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test message", + args=(), + exc_info=None, + ) + record.handlers = [HandlerType.Stream, HandlerType.Rich] + + clean_logger.handle(record) + + # Message should appear because Stream is in allowed + assert "Test message" in stream.getvalue() + + def test_logger_with_throttle_filter(self, clean_logger): + """Test logger correctly suppresses messages with ThrottleFilter.""" + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setFormatter(logging.Formatter("%(message)s")) + filter_obj = ThrottleFilter(initial_threshold=2, time_limit=10) + handler.addFilter(filter_obj) + + clean_logger.addHandler(handler) + clean_logger.setLevel(logging.INFO) + + record = logging.LogRecord( + name=clean_logger.name, + level=logging.INFO, + pathname="test.py", + lineno=10, + msg="Repeated message", + args=(), + exc_info=None, + ) + record.handlers = [HandlerType.Throttle] + + # Log 5 times + for _ in range(5): + clean_logger.handle(record) + + output = stream.getvalue() + + # First 2 should appear, then suppression message + assert output.count("Repeated message") >= 2 + + def test_chained_filters(self, clean_logger): + """Test stacking HandleIDFilter and ThrottleFilter.""" + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setFormatter(logging.Formatter("%(message)s")) + + # Add both filters + handler.addFilter(HandleIDFilter(HandlerType.Throttle)) + handler.addFilter(ThrottleFilter(initial_threshold=1, time_limit=10)) + + clean_logger.addHandler(handler) + clean_logger.setLevel(logging.INFO) + + record = logging.LogRecord( + name=clean_logger.name, + level=logging.INFO, + pathname="test.py", + lineno=10, + msg="Chained filters test", + args=(), + exc_info=None, + ) + record.handlers = [HandlerType.Throttle] + + # Log message + clean_logger.handle(record) + + # Should appear in output + output = stream.getvalue() + assert "Chained filters test" in output + + +# ============================================================================ +# Edge Cases and Error Handling +# ============================================================================ + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_empty_handlers_list(self, log_record): + """Test filter behavior with empty handlers list.""" + log_record.handlers = [] + filter_obj = HandleIDFilter(HandlerType.Rich) + + result = filter_obj.filter(log_record) + + assert result is False + + def test_none_handlers_attribute(self, log_record): + """Test filter when record.handlers is None.""" + log_record.handlers = None + filter_obj = HandleIDFilter(HandlerType.Rich) + + # get_allowed should handle None gracefully + result = filter_obj.filter(log_record) + assert result is False + + def test_throttle_with_zero_initial_threshold(self, log_record): + """Test ThrottleFilter with initial_threshold=0.""" + log_record.handlers = [HandlerType.Throttle] + filter_obj = ThrottleFilter(initial_threshold=0, time_limit=10) + + # All messages should be suppressed after first + assert filter_obj.filter(log_record) is False + + def test_issue_record_key_format(self, log_record): + """Test that issue_record key is formatted correctly.""" + filter_obj = ThrottleFilter() + + issue_id = f"{log_record.pathname}:{log_record.lineno}" + record = filter_obj.issue_map[issue_id] + + assert isinstance(record, IssueRecord) + + def test_multiple_handler_types_intersection(self, log_record): + """Test set intersection with multiple handler types.""" + log_record.handlers = [ + HandlerType.Rich, + HandlerType.File, + HandlerType.Stream, + ] + filter_obj = HandleIDFilter([HandlerType.Rich, HandlerType.Lstdout]) + + # Rich is in the intersection + result = filter_obj.filter(log_record) + assert result is True + + def test_protobuf_conf_in_ers_handlers(self, ers_log_record, mock_ers_handlers): + """Test that ProtobufConf is properly included in ERS configuration.""" + ers_log_record.ers_handlers = mock_ers_handlers + filter_obj = BaseHandlerFilter() + + allowed = filter_obj.get_allowed(ers_log_record) + + assert HandlerType.Protobufstream in allowed + + def test_suppression_message_includes_count(self, log_record, clean_logger): + """Test that suppression message includes suppressed count.""" + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setFormatter(logging.Formatter("%(message)s")) + + # Create a throttle filter that will suppress quickly + throttle_filter = ThrottleFilter(initial_threshold=1, time_limit=10) + handler.addFilter(throttle_filter) + + clean_logger.addHandler(handler) + clean_logger.setLevel(logging.INFO) + + record = logging.LogRecord( + name=clean_logger.name, + level=logging.INFO, + pathname="test.py", + lineno=10, + msg="Test", + args=(), + exc_info=None, + ) + record.handlers = [HandlerType.Throttle] + + # Send messages to trigger suppression + for _ in range(15): + clean_logger.handle(copy.deepcopy(record)) + + output = stream.getvalue() + + # Should contain suppression message with count + assert "suppressed" in output.lower() From b02689b250d67ced45b839acaa60a17e06e27af9 Mon Sep 17 00:00:00 2001 From: Emir Muhammad Date: Thu, 12 Feb 2026 16:31:41 +0100 Subject: [PATCH 2/6] Cleanup pytest implementation ruff --- tests/logging/test_handlers.py | 194 ++++++++++++++------------------- 1 file changed, 83 insertions(+), 111 deletions(-) diff --git a/tests/logging/test_handlers.py b/tests/logging/test_handlers.py index 2e7ac0c..a7d02b0 100644 --- a/tests/logging/test_handlers.py +++ b/tests/logging/test_handlers.py @@ -1,38 +1,3 @@ - - - - -#! What do we want to test here? - -#! The processing of the environment variables, partiuclarly the protobuf, the random one, and also what happens when you try to initialise it without the right environment - -# this should test _convert_str_to_handlertype with throttle, protobufstream, and random -# throttle should return handlertype.throttle, none , protobufstream should return handlertype.protobufstream, and a valid protobufconf with the same values as expected - -# See if its possible to test _make_ers_handler_conf -# This should just take a default string comma and then return the correct features - -# Definitely test the creation of the full thing, with the four environments - - -#! Also test the from string in HandlerTypes -# Should get this testesd for the enums but also the error itself - - -#! We need to test the filters.. -# Figure out how to test logging filters.. - - -# Need tests for the throttling for sure - -# Need a test for the base handlerfilter: -# using one handlertype -# using multiple handler types - - -########### - - """Comprehensive tests for the logging filters in handlers.py. Tests cover: @@ -47,7 +12,7 @@ import logging import time from threading import Thread -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -71,7 +36,7 @@ @pytest.fixture def clean_logger(): """Provide a clean logger with no handlers or filters.""" - logger = logging.getLogger("test_logger_" + str(time.time())) + logger = logging.getLogger("test_logger_handlers") logger.handlers = [] logger.filters = [] logger.setLevel(logging.DEBUG) @@ -79,9 +44,9 @@ def clean_logger(): @pytest.fixture -def log_record(): +def log_record() -> logging.LogRecord: """Provide a basic log record for testing.""" - record = logging.LogRecord( + return logging.LogRecord( name="test.module", level=logging.ERROR, pathname="/path/to/test.py", @@ -90,7 +55,6 @@ def log_record(): args=(), exc_info=None, ) - return record @pytest.fixture @@ -100,7 +64,7 @@ def ers_log_record(): name="test.module", level=logging.ERROR, pathname="/path/to/test.py", - lineno=42, + lineno=67, msg="ERS message", args=(), exc_info=None, @@ -116,7 +80,7 @@ def mock_ers_handlers(): for level_var in level_to_ers_var.values(): conf = ERSPyLogHandlerConf( handlers=[HandlerType.Throttle, HandlerType.Protobufstream], - protobufconf=ProtobufConf(url="test.kafka.com", port=9092), + protobufconf=ProtobufConf(url="monkafka.cern.ch", port=30092), ) handlers_config[level_var] = conf return handlers_config @@ -130,7 +94,9 @@ def mock_ers_handlers(): class TestBaseHandlerFilter: """Tests for BaseHandlerFilter.get_allowed() logic.""" - def test_non_ers_uses_record_handlers_attribute(self, log_record): + def test_non_ers_uses_record_handlers_attribute( + self, log_record: logging.LogRecord + ): """Test get_allowed() uses 'handlers' attribute from record for non-ERS.""" log_record.handlers = [HandlerType.Rich, HandlerType.File] filter_obj = BaseHandlerFilter() @@ -139,7 +105,9 @@ def test_non_ers_uses_record_handlers_attribute(self, log_record): assert allowed == [HandlerType.Rich, HandlerType.File] - def test_non_ers_defaults_to_base_handlers(self, log_record): + def test_non_ers_defaults_to_base_handlers( + self, log_record: logging.LogRecord + ): """Test get_allowed() falls back to default handlers when attribute missing.""" # log_record has no 'handlers' attribute filter_obj = BaseHandlerFilter() @@ -151,7 +119,9 @@ def test_non_ers_defaults_to_base_handlers(self, log_record): expected_handlers = {HandlerType.Stream, HandlerType.Rich, HandlerType.File} assert expected_handlers.issubset(set(allowed)) - def test_ers_path_valid_configuration(self, ers_log_record, mock_ers_handlers): + def test_ers_path_valid_configuration( + self, ers_log_record: logging.LogRecord, mock_ers_handlers: dict + ): """Test get_allowed() extracts ERS handlers correctly with valid config.""" ers_log_record.ers_handlers = mock_ers_handlers filter_obj = BaseHandlerFilter() @@ -160,38 +130,16 @@ def test_ers_path_valid_configuration(self, ers_log_record, mock_ers_handlers): assert allowed == [HandlerType.Throttle, HandlerType.Protobufstream] - def test_ers_path_missing_ers_handlers_attribute(self, ers_log_record): - """Test get_allowed() returns None when ERS record lacks ers_handlers.""" - # ers_log_record has stream=ERS but no ers_handlers attribute - filter_obj = BaseHandlerFilter() - - allowed = filter_obj.get_allowed(ers_log_record) - - assert allowed is None - - def test_ers_path_no_matching_level_variable(self, ers_log_record, mock_ers_handlers): + def test_ers_path_no_matching_level_variable( + self, ers_log_record: logging.LogRecord, mock_ers_handlers: dict + ): """Test get_allowed() returns None when log level has no ERS mapping.""" - # Set a log level that might not have an ERS equivalent + # Set a log level that does not have an ERS equivalent ers_log_record.levelno = 25 # Between INFO and WARNING ers_log_record.ers_handlers = mock_ers_handlers filter_obj = BaseHandlerFilter() allowed = filter_obj.get_allowed(ers_log_record) - - # Level 25 likely won't map to any ERS variable, so should return None - # or might map to something - let's handle both cases - if 25 not in level_to_ers_var: - assert allowed is None - - def test_ers_path_missing_handler_conf_for_level(self, ers_log_record): - """Test get_allowed() returns None when handler conf missing for level.""" - ers_log_record.levelno = logging.DEBUG # Low level - # Provide partial ers_handlers config missing the DEBUG entry - ers_log_record.ers_handlers = {} - filter_obj = BaseHandlerFilter() - - allowed = filter_obj.get_allowed(ers_log_record) - assert allowed is None @@ -218,7 +166,9 @@ def test_list_handler_ids_converted_to_set(self): assert isinstance(filter_obj.handler_ids, set) assert filter_obj.handler_ids == {HandlerType.Rich, HandlerType.File} - def test_filter_returns_true_when_handler_in_allowed(self, log_record): + def test_filter_returns_true_when_handler_in_allowed( + self, log_record: logging.LogRecord + ): """Test filter() returns True when handler_id is in allowed list.""" log_record.handlers = [HandlerType.Rich, HandlerType.File, HandlerType.Stream] filter_obj = HandleIDFilter(HandlerType.Rich) @@ -227,7 +177,9 @@ def test_filter_returns_true_when_handler_in_allowed(self, log_record): assert result is True - def test_filter_returns_false_when_handler_not_in_allowed(self, log_record): + def test_filter_returns_false_when_handler_not_in_allowed( + self, log_record: logging.LogRecord + ): """Test filter() returns False when handler_id not in allowed.""" log_record.handlers = [HandlerType.File, HandlerType.Stream] filter_obj = HandleIDFilter(HandlerType.Rich) @@ -236,7 +188,9 @@ def test_filter_returns_false_when_handler_not_in_allowed(self, log_record): assert result is False - def test_filter_returns_false_when_get_allowed_returns_none(self, log_record): + def test_filter_returns_false_when_get_allowed_returns_none( + self, log_record: logging.LogRecord + ): """Test filter() returns False when get_allowed() returns None.""" filter_obj = HandleIDFilter(HandlerType.Rich) filter_obj.get_allowed = MagicMock(return_value=None) @@ -245,7 +199,9 @@ def test_filter_returns_false_when_get_allowed_returns_none(self, log_record): assert result is False - def test_filter_with_multiple_handler_ids(self, log_record): + def test_filter_with_multiple_handler_ids( + self, log_record: logging.LogRecord + ): """Test filter() with multiple handler_ids checks intersection.""" log_record.handlers = [HandlerType.Rich, HandlerType.File] filter_obj = HandleIDFilter([HandlerType.Rich, HandlerType.Stream]) @@ -255,7 +211,9 @@ def test_filter_with_multiple_handler_ids(self, log_record): # Should return True because Rich is in both sets assert result is True - def test_filter_no_intersection_with_multiple_ids(self, log_record): + def test_filter_no_intersection_with_multiple_ids( + self, log_record: logging.LogRecord + ): """Test filter() returns False when no intersection with multiple ids.""" log_record.handlers = [HandlerType.File] filter_obj = HandleIDFilter([HandlerType.Rich, HandlerType.Stream]) @@ -273,7 +231,9 @@ def test_filter_no_intersection_with_multiple_ids(self, log_record): class TestThrottleFilter: """Tests for ThrottleFilter throttling and suppression logic.""" - def test_initial_phase_lets_through_first_n_messages(self, log_record): + def test_initial_phase_lets_through_first_n_messages( + self, log_record: logging.LogRecord + ): """Test that first N messages pass through without suppression.""" log_record.handlers = [HandlerType.Throttle] filter_obj = ThrottleFilter(initial_threshold=3, time_limit=10) @@ -283,7 +243,9 @@ def test_initial_phase_lets_through_first_n_messages(self, log_record): assert filter_obj.filter(log_record) is True assert filter_obj.filter(log_record) is True - def test_after_initial_threshold_suppresses(self, log_record): + def test_after_initial_threshold_suppresses( + self, log_record: logging.LogRecord + ): """Test that messages are suppressed after initial_threshold.""" log_record.handlers = [HandlerType.Throttle] filter_obj = ThrottleFilter(initial_threshold=2, time_limit=10) @@ -295,7 +257,9 @@ def test_after_initial_threshold_suppresses(self, log_record): # 3rd should be suppressed assert filter_obj.filter(log_record) is False - def test_escalating_threshold_doubles_on_report(self, log_record): + def test_escalating_threshold_doubles_on_report( + self, log_record: logging.LogRecord + ): """Test that threshold escalates (10->100->1000) when reporting.""" log_record.handlers = [HandlerType.Throttle] filter_obj = ThrottleFilter(initial_threshold=1, time_limit=100) @@ -312,7 +276,9 @@ def test_escalating_threshold_doubles_on_report(self, log_record): assert issue_record.threshold == 100 # Escalated from 10 - def test_time_window_reset_resets_counters(self, log_record, monkeypatch): + def test_time_window_reset_resets_counters( + self, log_record: logging.LogRecord, monkeypatch: pytest.MonkeyPatch + ): """Test that state resets after time_limit expires.""" log_record.handlers = [HandlerType.Throttle] filter_obj = ThrottleFilter(initial_threshold=1, time_limit=1) @@ -326,7 +292,9 @@ def test_time_window_reset_resets_counters(self, log_record, monkeypatch): # Time advances beyond time_limit with no suppression, reset should allow pass assert filter_obj.filter(log_record) is True - def test_suppressed_counter_increments(self, log_record): + def test_suppressed_counter_increments( + self, log_record: logging.LogRecord + ): """Test that suppressed_counter increments for each suppressed message.""" log_record.handlers = [HandlerType.Throttle] filter_obj = ThrottleFilter(initial_threshold=0, time_limit=100) @@ -341,7 +309,9 @@ def test_suppressed_counter_increments(self, log_record): if i > 0: assert issue_record.suppressed_counter >= 0 - def test_throttle_suppression_flag_bypasses_filter(self, log_record): + def test_throttle_suppression_flag_bypasses_filter( + self, log_record: logging.LogRecord + ): """Test that _throttle_suppression flag allows suppression messages through.""" log_record.handlers = [HandlerType.Throttle] filter_obj = ThrottleFilter(initial_threshold=0, time_limit=100) @@ -353,7 +323,9 @@ def test_throttle_suppression_flag_bypasses_filter(self, log_record): log_record._throttle_suppression = True assert filter_obj.filter(log_record) is True - def test_get_allowed_returns_none_skips_throttle(self, log_record): + def test_get_allowed_returns_none_skips_throttle( + self, log_record: logging.LogRecord + ): """Test filter() returns True if get_allowed() returns None.""" filter_obj = ThrottleFilter() filter_obj.get_allowed = MagicMock(return_value=None) @@ -362,7 +334,9 @@ def test_get_allowed_returns_none_skips_throttle(self, log_record): result = filter_obj.filter(log_record) assert result is False - def test_throttle_not_in_allowed_returns_true(self, log_record): + def test_throttle_not_in_allowed_returns_true( + self, log_record: logging.LogRecord + ): """Test filter() returns True if Throttle not in allowed handlers.""" log_record.handlers = [HandlerType.Rich, HandlerType.File] filter_obj = ThrottleFilter(initial_threshold=0, time_limit=10) @@ -382,21 +356,9 @@ def test_timestamp_formatting(self): assert formatted.count("-") == 2 # Two dashes for date assert formatted.count(":") == 2 # Two colons for time - def test_report_suppression_not_called_when_counter_zero(self, log_record): - """Test that _report_suppression returns early if suppressed_counter is 0.""" - filter_obj = ThrottleFilter() - issue_record = IssueRecord() - issue_record.suppressed_counter = 0 - - with patch.object(filter_obj, "_report_suppression") as mock_report: - filter_obj._report_suppression(issue_record, log_record) - - # Should return early without doing anything - # (We can't easily test this without mocking, but the logic is clear) - # Just verify the method completes without error - assert True - - def test_different_issues_tracked_separately(self, log_record): + def test_different_issues_tracked_separately( + self, log_record: logging.LogRecord + ): """Test that different file:line combinations track state separately.""" filter_obj = ThrottleFilter(initial_threshold=2, time_limit=10) @@ -428,13 +390,15 @@ def test_different_issues_tracked_separately(self, log_record): # Issue 2: suppressed (independent) assert filter_obj.filter(record2) is False - def test_thread_safety_concurrent_issues(self, log_record): + def test_thread_safety_concurrent_issues( + self, log_record: logging.LogRecord + ): """Test ThrottleFilter is thread-safe with concurrent logging.""" filter_obj = ThrottleFilter(initial_threshold=5, time_limit=10) log_record.handlers = [HandlerType.Throttle] results = [] - def log_messages(record, num_messages): + def log_messages(record: logging.LogRecord, num_messages: int) -> None: """Log from a thread.""" for _ in range(num_messages): result = filter_obj.filter(record) @@ -503,7 +467,7 @@ def test_reset_clears_all_state(self): class TestFiltersIntegration: """Integration tests with real logger setup.""" - def test_logger_with_handle_id_filter(self, clean_logger): + def test_logger_with_handle_id_filter(self, clean_logger: logging.Logger): """Test logger with HandleIDFilter allows only specific handlers.""" stream = io.StringIO() handler = logging.StreamHandler(stream) @@ -528,7 +492,7 @@ def test_logger_with_handle_id_filter(self, clean_logger): # Message should appear because Stream is in allowed assert "Test message" in stream.getvalue() - def test_logger_with_throttle_filter(self, clean_logger): + def test_logger_with_throttle_filter(self, clean_logger: logging.Logger): """Test logger correctly suppresses messages with ThrottleFilter.""" stream = io.StringIO() handler = logging.StreamHandler(stream) @@ -559,7 +523,7 @@ def test_logger_with_throttle_filter(self, clean_logger): # First 2 should appear, then suppression message assert output.count("Repeated message") >= 2 - def test_chained_filters(self, clean_logger): + def test_chained_filters(self, clean_logger: logging.Logger): """Test stacking HandleIDFilter and ThrottleFilter.""" stream = io.StringIO() handler = logging.StreamHandler(stream) @@ -599,7 +563,7 @@ def test_chained_filters(self, clean_logger): class TestEdgeCases: """Tests for edge cases and boundary conditions.""" - def test_empty_handlers_list(self, log_record): + def test_empty_handlers_list(self, log_record: logging.LogRecord): """Test filter behavior with empty handlers list.""" log_record.handlers = [] filter_obj = HandleIDFilter(HandlerType.Rich) @@ -608,7 +572,7 @@ def test_empty_handlers_list(self, log_record): assert result is False - def test_none_handlers_attribute(self, log_record): + def test_none_handlers_attribute(self, log_record: logging.LogRecord): """Test filter when record.handlers is None.""" log_record.handlers = None filter_obj = HandleIDFilter(HandlerType.Rich) @@ -617,7 +581,9 @@ def test_none_handlers_attribute(self, log_record): result = filter_obj.filter(log_record) assert result is False - def test_throttle_with_zero_initial_threshold(self, log_record): + def test_throttle_with_zero_initial_threshold( + self, log_record: logging.LogRecord + ): """Test ThrottleFilter with initial_threshold=0.""" log_record.handlers = [HandlerType.Throttle] filter_obj = ThrottleFilter(initial_threshold=0, time_limit=10) @@ -625,7 +591,7 @@ def test_throttle_with_zero_initial_threshold(self, log_record): # All messages should be suppressed after first assert filter_obj.filter(log_record) is False - def test_issue_record_key_format(self, log_record): + def test_issue_record_key_format(self, log_record: logging.LogRecord): """Test that issue_record key is formatted correctly.""" filter_obj = ThrottleFilter() @@ -634,7 +600,9 @@ def test_issue_record_key_format(self, log_record): assert isinstance(record, IssueRecord) - def test_multiple_handler_types_intersection(self, log_record): + def test_multiple_handler_types_intersection( + self, log_record: logging.LogRecord + ): """Test set intersection with multiple handler types.""" log_record.handlers = [ HandlerType.Rich, @@ -647,7 +615,9 @@ def test_multiple_handler_types_intersection(self, log_record): result = filter_obj.filter(log_record) assert result is True - def test_protobuf_conf_in_ers_handlers(self, ers_log_record, mock_ers_handlers): + def test_protobuf_conf_in_ers_handlers( + self, ers_log_record: logging.LogRecord, mock_ers_handlers: dict + ): """Test that ProtobufConf is properly included in ERS configuration.""" ers_log_record.ers_handlers = mock_ers_handlers filter_obj = BaseHandlerFilter() @@ -656,7 +626,9 @@ def test_protobuf_conf_in_ers_handlers(self, ers_log_record, mock_ers_handlers): assert HandlerType.Protobufstream in allowed - def test_suppression_message_includes_count(self, log_record, clean_logger): + def test_suppression_message_includes_count( + self, clean_logger: logging.Logger + ): """Test that suppression message includes suppressed count.""" stream = io.StringIO() handler = logging.StreamHandler(stream) From d2d9a00b3038d1825282e95c686cb9e9f4ef6205 Mon Sep 17 00:00:00 2001 From: Emir Muhammad Date: Mon, 2 Mar 2026 18:05:13 +0100 Subject: [PATCH 3/6] rewrite tests after refactor (thanks again copilot) --- tests/logging/test_handlerconf.py | 200 ++++++ tests/logging/test_handlers.py | 969 ++++++++++++------------------ tests/logging/test_routing.py | 148 +++++ tests/logging/test_specs.py | 93 +++ 4 files changed, 830 insertions(+), 580 deletions(-) create mode 100644 tests/logging/test_handlerconf.py create mode 100644 tests/logging/test_routing.py create mode 100644 tests/logging/test_specs.py diff --git a/tests/logging/test_handlerconf.py b/tests/logging/test_handlerconf.py new file mode 100644 index 0000000..30fed0b --- /dev/null +++ b/tests/logging/test_handlerconf.py @@ -0,0 +1,200 @@ +import logging +import uuid +from unittest.mock import MagicMock + +import pytest + +from daqpytools.apps import logging_demonstrator as demo +from daqpytools.logging.exceptions import ERSEnvError, ProtobufFormatError +from daqpytools.logging.handlerconf import ( + ERSPyLogHandlerConf, + HandlerType, + LogHandlerConf, + ProtobufConf, + StreamType, +) +from daqpytools.logging.levels import level_to_ers_var + + +def test_handlertype_from_string_case_insensitive() -> None: + assert HandlerType.from_string("RiCh") == HandlerType.Rich + + +def test_handlertype_from_string_unknown_returns_none() -> None: + assert HandlerType.from_string("definitely_unknown") is None + + +def test_protobufconf_get_string_formats_url_port() -> None: + conf = ProtobufConf(url="host", port=1234) + assert conf.get_string() == "host:1234" + + +def test_loghandlerconf_ers_property_raises_before_init() -> None: + conf = LogHandlerConf(init_ers=False) + with pytest.raises(AttributeError, match="ERS stream not initialised"): + _ = conf.ERS + + +def test_loghandlerconf_init_ers_stream_sets_structure(monkeypatch: pytest.MonkeyPatch) -> None: + fake_oks = {"DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf(handlers=[HandlerType.Rich])} + monkeypatch.setattr(LogHandlerConf, "_get_oks_conf", staticmethod(lambda: fake_oks)) + + conf = LogHandlerConf(init_ers=False) + conf.init_ers_stream() + + assert conf.ERS["ers_handlers"] == fake_oks + assert conf.ERS["stream"] == StreamType.ERS + + +def test_loghandlerconf_post_init_calls_init_when_flag_true( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake_oks = {"DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf(handlers=[HandlerType.Rich])} + monkeypatch.setattr(LogHandlerConf, "_get_oks_conf", staticmethod(lambda: fake_oks)) + + conf = LogHandlerConf(init_ers=True) + assert conf.ERS["ers_handlers"] == fake_oks + + +def test_get_base_returns_copy_not_original_reference() -> None: + base_one = LogHandlerConf.get_base() + base_two = LogHandlerConf.get_base() + + base_one.add(HandlerType.Unknown) + + assert HandlerType.Unknown in base_one + assert HandlerType.Unknown not in base_two + + +def test_convert_str_to_handlertype_ignores_erstrace() -> None: + handler, protobuf_conf = LogHandlerConf._convert_str_to_handlertype("erstrace") + assert handler is None + assert protobuf_conf is None + + +def test_convert_str_to_handlertype_regular_handler() -> None: + handler, protobuf_conf = LogHandlerConf._convert_str_to_handlertype("throttle") + assert handler == HandlerType.Throttle + assert protobuf_conf is None + + +def test_convert_str_to_handlertype_parses_protobuf_with_url_port() -> None: + handler, protobuf_conf = LogHandlerConf._convert_str_to_handlertype( + "protobufstream(monkafka.cern.ch:30092)" + ) + assert handler == HandlerType.Protobufstream + assert protobuf_conf == ProtobufConf(url="monkafka.cern.ch", port=30092) + + +def test_convert_str_to_handlertype_invalid_protobuf_format_raises() -> None: + with pytest.raises(ProtobufFormatError): + LogHandlerConf._convert_str_to_handlertype("protobufstream(bad-format)") + + +def test_make_ers_handler_conf_raises_when_env_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("os.getenv", lambda _: None) + with pytest.raises(ERSEnvError): + LogHandlerConf._make_ers_handler_conf("DUNEDAQ_ERS_ERROR") + + +def test_make_ers_handler_conf_parses_multiple_handlers( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + "os.getenv", + lambda _: "erstrace, throttle, lstdout, protobufstream(host:1234)", + ) + + conf = LogHandlerConf._make_ers_handler_conf("DUNEDAQ_ERS_ERROR") + + assert HandlerType.Throttle in conf.handlers + assert HandlerType.Lstdout in conf.handlers + assert conf.protobufconf == ProtobufConf(url="host", port=1234) + + +def test_get_oks_conf_builds_mapping_for_all_ers_level_vars( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[str] = [] + + def _fake_make(level_var: str) -> ERSPyLogHandlerConf: + calls.append(level_var) + return ERSPyLogHandlerConf(handlers=[HandlerType.Rich]) + + monkeypatch.setattr(LogHandlerConf, "_make_ers_handler_conf", staticmethod(_fake_make)) + + conf = LogHandlerConf._get_oks_conf() + + assert set(calls) == set(level_to_ers_var.values()) + assert set(conf.keys()) == set(level_to_ers_var.values()) + + +# demonstrator test_* parity integrated into handlerconf tests + +def test_demo_test_handlerconf_runs_ers_flow_and_restores(monkeypatch: pytest.MonkeyPatch) -> None: + logger = MagicMock(spec=logging.Logger) + + class FakeHC: + Base = {"handlers": {HandlerType.Stream}, "stream": StreamType.BASE} + Opmon = {"handlers": {HandlerType.Rich}, "stream": StreamType.OPMON} + + def __init__(self, init_ers: bool = False) -> None: + self._ers = { + "ers_handlers": {"DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf(handlers=[HandlerType.Rich])}, + "stream": StreamType.ERS, + } + if init_ers: + self.init_ers_stream() + + @property + def ERS(self) -> dict: + return self._ers + + def init_ers_stream(self) -> None: + return None + + restore_mock = MagicMock() + + monkeypatch.setattr(demo, "LogHandlerConf", FakeHC) + monkeypatch.setattr(demo, "restore_original_envs", restore_mock) + + demo.test_handlerconf(logger) + + restore_mock.assert_called_once() + logger.warning.assert_called() + logger.info.assert_called() + + +def test_demo_test_ers_handler_configuration_calls_setup_and_logs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + logger_name = f"demo.ers.{uuid.uuid4()}" + logger = logging.getLogger(logger_name) + logger.handlers = [] + logger.filters = [] + logger.propagate = False + + get_logger_mock = MagicMock(return_value=logger) + setup_mock = MagicMock() + + class FakeHC: + def __init__(self, init_ers: bool = False) -> None: + assert init_ers is True + self._ers = {"stream": StreamType.ERS, "ers_handlers": {}} + + @property + def ERS(self) -> dict: + return self._ers + + monkeypatch.setattr(demo, "get_daq_logger", get_logger_mock) + monkeypatch.setattr(demo, "setup_daq_ers_logger", setup_mock) + monkeypatch.setattr(demo, "LogHandlerConf", FakeHC) + + demo.test_ers_handler_configuration("INFO") + + get_logger_mock.assert_called_once() + setup_mock.assert_called_once_with(logger, "session_temp") + + logging.root.manager.loggerDict.pop(logger_name, None) diff --git a/tests/logging/test_handlers.py b/tests/logging/test_handlers.py index a7d02b0..636ca4f 100644 --- a/tests/logging/test_handlers.py +++ b/tests/logging/test_handlers.py @@ -1,662 +1,471 @@ -"""Comprehensive tests for the logging filters in handlers.py. - -Tests cover: -- BaseHandlerFilter: Handler selection logic for both ERS and non-ERS paths -- HandleIDFilter: Filter that accepts only specific handler types -- ThrottleFilter: Advanced throttling with escalating thresholds and time windows -- Integration: Real logger usage with filters and handlers -""" - -import copy -import io import logging -import time -from threading import Thread -from unittest.mock import MagicMock +import uuid +from collections.abc import Iterator +from unittest.mock import MagicMock, call import pytest -from daqpytools.logging.handlers import ( - BaseHandlerFilter, - ERSPyLogHandlerConf, - HandleIDFilter, - HandlerType, - IssueRecord, - ProtobufConf, - StreamType, - ThrottleFilter, -) -from daqpytools.logging.levels import level_to_ers_var - -# ============================================================================ -# FIXTURES -# ============================================================================ +from daqpytools.apps import logging_demonstrator as demo +from daqpytools.logging.exceptions import ERSInitError, LoggerHandlerError +from daqpytools.logging.filters import HandleIDFilter +from daqpytools.logging.formatter import LoggingFormatter +from daqpytools.logging.handlerconf import HandlerType +from daqpytools.logging.rich_handler import FormattedRichHandler +from daqpytools.logging import handlers as handlers_mod @pytest.fixture -def clean_logger(): - """Provide a clean logger with no handlers or filters.""" - logger = logging.getLogger("test_logger_handlers") +def clean_logger() -> Iterator[logging.Logger]: + name = f"test.handlers.{uuid.uuid4()}" + logger = logging.getLogger(name) logger.handlers = [] logger.filters = [] + logger.propagate = False logger.setLevel(logging.DEBUG) - return logger + yield logger + for handler in logger.handlers[:]: + logger.removeHandler(handler) + try: + handler.close() + except Exception: + pass + logger.filters = [] + logging.root.manager.loggerDict.pop(name, None) @pytest.fixture -def log_record() -> logging.LogRecord: - """Provide a basic log record for testing.""" - return logging.LogRecord( - name="test.module", - level=logging.ERROR, - pathname="/path/to/test.py", - lineno=42, - msg="Test message", - args=(), - exc_info=None, - ) +def parent_child_loggers() -> Iterator[tuple[logging.Logger, logging.Logger]]: + parent_name = f"test.handlers.parent.{uuid.uuid4()}" + child_name = f"{parent_name}.child" + parent = logging.getLogger(parent_name) + child = logging.getLogger(child_name) -@pytest.fixture -def ers_log_record(): - """Provide a log record configured for ERS streaming.""" - record = logging.LogRecord( - name="test.module", - level=logging.ERROR, - pathname="/path/to/test.py", - lineno=67, - msg="ERS message", - args=(), - exc_info=None, + parent.handlers = [] + parent.filters = [] + parent.propagate = False + parent.setLevel(logging.DEBUG) + + child.handlers = [] + child.filters = [] + child.propagate = True + child.setLevel(logging.DEBUG) + + yield parent, child + + for logger in [child, parent]: + for handler in logger.handlers[:]: + logger.removeHandler(handler) + try: + handler.close() + except Exception: + pass + logger.filters = [] + logging.root.manager.loggerDict.pop(logger.name, None) + + +def test_logger_has_handler_non_logger_returns_false() -> None: + assert handlers_mod.logger_has_handler(MagicMock(), logging.StreamHandler) is False + + +def test_logger_has_handler_matches_non_stream_type(clean_logger: logging.Logger) -> None: + handler = logging.NullHandler() + clean_logger.addHandler(handler) + assert handlers_mod.logger_has_handler(clean_logger, logging.NullHandler) is True + + +def test_logger_has_handler_matches_stream_by_target_stream( + clean_logger: logging.Logger, +) -> None: + stdout_handler = logging.StreamHandler(handlers_mod.STDOUT_HANDLER_SPEC.target_stream) + clean_logger.addHandler(stdout_handler) + + assert ( + handlers_mod.logger_has_handler( + clean_logger, + logging.StreamHandler, + target_stream=handlers_mod.STDOUT_HANDLER_SPEC.target_stream, + ) + is True + ) + assert ( + handlers_mod.logger_has_handler( + clean_logger, + logging.StreamHandler, + target_stream=handlers_mod.STDERR_HANDLER_SPEC.target_stream, + ) + is False ) - record.stream = StreamType.ERS - return record -@pytest.fixture -def mock_ers_handlers(): - """Provide mock ERS handler configuration for testing.""" - handlers_config = {} - for level_var in level_to_ers_var.values(): - conf = ERSPyLogHandlerConf( - handlers=[HandlerType.Throttle, HandlerType.Protobufstream], - protobufconf=ProtobufConf(url="monkafka.cern.ch", port=30092), +def test_logger_has_filter_detects_filter_type(clean_logger: logging.Logger) -> None: + clean_logger.addFilter(logging.Filter("named.filter")) + assert handlers_mod.logger_has_filter(clean_logger, logging.Filter) is True + + +def test_ancestors_have_handlers_returns_false_when_disabled( + parent_child_loggers: tuple[logging.Logger, logging.Logger], +) -> None: + _, child = parent_child_loggers + assert handlers_mod.ancestors_have_handlers(child, False, logging.NullHandler) is False + + +def test_ancestors_have_handlers_rejects_root_logger() -> None: + with pytest.raises(ValueError, match="root logger"): + handlers_mod.ancestors_have_handlers( + logging.getLogger(), + True, + logging.NullHandler, ) - handlers_config[level_var] = conf - return handlers_config -# ============================================================================ -# BaseHandlerFilter Tests -# ============================================================================ +def test_ancestors_have_handlers_requires_target_for_streamhandler( + clean_logger: logging.Logger, +) -> None: + with pytest.raises(ValueError, match="target_stream must be specified"): + handlers_mod.ancestors_have_handlers( + clean_logger, + True, + logging.StreamHandler, + ) -class TestBaseHandlerFilter: - """Tests for BaseHandlerFilter.get_allowed() logic.""" +def test_ancestors_have_handlers_rejects_target_for_non_stream( + clean_logger: logging.Logger, +) -> None: + with pytest.raises(ValueError, match="target_stream can only be specified"): + handlers_mod.ancestors_have_handlers( + clean_logger, + True, + logging.NullHandler, + target_stream=handlers_mod.STDOUT_HANDLER_SPEC.target_stream, + ) - def test_non_ers_uses_record_handlers_attribute( - self, log_record: logging.LogRecord - ): - """Test get_allowed() uses 'handlers' attribute from record for non-ERS.""" - log_record.handlers = [HandlerType.Rich, HandlerType.File] - filter_obj = BaseHandlerFilter() - allowed = filter_obj.get_allowed(log_record) +def test_ancestors_have_handlers_detects_parent_handler( + parent_child_loggers: tuple[logging.Logger, logging.Logger], +) -> None: + parent, child = parent_child_loggers + parent.addHandler(logging.NullHandler()) - assert allowed == [HandlerType.Rich, HandlerType.File] + assert handlers_mod.ancestors_have_handlers(child, True, logging.NullHandler) is True - def test_non_ers_defaults_to_base_handlers( - self, log_record: logging.LogRecord - ): - """Test get_allowed() falls back to default handlers when attribute missing.""" - # log_record has no 'handlers' attribute - filter_obj = BaseHandlerFilter() - allowed = filter_obj.get_allowed(log_record) +def test_check_parent_handlers_raises_loggerhandlererror( + parent_child_loggers: tuple[logging.Logger, logging.Logger], +) -> None: + parent, child = parent_child_loggers + parent.addHandler(logging.NullHandler()) - # Should return the base handlers from LogHandlerConf - assert allowed is not None - expected_handlers = {HandlerType.Stream, HandlerType.Rich, HandlerType.File} - assert expected_handlers.issubset(set(allowed)) + with pytest.raises(LoggerHandlerError): + handlers_mod.check_parent_handlers(child, True, logging.NullHandler) - def test_ers_path_valid_configuration( - self, ers_log_record: logging.LogRecord, mock_ers_handlers: dict - ): - """Test get_allowed() extracts ERS handlers correctly with valid config.""" - ers_log_record.ers_handlers = mock_ers_handlers - filter_obj = BaseHandlerFilter() - allowed = filter_obj.get_allowed(ers_log_record) +def test_logger_or_ancestors_have_handler_checks_local_then_parent( + parent_child_loggers: tuple[logging.Logger, logging.Logger], +) -> None: + parent, child = parent_child_loggers + assert handlers_mod.logger_or_ancestors_have_handler(child, True, logging.NullHandler) is False - assert allowed == [HandlerType.Throttle, HandlerType.Protobufstream] + parent.addHandler(logging.NullHandler()) + assert handlers_mod.logger_or_ancestors_have_handler(child, True, logging.NullHandler) is True - def test_ers_path_no_matching_level_variable( - self, ers_log_record: logging.LogRecord, mock_ers_handlers: dict - ): - """Test get_allowed() returns None when log level has no ERS mapping.""" - # Set a log level that does not have an ERS equivalent - ers_log_record.levelno = 25 # Between INFO and WARNING - ers_log_record.ers_handlers = mock_ers_handlers - filter_obj = BaseHandlerFilter() - allowed = filter_obj.get_allowed(ers_log_record) - assert allowed is None +def test_get_handler_specs_returns_expected_specs() -> None: + assert len(handlers_mod.get_handler_specs(HandlerType.Rich)) == 1 + assert len(handlers_mod.get_handler_specs(HandlerType.Lstdout)) == 1 + assert len(handlers_mod.get_handler_specs(HandlerType.Lstderr)) == 1 + assert len(handlers_mod.get_handler_specs(HandlerType.Stream)) == 2 + assert len(handlers_mod.get_handler_specs(HandlerType.File)) == 1 + assert len(handlers_mod.get_handler_specs(HandlerType.Protobufstream)) == 1 -# ============================================================================ -# HandleIDFilter Tests -# ============================================================================ +def test_build_rich_handler_uses_get_width_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(handlers_mod, "get_width", lambda: 111) + handler = handlers_mod._build_rich_handler() + assert isinstance(handler, FormattedRichHandler) + assert handler.console.width == 111 -class TestHandleIDFilter: - """Tests for HandleIDFilter.filter() logic.""" +def test_build_stdout_handler_sets_formatter() -> None: + handler = handlers_mod._build_stdout_handler() + assert isinstance(handler, logging.StreamHandler) + assert handler.stream is handlers_mod.STDOUT_HANDLER_SPEC.target_stream + assert isinstance(handler.formatter, LoggingFormatter) - def test_single_handler_id_normalized_to_set(self): - """Test that single handler_id is normalized to a set.""" - filter_obj = HandleIDFilter(HandlerType.Rich) - assert isinstance(filter_obj.handler_ids, set) - assert HandlerType.Rich in filter_obj.handler_ids +def test_build_stderr_handler_sets_level_and_formatter() -> None: + handler = handlers_mod._build_stderr_handler() + assert isinstance(handler, logging.StreamHandler) + assert handler.stream is handlers_mod.STDERR_HANDLER_SPEC.target_stream + assert handler.level == logging.ERROR + assert isinstance(handler.formatter, LoggingFormatter) - def test_list_handler_ids_converted_to_set(self): - """Test that list of handler_ids is converted to a set.""" - handlers = [HandlerType.Rich, HandlerType.File] - filter_obj = HandleIDFilter(handlers) - assert isinstance(filter_obj.handler_ids, set) - assert filter_obj.handler_ids == {HandlerType.Rich, HandlerType.File} +def test_build_file_handler_requires_path() -> None: + with pytest.raises(ValueError, match="path is required"): + handlers_mod._build_file_handler() - def test_filter_returns_true_when_handler_in_allowed( - self, log_record: logging.LogRecord - ): - """Test filter() returns True when handler_id is in allowed list.""" - log_record.handlers = [HandlerType.Rich, HandlerType.File, HandlerType.Stream] - filter_obj = HandleIDFilter(HandlerType.Rich) - result = filter_obj.filter(log_record) +def test_build_file_handler_creates_handler_with_formatter(tmp_path: pytest.TempPathFactory) -> None: + file_path = tmp_path / "test.log" + handler = handlers_mod._build_file_handler(path=str(file_path)) + assert isinstance(handler, logging.FileHandler) + assert isinstance(handler.formatter, LoggingFormatter) + handler.close() - assert result is True - def test_filter_returns_false_when_handler_not_in_allowed( - self, log_record: logging.LogRecord - ): - """Test filter() returns False when handler_id not in allowed.""" - log_record.handlers = [HandlerType.File, HandlerType.Stream] - filter_obj = HandleIDFilter(HandlerType.Rich) +def test_build_erskafka_handler_wraps_exception(monkeypatch: pytest.MonkeyPatch) -> None: + def _raise(*args: object, **kwargs: object) -> None: + del args, kwargs + raise RuntimeError("boom") - result = filter_obj.filter(log_record) + monkeypatch.setattr(handlers_mod, "ERSKafkaLogHandler", _raise) + with pytest.raises(ERSInitError): + handlers_mod._build_erskafka_handler(session_name="s1") - assert result is False - def test_filter_returns_false_when_get_allowed_returns_none( - self, log_record: logging.LogRecord - ): - """Test filter() returns False when get_allowed() returns None.""" - filter_obj = HandleIDFilter(HandlerType.Rich) - filter_obj.get_allowed = MagicMock(return_value=None) +def test_build_erskafka_handler_success_passes_arguments( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeKafkaHandler: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs - result = filter_obj.filter(log_record) + monkeypatch.setattr(handlers_mod, "ERSKafkaLogHandler", FakeKafkaHandler) + handler = handlers_mod._build_erskafka_handler( + session_name="session_x", + topic="topic_x", + address="addr_x", + ers_app_name="app_x", + ) - assert result is False + assert isinstance(handler, FakeKafkaHandler) + assert handler.kwargs["session"] == "session_x" + assert handler.kwargs["kafka_address"] == "addr_x" + assert handler.kwargs["kafka_topic"] == "topic_x" + assert handler.kwargs["app_name"] == "app_x" - def test_filter_with_multiple_handler_ids( - self, log_record: logging.LogRecord - ): - """Test filter() with multiple handler_ids checks intersection.""" - log_record.handlers = [HandlerType.Rich, HandlerType.File] - filter_obj = HandleIDFilter([HandlerType.Rich, HandlerType.Stream]) - result = filter_obj.filter(log_record) +def test_add_handler_adds_single_spec_and_handleidfilter(clean_logger: logging.Logger) -> None: + handlers_mod.add_handler(clean_logger, HandlerType.Rich, use_parent_handlers=True) - # Should return True because Rich is in both sets - assert result is True + assert len(clean_logger.handlers) == 1 + assert isinstance(clean_logger.handlers[0], FormattedRichHandler) + assert any( + isinstance(logger_filter, HandleIDFilter) + for logger_filter in clean_logger.handlers[0].filters + ) - def test_filter_no_intersection_with_multiple_ids( - self, log_record: logging.LogRecord - ): - """Test filter() returns False when no intersection with multiple ids.""" - log_record.handlers = [HandlerType.File] - filter_obj = HandleIDFilter([HandlerType.Rich, HandlerType.Stream]) - result = filter_obj.filter(log_record) +def test_add_handler_skips_when_matching_handler_exists(clean_logger: logging.Logger) -> None: + handlers_mod.add_handler(clean_logger, HandlerType.Rich, use_parent_handlers=True) + handlers_mod.add_handler(clean_logger, HandlerType.Rich, use_parent_handlers=True) + assert len(clean_logger.handlers) == 1 - assert result is False +def test_add_handler_skips_when_parent_has_handler( + parent_child_loggers: tuple[logging.Logger, logging.Logger], +) -> None: + parent, child = parent_child_loggers + handlers_mod.add_handler(parent, HandlerType.Rich, use_parent_handlers=True) + handlers_mod.add_handler(child, HandlerType.Rich, use_parent_handlers=True) -# ============================================================================ -# ThrottleFilter Tests -# ============================================================================ + assert len(parent.handlers) == 1 + assert len(child.handlers) == 0 -class TestThrottleFilter: - """Tests for ThrottleFilter throttling and suppression logic.""" - - def test_initial_phase_lets_through_first_n_messages( - self, log_record: logging.LogRecord - ): - """Test that first N messages pass through without suppression.""" - log_record.handlers = [HandlerType.Throttle] - filter_obj = ThrottleFilter(initial_threshold=3, time_limit=10) +def test_add_handler_accepts_string_type(clean_logger: logging.Logger) -> None: + handlers_mod.add_handler(clean_logger, "rich", use_parent_handlers=True) + assert len(clean_logger.handlers) == 1 - # First 3 messages should pass - assert filter_obj.filter(log_record) is True - assert filter_obj.filter(log_record) is True - assert filter_obj.filter(log_record) is True - - def test_after_initial_threshold_suppresses( - self, log_record: logging.LogRecord - ): - """Test that messages are suppressed after initial_threshold.""" - log_record.handlers = [HandlerType.Throttle] - filter_obj = ThrottleFilter(initial_threshold=2, time_limit=10) - - # First 2 pass - assert filter_obj.filter(log_record) is True - assert filter_obj.filter(log_record) is True - - # 3rd should be suppressed - assert filter_obj.filter(log_record) is False - - def test_escalating_threshold_doubles_on_report( - self, log_record: logging.LogRecord - ): - """Test that threshold escalates (10->100->1000) when reporting.""" - log_record.handlers = [HandlerType.Throttle] - filter_obj = ThrottleFilter(initial_threshold=1, time_limit=100) - - issue_id = f"{log_record.pathname}:{log_record.lineno}" - issue_record = filter_obj.issue_map[issue_id] - - # First is emitted - # Next 10 are suppressed - # needs 1 more to trigger update - for _ in range(12): - filter_obj._throttle(issue_record, log_record) - - - assert issue_record.threshold == 100 # Escalated from 10 - - def test_time_window_reset_resets_counters( - self, log_record: logging.LogRecord, monkeypatch: pytest.MonkeyPatch - ): - """Test that state resets after time_limit expires.""" - log_record.handlers = [HandlerType.Throttle] - filter_obj = ThrottleFilter(initial_threshold=1, time_limit=1) - - times = iter([1000.0, 1002.5]) - monkeypatch.setattr(time, "time", lambda: next(times)) - - # First message passes - assert filter_obj.filter(log_record) is True - - # Time advances beyond time_limit with no suppression, reset should allow pass - assert filter_obj.filter(log_record) is True - - def test_suppressed_counter_increments( - self, log_record: logging.LogRecord - ): - """Test that suppressed_counter increments for each suppressed message.""" - log_record.handlers = [HandlerType.Throttle] - filter_obj = ThrottleFilter(initial_threshold=0, time_limit=100) - - issue_id = f"{log_record.pathname}:{log_record.lineno}" - issue_record = filter_obj.issue_map[issue_id] - - # Send 5 messages - for i in range(5): - filter_obj.filter(log_record) - # After initial messages handled, counter should increment - if i > 0: - assert issue_record.suppressed_counter >= 0 - - def test_throttle_suppression_flag_bypasses_filter( - self, log_record: logging.LogRecord - ): - """Test that _throttle_suppression flag allows suppression messages through.""" - log_record.handlers = [HandlerType.Throttle] - filter_obj = ThrottleFilter(initial_threshold=0, time_limit=100) - - # Normal message is suppressed - assert filter_obj.filter(log_record) is False - - # Same message with suppression flag bypasses filter - log_record._throttle_suppression = True - assert filter_obj.filter(log_record) is True - - def test_get_allowed_returns_none_skips_throttle( - self, log_record: logging.LogRecord - ): - """Test filter() returns True if get_allowed() returns None.""" - filter_obj = ThrottleFilter() - filter_obj.get_allowed = MagicMock(return_value=None) - - # Should return False because allowed is None - result = filter_obj.filter(log_record) - assert result is False - - def test_throttle_not_in_allowed_returns_true( - self, log_record: logging.LogRecord - ): - """Test filter() returns True if Throttle not in allowed handlers.""" - log_record.handlers = [HandlerType.Rich, HandlerType.File] - filter_obj = ThrottleFilter(initial_threshold=0, time_limit=10) - - # Throttle not in allowed, so should return True - assert filter_obj.filter(log_record) is True - - def test_timestamp_formatting(self): - """Test that timestamp formatting produces valid ISO format.""" - filter_obj = ThrottleFilter() - timestamp = time.time() - - formatted = filter_obj._format_timestamp(timestamp) - - # Should be ISO format with microseconds - assert len(formatted) == 26 # YYYY-MM-DD HH:MM:SS.ffffff - assert formatted.count("-") == 2 # Two dashes for date - assert formatted.count(":") == 2 # Two colons for time - - def test_different_issues_tracked_separately( - self, log_record: logging.LogRecord - ): - """Test that different file:line combinations track state separately.""" - filter_obj = ThrottleFilter(initial_threshold=2, time_limit=10) - - # First issue - record1 = copy.deepcopy(log_record) - record1.pathname = "/path1.py" - record1.lineno = 10 - record1.handlers = [HandlerType.Throttle] - - # Second issue - record2 = copy.deepcopy(log_record) - record2.pathname = "/path2.py" - record2.lineno = 20 - record2.handlers = [HandlerType.Throttle] - - # Both pass initial threshold - assert filter_obj.filter(record1) is True - assert filter_obj.filter(record2) is True - - # Issue 1: passes again - assert filter_obj.filter(record1) is True - - # Issue 2: passes again (separate tracking) - assert filter_obj.filter(record2) is True - - # Issue 1: suppressed - assert filter_obj.filter(record1) is False - - # Issue 2: suppressed (independent) - assert filter_obj.filter(record2) is False - - def test_thread_safety_concurrent_issues( - self, log_record: logging.LogRecord - ): - """Test ThrottleFilter is thread-safe with concurrent logging.""" - filter_obj = ThrottleFilter(initial_threshold=5, time_limit=10) - log_record.handlers = [HandlerType.Throttle] - results = [] - - def log_messages(record: logging.LogRecord, num_messages: int) -> None: - """Log from a thread.""" - for _ in range(num_messages): - result = filter_obj.filter(record) - results.append(result) - - # Create threads logging to same issue - threads = [] - for _ in range(3): - thread = Thread(target=log_messages, args=(log_record, 10)) - threads.append(thread) - thread.start() - - # Wait for all threads - for thread in threads: - thread.join() - - # Should have completed without deadlock - assert len(results) == 30 - # First 5 should pass (initial threshold) - assert results[:5].count(True) >= 3 # At least some early ones pass - - -# ============================================================================ -# IssueRecord Tests -# ============================================================================ - - -class TestIssueRecord: - """Tests for IssueRecord state tracking.""" - - def test_init_sets_defaults(self): - """Test that __init__ sets proper default values.""" - record = IssueRecord() - - assert record.last_occurrence == 0.0 - assert record.last_report == 0.0 - assert record.initial_counter == 0 - assert record.threshold == 10 - assert record.suppressed_counter == 0 - assert record.last_occurrence_formatted == "" - - def test_reset_clears_all_state(self): - """Test that reset() clears all counters and timestamps.""" - record = IssueRecord() - record.last_occurrence = 100.0 - record.initial_counter = 5 - record.suppressed_counter = 20 - record.threshold = 100 - record.last_occurrence_formatted = "2025-01-01 12:00:00.000000" - - record.reset() - - assert record.last_occurrence == 0.0 - assert record.last_report == 0.0 - assert record.initial_counter == 0 - assert record.threshold == 10 - assert record.suppressed_counter == 0 - assert record.last_occurrence_formatted == "" - - -# ============================================================================ -# Integration Tests -# ============================================================================ - - -class TestFiltersIntegration: - """Integration tests with real logger setup.""" - - def test_logger_with_handle_id_filter(self, clean_logger: logging.Logger): - """Test logger with HandleIDFilter allows only specific handlers.""" - stream = io.StringIO() - handler = logging.StreamHandler(stream) - handler.addFilter(HandleIDFilter(HandlerType.Stream)) - - clean_logger.addHandler(handler) - - # Log with matching handler type - record = logging.LogRecord( - name=clean_logger.name, - level=logging.INFO, - pathname="test.py", - lineno=1, - msg="Test message", - args=(), - exc_info=None, - ) - record.handlers = [HandlerType.Stream, HandlerType.Rich] - - clean_logger.handle(record) - - # Message should appear because Stream is in allowed - assert "Test message" in stream.getvalue() - - def test_logger_with_throttle_filter(self, clean_logger: logging.Logger): - """Test logger correctly suppresses messages with ThrottleFilter.""" - stream = io.StringIO() - handler = logging.StreamHandler(stream) - handler.setFormatter(logging.Formatter("%(message)s")) - filter_obj = ThrottleFilter(initial_threshold=2, time_limit=10) - handler.addFilter(filter_obj) - - clean_logger.addHandler(handler) - clean_logger.setLevel(logging.INFO) - - record = logging.LogRecord( - name=clean_logger.name, - level=logging.INFO, - pathname="test.py", - lineno=10, - msg="Repeated message", - args=(), - exc_info=None, - ) - record.handlers = [HandlerType.Throttle] - - # Log 5 times - for _ in range(5): - clean_logger.handle(record) - - output = stream.getvalue() - - # First 2 should appear, then suppression message - assert output.count("Repeated message") >= 2 - - def test_chained_filters(self, clean_logger: logging.Logger): - """Test stacking HandleIDFilter and ThrottleFilter.""" - stream = io.StringIO() - handler = logging.StreamHandler(stream) - handler.setFormatter(logging.Formatter("%(message)s")) - - # Add both filters - handler.addFilter(HandleIDFilter(HandlerType.Throttle)) - handler.addFilter(ThrottleFilter(initial_threshold=1, time_limit=10)) - - clean_logger.addHandler(handler) - clean_logger.setLevel(logging.INFO) - - record = logging.LogRecord( - name=clean_logger.name, - level=logging.INFO, - pathname="test.py", - lineno=10, - msg="Chained filters test", - args=(), - exc_info=None, - ) - record.handlers = [HandlerType.Throttle] - # Log message - clean_logger.handle(record) +def test_add_handler_unknown_string_does_nothing(clean_logger: logging.Logger) -> None: + handlers_mod.add_handler(clean_logger, "unknown_type", use_parent_handlers=True) + assert len(clean_logger.handlers) == 0 - # Should appear in output - output = stream.getvalue() - assert "Chained filters test" in output +def test_add_handler_uses_explicit_fallback_override(clean_logger: logging.Logger) -> None: + override = {HandlerType.Unknown} + handlers_mod.add_handler( + clean_logger, + HandlerType.Rich, + use_parent_handlers=True, + fallback_handler=override, + ) -# ============================================================================ -# Edge Cases and Error Handling -# ============================================================================ + handler_filter = next( + logger_filter + for logger_filter in clean_logger.handlers[0].filters + if isinstance(logger_filter, HandleIDFilter) + ) + assert handler_filter.fallback_handlers == override -class TestEdgeCases: - """Tests for edge cases and boundary conditions.""" +def test_add_handler_for_stream_adds_stdout_and_stderr(clean_logger: logging.Logger) -> None: + handlers_mod.add_handler(clean_logger, HandlerType.Stream, use_parent_handlers=True) + stream_handlers = [ + handler for handler in clean_logger.handlers if isinstance(handler, logging.StreamHandler) + ] + assert len(stream_handlers) == 2 - def test_empty_handlers_list(self, log_record: logging.LogRecord): - """Test filter behavior with empty handlers list.""" - log_record.handlers = [] - filter_obj = HandleIDFilter(HandlerType.Rich) - result = filter_obj.filter(log_record) +def test_add_handlers_from_types_stream_deduplicates(clean_logger: logging.Logger) -> None: + handlers_mod.add_handlers_from_types( + clean_logger, + {HandlerType.Stream, HandlerType.Lstdout, HandlerType.Lstderr}, + use_parent_handlers=True, + fallback_handlers={HandlerType.Stream}, + ) + stream_handlers = [ + handler for handler in clean_logger.handlers if isinstance(handler, logging.StreamHandler) + ] + assert len(stream_handlers) == 2 + + +def test_add_handlers_from_types_routes_to_filter_spec( + clean_logger: logging.Logger, + monkeypatch: pytest.MonkeyPatch, +) -> None: + add_filter_mock = MagicMock() + monkeypatch.setattr(handlers_mod, "add_filter", add_filter_mock) + + handlers_mod.add_handlers_from_types( + clean_logger, + {HandlerType.Throttle}, + use_parent_handlers=True, + fallback_handlers={HandlerType.Throttle}, + ) + + add_filter_mock.assert_called_once() - assert result is False - def test_none_handlers_attribute(self, log_record: logging.LogRecord): - """Test filter when record.handlers is None.""" - log_record.handlers = None - filter_obj = HandleIDFilter(HandlerType.Rich) +def test_add_handlers_from_types_no_duplicate_filter( + clean_logger: logging.Logger, + monkeypatch: pytest.MonkeyPatch, +) -> None: + clean_logger.addFilter(MagicMock(spec=handlers_mod.get_filter_spec(HandlerType.Throttle).filter_class)) + add_filter_mock = MagicMock() + monkeypatch.setattr(handlers_mod, "add_filter", add_filter_mock) - # get_allowed should handle None gracefully - result = filter_obj.filter(log_record) - assert result is False + handlers_mod.add_handlers_from_types( + clean_logger, + {HandlerType.Throttle}, + use_parent_handlers=True, + fallback_handlers={HandlerType.Throttle}, + ) - def test_throttle_with_zero_initial_threshold( - self, log_record: logging.LogRecord - ): - """Test ThrottleFilter with initial_threshold=0.""" - log_record.handlers = [HandlerType.Throttle] - filter_obj = ThrottleFilter(initial_threshold=0, time_limit=10) + add_filter_mock.assert_not_called() - # All messages should be suppressed after first - assert filter_obj.filter(log_record) is False - def test_issue_record_key_format(self, log_record: logging.LogRecord): - """Test that issue_record key is formatted correctly.""" - filter_obj = ThrottleFilter() +# demonstrator test_* parity integrated into handlers tests - issue_id = f"{log_record.pathname}:{log_record.lineno}" - record = filter_obj.issue_map[issue_id] +def test_demo_test_main_functions_emits_expected_levels() -> None: + logger = MagicMock(spec=logging.Logger) - assert isinstance(record, IssueRecord) + demo.test_main_functions(logger) - def test_multiple_handler_types_intersection( - self, log_record: logging.LogRecord - ): - """Test set intersection with multiple handler types.""" - log_record.handlers = [ - HandlerType.Rich, - HandlerType.File, - HandlerType.Stream, - ] - filter_obj = HandleIDFilter([HandlerType.Rich, HandlerType.Lstdout]) - - # Rich is in the intersection - result = filter_obj.filter(log_record) - assert result is True - - def test_protobuf_conf_in_ers_handlers( - self, ers_log_record: logging.LogRecord, mock_ers_handlers: dict - ): - """Test that ProtobufConf is properly included in ERS configuration.""" - ers_log_record.ers_handlers = mock_ers_handlers - filter_obj = BaseHandlerFilter() - - allowed = filter_obj.get_allowed(ers_log_record) - - assert HandlerType.Protobufstream in allowed - - def test_suppression_message_includes_count( - self, clean_logger: logging.Logger - ): - """Test that suppression message includes suppressed count.""" - stream = io.StringIO() - handler = logging.StreamHandler(stream) - handler.setFormatter(logging.Formatter("%(message)s")) - - # Create a throttle filter that will suppress quickly - throttle_filter = ThrottleFilter(initial_threshold=1, time_limit=10) - handler.addFilter(throttle_filter) - - clean_logger.addHandler(handler) - clean_logger.setLevel(logging.INFO) - - record = logging.LogRecord( - name=clean_logger.name, - level=logging.INFO, - pathname="test.py", - lineno=10, - msg="Test", - args=(), - exc_info=None, - ) - record.handlers = [HandlerType.Throttle] + logger.debug.assert_called_once() + assert logger.info.call_count >= 2 + assert logger.warning.call_count >= 2 + logger.error.assert_called_once() + logger.critical.assert_called_once() - # Send messages to trigger suppression - for _ in range(15): - clean_logger.handle(copy.deepcopy(record)) - output = stream.getvalue() +def test_demo_test_child_logger_builds_child_and_logs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + child_logger = MagicMock(spec=logging.Logger) + get_logger_mock = MagicMock(return_value=child_logger) + monkeypatch.setattr(demo, "get_daq_logger", get_logger_mock) - # Should contain suppression message with count - assert "suppressed" in output.lower() + demo.test_child_logger( + logger_name="parent.logger", + log_level="INFO", + disable_logger_inheritance=True, + rich_handler=True, + file_handler_path="/tmp/demo.log", + stream_handlers=True, + ) + + get_logger_mock.assert_called_once_with( + logger_name="parent.logger.child", + log_level="INFO", + use_parent_handlers=False, + rich_handler=True, + file_handler_path="/tmp/demo.log", + stream_handlers=True, + ) + child_logger.debug.assert_called_once() + child_logger.info.assert_called() + child_logger.warning.assert_called() + child_logger.error.assert_called_once() + child_logger.critical.assert_called_once() + + +def test_demo_test_throttle_uses_throttle_extra_and_sleep( + monkeypatch: pytest.MonkeyPatch, +) -> None: + logger = MagicMock(spec=logging.Logger) + sleep_mock = MagicMock() + monkeypatch.setattr(demo.time, "sleep", sleep_mock) + + demo.test_throttle(logger) + + sleep_mock.assert_called_once_with(31) + logger.warning.assert_called_once_with("Sleeping for 30 seconds") + assert logger.info.call_count == 1050 + + first_call_kwargs = logger.info.call_args_list[0].kwargs + assert first_call_kwargs["extra"]["handlers"] == [ + HandlerType.Rich, + HandlerType.Throttle, + ] + + +def test_demo_test_handlertypes_routes_expected_extras() -> None: + logger = MagicMock(spec=logging.Logger) + + demo.test_handlertypes(logger) + + critical_calls = logger.critical.call_args_list + assert any(c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Rich] for c in critical_calls) + assert any(c.kwargs.get("extra", {}).get("handlers") == [HandlerType.File] for c in critical_calls) + assert any(c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Lstdout] for c in critical_calls) + assert any(c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Throttle] for c in critical_calls) + assert any( + c.kwargs.get("extra", {}).get("handlers") + == [HandlerType.Rich, HandlerType.Protobufstream] + for c in critical_calls + ) + + +def test_demo_test_fallback_handlers_calls_add_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + logger = MagicMock(spec=logging.Logger) + get_logger_mock = MagicMock(return_value=logger) + add_handler_mock = MagicMock() + + monkeypatch.setattr(demo, "get_daq_logger", get_logger_mock) + monkeypatch.setattr(demo, "add_handler", add_handler_mock) + + demo.test_fallback_handlers("DEBUG") + + get_logger_mock.assert_called_once_with( + logger_name="fallback_logger", + log_level="DEBUG", + stream_handlers=False, + rich_handler=True, + ) + add_handler_mock.assert_has_calls( + [ + call(logger, HandlerType.Lstdout, True), + call( + logger, + HandlerType.Lstderr, + True, + fallback_handler={HandlerType.Unknown}, + ), + ] + ) diff --git a/tests/logging/test_routing.py b/tests/logging/test_routing.py new file mode 100644 index 0000000..1587c41 --- /dev/null +++ b/tests/logging/test_routing.py @@ -0,0 +1,148 @@ +import logging + +from daqpytools.logging.handlerconf import ERSPyLogHandlerConf, HandlerType, StreamType +from daqpytools.logging.routing import ( + AllowedHandlersStrategy, + DefaultAllowedHandlerStrategy, + ERSAllowedHandlersStrategy, + StreamAwareAllowedHandlersStrategy, +) + + +class _StrategyForHelper(AllowedHandlersStrategy): + def resolve(self, record: logging.LogRecord, fallback_handlers: set[object]) -> set[object] | None: + del record, fallback_handlers + return None + + +def _record(level: int = logging.INFO) -> logging.LogRecord: + return logging.LogRecord( + name="test.routing", + level=level, + pathname="/tmp/test_routing.py", + lineno=10, + msg="message", + args=(), + exc_info=None, + ) + + +def test_safe_return_set_filters_out_none() -> None: + strategy = _StrategyForHelper() + assert strategy.safe_return_set({None, HandlerType.Rich}) == {HandlerType.Rich} + + +def test_safe_return_set_returns_empty_set_when_all_none() -> None: + strategy = _StrategyForHelper() + assert strategy.safe_return_set({None}) == set() + + +def test_default_strategy_uses_record_handlers_when_present() -> None: + strategy = DefaultAllowedHandlerStrategy() + record = _record() + record.handlers = {HandlerType.Rich, None} + + assert strategy.resolve(record, {HandlerType.File}) == {HandlerType.Rich} + + +def test_default_strategy_uses_fallback_when_handlers_missing() -> None: + strategy = DefaultAllowedHandlerStrategy() + record = _record() + + assert strategy.resolve(record, {HandlerType.File}) == {HandlerType.File} + + +def test_default_strategy_returns_none_when_handlers_attr_is_none() -> None: + strategy = DefaultAllowedHandlerStrategy() + record = _record() + record.handlers = None + + assert strategy.resolve(record, {HandlerType.File}) is None + + +def test_ers_strategy_returns_none_when_level_not_mapped() -> None: + strategy = ERSAllowedHandlersStrategy() + record = _record(level=25) + record.stream = StreamType.ERS + record.ers_handlers = {} + + assert strategy.resolve(record, set()) is None + + +def test_ers_strategy_returns_none_without_ers_handlers() -> None: + strategy = ERSAllowedHandlersStrategy() + record = _record(level=logging.ERROR) + record.stream = StreamType.ERS + + assert strategy.resolve(record, set()) is None + + +def test_ers_strategy_returns_none_when_level_conf_missing() -> None: + strategy = ERSAllowedHandlersStrategy() + record = _record(level=logging.ERROR) + record.stream = StreamType.ERS + record.ers_handlers = {"DUNEDAQ_ERS_WARNING": ERSPyLogHandlerConf(handlers=[HandlerType.Rich])} + + assert strategy.resolve(record, set()) is None + + +def test_ers_strategy_returns_handler_set_when_valid() -> None: + strategy = ERSAllowedHandlersStrategy() + record = _record(level=logging.ERROR) + record.stream = StreamType.ERS + record.ers_handlers = { + "DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf( + handlers=[HandlerType.Throttle, None], + ) + } + + assert strategy.resolve(record, set()) == {HandlerType.Throttle} + + +class _FakeDefault: + def __init__(self) -> None: + self.called = False + + def resolve(self, record: logging.LogRecord, fallback_handlers: set[object]) -> set[object] | None: + del record, fallback_handlers + self.called = True + return {HandlerType.File} + + +class _FakeERS: + def __init__(self) -> None: + self.called = False + + def resolve(self, record: logging.LogRecord, fallback_handlers: set[object]) -> set[object] | None: + del record, fallback_handlers + self.called = True + return {HandlerType.Throttle} + + +def test_streamaware_uses_ers_strategy_for_ers_stream() -> None: + default = _FakeDefault() + ers = _FakeERS() + strategy = StreamAwareAllowedHandlersStrategy(default_strategy=default, ers_strategy=ers) + + record = _record(level=logging.ERROR) + record.stream = StreamType.ERS + + result = strategy.resolve(record, {HandlerType.Rich}) + + assert result == {HandlerType.Throttle} + assert ers.called is True + assert default.called is False + + +def test_streamaware_uses_default_strategy_for_non_ers_stream() -> None: + default = _FakeDefault() + ers = _FakeERS() + strategy = StreamAwareAllowedHandlersStrategy(default_strategy=default, ers_strategy=ers) + + record = _record(level=logging.INFO) + + result = strategy.resolve(record, {HandlerType.Rich}) + + assert result == {HandlerType.File} + assert default.called is True + assert ers.called is False diff --git a/tests/logging/test_specs.py b/tests/logging/test_specs.py new file mode 100644 index 0000000..d22f7f1 --- /dev/null +++ b/tests/logging/test_specs.py @@ -0,0 +1,93 @@ +from dataclasses import FrozenInstanceError +import logging + +import pytest + +from daqpytools.logging.specs import FilterSpec, HandlerSpec + + +def _handler_factory(**kwargs: object) -> logging.Handler: + del kwargs + return logging.NullHandler() + + +def _filter_factory(fallback_handlers: set[object], extras: dict[str, object]) -> logging.Filter: + del fallback_handlers, extras + return logging.Filter() + + +def test_handlerspec_is_frozen() -> None: + spec = HandlerSpec( + alias="rich", + handler_class=logging.NullHandler, + factory=_handler_factory, + fallback_types=("rich",), + ) + + with pytest.raises(FrozenInstanceError): + spec.alias = "stream" + + +def test_handlerspec_stores_target_stream_optional() -> None: + spec_without = HandlerSpec( + alias="a", + handler_class=logging.NullHandler, + factory=_handler_factory, + fallback_types=("a",), + ) + assert spec_without.target_stream is None + + stream = object() + spec_with = HandlerSpec( + alias="b", + handler_class=logging.StreamHandler, + factory=_handler_factory, + fallback_types=("b",), + target_stream=stream, + ) + assert spec_with.target_stream is stream + + +def test_handlerspec_factory_signature_compatible() -> None: + spec = HandlerSpec( + alias="rich", + handler_class=logging.NullHandler, + factory=_handler_factory, + fallback_types=("rich",), + ) + handler = spec.factory(width=120) + assert isinstance(handler, logging.NullHandler) + + +def test_filterspec_is_frozen() -> None: + spec = FilterSpec( + alias="throttle", + filter_class=logging.Filter, + factory=_filter_factory, + ) + + with pytest.raises(FrozenInstanceError): + spec.alias = "x" + + +def test_filterspec_defaults_fallback_types_to_empty_tuple() -> None: + spec = FilterSpec( + alias="throttle", + filter_class=logging.Filter, + factory=_filter_factory, + ) + assert spec.fallback_types == () + + +def test_filterspec_accepts_factory_and_alias() -> None: + spec = FilterSpec( + alias="throttle", + filter_class=logging.Filter, + factory=_filter_factory, + fallback_types=("throttle",), + ) + + assert spec.alias == "throttle" + assert spec.filter_class is logging.Filter + assert spec.fallback_types == ("throttle",) + assert isinstance(spec.factory(set(), {}), logging.Filter) From 8d1e96d199ec6412d63cdc3998a8797fac838ea3 Mon Sep 17 00:00:00 2001 From: Emir Muhammad Date: Fri, 20 Mar 2026 16:59:58 +0100 Subject: [PATCH 4/6] cleanup pytest --- tests/logging/test_handlerconf.py | 41 +++++++---- tests/logging/test_handlers.py | 114 ++++++++++++++++++++++-------- tests/logging/test_routing.py | 32 +++++++-- tests/logging/test_specs.py | 6 +- 4 files changed, 143 insertions(+), 50 deletions(-) diff --git a/tests/logging/test_handlerconf.py b/tests/logging/test_handlerconf.py index 30fed0b..7681f77 100644 --- a/tests/logging/test_handlerconf.py +++ b/tests/logging/test_handlerconf.py @@ -1,5 +1,6 @@ import logging import uuid +from typing import ClassVar from unittest.mock import MagicMock import pytest @@ -35,9 +36,13 @@ def test_loghandlerconf_ers_property_raises_before_init() -> None: _ = conf.ERS -def test_loghandlerconf_init_ers_stream_sets_structure(monkeypatch: pytest.MonkeyPatch) -> None: +def test_loghandlerconf_init_ers_stream_sets_structure( + monkeypatch: pytest.MonkeyPatch, +) -> None: fake_oks = {"DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf(handlers=[HandlerType.Rich])} - monkeypatch.setattr(LogHandlerConf, "_get_oks_conf", staticmethod(lambda: fake_oks)) + monkeypatch.setattr( + LogHandlerConf, "_get_oks_conf", staticmethod(lambda: fake_oks) + ) conf = LogHandlerConf(init_ers=False) conf.init_ers_stream() @@ -123,7 +128,9 @@ def _fake_make(level_var: str) -> ERSPyLogHandlerConf: calls.append(level_var) return ERSPyLogHandlerConf(handlers=[HandlerType.Rich]) - monkeypatch.setattr(LogHandlerConf, "_make_ers_handler_conf", staticmethod(_fake_make)) + monkeypatch.setattr( + LogHandlerConf, "_make_ers_handler_conf", staticmethod(_fake_make) + ) conf = LogHandlerConf._get_oks_conf() @@ -133,36 +140,44 @@ def _fake_make(level_var: str) -> ERSPyLogHandlerConf: # demonstrator test_* parity integrated into handlerconf tests -def test_demo_test_handlerconf_runs_ers_flow_and_restores(monkeypatch: pytest.MonkeyPatch) -> None: +def test_demo_test_handlerconf_runs_ers_flow_and_restores( + monkeypatch: pytest.MonkeyPatch, +) -> None: logger = MagicMock(spec=logging.Logger) class FakeHC: - Base = {"handlers": {HandlerType.Stream}, "stream": StreamType.BASE} - Opmon = {"handlers": {HandlerType.Rich}, "stream": StreamType.OPMON} + Base: ClassVar[dict] = { + "handlers": {HandlerType.Stream}, + "stream": StreamType.BASE, + } + Opmon: ClassVar[dict] = { + "handlers": {HandlerType.Rich}, + "stream": StreamType.OPMON, + } def __init__(self, init_ers: bool = False) -> None: self._ers = { - "ers_handlers": {"DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf(handlers=[HandlerType.Rich])}, + "ers_handlers": { + "DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf( + handlers=[HandlerType.Rich] + ), + }, "stream": StreamType.ERS, } if init_ers: self.init_ers_stream() @property - def ERS(self) -> dict: + def ERS(self) -> dict: # noqa: N802 return self._ers def init_ers_stream(self) -> None: return None - restore_mock = MagicMock() - monkeypatch.setattr(demo, "LogHandlerConf", FakeHC) - monkeypatch.setattr(demo, "restore_original_envs", restore_mock) demo.test_handlerconf(logger) - restore_mock.assert_called_once() logger.warning.assert_called() logger.info.assert_called() @@ -185,7 +200,7 @@ def __init__(self, init_ers: bool = False) -> None: self._ers = {"stream": StreamType.ERS, "ers_handlers": {}} @property - def ERS(self) -> dict: + def ERS(self) -> dict: # noqa: N802 return self._ers monkeypatch.setattr(demo, "get_daq_logger", get_logger_mock) diff --git a/tests/logging/test_handlers.py b/tests/logging/test_handlers.py index 636ca4f..0b2284f 100644 --- a/tests/logging/test_handlers.py +++ b/tests/logging/test_handlers.py @@ -6,12 +6,12 @@ import pytest from daqpytools.apps import logging_demonstrator as demo +from daqpytools.logging import handlers as handlers_mod from daqpytools.logging.exceptions import ERSInitError, LoggerHandlerError from daqpytools.logging.filters import HandleIDFilter from daqpytools.logging.formatter import LoggingFormatter from daqpytools.logging.handlerconf import HandlerType from daqpytools.logging.rich_handler import FormattedRichHandler -from daqpytools.logging import handlers as handlers_mod @pytest.fixture @@ -27,8 +27,9 @@ def clean_logger() -> Iterator[logging.Logger]: logger.removeHandler(handler) try: handler.close() - except Exception: - pass + except Exception as e: + # Handler already closed or does not support close + del e logger.filters = [] logging.root.manager.loggerDict.pop(name, None) @@ -58,8 +59,9 @@ def parent_child_loggers() -> Iterator[tuple[logging.Logger, logging.Logger]]: logger.removeHandler(handler) try: handler.close() - except Exception: - pass + except Exception as e: + # Handler already closed or does not support close + del e logger.filters = [] logging.root.manager.loggerDict.pop(logger.name, None) @@ -68,16 +70,23 @@ def test_logger_has_handler_non_logger_returns_false() -> None: assert handlers_mod.logger_has_handler(MagicMock(), logging.StreamHandler) is False -def test_logger_has_handler_matches_non_stream_type(clean_logger: logging.Logger) -> None: +def test_logger_has_handler_matches_non_stream_type( + clean_logger: logging.Logger, +) -> None: handler = logging.NullHandler() clean_logger.addHandler(handler) - assert handlers_mod.logger_has_handler(clean_logger, logging.NullHandler) is True + assert ( + handlers_mod.logger_has_handler(clean_logger, logging.NullHandler) + is True + ) def test_logger_has_handler_matches_stream_by_target_stream( clean_logger: logging.Logger, ) -> None: - stdout_handler = logging.StreamHandler(handlers_mod.STDOUT_HANDLER_SPEC.target_stream) + stdout_handler = logging.StreamHandler( + handlers_mod.STDOUT_HANDLER_SPEC.target_stream + ) clean_logger.addHandler(stdout_handler) assert ( @@ -107,7 +116,10 @@ def test_ancestors_have_handlers_returns_false_when_disabled( parent_child_loggers: tuple[logging.Logger, logging.Logger], ) -> None: _, child = parent_child_loggers - assert handlers_mod.ancestors_have_handlers(child, False, logging.NullHandler) is False + assert ( + handlers_mod.ancestors_have_handlers(child, False, logging.NullHandler) + is False + ) def test_ancestors_have_handlers_rejects_root_logger() -> None: @@ -148,7 +160,10 @@ def test_ancestors_have_handlers_detects_parent_handler( parent, child = parent_child_loggers parent.addHandler(logging.NullHandler()) - assert handlers_mod.ancestors_have_handlers(child, True, logging.NullHandler) is True + assert ( + handlers_mod.ancestors_have_handlers(child, True, logging.NullHandler) + is True + ) def test_check_parent_handlers_raises_loggerhandlererror( @@ -165,10 +180,20 @@ def test_logger_or_ancestors_have_handler_checks_local_then_parent( parent_child_loggers: tuple[logging.Logger, logging.Logger], ) -> None: parent, child = parent_child_loggers - assert handlers_mod.logger_or_ancestors_have_handler(child, True, logging.NullHandler) is False + assert ( + handlers_mod.logger_or_ancestors_have_handler( + child, True, logging.NullHandler + ) + is False + ) parent.addHandler(logging.NullHandler()) - assert handlers_mod.logger_or_ancestors_have_handler(child, True, logging.NullHandler) is True + assert ( + handlers_mod.logger_or_ancestors_have_handler( + child, True, logging.NullHandler + ) + is True + ) def test_get_handler_specs_returns_expected_specs() -> None: @@ -180,7 +205,9 @@ def test_get_handler_specs_returns_expected_specs() -> None: assert len(handlers_mod.get_handler_specs(HandlerType.Protobufstream)) == 1 -def test_build_rich_handler_uses_get_width_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: +def test_build_rich_handler_uses_get_width_when_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: monkeypatch.setattr(handlers_mod, "get_width", lambda: 111) handler = handlers_mod._build_rich_handler() assert isinstance(handler, FormattedRichHandler) @@ -207,7 +234,9 @@ def test_build_file_handler_requires_path() -> None: handlers_mod._build_file_handler() -def test_build_file_handler_creates_handler_with_formatter(tmp_path: pytest.TempPathFactory) -> None: +def test_build_file_handler_creates_handler_with_formatter( + tmp_path: pytest.TempPathFactory, +) -> None: file_path = tmp_path / "test.log" handler = handlers_mod._build_file_handler(path=str(file_path)) assert isinstance(handler, logging.FileHandler) @@ -215,7 +244,9 @@ def test_build_file_handler_creates_handler_with_formatter(tmp_path: pytest.Temp handler.close() -def test_build_erskafka_handler_wraps_exception(monkeypatch: pytest.MonkeyPatch) -> None: +def test_build_erskafka_handler_wraps_exception( + monkeypatch: pytest.MonkeyPatch, +) -> None: def _raise(*args: object, **kwargs: object) -> None: del args, kwargs raise RuntimeError("boom") @@ -247,7 +278,9 @@ def __init__(self, **kwargs: object) -> None: assert handler.kwargs["app_name"] == "app_x" -def test_add_handler_adds_single_spec_and_handleidfilter(clean_logger: logging.Logger) -> None: +def test_add_handler_adds_single_spec_and_handleidfilter( + clean_logger: logging.Logger, +) -> None: handlers_mod.add_handler(clean_logger, HandlerType.Rich, use_parent_handlers=True) assert len(clean_logger.handlers) == 1 @@ -258,7 +291,9 @@ def test_add_handler_adds_single_spec_and_handleidfilter(clean_logger: logging.L ) -def test_add_handler_skips_when_matching_handler_exists(clean_logger: logging.Logger) -> None: +def test_add_handler_skips_when_matching_handler_exists( + clean_logger: logging.Logger, +) -> None: handlers_mod.add_handler(clean_logger, HandlerType.Rich, use_parent_handlers=True) handlers_mod.add_handler(clean_logger, HandlerType.Rich, use_parent_handlers=True) assert len(clean_logger.handlers) == 1 @@ -285,7 +320,9 @@ def test_add_handler_unknown_string_does_nothing(clean_logger: logging.Logger) - assert len(clean_logger.handlers) == 0 -def test_add_handler_uses_explicit_fallback_override(clean_logger: logging.Logger) -> None: +def test_add_handler_uses_explicit_fallback_override( + clean_logger: logging.Logger, +) -> None: override = {HandlerType.Unknown} handlers_mod.add_handler( clean_logger, @@ -302,15 +339,21 @@ def test_add_handler_uses_explicit_fallback_override(clean_logger: logging.Logge assert handler_filter.fallback_handlers == override -def test_add_handler_for_stream_adds_stdout_and_stderr(clean_logger: logging.Logger) -> None: +def test_add_handler_for_stream_adds_stdout_and_stderr( + clean_logger: logging.Logger, +) -> None: handlers_mod.add_handler(clean_logger, HandlerType.Stream, use_parent_handlers=True) stream_handlers = [ - handler for handler in clean_logger.handlers if isinstance(handler, logging.StreamHandler) + handler + for handler in clean_logger.handlers + if isinstance(handler, logging.StreamHandler) ] assert len(stream_handlers) == 2 -def test_add_handlers_from_types_stream_deduplicates(clean_logger: logging.Logger) -> None: +def test_add_handlers_from_types_stream_deduplicates( + clean_logger: logging.Logger, +) -> None: handlers_mod.add_handlers_from_types( clean_logger, {HandlerType.Stream, HandlerType.Lstdout, HandlerType.Lstderr}, @@ -318,7 +361,9 @@ def test_add_handlers_from_types_stream_deduplicates(clean_logger: logging.Logge fallback_handlers={HandlerType.Stream}, ) stream_handlers = [ - handler for handler in clean_logger.handlers if isinstance(handler, logging.StreamHandler) + handler + for handler in clean_logger.handlers + if isinstance(handler, logging.StreamHandler) ] assert len(stream_handlers) == 2 @@ -374,6 +419,7 @@ def test_demo_test_main_functions_emits_expected_levels() -> None: def test_demo_test_child_logger_builds_child_and_logs( monkeypatch: pytest.MonkeyPatch, + tmp_path: pytest.TempPathFactory, ) -> None: child_logger = MagicMock(spec=logging.Logger) get_logger_mock = MagicMock(return_value=child_logger) @@ -384,7 +430,7 @@ def test_demo_test_child_logger_builds_child_and_logs( log_level="INFO", disable_logger_inheritance=True, rich_handler=True, - file_handler_path="/tmp/demo.log", + file_handler_path=str(tmp_path / "demo.log"), stream_handlers=True, ) @@ -393,7 +439,7 @@ def test_demo_test_child_logger_builds_child_and_logs( log_level="INFO", use_parent_handlers=False, rich_handler=True, - file_handler_path="/tmp/demo.log", + file_handler_path=str(tmp_path / "demo.log"), stream_handlers=True, ) child_logger.debug.assert_called_once() @@ -429,10 +475,22 @@ def test_demo_test_handlertypes_routes_expected_extras() -> None: demo.test_handlertypes(logger) critical_calls = logger.critical.call_args_list - assert any(c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Rich] for c in critical_calls) - assert any(c.kwargs.get("extra", {}).get("handlers") == [HandlerType.File] for c in critical_calls) - assert any(c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Lstdout] for c in critical_calls) - assert any(c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Throttle] for c in critical_calls) + assert any( + c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Rich] + for c in critical_calls + ) + assert any( + c.kwargs.get("extra", {}).get("handlers") == [HandlerType.File] + for c in critical_calls + ) + assert any( + c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Lstdout] + for c in critical_calls + ) + assert any( + c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Throttle] + for c in critical_calls + ) assert any( c.kwargs.get("extra", {}).get("handlers") == [HandlerType.Rich, HandlerType.Protobufstream] diff --git a/tests/logging/test_routing.py b/tests/logging/test_routing.py index 1587c41..1ebb818 100644 --- a/tests/logging/test_routing.py +++ b/tests/logging/test_routing.py @@ -10,7 +10,11 @@ class _StrategyForHelper(AllowedHandlersStrategy): - def resolve(self, record: logging.LogRecord, fallback_handlers: set[object]) -> set[object] | None: + def resolve( + self, + record: logging.LogRecord, + fallback_handlers: set[object], + ) -> set[object] | None: del record, fallback_handlers return None @@ -19,7 +23,7 @@ def _record(level: int = logging.INFO) -> logging.LogRecord: return logging.LogRecord( name="test.routing", level=level, - pathname="/tmp/test_routing.py", + pathname="", lineno=10, msg="message", args=(), @@ -81,7 +85,9 @@ def test_ers_strategy_returns_none_when_level_conf_missing() -> None: strategy = ERSAllowedHandlersStrategy() record = _record(level=logging.ERROR) record.stream = StreamType.ERS - record.ers_handlers = {"DUNEDAQ_ERS_WARNING": ERSPyLogHandlerConf(handlers=[HandlerType.Rich])} + record.ers_handlers = { + "DUNEDAQ_ERS_WARNING": ERSPyLogHandlerConf(handlers=[HandlerType.Rich]) + } assert strategy.resolve(record, set()) is None @@ -103,7 +109,11 @@ class _FakeDefault: def __init__(self) -> None: self.called = False - def resolve(self, record: logging.LogRecord, fallback_handlers: set[object]) -> set[object] | None: + def resolve( + self, + record: logging.LogRecord, + fallback_handlers: set[object], + ) -> set[object] | None: del record, fallback_handlers self.called = True return {HandlerType.File} @@ -113,7 +123,11 @@ class _FakeERS: def __init__(self) -> None: self.called = False - def resolve(self, record: logging.LogRecord, fallback_handlers: set[object]) -> set[object] | None: + def resolve( + self, + record: logging.LogRecord, + fallback_handlers: set[object], + ) -> set[object] | None: del record, fallback_handlers self.called = True return {HandlerType.Throttle} @@ -122,7 +136,9 @@ def resolve(self, record: logging.LogRecord, fallback_handlers: set[object]) -> def test_streamaware_uses_ers_strategy_for_ers_stream() -> None: default = _FakeDefault() ers = _FakeERS() - strategy = StreamAwareAllowedHandlersStrategy(default_strategy=default, ers_strategy=ers) + strategy = StreamAwareAllowedHandlersStrategy( + default_strategy=default, ers_strategy=ers + ) record = _record(level=logging.ERROR) record.stream = StreamType.ERS @@ -137,7 +153,9 @@ def test_streamaware_uses_ers_strategy_for_ers_stream() -> None: def test_streamaware_uses_default_strategy_for_non_ers_stream() -> None: default = _FakeDefault() ers = _FakeERS() - strategy = StreamAwareAllowedHandlersStrategy(default_strategy=default, ers_strategy=ers) + strategy = StreamAwareAllowedHandlersStrategy( + default_strategy=default, ers_strategy=ers + ) record = _record(level=logging.INFO) diff --git a/tests/logging/test_specs.py b/tests/logging/test_specs.py index d22f7f1..74a57b5 100644 --- a/tests/logging/test_specs.py +++ b/tests/logging/test_specs.py @@ -1,5 +1,5 @@ -from dataclasses import FrozenInstanceError import logging +from dataclasses import FrozenInstanceError import pytest @@ -11,7 +11,9 @@ def _handler_factory(**kwargs: object) -> logging.Handler: return logging.NullHandler() -def _filter_factory(fallback_handlers: set[object], extras: dict[str, object]) -> logging.Filter: +def _filter_factory( + fallback_handlers: set[object], extras: dict[str, object] +) -> logging.Filter: del fallback_handlers, extras return logging.Filter() From b660f54944f1cb61f159389d0dda81c3d2185bd0 Mon Sep 17 00:00:00 2001 From: Emir Muhammad Date: Fri, 20 Mar 2026 17:00:07 +0100 Subject: [PATCH 5/6] add tests after refactor --- tests/logging/test_filters.py | 134 ++++++++++++++++++++++++++++++++++ tests/logging/test_logger.py | 68 ++++++++++++++++- 2 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 tests/logging/test_filters.py diff --git a/tests/logging/test_filters.py b/tests/logging/test_filters.py new file mode 100644 index 0000000..9bc6a0e --- /dev/null +++ b/tests/logging/test_filters.py @@ -0,0 +1,134 @@ +import logging +import uuid +from collections.abc import Iterator +from unittest.mock import MagicMock + +import pytest + +from daqpytools.logging.filters import ( + HandleIDFilter, + ThrottleFilter, + add_filter, + add_throttle_filter, + get_filter_spec, +) +from daqpytools.logging.handlerconf import HandlerType + + +def _record(level: int = logging.INFO) -> logging.LogRecord: + return logging.LogRecord( + name="test.filters", + level=level, + pathname="", + lineno=42, + msg="message", + args=(), + exc_info=None, + ) + + +@pytest.fixture +def clean_logger() -> Iterator[logging.Logger]: + name = f"test.filters.{uuid.uuid4()}" + logger = logging.getLogger(name) + logger.handlers = [] + logger.filters = [] + logger.propagate = False + logger.setLevel(logging.DEBUG) + yield logger + for handler in logger.handlers[:]: + logger.removeHandler(handler) + logger.filters = [] + logging.root.manager.loggerDict.pop(name, None) + + +def test_get_filter_spec_registry_lookup() -> None: + assert get_filter_spec(HandlerType.Throttle) is not None + assert get_filter_spec(HandlerType.Rich) is None + + +def test_add_filter_uses_default_fallback_from_spec( + clean_logger: logging.Logger, +) -> None: + add_filter(clean_logger, HandlerType.Throttle, fallback_handlers=None) + + assert len(clean_logger.filters) == 1 + logger_filter = clean_logger.filters[0] + assert isinstance(logger_filter, ThrottleFilter) + assert logger_filter.fallback_handlers == {HandlerType.Throttle} + + +def test_add_filter_uses_explicit_fallback_and_extras( + clean_logger: logging.Logger, +) -> None: + add_filter( + clean_logger, + HandlerType.Throttle, + fallback_handlers={HandlerType.Unknown}, + initial_treshold=7, + time_limit=11, + ) + + logger_filter = clean_logger.filters[0] + assert isinstance(logger_filter, ThrottleFilter) + assert logger_filter.fallback_handlers == {HandlerType.Unknown} + assert logger_filter.initial_threshold == 7 + assert logger_filter.time_limit == 11 + + +def test_add_throttle_filter_delegates_to_add_filter( + clean_logger: logging.Logger, + monkeypatch: pytest.MonkeyPatch, +) -> None: + add_filter_mock = MagicMock() + monkeypatch.setattr("daqpytools.logging.filters.add_filter", add_filter_mock) + + add_throttle_filter(clean_logger, fallback_handlers={HandlerType.Throttle}) + + add_filter_mock.assert_called_once_with( + clean_logger, + HandlerType.Throttle, + {HandlerType.Throttle}, + ) + + +def test_handleid_filter_matches_with_fallback_when_record_handlers_missing() -> None: + logger_filter = HandleIDFilter( + handler_id=HandlerType.Rich, + fallback_handlers={HandlerType.Rich}, + ) + + assert logger_filter.filter(_record()) is True + + +def test_handleid_filter_rejects_when_no_allowed_handlers() -> None: + logger_filter = HandleIDFilter( + handler_id=HandlerType.Rich, + fallback_handlers={HandlerType.Rich}, + ) + record = _record() + record.handlers = None + + assert logger_filter.filter(record) is False + + +def test_throttle_filter_passthrough_when_throttle_not_allowed() -> None: + logger_filter = ThrottleFilter( + fallback_handlers={HandlerType.Rich}, + initial_threshold=1, + time_limit=60, + ) + + assert logger_filter.filter(_record(level=logging.ERROR)) is True + + +def test_throttle_filter_suppresses_after_initial_threshold() -> None: + logger_filter = ThrottleFilter( + fallback_handlers={HandlerType.Throttle}, + initial_threshold=1, + time_limit=60, + ) + record = _record(level=logging.ERROR) + + assert logger_filter.filter(record) is True + assert logger_filter.filter(record) is False diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py index ce7db56..0df643a 100644 --- a/tests/logging/test_logger.py +++ b/tests/logging/test_logger.py @@ -5,8 +5,14 @@ import pytest from daqpytools.logging.exceptions import LoggerSetupError +from daqpytools.logging.handlerconf import ERSPyLogHandlerConf, HandlerType, ProtobufConf from daqpytools.logging.handlers import logger_or_ancestors_have_handler -from daqpytools.logging.logger import get_daq_logger, setup_root_logger +from daqpytools.logging.logger import ( + get_daq_logger, + setup_daq_ers_logger, + setup_root_logger, +) +from daqpytools.logging import logger as logger_mod test_logger_name = "test_logger" test_logger_child_name = f"{test_logger_name}.child" @@ -225,3 +231,63 @@ def test_logger_parent_walk_handles_mock_logger_cycle(): logger_or_ancestors_have_handler(fake_logger, True, logging.NullHandler) is False ) + + +def test_setup_daq_ers_logger_uses_single_protobuf_config( + monkeypatch: pytest.MonkeyPatch, +) -> None: + logger = logging.getLogger(f"test.logger.ers.{id(object())}") + add_handlers_mock = MagicMock() + monkeypatch.setattr(logger_mod, "add_handlers_from_types", add_handlers_mock) + + oks_conf = { + "DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf( + handlers=[HandlerType.Rich, HandlerType.Protobufstream], + protobufconf=ProtobufConf(url="host-a", port=30092), + ), + "DUNEDAQ_ERS_WARNING": ERSPyLogHandlerConf( + handlers=[HandlerType.Throttle], + protobufconf=ProtobufConf(url="host-a", port=30092), + ), + } + monkeypatch.setattr( + logger_mod.LogHandlerConf, + "_get_oks_conf", + staticmethod(lambda: oks_conf), + ) + + setup_daq_ers_logger(logger, ers_kafka_session="session-a", ers_app_name="app-a") + + add_handlers_mock.assert_called_once_with( + logger, + {HandlerType.Rich, HandlerType.Protobufstream, HandlerType.Throttle}, + use_parent_handlers=True, + fallback_handlers={HandlerType.Unknown}, + session_name="session-a", + ers_app_name="app-a", + address="host-a:30092", + ) + + +def test_setup_daq_ers_logger_rejects_multiple_protobuf_configs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + logger = logging.getLogger(f"test.logger.ers.multi.{id(object())}") + oks_conf = { + "DUNEDAQ_ERS_ERROR": ERSPyLogHandlerConf( + handlers=[HandlerType.Protobufstream], + protobufconf=ProtobufConf(url="host-a", port=30092), + ), + "DUNEDAQ_ERS_WARNING": ERSPyLogHandlerConf( + handlers=[HandlerType.Protobufstream], + protobufconf=ProtobufConf(url="host-b", port=30093), + ), + } + monkeypatch.setattr( + logger_mod.LogHandlerConf, + "_get_oks_conf", + staticmethod(lambda: oks_conf), + ) + + with pytest.raises(ValueError, match="Multiple protobufstream"): + setup_daq_ers_logger(logger, ers_kafka_session="session-a") From f4e7325b8c9bcaf012529bc37ccc0005b7d83501 Mon Sep 17 00:00:00 2001 From: Emir Muhammad Date: Fri, 20 Mar 2026 17:07:59 +0100 Subject: [PATCH 6/6] Test more exceptions --- tests/logging/test_exceptions.py | 26 ++++++++++++++++++++++++++ tests/logging/test_logger.py | 8 ++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/logging/test_exceptions.py b/tests/logging/test_exceptions.py index bf69d46..49d1c71 100644 --- a/tests/logging/test_exceptions.py +++ b/tests/logging/test_exceptions.py @@ -3,8 +3,12 @@ import pytest from daqpytools.logging.exceptions import ( + ERSEnvError, + ERSInitError, + LoggerConfigurationError, LoggerHandlerError, LoggerSetupError, + ProtobufFormatError, ) @@ -23,3 +27,25 @@ def test_exceptions(): assert str(exc_info.value) == ( "Constructing test_logger failed as: \nThe test made me do it :(" ) + + +def test_configuration_and_ers_related_exceptions() -> None: + config_path = "tests/logging/log_format.ini" + with pytest.raises(LoggerConfigurationError) as exc_info: + raise LoggerConfigurationError(config_path, "bad section") + assert f"Configuration file '{config_path}'" in str(exc_info.value) + assert "bad section" in str(exc_info.value) + + with pytest.raises(ERSEnvError) as exc_info: + raise ERSEnvError("DUNEDAQ_ERS_ERROR") + assert str(exc_info.value) == "The environment variable DUNEDAQ_ERS_ERROR is empty" + + with pytest.raises(ERSInitError) as exc_info: + raise ERSInitError("host:30092", "ers_stream") + assert "address='host:30092'" in str(exc_info.value) + assert "topic='ers_stream'" in str(exc_info.value) + + with pytest.raises(ProtobufFormatError) as exc_info: + raise ProtobufFormatError("protobufstream(bad)") + assert "protobufstream URLs must be formatted (url:port)." in str(exc_info.value) + assert "protobufstream(bad)" in str(exc_info.value) diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py index 0df643a..a7c8995 100644 --- a/tests/logging/test_logger.py +++ b/tests/logging/test_logger.py @@ -4,15 +4,19 @@ import pytest +from daqpytools.logging import logger as logger_mod from daqpytools.logging.exceptions import LoggerSetupError -from daqpytools.logging.handlerconf import ERSPyLogHandlerConf, HandlerType, ProtobufConf +from daqpytools.logging.handlerconf import ( + ERSPyLogHandlerConf, + HandlerType, + ProtobufConf, +) from daqpytools.logging.handlers import logger_or_ancestors_have_handler from daqpytools.logging.logger import ( get_daq_logger, setup_daq_ers_logger, setup_root_logger, ) -from daqpytools.logging import logger as logger_mod test_logger_name = "test_logger" test_logger_child_name = f"{test_logger_name}.child"