diff --git a/CHANGELOG.md b/CHANGELOG.md index e9da0f3fda..dcc898cbb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ENHANCEMENTS: * Pass OIDC vars directly to the devcontainer ([#4871](https://github.com/microsoft/AzureTRE/issues/4871)) BUG FIXES: +* Implement service bus consumer monitoring with heartbeat detection, automatic recovery, and /health endpoint integration to prevent operations getting stuck indefinitely ([#4464](https://github.com/microsoft/AzureTRE/issues/4464)) * Fix property substitution not occuring where there is only a main step in the pipeline ([#4824](https://github.com/microsoft/AzureTRE/issues/4824)) * Fix Mysql template ignored storage_mb ([#4846](https://github.com/microsoft/AzureTRE/issues/4846)) * Fix duplicate `TOPIC_SUBSCRIPTION_NAME` in `core/terraform/airlock/airlock_processor.tf` ([#4847](https://github.com/microsoft/AzureTRE/pull/4847)) diff --git a/api_app/_version.py b/api_app/_version.py index 6623c5202f..025f4c5d0b 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.25.14" +__version__ = "0.26.1" diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index 2cefe21266..c9674d6179 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -3,7 +3,7 @@ from core import credentials from models.schemas.status import HealthCheck, ServiceStatus, StatusEnum from resources import strings -from services.health_checker import create_resource_processor_status, create_state_store_status, create_service_bus_status +from services.health_checker import create_airlock_consumer_status, create_deployment_consumer_status, create_resource_processor_status, create_state_store_status, create_service_bus_status from services.logging import logger router = APIRouter() @@ -14,22 +14,28 @@ async def health_check(request: Request) -> HealthCheck: # The health endpoint checks the status of key components of the system. # Note that Resource Processor checks incur Azure management calls, so # calling this endpoint frequently may result in API throttling. + deployment_consumer = getattr(request.app.state, 'deployment_status_updater', None) + airlock_consumer = getattr(request.app.state, 'airlock_status_updater', None) + async with credentials.get_credential_async_context() as credential: - cosmos, sb, rp = await asyncio.gather( + cosmos, sb, rp, deploy, airlock = await asyncio.gather( create_state_store_status(), create_service_bus_status(credential), - create_resource_processor_status(credential) + create_resource_processor_status(credential), + create_deployment_consumer_status(deployment_consumer), + create_airlock_consumer_status(airlock_consumer), ) - cosmos_status, cosmos_message = cosmos - sb_status, sb_message = sb - rp_status, rp_message = rp - if cosmos_status == StatusEnum.not_ok or sb_status == StatusEnum.not_ok or rp_status == StatusEnum.not_ok: - logger.error(f'Cosmos Status: {cosmos_status}, message: {cosmos_message}') - logger.error(f'Service Bus Status: {sb_status}, message: {sb_message}') - logger.error(f'Resource Processor Status: {rp_status}, message: {rp_message}') - services = [ServiceStatus(service=strings.COSMOS_DB, status=cosmos_status, message=cosmos_message), - ServiceStatus(service=strings.SERVICE_BUS, status=sb_status, message=sb_message), - ServiceStatus(service=strings.RESOURCE_PROCESSOR, status=rp_status, message=rp_message)] + services = [ + ServiceStatus(service=strings.COSMOS_DB, status=cosmos[0], message=cosmos[1]), + ServiceStatus(service=strings.SERVICE_BUS, status=sb[0], message=sb[1]), + ServiceStatus(service=strings.RESOURCE_PROCESSOR, status=rp[0], message=rp[1]), + ServiceStatus(service=strings.DEPLOYMENT_STATUS_CONSUMER, status=deploy[0], message=deploy[1]), + ServiceStatus(service=strings.AIRLOCK_STATUS_CONSUMER, status=airlock[0], message=airlock[1]), + ] + + for svc in services: + if svc.status == StatusEnum.not_ok: + logger.error(f'{svc.service} Status: {svc.status}, message: {svc.message}') return HealthCheck(services=services) diff --git a/api_app/main.py b/api_app/main.py index 0bdc769141..f3f53fecc9 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -34,8 +34,12 @@ async def lifespan(app: FastAPI): airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() - asyncio.create_task(deploymentStatusUpdater.receive_messages()) - asyncio.create_task(airlockStatusUpdater.receive_messages()) + # Store consumer references on app.state so the /health endpoint can check their heartbeats + app.state.deployment_status_updater = deploymentStatusUpdater + app.state.airlock_status_updater = airlockStatusUpdater + + asyncio.create_task(deploymentStatusUpdater.supervisor_with_heartbeat_check()) + asyncio.create_task(airlockStatusUpdater.supervisor_with_heartbeat_check()) yield diff --git a/api_app/resources/strings.py b/api_app/resources/strings.py index c54a40ba52..07726b43ad 100644 --- a/api_app/resources/strings.py +++ b/api_app/resources/strings.py @@ -106,6 +106,12 @@ RESOURCE_PROCESSOR_GENERAL_ERROR_MESSAGE = "Resource Processor is not responding" RESOURCE_PROCESSOR_HEALTHY_MESSAGE = "HealthState/healthy" +# Service bus consumer status +DEPLOYMENT_STATUS_CONSUMER = "Deployment Status Consumer" +AIRLOCK_STATUS_CONSUMER = "Airlock Status Consumer" +CONSUMER_HEARTBEAT_STALE = "{} heartbeat is stale or missing" +CONSUMER_NOT_INITIALIZED = "{} has not been initialized" + # Error strings ACCESS_APP_IS_MISSING_ROLE = "The App is missing role" ACCESS_PLEASE_SUPPLY_CLIENT_ID = "Please supply the client_id for the AAD application" diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index a643404a86..ce3ad3c093 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -16,12 +16,13 @@ from models.domain.airlock_operations import StepResultStatusUpdateMessage from core import config, credentials from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer -class AirlockStatusUpdater(): +class AirlockStatusUpdater(ServiceBusConsumer): def __init__(self): - pass + super().__init__("airlock_status_updater") async def init_repos(self): self.airlock_request_repo = await AirlockRequestRepository.create() @@ -36,9 +37,13 @@ async def receive_messages(self): try: current_time = time.time() polling_count += 1 + + # Update heartbeat for supervisor monitoring + self.update_heartbeat() + # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: - logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue {polling_count} times in the last minute") + logger.info(f"{config.SERVICE_BUS_STEP_RESULT_QUEUE} queue polled {polling_count} times in the last minute") last_heartbeat_time = current_time polling_count = 0 @@ -64,13 +69,13 @@ async def receive_messages(self): # Timeout occurred whilst connecting to a session - this is expected and indicates no non-empty sessions are available logger.debug("No sessions for this process. Will look again...") - except ServiceBusConnectionError: + except ServiceBusConnectionError as e: # Occasionally there will be a transient / network-level error in connecting to SB. - logger.info("Unknown Service Bus connection error. Will retry...") + logger.warning(f"Service Bus connection error (will retry): {e}") except Exception as e: # Catch all other exceptions, log them via .exception to get the stack trace, and reconnect - logger.exception(f"Unknown exception. Will retry - {e}") + logger.exception(f"Unexpected error in message processing: {type(e).__name__}: {e}") async def process_message(self, msg): with tracer.start_as_current_span("process_message") as current_span: diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 41670464c7..8cf4d53659 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -1,7 +1,7 @@ -import asyncio import json import uuid import time +from typing import Dict, List, Any from pydantic import ValidationError, parse_obj_as @@ -21,11 +21,12 @@ from models.domain.operation import DeploymentStatusUpdateMessage, Operation, OperationStep, Status from resources import strings from services.logging import logger, tracer +from service_bus.service_bus_consumer import ServiceBusConsumer -class DeploymentStatusUpdater(): +class DeploymentStatusUpdater(ServiceBusConsumer): def __init__(self): - pass + super().__init__("deployment_status_updater") async def init_repos(self): self.operations_repo = await OperationRepository.create() @@ -33,9 +34,6 @@ async def init_repos(self): self.resource_template_repo = await ResourceTemplateRepository.create() self.resource_history_repo = await ResourceHistoryRepository.create() - def run(self, *args, **kwargs): - asyncio.run(self.receive_messages()) - async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): last_heartbeat_time = 0 @@ -45,9 +43,12 @@ async def receive_messages(self): try: current_time = time.time() polling_count += 1 + + # Update heartbeat for supervisor monitoring + self.update_heartbeat() # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: - logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue {polling_count} times in the last minute") + logger.info(f"{config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue polled {polling_count} times in the last minute") last_heartbeat_time = current_time polling_count = 0 @@ -73,15 +74,15 @@ async def receive_messages(self): # Timeout occurred whilst connecting to a session - this is expected and indicates no non-empty sessions are available logger.debug("No sessions for this process. Will look again...") - except ServiceBusConnectionError: + except ServiceBusConnectionError as e: # Occasionally there will be a transient / network-level error in connecting to SB. - logger.info("Unknown Service Bus connection error. Will retry...") + logger.warning(f"Service Bus connection error (will retry): {e}") except Exception as e: # Catch all other exceptions, log them via .exception to get the stack trace, and reconnect - logger.exception(f"Unknown exception. Will retry - {e}") + logger.exception(f"Unexpected error in message processing: {type(e).__name__}: {e}") - async def process_message(self, msg): + async def process_message(self, msg) -> bool: complete_message = False message = "" @@ -115,6 +116,11 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage try: # update the op operation = await self.operations_repo.get_operation_by_id(str(message.operationId)) + + # Add null safety for operation steps + if not operation.steps: + raise ValueError(f"Operation {message.operationId} has no steps") + step_to_update = None is_last_step = False @@ -128,7 +134,7 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage is_last_step = True if step_to_update is None: - raise f"Error finding step {message.stepId} in operation {message.operationId}" + raise ValueError(f"Step {message.stepId} not found in operation {message.operationId}") # update the step status step_to_update.status = message.status @@ -159,7 +165,8 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage # more steps in the op to do? if is_last_step is False: - assert current_step_index < (len(operation.steps) - 1) + if current_step_index >= len(operation.steps) - 1: + raise ValueError(f"Step index {current_step_index} is the last step in operation (has {len(operation.steps)} steps), but more steps were expected") next_step = operation.steps[current_step_index + 1] # catch any errors in updating the resource - maybe Cosmos / schema invalid etc, and report them back to the op @@ -255,7 +262,7 @@ def get_failure_status_for_action(self, action: RequestAction): return status - def create_updated_resource_document(self, resource: dict, message: DeploymentStatusUpdateMessage): + def create_updated_resource_document(self, resource: Dict[str, Any], message: DeploymentStatusUpdateMessage) -> Dict[str, Any]: """ Merge the outputs with the resource document to persist """ @@ -268,7 +275,7 @@ def create_updated_resource_document(self, resource: dict, message: DeploymentSt return resource - def convert_outputs_to_dict(self, outputs_list: [Output]): + def convert_outputs_to_dict(self, outputs_list: List[Output]) -> Dict[str, Any]: """ Convert a list of Porter outputs to a dictionary """ diff --git a/api_app/service_bus/service_bus_consumer.py b/api_app/service_bus/service_bus_consumer.py new file mode 100644 index 0000000000..34b6163451 --- /dev/null +++ b/api_app/service_bus/service_bus_consumer.py @@ -0,0 +1,95 @@ +import asyncio +import time + +from services.logging import logger + +# Configuration constants for monitoring intervals +HEARTBEAT_CHECK_INTERVAL_SECONDS = 60 +HEARTBEAT_STALENESS_THRESHOLD_SECONDS = 300 +RESTART_DELAY_SECONDS = 5 +MAX_RESTART_DELAY_SECONDS = 300 +SUPERVISOR_ERROR_DELAY_SECONDS = 30 + + +class ServiceBusConsumer: + + def __init__(self, consumer_name: str): + self.service_name = consumer_name.replace('_', ' ').title() + self._last_heartbeat: float = time.monotonic() + self._restart_delay: float = RESTART_DELAY_SECONDS + logger.info(f"Initializing {self.service_name}") + + def update_heartbeat(self): + self._last_heartbeat = time.monotonic() + + def check_heartbeat(self, max_age_seconds: int = HEARTBEAT_STALENESS_THRESHOLD_SECONDS) -> bool: + age = time.monotonic() - self._last_heartbeat + if age > max_age_seconds: + logger.warning(f"{self.service_name} heartbeat is {age:.1f}s old (threshold: {max_age_seconds}s)") + return False + return True + + async def _receive_messages_loop(self): + """Run receive_messages() in a loop with exponential backoff on failure.""" + while True: + try: + start_time = time.monotonic() + logger.info(f"Starting {self.service_name} receive_messages loop...") + await self.receive_messages() + logger.warning(f"{self.service_name} receive_messages() returned unexpectedly") + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(f"{self.service_name} receive_messages failed: {e}") + + # Reset backoff if the consumer ran long enough to be considered healthy + elapsed = time.monotonic() - start_time + if elapsed > self._restart_delay: + self._restart_delay = RESTART_DELAY_SECONDS + + logger.info(f"{self.service_name} restarting in {self._restart_delay:.0f}s...") + await asyncio.sleep(self._restart_delay) + self._restart_delay = min(self._restart_delay * 2, MAX_RESTART_DELAY_SECONDS) + + async def supervisor_with_heartbeat_check(self): + task = None + try: + while True: + try: + if task is None or task.done(): + if task and task.done(): + try: + await task + except Exception as e: + logger.exception(f"{self.service_name} task failed unexpectedly: {e}") + await asyncio.sleep(RESTART_DELAY_SECONDS) + + logger.info(f"Starting {self.service_name} task...") + task = asyncio.create_task(self._receive_messages_loop()) + self.update_heartbeat() + + await asyncio.sleep(HEARTBEAT_CHECK_INTERVAL_SECONDS) + + if not self.check_heartbeat(): + logger.warning(f"{self.service_name} heartbeat stale, restarting...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + task = None + self._restart_delay = RESTART_DELAY_SECONDS + except Exception as e: + logger.exception(f"{self.service_name} supervisor error: {e}") + await asyncio.sleep(SUPERVISOR_ERROR_DELAY_SECONDS) + finally: + if task and not task.done(): + logger.info(f"Cleaning up {self.service_name} task...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def receive_messages(self): + raise NotImplementedError("Subclasses must implement receive_messages()") diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index a4d53067b0..8b56757317 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple from azure.core import exceptions from azure.servicebus.aio import ServiceBusClient from azure.mgmt.compute.aio import ComputeManagementClient @@ -11,6 +11,7 @@ from core import config from models.schemas.status import StatusEnum from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer from services.logging import logger @@ -55,6 +56,22 @@ async def create_service_bus_status(credential) -> Tuple[StatusEnum, str]: return status, message +def create_consumer_status(consumer: Optional[ServiceBusConsumer], name: str) -> Tuple[StatusEnum, str]: + if consumer is None: + return StatusEnum.not_ok, strings.CONSUMER_NOT_INITIALIZED.format(name) + if consumer.check_heartbeat(): + return StatusEnum.ok, "" + return StatusEnum.not_ok, strings.CONSUMER_HEARTBEAT_STALE.format(name) + + +async def create_deployment_consumer_status(consumer: Optional[ServiceBusConsumer]) -> Tuple[StatusEnum, str]: + return create_consumer_status(consumer, strings.DEPLOYMENT_STATUS_CONSUMER) + + +async def create_airlock_consumer_status(consumer: Optional[ServiceBusConsumer]) -> Tuple[StatusEnum, str]: + return create_consumer_status(consumer, strings.AIRLOCK_STATUS_CONSUMER) + + async def create_resource_processor_status(credential) -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" diff --git a/api_app/services/logging.py b/api_app/services/logging.py index ad6966b6d8..195fbc82fb 100644 --- a/api_app/services/logging.py +++ b/api_app/services/logging.py @@ -1,10 +1,14 @@ import logging +import os from opentelemetry.instrumentation.logging import LoggingInstrumentor from opentelemetry import trace from azure.monitor.opentelemetry import configure_azure_monitor from core.config import APPLICATIONINSIGHTS_CONNECTION_STRING, LOGGING_LEVEL +# Standard log format with worker ID +LOG_FORMAT = '%(asctime)s - Worker %(worker_id)s - %(name)s - %(levelname)s - %(message)s' + UNWANTED_LOGGERS = [ "azure.core.pipeline.policies.http_logging_policy", "azure.eventhub._eventprocessor.event_processor", @@ -45,6 +49,24 @@ "urllib3.connectionpool" ] + +class WorkerIdFilter(logging.Filter): + """ + A filter that adds worker_id to all log records. + """ + + def __init__(self): + super().__init__() + # Get the process ID as a unique worker identifier + self.worker_id = os.getpid() + + def filter(self, record: logging.LogRecord) -> bool: + # Add worker_id as an attribute to the log record if not already set + if not hasattr(record, 'worker_id'): + record.worker_id = self.worker_id + return True + + logger = logging.getLogger("azuretre_api") tracer = trace.get_tracer("azuretre_api") @@ -57,6 +79,20 @@ def configure_loggers(): logging.getLogger(logger_name).setLevel(logging.CRITICAL) +def apply_worker_id_to_logger(logger_instance): + """ + Apply the worker ID filter to a logger instance. + """ + worker_filter = WorkerIdFilter() + logger_instance.addFilter(worker_filter) + + # Update handlers to include worker_id in the format + for handler in logger_instance.handlers: + if isinstance(handler, logging.StreamHandler): + formatter = logging.Formatter(LOG_FORMAT) + handler.setFormatter(formatter) + + def initialize_logging() -> logging.Logger: configure_loggers() @@ -88,9 +124,18 @@ def initialize_logging() -> logging.Logger: LoggingInstrumentor().instrument( set_logging_format=True, log_level=logging_level, - tracer_provider=tracer._real_tracer + tracer_provider=tracer._real_tracer, + log_format=LOG_FORMAT ) + # Set up a handler if none exists + if not logger.handlers: + handler = logging.StreamHandler() + logger.addHandler(handler) + + # Apply worker ID filter + apply_worker_id_to_logger(logger) + logger.info("Logging initialized with level: %s", LOGGING_LEVEL) return logger diff --git a/api_app/tests_ma/test_api/test_routes/test_health.py b/api_app/tests_ma/test_api/test_routes/test_health.py index 21e795b619..ab52fe87f6 100644 --- a/api_app/tests_ma/test_api/test_routes/test_health.py +++ b/api_app/tests_ma/test_api/test_routes/test_health.py @@ -1,9 +1,10 @@ import pytest from httpx import AsyncClient -from mock import patch +from mock import patch, MagicMock from models.schemas.status import StatusEnum from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer pytestmark = pytest.mark.asyncio @@ -48,3 +49,84 @@ async def test_health_response_contains_resource_processor_status(health_check_c response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) assert {"message": message, "service": strings.RESOURCE_PROCESSOR, "status": strings.OK} in response.json()["services"] + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_contains_consumer_statuses(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint includes deployment and airlock consumer status.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Simulate consumers stored on app.state with healthy heartbeats + mock_deployment_consumer = MagicMock(spec=ServiceBusConsumer) + mock_deployment_consumer.check_heartbeat.return_value = True + mock_airlock_consumer = MagicMock(spec=ServiceBusConsumer) + mock_airlock_consumer.check_heartbeat.return_value = True + app.state.deployment_status_updater = mock_deployment_consumer + app.state.airlock_status_updater = mock_airlock_consumer + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + assert {"message": "", "service": strings.DEPLOYMENT_STATUS_CONSUMER, "status": strings.OK} in services + assert {"message": "", "service": strings.AIRLOCK_STATUS_CONSUMER, "status": strings.OK} in services + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_reports_stale_consumer(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint reports not_ok when a consumer heartbeat is stale.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Simulate deployment consumer with stale heartbeat + mock_deployment_consumer = MagicMock(spec=ServiceBusConsumer) + mock_deployment_consumer.check_heartbeat.return_value = False + mock_airlock_consumer = MagicMock(spec=ServiceBusConsumer) + mock_airlock_consumer.check_heartbeat.return_value = True + app.state.deployment_status_updater = mock_deployment_consumer + app.state.airlock_status_updater = mock_airlock_consumer + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + deploy_svc = next(s for s in services if s["service"] == strings.DEPLOYMENT_STATUS_CONSUMER) + assert deploy_svc["status"] == strings.NOT_OK + assert deploy_svc["message"] == strings.CONSUMER_HEARTBEAT_STALE.format(strings.DEPLOYMENT_STATUS_CONSUMER) + + airlock_svc = next(s for s in services if s["service"] == strings.AIRLOCK_STATUS_CONSUMER) + assert airlock_svc["status"] == strings.OK + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_handles_missing_consumers(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint handles missing consumer references gracefully.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Remove consumer references from app.state if they exist + if hasattr(app.state, 'deployment_status_updater'): + delattr(app.state, 'deployment_status_updater') + if hasattr(app.state, 'airlock_status_updater'): + delattr(app.state, 'airlock_status_updater') + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + deploy_svc = next(s for s in services if s["service"] == strings.DEPLOYMENT_STATUS_CONSUMER) + assert deploy_svc["status"] == strings.NOT_OK + assert deploy_svc["message"] == strings.CONSUMER_NOT_INITIALIZED.format(strings.DEPLOYMENT_STATUS_CONSUMER) diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py new file mode 100644 index 0000000000..7d74a99af3 --- /dev/null +++ b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py @@ -0,0 +1,243 @@ +import asyncio +import time +import pytest +from unittest.mock import patch + +from service_bus.service_bus_consumer import ( + ServiceBusConsumer, + HEARTBEAT_STALENESS_THRESHOLD_SECONDS, + RESTART_DELAY_SECONDS, + MAX_RESTART_DELAY_SECONDS, +) + + +# Create a concrete implementation for testing +class MockConsumer(ServiceBusConsumer): + def __init__(self): + super().__init__("test_consumer") + self.receive_messages_called = False + + async def receive_messages(self): + self.receive_messages_called = True + await asyncio.sleep(0.1) + return + + +@pytest.mark.asyncio +async def test_init(): + """Test initialization of ServiceBusConsumer.""" + consumer = MockConsumer() + assert consumer.service_name == "Test Consumer" + assert consumer._restart_delay == RESTART_DELAY_SECONDS + assert consumer._last_heartbeat > 0 + + +@pytest.mark.asyncio +async def test_update_heartbeat(): + """Test updating heartbeat updates timestamp.""" + consumer = MockConsumer() + old_heartbeat = consumer._last_heartbeat + await asyncio.sleep(0.01) + consumer.update_heartbeat() + + assert consumer._last_heartbeat > old_heartbeat + + +@pytest.mark.asyncio +async def test_check_heartbeat_recent(): + """Test checking a recent heartbeat returns True.""" + consumer = MockConsumer() + assert consumer.check_heartbeat(max_age_seconds=300) is True + + +@pytest.mark.asyncio +async def test_check_heartbeat_stale(): + """Test checking a stale heartbeat returns False.""" + consumer = MockConsumer() + consumer._last_heartbeat = time.monotonic() - 400 + assert consumer.check_heartbeat(max_age_seconds=300) is False + + +@pytest.mark.asyncio +async def test_check_heartbeat_default_uses_constant(): + """Test that check_heartbeat default max_age_seconds uses the module constant.""" + import inspect + sig = inspect.signature(ServiceBusConsumer.check_heartbeat) + default = sig.parameters['max_age_seconds'].default + assert default == HEARTBEAT_STALENESS_THRESHOLD_SECONDS + + +@pytest.mark.asyncio +async def test_backoff_increases_on_consecutive_failures(): + """Test that restart delay increases exponentially on immediate failures.""" + consumer = MockConsumer() + + async def failing_receive(): + raise RuntimeError("Simulated failure") + + consumer.receive_messages = failing_receive + + sleep_calls = [] + call_count = 0 + + async def mock_sleep(duration): + nonlocal call_count + sleep_calls.append(duration) + call_count += 1 + if call_count >= 3: + raise asyncio.CancelledError() + + # Patch time.monotonic to always return the same value so elapsed time is 0 (immediate failure) + fixed_time = time.monotonic() + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.time.monotonic", return_value=fixed_time): + try: + await consumer._receive_messages_loop() + except asyncio.CancelledError: + pass + + assert sleep_calls[0] == RESTART_DELAY_SECONDS + assert sleep_calls[1] == RESTART_DELAY_SECONDS * 2 + assert sleep_calls[2] == RESTART_DELAY_SECONDS * 4 + + +@pytest.mark.asyncio +async def test_backoff_caps_at_maximum(): + """Test that restart delay caps at MAX_RESTART_DELAY_SECONDS.""" + consumer = MockConsumer() + consumer._restart_delay = MAX_RESTART_DELAY_SECONDS + + async def failing_receive(): + raise RuntimeError("Simulated failure") + + consumer.receive_messages = failing_receive + + sleep_calls = [] + + async def mock_sleep(duration): + sleep_calls.append(duration) + raise asyncio.CancelledError() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep): + try: + await consumer._receive_messages_loop() + except asyncio.CancelledError: + pass + + assert sleep_calls[0] == MAX_RESTART_DELAY_SECONDS + + +@pytest.mark.asyncio +async def test_supervisor_restarts_failed_task(): + """Test supervisor restarts the receive_messages task when it fails.""" + consumer = MockConsumer() + + task_create_calls = 0 + sleep_calls = [] + + class FailOnFirstDoneTask: + """A mock task that reports done() immediately to simulate task failure.""" + + def __init__(self): + nonlocal task_create_calls + task_create_calls += 1 + self._is_first_task = (task_create_calls == 1) + + def done(self): + # First task always reports done (crashed) + # Second task always reports running + return self._is_first_task + + def cancel(self): + pass + + def __await__(self): + async def _await(): + if self._is_first_task: + raise RuntimeError("Simulated task failure") + return None + return _await().__await__() + + iteration = 0 + + async def mock_sleep(duration): + nonlocal iteration + sleep_calls.append(duration) + iteration += 1 + if iteration >= 4: + raise KeyboardInterrupt("Test complete") + + consumer.check_heartbeat = lambda **kwargs: True + + def create_fail_task(coro): + coro.close() + return FailOnFirstDoneTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_fail_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + assert task_create_calls >= 2 + + +@pytest.mark.asyncio +async def test_supervisor_restarts_on_stale_heartbeat(): + """Test supervisor cancels and restarts task when heartbeat goes stale.""" + consumer = MockConsumer() + + heartbeat_calls = 0 + task_create_calls = 0 + task_cancel_calls = 0 + sleep_calls = [] + + def mock_check_heartbeat(**kwargs): + nonlocal heartbeat_calls + heartbeat_calls += 1 + if heartbeat_calls == 1: + return True # Heartbeat is fresh + elif heartbeat_calls == 2: + return False # Heartbeat is stale, should trigger restart + else: + raise KeyboardInterrupt("Test complete") + + async def mock_sleep(duration): + sleep_calls.append(duration) + + class MockTask: + def __init__(self): + nonlocal task_create_calls + task_create_calls += 1 + + def cancel(self): + nonlocal task_cancel_calls + task_cancel_calls += 1 + + def done(self): + return False + + def __await__(self): + async def _await(): + return None + return _await().__await__() + + consumer.check_heartbeat = mock_check_heartbeat + + def create_mock_task(coro): + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + assert heartbeat_calls >= 2 + assert task_create_calls >= 2 + assert task_cancel_calls >= 1 + assert 60 in sleep_calls + assert consumer._restart_delay == RESTART_DELAY_SECONDS diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py new file mode 100644 index 0000000000..bb7c4996e8 --- /dev/null +++ b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py @@ -0,0 +1,182 @@ +import asyncio +import time +import pytest +from unittest.mock import patch +from service_bus.service_bus_consumer import ( + ServiceBusConsumer, + RESTART_DELAY_SECONDS, + HEARTBEAT_CHECK_INTERVAL_SECONDS, + HEARTBEAT_STALENESS_THRESHOLD_SECONDS, + MAX_RESTART_DELAY_SECONDS, + SUPERVISOR_ERROR_DELAY_SECONDS, +) + + +# Create a concrete implementation for testing edge cases +class MockConsumerForEdgeCases(ServiceBusConsumer): + def __init__(self): + super().__init__("test_consumer_edge") + self.receive_messages_called = False + + async def receive_messages(self): + self.receive_messages_called = True + await asyncio.sleep(0.01) + return + + +@pytest.mark.asyncio +async def test_stale_heartbeat_detection(): + """Test that stale heartbeat is correctly detected.""" + consumer = MockConsumerForEdgeCases() + consumer._last_heartbeat = time.monotonic() - 400 + assert consumer.check_heartbeat(max_age_seconds=300) is False + + +@pytest.mark.asyncio +async def test_fresh_heartbeat_detection(): + """Test that fresh heartbeat is correctly detected.""" + consumer = MockConsumerForEdgeCases() + assert consumer.check_heartbeat(max_age_seconds=300) is True + + +@pytest.mark.asyncio +async def test_backoff_resets_after_long_running_receive(): + """Test that backoff resets when receive_messages ran longer than the current delay.""" + consumer = MockConsumerForEdgeCases() + consumer._restart_delay = 80 + + monotonic_values = iter([100.0, 200.0, 200.0]) # start=100, elapsed_check=200 (ran 100s > 80s delay), then next start + + async def long_running_receive(): + raise RuntimeError("Failure after running a while") + + consumer.receive_messages = long_running_receive + + async def mock_sleep(duration): + raise asyncio.CancelledError() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.time.monotonic", side_effect=monotonic_values): + try: + await consumer._receive_messages_loop() + except asyncio.CancelledError: + pass + + # Backoff should have reset to base since elapsed (100s) > old delay (80s) + assert consumer._restart_delay == RESTART_DELAY_SECONDS + + +@pytest.mark.asyncio +async def test_supervisor_cleanup_on_exception(): + """Test supervisor properly cleans up tasks when interrupted.""" + consumer = MockConsumerForEdgeCases() + + # Track task lifecycle + task_created = False + task_cancelled = False + + class MockTask: + def __init__(self): + nonlocal task_created + task_created = True + self.done_count = 0 + + def done(self): + return self.done_count > 0 + + def cancel(self): + nonlocal task_cancelled + task_cancelled = True + self.done_count = 1 + + def __await__(self): + async def _await(): + if task_cancelled: + raise asyncio.CancelledError() + return None + return _await().__await__() + + # Mock to trigger cleanup after one iteration + call_count = 0 + + async def mock_sleep(duration): + nonlocal call_count + call_count += 1 + if call_count >= 2: # Trigger cleanup after heartbeat check + raise KeyboardInterrupt("Test cleanup") + + def create_mock_task(coro): + # Close the coroutine to avoid "coroutine never awaited" warning + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task), \ + patch.object(consumer, "check_heartbeat", return_value=True): + + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + # Intentionally ignore KeyboardInterrupt to test cleanup logic after interruption + pass + + # Verify task was created and cancelled during cleanup + assert task_created, "Task should have been created" + assert task_cancelled, "Task should have been cancelled during cleanup" + + +def test_restart_delay_configuration(): + """Test that configuration constants exist and have reasonable values.""" + assert RESTART_DELAY_SECONDS > 0 + assert RESTART_DELAY_SECONDS <= 10 + assert MAX_RESTART_DELAY_SECONDS >= RESTART_DELAY_SECONDS + assert MAX_RESTART_DELAY_SECONDS <= 600 + assert HEARTBEAT_CHECK_INTERVAL_SECONDS > 0 + assert HEARTBEAT_STALENESS_THRESHOLD_SECONDS > HEARTBEAT_CHECK_INTERVAL_SECONDS + assert SUPERVISOR_ERROR_DELAY_SECONDS > 0 + + +@pytest.mark.asyncio +async def test_supervisor_resets_backoff_on_stale_heartbeat_restart(): + """Test that supervisor resets backoff when restarting due to stale heartbeat.""" + consumer = MockConsumerForEdgeCases() + consumer._restart_delay = 160 + + heartbeat_calls = 0 + + def mock_check_heartbeat(**kwargs): + nonlocal heartbeat_calls + heartbeat_calls += 1 + if heartbeat_calls == 1: + return False + raise KeyboardInterrupt("Test complete") + + async def mock_sleep(duration): + pass + + class MockTask: + def done(self): + return False + + def cancel(self): + pass + + def __await__(self): + async def _await(): + return None + return _await().__await__() + + consumer.check_heartbeat = mock_check_heartbeat + + def create_mock_task(coro): + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + assert consumer._restart_delay == RESTART_DELAY_SECONDS