From 4da58504a2d6ee36a8c0f30001e52c29da90a65a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Jun 2025 10:53:44 +0000 Subject: [PATCH 01/20] Initial plan From 6cd0b5f0fc9a8818463ba268cbfc54abfd15d15a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Jun 2025 11:06:33 +0000 Subject: [PATCH 02/20] Add restart mechanism to deployment status updater to fix stuck operations Co-authored-by: marrobi <17089773+marrobi@users.noreply.github.com> --- api_app/main.py | 2 +- .../service_bus/deployment_status_updater.py | 14 +++++++- .../test_deployment_status_update.py | 36 +++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/api_app/main.py b/api_app/main.py index 0bdc769141..fb6a826c19 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -34,7 +34,7 @@ async def lifespan(app: FastAPI): airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() - asyncio.create_task(deploymentStatusUpdater.receive_messages()) + asyncio.create_task(deploymentStatusUpdater.receive_messages_with_restart_check()) asyncio.create_task(airlockStatusUpdater.receive_messages()) yield diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 4bac477754..a861eb0615 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -33,7 +33,19 @@ async def init_repos(self): self.resource_history_repo = await ResourceHistoryRepository.create() def run(self, *args, **kwargs): - asyncio.run(self.receive_messages()) + asyncio.run(self.receive_messages_with_restart_check()) + + async def receive_messages_with_restart_check(self): + """ + Continuously run the receive_messages method, restarting it if it stops unexpectedly. + """ + while True: + try: + logger.info("Starting the receive_messages loop...") + await self.receive_messages() + except Exception as e: + logger.exception(f"receive_messages stopped unexpectedly. Restarting... - {e}") + await asyncio.sleep(5) async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index db80c5b1f7..0fbd46e2d2 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -413,3 +413,39 @@ async def test_convert_outputs_to_dict(): 'list2': ['one', 'two'] } assert status_updater.convert_outputs_to_dict(deployment_status_update_message.outputs) == expected_result + + +@patch('service_bus.deployment_status_updater.asyncio.sleep') +@patch('services.logging.logger.exception') +@patch('services.logging.logger.info') +async def test_receive_messages_with_restart_check_restarts_on_exception(mock_logger_info, mock_logger_exception, mock_sleep): + """Test that receive_messages_with_restart_check properly restarts when receive_messages fails""" + status_updater = DeploymentStatusUpdater() + + # Mock receive_messages to fail once, then succeed (stopping the loop) + call_count = 0 + + async def mock_receive_messages(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("Test exception") + # Second call succeeds but we need to stop the loop somehow + # For testing purposes, we'll just raise a different exception to break the loop + raise KeyboardInterrupt("Test interrupt to stop loop") + + status_updater.receive_messages = mock_receive_messages + + # Test that the restart mechanism works + try: + await status_updater.receive_messages_with_restart_check() + except KeyboardInterrupt: + pass # Expected to stop the loop + + # Verify the restart mechanism worked + assert call_count == 2, "receive_messages should have been called twice (once failed, once succeeded)" + + # Verify logging calls + mock_logger_info.assert_called_with("Starting the receive_messages loop...") + mock_logger_exception.assert_called_once_with("receive_messages stopped unexpectedly. Restarting... - Test exception") + mock_sleep.assert_called_once_with(5) From c7072b9206b84f3a2eb65d9e6df4843bcd9ab40e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Jun 2025 21:00:13 +0000 Subject: [PATCH 03/20] Add heartbeat monitoring to supervisor function for stuck process detection Co-authored-by: marrobi <17089773+marrobi@users.noreply.github.com> --- resource_processor/tests_rp/test_runner.py | 50 ++++++++++++--- resource_processor/vmss_porter/runner.py | 71 +++++++++++++++++++++- 2 files changed, 111 insertions(+), 10 deletions(-) diff --git a/resource_processor/tests_rp/test_runner.py b/resource_processor/tests_rp/test_runner.py index 7c1dc6d2e4..ff2040af19 100644 --- a/resource_processor/tests_rp/test_runner.py +++ b/resource_processor/tests_rp/test_runner.py @@ -1,8 +1,9 @@ import json -from unittest.mock import patch, AsyncMock, Mock +from unittest.mock import patch, AsyncMock, Mock, mock_open import pytest from resource_processor.vmss_porter.runner import ( - set_up_config, receive_message, invoke_porter_action, get_porter_outputs, check_runners, runner + set_up_config, receive_message, invoke_porter_action, get_porter_outputs, check_runners, runner, + update_heartbeat, check_process_heartbeat ) from azure.servicebus.aio import ServiceBusClient from azure.servicebus import ServiceBusSessionFilter @@ -59,7 +60,7 @@ async def test_runner(mock_receive_message, mock_service_bus_client, mock_defaul mock_default_credential.assert_called_once_with('test_msi_id') mock_service_bus_client.assert_called_once_with("test_namespace", mock_credential) - mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config) + mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config, 0) @pytest.mark.asyncio @@ -74,7 +75,7 @@ async def test_runner_no_msi_id(mock_receive_message, mock_service_bus_client, m mock_default_credential.assert_called_once_with(None) mock_service_bus_client.assert_called_once_with("test_namespace", mock_credential) - mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config) + mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config, 0) @pytest.mark.asyncio @@ -91,7 +92,7 @@ async def test_runner_exception(mock_receive_message, mock_service_bus_client, m mock_default_credential.assert_called_once_with('test_msi_id') mock_service_bus_client.assert_called_once_with("test_namespace", mock_credential) - mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config) + mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config, 0) @pytest.mark.asyncio @@ -113,7 +114,7 @@ async def test_receive_message(mock_invoke_porter_action, mock_service_bus_clien config = {"resource_request_queue": "test_queue"} - await receive_message(mock_service_bus_client_instance, config, keep_running=run_once) + await receive_message(mock_service_bus_client_instance, config, 0, keep_running=run_once) mock_receiver.complete_message.assert_called_once() mock_service_bus_client_instance.get_queue_receiver.assert_called_once_with(queue_name="test_queue", max_wait_time=1, session_id=ServiceBusSessionFilter.NEXT_AVAILABLE) @@ -138,7 +139,7 @@ async def test_receive_message_unknown_exception(mock_auto_lock_renewer, mock_se config = {"resource_request_queue": "test_queue"} with patch("resource_processor.vmss_porter.runner.receive_message", side_effect=Exception("Test Exception")): - await receive_message(mock_service_bus_client_instance, config, keep_running=run_once) + await receive_message(mock_service_bus_client_instance, config, 0, keep_running=run_once) mock_logger.exception.assert_any_call("Unknown exception. Will retry...") @@ -282,3 +283,38 @@ async def test_check_runners(_): await check_runners(processes, mock_httpserver, keep_running=run_once) mock_httpserver.kill.assert_called_once() + + +@patch("resource_processor.vmss_porter.runner.time.time", return_value=1234567890.0 + 100) # 100 seconds later +@patch("resource_processor.vmss_porter.runner.os.path.exists", return_value=True) +@patch("resource_processor.vmss_porter.runner.open", new_callable=mock_open, read_data="1234567890.0") +def test_check_process_heartbeat_recent(mock_file, mock_exists, mock_time): + """Test checking a recent heartbeat.""" + result = check_process_heartbeat(0, max_age_seconds=300) + assert result is True + + +@patch("resource_processor.vmss_porter.runner.time.time", return_value=1234567890.0 + 400) # 400 seconds later +@patch("resource_processor.vmss_porter.runner.os.path.exists", return_value=True) +@patch("resource_processor.vmss_porter.runner.open", new_callable=mock_open, read_data="1234567890.0") +def test_check_process_heartbeat_stale(mock_file, mock_exists, mock_time): + """Test checking a stale heartbeat.""" + result = check_process_heartbeat(0, max_age_seconds=300) + assert result is False + + +@patch("resource_processor.vmss_porter.runner.os.path.exists", return_value=False) +def test_check_process_heartbeat_no_file(mock_exists): + """Test checking heartbeat when file doesn't exist.""" + result = check_process_heartbeat(0) + assert result is False + + +@patch("resource_processor.vmss_porter.runner.time.time", return_value=1234567890.0) +@patch("resource_processor.vmss_porter.runner.open", new_callable=mock_open) +def test_update_heartbeat(mock_file, mock_time): + """Test updating heartbeat.""" + update_heartbeat(0) + mock_file.assert_called_once_with("/tmp/resource_processor_heartbeat_0.txt", 'w') + handle = mock_file.return_value.__enter__.return_value + handle.write.assert_called_once_with("1234567890.0") diff --git a/resource_processor/vmss_porter/runner.py b/resource_processor/vmss_porter/runner.py index 21c5d424ee..75d96807bf 100644 --- a/resource_processor/vmss_porter/runner.py +++ b/resource_processor/vmss_porter/runner.py @@ -5,6 +5,7 @@ import asyncio import logging import sys +import os from helpers.commands import azure_acr_login_command, azure_login_command, build_porter_command, build_porter_command_for_outputs, apply_porter_credentials_sets_command from shared.config import get_config from helpers.httpserver import start_server @@ -38,7 +39,19 @@ async def default_credentials(msi_id): await credential.close() -async def receive_message(service_bus_client, config: dict, keep_running=lambda: True): +def update_heartbeat(process_number: int): + """ + Update heartbeat file for this process + """ + heartbeat_file = f"/tmp/resource_processor_heartbeat_{process_number}.txt" + try: + with open(heartbeat_file, 'w') as f: + f.write(str(time.time())) + except Exception as e: + logger.warning(f"Failed to update heartbeat for process {process_number}: {e}") + + +async def receive_message(service_bus_client, config: dict, process_number: int, keep_running=lambda: True): """ This method is run per process. Each process will connect to service bus and try to establish a session. If messages are there, the process will continue to receive all the messages associated with that session. @@ -52,6 +65,10 @@ async def receive_message(service_bus_client, config: dict, keep_running=lambda: try: current_time = time.time() polling_count += 1 + + # Update heartbeat file for supervisor monitoring + update_heartbeat(process_number) + # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: logger.info(f"Queue reader heartbeat: Polled for sessions {polling_count} times in the last minute") @@ -274,7 +291,28 @@ async def runner(process_number: int, config: dict): with tracer.start_as_current_span(process_number): async with default_credentials(config["vmss_msi_id"]) as credential: service_bus_client = ServiceBusClient(config["service_bus_namespace"], credential) - await receive_message(service_bus_client, config) + await receive_message(service_bus_client, config, process_number) + + +def check_process_heartbeat(process_number: int, max_age_seconds: int = 300) -> bool: + """ + Check if a process heartbeat is recent enough + """ + heartbeat_file = f"/tmp/resource_processor_heartbeat_{process_number}.txt" + try: + if not os.path.exists(heartbeat_file): + return False + + with open(heartbeat_file, 'r') as f: + heartbeat_time = float(f.read().strip()) + + current_time = time.time() + age = current_time - heartbeat_time + + return age <= max_age_seconds + except (ValueError, IOError) as e: + logger.warning(f"Failed to read heartbeat for process {process_number}: {e}") + return False async def check_runners(processes: list, httpserver: Process, keep_running=lambda: True): @@ -282,9 +320,36 @@ async def check_runners(processes: list, httpserver: Process, keep_running=lambd while keep_running(): await asyncio.sleep(30) - if all(not process.is_alive() for process in processes): + + # Check if all processes are alive + all_dead = all(not process.is_alive() for process in processes) + if all_dead: logger.error("All runner processes have failed!") httpserver.kill() + return + + # Check heartbeats for alive processes + stale_processes = [] + for i, process in enumerate(processes): + if process.is_alive() and not check_process_heartbeat(i): + logger.warning(f"Process {i} appears to be stuck (no heartbeat update)") + stale_processes.append((i, process)) + + # Restart stale processes + for process_num, process in stale_processes: + logger.warning(f"Terminating and restarting stuck process {process_num}") + process.terminate() + process.join(timeout=10) # Wait up to 10 seconds for graceful termination + if process.is_alive(): + logger.warning(f"Force killing process {process_num}") + process.kill() + + # Start new process + logger.info(f"Restarting process {process_num}") + config = set_up_config() # Get fresh config + new_process = Process(target=lambda pnum=process_num: asyncio.run(runner(pnum, config))) + processes[process_num] = new_process + new_process.start() if __name__ == "__main__": From 202e72602ecf575cb2a5692aff032ed057c13ff6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Jun 2025 21:16:50 +0000 Subject: [PATCH 04/20] Move heartbeat monitoring from resource processor to deployment status updater Co-authored-by: marrobi <17089773+marrobi@users.noreply.github.com> --- api_app/main.py | 2 +- .../service_bus/deployment_status_updater.py | 71 ++++++++++++++++++- .../test_deployment_status_update.py | 43 +++++++++++ resource_processor/tests_rp/test_runner.py | 50 ++----------- resource_processor/vmss_porter/runner.py | 71 +------------------ 5 files changed, 124 insertions(+), 113 deletions(-) diff --git a/api_app/main.py b/api_app/main.py index fb6a826c19..c0769c753d 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -34,7 +34,7 @@ async def lifespan(app: FastAPI): airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() - asyncio.create_task(deploymentStatusUpdater.receive_messages_with_restart_check()) + asyncio.create_task(deploymentStatusUpdater.supervisor_with_heartbeat_check()) asyncio.create_task(airlockStatusUpdater.receive_messages()) yield diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 27d6ac493b..c45ffaeded 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -2,6 +2,7 @@ import json import uuid import time +import os from pydantic import ValidationError, parse_obj_as @@ -25,7 +26,7 @@ class DeploymentStatusUpdater(): def __init__(self): - pass + self.heartbeat_file = "/tmp/deployment_status_updater_heartbeat.txt" async def init_repos(self): self.operations_repo = await OperationRepository.create() @@ -36,6 +37,35 @@ async def init_repos(self): def run(self, *args, **kwargs): asyncio.run(self.receive_messages_with_restart_check()) + def update_heartbeat(self): + """ + Update heartbeat file for monitoring + """ + try: + with open(self.heartbeat_file, 'w') as f: + f.write(str(time.time())) + except Exception as e: + logger.warning(f"Failed to update heartbeat: {e}") + + def check_heartbeat(self, max_age_seconds: int = 300) -> bool: + """ + Check if the heartbeat is recent enough + """ + try: + if not os.path.exists(self.heartbeat_file): + return False + + with open(self.heartbeat_file, 'r') as f: + heartbeat_time = float(f.read().strip()) + + current_time = time.time() + age = current_time - heartbeat_time + + return age <= max_age_seconds + except (ValueError, IOError) as e: + logger.warning(f"Failed to read heartbeat: {e}") + return False + async def receive_messages_with_restart_check(self): """ Continuously run the receive_messages method, restarting it if it stops unexpectedly. @@ -48,6 +78,41 @@ async def receive_messages_with_restart_check(self): logger.exception(f"receive_messages stopped unexpectedly. Restarting... - {e}") await asyncio.sleep(5) + async def supervisor_with_heartbeat_check(self): + """ + Supervisor function that monitors the heartbeat and restarts if stuck. + """ + task = None + while True: + try: + # Start the receive_messages task if not running + if task is None or task.done(): + if task and task.done(): + try: + await task # Check for any exception + except Exception as e: + logger.exception(f"receive_messages task failed: {e}") + + logger.info("Starting receive_messages task...") + task = asyncio.create_task(self.receive_messages()) + + # Wait before checking heartbeat + await asyncio.sleep(60) # Check every minute + + # Check if heartbeat is stale + if not self.check_heartbeat(max_age_seconds=300): # 5 minutes max age + logger.warning("Heartbeat is stale, restarting receive_messages task...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + task = None + + except Exception as e: + logger.exception(f"Supervisor error: {e}") + await asyncio.sleep(30) + async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): last_heartbeat_time = 0 @@ -57,6 +122,10 @@ async def receive_messages(self): try: current_time = time.time() polling_count += 1 + + # Update heartbeat file for supervisor monitoring + self.update_heartbeat() + # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue {polling_count} times in the last minute") diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index 0fbd46e2d2..1384eacf91 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -449,3 +449,46 @@ async def mock_receive_messages(): mock_logger_info.assert_called_with("Starting the receive_messages loop...") mock_logger_exception.assert_called_once_with("receive_messages stopped unexpectedly. Restarting... - Test exception") mock_sleep.assert_called_once_with(5) + + +@patch("service_bus.deployment_status_updater.time.time", return_value=1234567890.0 + 100) # 100 seconds later +@patch("service_bus.deployment_status_updater.os.path.exists", return_value=True) +@patch("builtins.open", create=True) +async def test_check_heartbeat_recent(mock_open, mock_exists, mock_time): + """Test checking a recent heartbeat.""" + mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" + + status_updater = DeploymentStatusUpdater() + result = status_updater.check_heartbeat(max_age_seconds=300) + assert result is True + + +@patch("service_bus.deployment_status_updater.time.time", return_value=1234567890.0 + 400) # 400 seconds later +@patch("service_bus.deployment_status_updater.os.path.exists", return_value=True) +@patch("builtins.open", create=True) +async def test_check_heartbeat_stale(mock_open, mock_exists, mock_time): + """Test checking a stale heartbeat.""" + mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" + + status_updater = DeploymentStatusUpdater() + result = status_updater.check_heartbeat(max_age_seconds=300) + assert result is False + + +@patch("service_bus.deployment_status_updater.os.path.exists", return_value=False) +async def test_check_heartbeat_no_file(mock_exists): + """Test checking heartbeat when file doesn't exist.""" + status_updater = DeploymentStatusUpdater() + result = status_updater.check_heartbeat() + assert result is False + + +@patch("service_bus.deployment_status_updater.time.time", return_value=1234567890.0) +@patch("builtins.open", create=True) +async def test_update_heartbeat(mock_open, mock_time): + """Test updating heartbeat.""" + status_updater = DeploymentStatusUpdater() + status_updater.update_heartbeat() + + mock_open.assert_called_once_with("/tmp/deployment_status_updater_heartbeat.txt", 'w') + mock_open.return_value.__enter__.return_value.write.assert_called_once_with("1234567890.0") diff --git a/resource_processor/tests_rp/test_runner.py b/resource_processor/tests_rp/test_runner.py index ff2040af19..7c1dc6d2e4 100644 --- a/resource_processor/tests_rp/test_runner.py +++ b/resource_processor/tests_rp/test_runner.py @@ -1,9 +1,8 @@ import json -from unittest.mock import patch, AsyncMock, Mock, mock_open +from unittest.mock import patch, AsyncMock, Mock import pytest from resource_processor.vmss_porter.runner import ( - set_up_config, receive_message, invoke_porter_action, get_porter_outputs, check_runners, runner, - update_heartbeat, check_process_heartbeat + set_up_config, receive_message, invoke_porter_action, get_porter_outputs, check_runners, runner ) from azure.servicebus.aio import ServiceBusClient from azure.servicebus import ServiceBusSessionFilter @@ -60,7 +59,7 @@ async def test_runner(mock_receive_message, mock_service_bus_client, mock_defaul mock_default_credential.assert_called_once_with('test_msi_id') mock_service_bus_client.assert_called_once_with("test_namespace", mock_credential) - mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config, 0) + mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config) @pytest.mark.asyncio @@ -75,7 +74,7 @@ async def test_runner_no_msi_id(mock_receive_message, mock_service_bus_client, m mock_default_credential.assert_called_once_with(None) mock_service_bus_client.assert_called_once_with("test_namespace", mock_credential) - mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config, 0) + mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config) @pytest.mark.asyncio @@ -92,7 +91,7 @@ async def test_runner_exception(mock_receive_message, mock_service_bus_client, m mock_default_credential.assert_called_once_with('test_msi_id') mock_service_bus_client.assert_called_once_with("test_namespace", mock_credential) - mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config, 0) + mock_receive_message.assert_called_once_with(mock_service_bus_client_instance, config) @pytest.mark.asyncio @@ -114,7 +113,7 @@ async def test_receive_message(mock_invoke_porter_action, mock_service_bus_clien config = {"resource_request_queue": "test_queue"} - await receive_message(mock_service_bus_client_instance, config, 0, keep_running=run_once) + await receive_message(mock_service_bus_client_instance, config, keep_running=run_once) mock_receiver.complete_message.assert_called_once() mock_service_bus_client_instance.get_queue_receiver.assert_called_once_with(queue_name="test_queue", max_wait_time=1, session_id=ServiceBusSessionFilter.NEXT_AVAILABLE) @@ -139,7 +138,7 @@ async def test_receive_message_unknown_exception(mock_auto_lock_renewer, mock_se config = {"resource_request_queue": "test_queue"} with patch("resource_processor.vmss_porter.runner.receive_message", side_effect=Exception("Test Exception")): - await receive_message(mock_service_bus_client_instance, config, 0, keep_running=run_once) + await receive_message(mock_service_bus_client_instance, config, keep_running=run_once) mock_logger.exception.assert_any_call("Unknown exception. Will retry...") @@ -283,38 +282,3 @@ async def test_check_runners(_): await check_runners(processes, mock_httpserver, keep_running=run_once) mock_httpserver.kill.assert_called_once() - - -@patch("resource_processor.vmss_porter.runner.time.time", return_value=1234567890.0 + 100) # 100 seconds later -@patch("resource_processor.vmss_porter.runner.os.path.exists", return_value=True) -@patch("resource_processor.vmss_porter.runner.open", new_callable=mock_open, read_data="1234567890.0") -def test_check_process_heartbeat_recent(mock_file, mock_exists, mock_time): - """Test checking a recent heartbeat.""" - result = check_process_heartbeat(0, max_age_seconds=300) - assert result is True - - -@patch("resource_processor.vmss_porter.runner.time.time", return_value=1234567890.0 + 400) # 400 seconds later -@patch("resource_processor.vmss_porter.runner.os.path.exists", return_value=True) -@patch("resource_processor.vmss_porter.runner.open", new_callable=mock_open, read_data="1234567890.0") -def test_check_process_heartbeat_stale(mock_file, mock_exists, mock_time): - """Test checking a stale heartbeat.""" - result = check_process_heartbeat(0, max_age_seconds=300) - assert result is False - - -@patch("resource_processor.vmss_porter.runner.os.path.exists", return_value=False) -def test_check_process_heartbeat_no_file(mock_exists): - """Test checking heartbeat when file doesn't exist.""" - result = check_process_heartbeat(0) - assert result is False - - -@patch("resource_processor.vmss_porter.runner.time.time", return_value=1234567890.0) -@patch("resource_processor.vmss_porter.runner.open", new_callable=mock_open) -def test_update_heartbeat(mock_file, mock_time): - """Test updating heartbeat.""" - update_heartbeat(0) - mock_file.assert_called_once_with("/tmp/resource_processor_heartbeat_0.txt", 'w') - handle = mock_file.return_value.__enter__.return_value - handle.write.assert_called_once_with("1234567890.0") diff --git a/resource_processor/vmss_porter/runner.py b/resource_processor/vmss_porter/runner.py index 75d96807bf..21c5d424ee 100644 --- a/resource_processor/vmss_porter/runner.py +++ b/resource_processor/vmss_porter/runner.py @@ -5,7 +5,6 @@ import asyncio import logging import sys -import os from helpers.commands import azure_acr_login_command, azure_login_command, build_porter_command, build_porter_command_for_outputs, apply_porter_credentials_sets_command from shared.config import get_config from helpers.httpserver import start_server @@ -39,19 +38,7 @@ async def default_credentials(msi_id): await credential.close() -def update_heartbeat(process_number: int): - """ - Update heartbeat file for this process - """ - heartbeat_file = f"/tmp/resource_processor_heartbeat_{process_number}.txt" - try: - with open(heartbeat_file, 'w') as f: - f.write(str(time.time())) - except Exception as e: - logger.warning(f"Failed to update heartbeat for process {process_number}: {e}") - - -async def receive_message(service_bus_client, config: dict, process_number: int, keep_running=lambda: True): +async def receive_message(service_bus_client, config: dict, keep_running=lambda: True): """ This method is run per process. Each process will connect to service bus and try to establish a session. If messages are there, the process will continue to receive all the messages associated with that session. @@ -65,10 +52,6 @@ async def receive_message(service_bus_client, config: dict, process_number: int, try: current_time = time.time() polling_count += 1 - - # Update heartbeat file for supervisor monitoring - update_heartbeat(process_number) - # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: logger.info(f"Queue reader heartbeat: Polled for sessions {polling_count} times in the last minute") @@ -291,28 +274,7 @@ async def runner(process_number: int, config: dict): with tracer.start_as_current_span(process_number): async with default_credentials(config["vmss_msi_id"]) as credential: service_bus_client = ServiceBusClient(config["service_bus_namespace"], credential) - await receive_message(service_bus_client, config, process_number) - - -def check_process_heartbeat(process_number: int, max_age_seconds: int = 300) -> bool: - """ - Check if a process heartbeat is recent enough - """ - heartbeat_file = f"/tmp/resource_processor_heartbeat_{process_number}.txt" - try: - if not os.path.exists(heartbeat_file): - return False - - with open(heartbeat_file, 'r') as f: - heartbeat_time = float(f.read().strip()) - - current_time = time.time() - age = current_time - heartbeat_time - - return age <= max_age_seconds - except (ValueError, IOError) as e: - logger.warning(f"Failed to read heartbeat for process {process_number}: {e}") - return False + await receive_message(service_bus_client, config) async def check_runners(processes: list, httpserver: Process, keep_running=lambda: True): @@ -320,36 +282,9 @@ async def check_runners(processes: list, httpserver: Process, keep_running=lambd while keep_running(): await asyncio.sleep(30) - - # Check if all processes are alive - all_dead = all(not process.is_alive() for process in processes) - if all_dead: + if all(not process.is_alive() for process in processes): logger.error("All runner processes have failed!") httpserver.kill() - return - - # Check heartbeats for alive processes - stale_processes = [] - for i, process in enumerate(processes): - if process.is_alive() and not check_process_heartbeat(i): - logger.warning(f"Process {i} appears to be stuck (no heartbeat update)") - stale_processes.append((i, process)) - - # Restart stale processes - for process_num, process in stale_processes: - logger.warning(f"Terminating and restarting stuck process {process_num}") - process.terminate() - process.join(timeout=10) # Wait up to 10 seconds for graceful termination - if process.is_alive(): - logger.warning(f"Force killing process {process_num}") - process.kill() - - # Start new process - logger.info(f"Restarting process {process_num}") - config = set_up_config() # Get fresh config - new_process = Process(target=lambda pnum=process_num: asyncio.run(runner(pnum, config))) - processes[process_num] = new_process - new_process.start() if __name__ == "__main__": From 381bd9cf28656161e97b13ca2329ed00063286d2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Jun 2025 21:39:15 +0000 Subject: [PATCH 05/20] Fix linting issues and increment API version Co-authored-by: marrobi <17089773+marrobi@users.noreply.github.com> --- api_app/_version.py | 2 +- api_app/service_bus/deployment_status_updater.py | 14 ++++++-------- .../test_deployment_status_update.py | 6 +++--- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/api_app/_version.py b/api_app/_version.py index a304c3734f..2743a504c7 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.24.3" +__version__ = "0.24.4" diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index c45ffaeded..9b4a01f51d 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -54,13 +54,13 @@ def check_heartbeat(self, max_age_seconds: int = 300) -> bool: try: if not os.path.exists(self.heartbeat_file): return False - + with open(self.heartbeat_file, 'r') as f: heartbeat_time = float(f.read().strip()) - + current_time = time.time() age = current_time - heartbeat_time - + return age <= max_age_seconds except (ValueError, IOError) as e: logger.warning(f"Failed to read heartbeat: {e}") @@ -92,13 +92,13 @@ async def supervisor_with_heartbeat_check(self): await task # Check for any exception except Exception as e: logger.exception(f"receive_messages task failed: {e}") - + logger.info("Starting receive_messages task...") task = asyncio.create_task(self.receive_messages()) # Wait before checking heartbeat await asyncio.sleep(60) # Check every minute - + # Check if heartbeat is stale if not self.check_heartbeat(max_age_seconds=300): # 5 minutes max age logger.warning("Heartbeat is stale, restarting receive_messages task...") @@ -108,7 +108,6 @@ async def supervisor_with_heartbeat_check(self): except asyncio.CancelledError: pass task = None - except Exception as e: logger.exception(f"Supervisor error: {e}") await asyncio.sleep(30) @@ -122,10 +121,9 @@ async def receive_messages(self): try: current_time = time.time() polling_count += 1 - + # Update heartbeat file for supervisor monitoring self.update_heartbeat() - # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue {polling_count} times in the last minute") diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index 1384eacf91..dd561005ab 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -457,7 +457,7 @@ async def mock_receive_messages(): async def test_check_heartbeat_recent(mock_open, mock_exists, mock_time): """Test checking a recent heartbeat.""" mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" - + status_updater = DeploymentStatusUpdater() result = status_updater.check_heartbeat(max_age_seconds=300) assert result is True @@ -469,7 +469,7 @@ async def test_check_heartbeat_recent(mock_open, mock_exists, mock_time): async def test_check_heartbeat_stale(mock_open, mock_exists, mock_time): """Test checking a stale heartbeat.""" mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" - + status_updater = DeploymentStatusUpdater() result = status_updater.check_heartbeat(max_age_seconds=300) assert result is False @@ -489,6 +489,6 @@ async def test_update_heartbeat(mock_open, mock_time): """Test updating heartbeat.""" status_updater = DeploymentStatusUpdater() status_updater.update_heartbeat() - + mock_open.assert_called_once_with("/tmp/deployment_status_updater_heartbeat.txt", 'w') mock_open.return_value.__enter__.return_value.write.assert_called_once_with("1234567890.0") From 7c5ff5d5a7f7d7b10c7d86a8375969f5686a4d33 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Thu, 26 Jun 2025 09:40:43 +0000 Subject: [PATCH 06/20] Refactor service bus components to implement heartbeat monitoring and improve logging structure --- api_app/main.py | 2 +- .../airlock_request_status_update.py | 11 +- .../service_bus/deployment_status_updater.py | 83 +---------- api_app/service_bus/service_bus_consumer.py | 87 +++++++++++ api_app/services/logging.py | 46 +++++- .../test_deployment_status_update.py | 15 +- .../test_service_bus_consumer.py | 138 ++++++++++++++++++ 7 files changed, 292 insertions(+), 90 deletions(-) create mode 100644 api_app/service_bus/service_bus_consumer.py create mode 100644 api_app/tests_ma/test_service_bus/test_service_bus_consumer.py diff --git a/api_app/main.py b/api_app/main.py index c0769c753d..d4d9e56723 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -35,7 +35,7 @@ async def lifespan(app: FastAPI): await airlockStatusUpdater.init_repos() asyncio.create_task(deploymentStatusUpdater.supervisor_with_heartbeat_check()) - asyncio.create_task(airlockStatusUpdater.receive_messages()) + asyncio.create_task(airlockStatusUpdater.supervisor_with_heartbeat_check()) yield diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index a643404a86..90892babba 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -16,12 +16,13 @@ from models.domain.airlock_operations import StepResultStatusUpdateMessage from core import config, credentials from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer -class AirlockStatusUpdater(): +class AirlockStatusUpdater(ServiceBusConsumer): def __init__(self): - pass + super().__init__("airlock_status_updater") async def init_repos(self): self.airlock_request_repo = await AirlockRequestRepository.create() @@ -36,9 +37,13 @@ async def receive_messages(self): try: current_time = time.time() polling_count += 1 + + # Update heartbeat file for supervisor monitoring + self.update_heartbeat() + # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: - logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue {polling_count} times in the last minute") + logger.info(f"{config.SERVICE_BUS_STEP_RESULT_QUEUE} queue polled {polling_count} times in the last minute") last_heartbeat_time = current_time polling_count = 0 diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 9b4a01f51d..2c20e1dd51 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -2,7 +2,6 @@ import json import uuid import time -import os from pydantic import ValidationError, parse_obj_as @@ -22,11 +21,12 @@ from models.domain.operation import DeploymentStatusUpdateMessage, Operation, OperationStep, Status from resources import strings from services.logging import logger, tracer +from service_bus.service_bus_consumer import ServiceBusConsumer -class DeploymentStatusUpdater(): +class DeploymentStatusUpdater(ServiceBusConsumer): def __init__(self): - self.heartbeat_file = "/tmp/deployment_status_updater_heartbeat.txt" + super().__init__("deployment_status_updater") async def init_repos(self): self.operations_repo = await OperationRepository.create() @@ -37,81 +37,6 @@ async def init_repos(self): def run(self, *args, **kwargs): asyncio.run(self.receive_messages_with_restart_check()) - def update_heartbeat(self): - """ - Update heartbeat file for monitoring - """ - try: - with open(self.heartbeat_file, 'w') as f: - f.write(str(time.time())) - except Exception as e: - logger.warning(f"Failed to update heartbeat: {e}") - - def check_heartbeat(self, max_age_seconds: int = 300) -> bool: - """ - Check if the heartbeat is recent enough - """ - try: - if not os.path.exists(self.heartbeat_file): - return False - - with open(self.heartbeat_file, 'r') as f: - heartbeat_time = float(f.read().strip()) - - current_time = time.time() - age = current_time - heartbeat_time - - return age <= max_age_seconds - except (ValueError, IOError) as e: - logger.warning(f"Failed to read heartbeat: {e}") - return False - - async def receive_messages_with_restart_check(self): - """ - Continuously run the receive_messages method, restarting it if it stops unexpectedly. - """ - while True: - try: - logger.info("Starting the receive_messages loop...") - await self.receive_messages() - except Exception as e: - logger.exception(f"receive_messages stopped unexpectedly. Restarting... - {e}") - await asyncio.sleep(5) - - async def supervisor_with_heartbeat_check(self): - """ - Supervisor function that monitors the heartbeat and restarts if stuck. - """ - task = None - while True: - try: - # Start the receive_messages task if not running - if task is None or task.done(): - if task and task.done(): - try: - await task # Check for any exception - except Exception as e: - logger.exception(f"receive_messages task failed: {e}") - - logger.info("Starting receive_messages task...") - task = asyncio.create_task(self.receive_messages()) - - # Wait before checking heartbeat - await asyncio.sleep(60) # Check every minute - - # Check if heartbeat is stale - if not self.check_heartbeat(max_age_seconds=300): # 5 minutes max age - logger.warning("Heartbeat is stale, restarting receive_messages task...") - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - task = None - except Exception as e: - logger.exception(f"Supervisor error: {e}") - await asyncio.sleep(30) - async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): last_heartbeat_time = 0 @@ -126,7 +51,7 @@ async def receive_messages(self): self.update_heartbeat() # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: - logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue {polling_count} times in the last minute") + logger.info(f"{config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue polled {polling_count} times in the last minute") last_heartbeat_time = current_time polling_count = 0 diff --git a/api_app/service_bus/service_bus_consumer.py b/api_app/service_bus/service_bus_consumer.py new file mode 100644 index 0000000000..af3b4aea3d --- /dev/null +++ b/api_app/service_bus/service_bus_consumer.py @@ -0,0 +1,87 @@ +import asyncio +import os +import time + +from services.logging import logger + + +class ServiceBusConsumer: + + def __init__(self, heartbeat_file_prefix: str): + # Create a unique identifier for this worker process + import tempfile + self.worker_id = os.getpid() + temp_dir = tempfile.gettempdir() + self.heartbeat_file = os.path.join(temp_dir, f"{heartbeat_file_prefix}_heartbeat_{self.worker_id}.txt") + self.service_name = heartbeat_file_prefix.replace('_', ' ').title() + logger.info(f"Initializing {self.service_name}") + + def update_heartbeat(self): + try: + with open(self.heartbeat_file, 'w') as f: + f.write(str(time.time())) + except Exception as e: + logger.warning(f"Failed to update heartbeat: {e}") + + def check_heartbeat(self, max_age_seconds: int = 300) -> bool: + try: + if not os.path.exists(self.heartbeat_file): + logger.warning("Heartbeat file does not exist") + return False + + with open(self.heartbeat_file, 'r') as f: + heartbeat_time = float(f.read().strip()) + + current_time = time.time() + age = current_time - heartbeat_time + + if age > max_age_seconds: + logger.warning(f"Heartbeat is {age:.1f} seconds old, exceeding the limit of {max_age_seconds} seconds") + + return age <= max_age_seconds + except (ValueError, IOError) as e: + logger.warning(f"Failed to read heartbeat: {e}") + return False + + async def receive_messages_with_restart_check(self): + while True: + try: + logger.info("Starting the receive_messages loop...") + await self.receive_messages() + except Exception as e: + logger.exception(f"receive_messages stopped unexpectedly. Restarting... - {e}") + await asyncio.sleep(5) + + async def supervisor_with_heartbeat_check(self): + task = None + while True: + try: + # Start the receive_messages task if not running + if task is None or task.done(): + if task and task.done(): + try: + await task # Check for any exception + except Exception as e: + logger.exception(f"receive_messages task failed: {e}") + + logger.info("Starting receive_messages task...") + task = asyncio.create_task(self.receive_messages()) + + # Wait before checking heartbeat + await asyncio.sleep(60) # Check every minute + + # Check if heartbeat is stale + if not self.check_heartbeat(max_age_seconds=300): # 5 minutes max age + logger.warning("Heartbeat is stale, restarting receive_messages task...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + task = None + except Exception as e: + logger.exception(f"Supervisor error: {e}") + await asyncio.sleep(30) + + async def receive_messages(self): + raise NotImplementedError("Subclasses must implement receive_messages()") diff --git a/api_app/services/logging.py b/api_app/services/logging.py index eeb9ca17cd..91fe90d0da 100644 --- a/api_app/services/logging.py +++ b/api_app/services/logging.py @@ -1,4 +1,5 @@ import logging +import os from opentelemetry.instrumentation.logging import LoggingInstrumentor from opentelemetry import trace from azure.monitor.opentelemetry import configure_azure_monitor @@ -45,6 +46,23 @@ "urllib3.connectionpool" ] + +class WorkerIdFilter(logging.Filter): + """ + A filter that adds worker_id to all log records. + """ + + def __init__(self): + super().__init__() + # Get the process ID as a unique worker identifier + self.worker_id = os.getpid() + + def filter(self, record: logging.LogRecord) -> bool: + # Add worker_id as an attribute to the log record + record.worker_id = self.worker_id + return True + + logger = logging.getLogger("azuretre_api") tracer = trace.get_tracer("azuretre_api") @@ -57,6 +75,20 @@ def configure_loggers(): logging.getLogger(logger_name).setLevel(logging.CRITICAL) +def apply_worker_id_to_logger(logger_instance): + """ + Apply the worker ID filter to a logger instance. + """ + worker_filter = WorkerIdFilter() + logger_instance.addFilter(worker_filter) + + # Update handlers to include worker_id in the format + for handler in logger_instance.handlers: + if isinstance(handler, logging.StreamHandler): + formatter = logging.Formatter('%(asctime)s - Worker %(worker_id)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + + def initialize_logging() -> logging.Logger: configure_loggers() @@ -84,12 +116,24 @@ def initialize_logging() -> logging.Logger: } ) + # Custom log format including worker_id + log_format = '%(asctime)s - Worker %(worker_id)s - %(name)s - %(levelname)s - %(message)s' + LoggingInstrumentor().instrument( set_logging_format=True, log_level=logging_level, - tracer_provider=tracer._real_tracer + tracer_provider=tracer._real_tracer, + log_format=log_format ) + # Set up a handler if none exists + if not logger.handlers: + handler = logging.StreamHandler() + logger.addHandler(handler) + + # Apply worker ID filter + apply_worker_id_to_logger(logger) + logger.info("Logging initialized with level: %s", LOGGING_LEVEL) return logger diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index dd561005ab..55adf8f8e7 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -452,7 +452,7 @@ async def mock_receive_messages(): @patch("service_bus.deployment_status_updater.time.time", return_value=1234567890.0 + 100) # 100 seconds later -@patch("service_bus.deployment_status_updater.os.path.exists", return_value=True) +@patch("service_bus.service_bus_consumer.os.path.exists", return_value=True) @patch("builtins.open", create=True) async def test_check_heartbeat_recent(mock_open, mock_exists, mock_time): """Test checking a recent heartbeat.""" @@ -463,8 +463,8 @@ async def test_check_heartbeat_recent(mock_open, mock_exists, mock_time): assert result is True -@patch("service_bus.deployment_status_updater.time.time", return_value=1234567890.0 + 400) # 400 seconds later -@patch("service_bus.deployment_status_updater.os.path.exists", return_value=True) +@patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0 + 400) # 400 seconds later +@patch("service_bus.service_bus_consumer.os.path.exists", return_value=True) @patch("builtins.open", create=True) async def test_check_heartbeat_stale(mock_open, mock_exists, mock_time): """Test checking a stale heartbeat.""" @@ -475,7 +475,7 @@ async def test_check_heartbeat_stale(mock_open, mock_exists, mock_time): assert result is False -@patch("service_bus.deployment_status_updater.os.path.exists", return_value=False) +@patch("service_bus.service_bus_consumer.os.path.exists", return_value=False) async def test_check_heartbeat_no_file(mock_exists): """Test checking heartbeat when file doesn't exist.""" status_updater = DeploymentStatusUpdater() @@ -483,12 +483,15 @@ async def test_check_heartbeat_no_file(mock_exists): assert result is False -@patch("service_bus.deployment_status_updater.time.time", return_value=1234567890.0) +@patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0) @patch("builtins.open", create=True) async def test_update_heartbeat(mock_open, mock_time): """Test updating heartbeat.""" status_updater = DeploymentStatusUpdater() status_updater.update_heartbeat() - mock_open.assert_called_once_with("/tmp/deployment_status_updater_heartbeat.txt", 'w') + # Using the worker_id in the assertion + import tempfile + expected_path = f"{tempfile.gettempdir()}/deployment_status_updater_heartbeat_{status_updater.worker_id}.txt" + mock_open.assert_called_once_with(expected_path, 'w') mock_open.return_value.__enter__.return_value.write.assert_called_once_with("1234567890.0") diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py new file mode 100644 index 0000000000..b8c82e9023 --- /dev/null +++ b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py @@ -0,0 +1,138 @@ +import asyncio +import pytest +from unittest.mock import patch, AsyncMock + +from service_bus.service_bus_consumer import ServiceBusConsumer + + +# Create a concrete implementation for testing +class TestConsumer(ServiceBusConsumer): + def __init__(self): + super().__init__("test_consumer") + self.receive_messages_called = False + + async def receive_messages(self): + self.receive_messages_called = True + # Simulate running once and then exiting + await asyncio.sleep(0.1) + return + + +@pytest.mark.asyncio +@patch("service_bus.service_bus_consumer.os.getpid", return_value=12345) +async def test_init(mock_getpid): + """Test initialization of ServiceBusConsumer.""" + consumer = TestConsumer() + assert consumer.worker_id == 12345 + assert consumer.heartbeat_file == "/tmp/test_consumer_heartbeat_12345.txt" + assert consumer.service_name == "Test Consumer" + + +@pytest.mark.asyncio +@patch("service_bus.service_bus_consumer.os.path.exists", return_value=True) +@patch("builtins.open", create=True) +async def test_check_heartbeat_recent(mock_open, mock_exists): + """Test checking a recent heartbeat.""" + mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" + + with patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0 + 60): # 60 seconds later + consumer = TestConsumer() + result = consumer.check_heartbeat(max_age_seconds=300) + assert result is True + + +@pytest.mark.asyncio +@patch("service_bus.service_bus_consumer.os.path.exists", return_value=True) +@patch("builtins.open", create=True) +async def test_check_heartbeat_stale(mock_open, mock_exists): + """Test checking a stale heartbeat.""" + mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" + + with patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0 + 400): # 400 seconds later + consumer = TestConsumer() + result = consumer.check_heartbeat(max_age_seconds=300) + assert result is False + + +@pytest.mark.asyncio +@patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0) +@patch("builtins.open", create=True) +async def test_update_heartbeat(mock_open, mock_time): + """Test updating heartbeat.""" + consumer = TestConsumer() + with patch("service_bus.service_bus_consumer.os.getpid", return_value=12345): + consumer.worker_id = 12345 # Set worker_id explicitly for test + consumer.update_heartbeat() + + import tempfile + expected_path = f"{tempfile.gettempdir()}/test_consumer_heartbeat_12345.txt" + mock_open.assert_called_once_with(expected_path, 'w') + mock_open.return_value.__enter__.return_value.write.assert_called_once_with("1234567890.0") + + +@pytest.mark.asyncio +async def test_receive_messages_with_restart_check(): + """Test receive_messages_with_restart_check calls receive_messages.""" + consumer = TestConsumer() + + # Mock the receive_messages method to raise an exception after first call + original_receive_messages = consumer.receive_messages + call_count = 0 + + async def mock_receive_messages(): + nonlocal call_count + call_count += 1 + if call_count == 1: + await original_receive_messages() + else: + raise Exception("Test exception") + + consumer.receive_messages = mock_receive_messages + + # Mock asyncio.sleep to avoid actual waiting + with patch("asyncio.sleep", new_callable=AsyncMock): + # Schedule the coroutine and run it for a short time + task = asyncio.create_task(consumer.receive_messages_with_restart_check()) + await asyncio.sleep(0.2) # Give it a chance to run + task.cancel() # Cancel to exit the infinite loop + + try: + await task + except asyncio.CancelledError: + pass + + assert consumer.receive_messages_called is True + assert call_count > 0 # Should have called at least once + + +@pytest.mark.asyncio +async def test_supervisor_with_heartbeat_check(): + """Test supervisor_with_heartbeat_check manages the receive_messages task.""" + consumer = TestConsumer() + + # Mock check_heartbeat to control the test flow + # Save original for potential future use + # original_check_heartbeat = consumer.check_heartbeat + heartbeat_calls = 0 + + def mock_check_heartbeat(max_age_seconds=300): + nonlocal heartbeat_calls + heartbeat_calls += 1 + # First call returns True, second call returns False to trigger restart + return heartbeat_calls != 2 + + consumer.check_heartbeat = mock_check_heartbeat + + # Mock asyncio.sleep to avoid actual waiting + with patch("asyncio.sleep", new_callable=AsyncMock): + # Schedule the coroutine and run it for a short time + task = asyncio.create_task(consumer.supervisor_with_heartbeat_check()) + await asyncio.sleep(0.2) # Give it a chance to run + task.cancel() # Cancel to exit the infinite loop + + try: + await task + except asyncio.CancelledError: + pass + + assert heartbeat_calls > 0 # Should have checked heartbeat at least once From 9329c94bc177f4a0ef34037dc1175a7f3109cb34 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Thu, 26 Jun 2025 11:39:45 +0000 Subject: [PATCH 07/20] update tests and fix issue. --- .../service_bus/deployment_status_updater.py | 3 - api_app/service_bus/service_bus_consumer.py | 2 +- .../test_service_bus_consumer.py | 141 ++++++++++++------ 3 files changed, 96 insertions(+), 50 deletions(-) diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 2c20e1dd51..6b14c743e4 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -34,9 +34,6 @@ async def init_repos(self): self.resource_template_repo = await ResourceTemplateRepository.create() self.resource_history_repo = await ResourceHistoryRepository.create() - def run(self, *args, **kwargs): - asyncio.run(self.receive_messages_with_restart_check()) - async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): last_heartbeat_time = 0 diff --git a/api_app/service_bus/service_bus_consumer.py b/api_app/service_bus/service_bus_consumer.py index af3b4aea3d..56508bcd82 100644 --- a/api_app/service_bus/service_bus_consumer.py +++ b/api_app/service_bus/service_bus_consumer.py @@ -65,7 +65,7 @@ async def supervisor_with_heartbeat_check(self): logger.exception(f"receive_messages task failed: {e}") logger.info("Starting receive_messages task...") - task = asyncio.create_task(self.receive_messages()) + task = asyncio.create_task(self.receive_messages_with_restart_check()) # Wait before checking heartbeat await asyncio.sleep(60) # Check every minute diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py index b8c82e9023..34cbaf830f 100644 --- a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py +++ b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py @@ -1,12 +1,12 @@ import asyncio import pytest -from unittest.mock import patch, AsyncMock +from unittest.mock import patch from service_bus.service_bus_consumer import ServiceBusConsumer # Create a concrete implementation for testing -class TestConsumer(ServiceBusConsumer): +class MockConsumer(ServiceBusConsumer): def __init__(self): super().__init__("test_consumer") self.receive_messages_called = False @@ -22,7 +22,7 @@ async def receive_messages(self): @patch("service_bus.service_bus_consumer.os.getpid", return_value=12345) async def test_init(mock_getpid): """Test initialization of ServiceBusConsumer.""" - consumer = TestConsumer() + consumer = MockConsumer() assert consumer.worker_id == 12345 assert consumer.heartbeat_file == "/tmp/test_consumer_heartbeat_12345.txt" assert consumer.service_name == "Test Consumer" @@ -36,7 +36,7 @@ async def test_check_heartbeat_recent(mock_open, mock_exists): mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" with patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0 + 60): # 60 seconds later - consumer = TestConsumer() + consumer = MockConsumer() result = consumer.check_heartbeat(max_age_seconds=300) assert result is True @@ -49,20 +49,19 @@ async def test_check_heartbeat_stale(mock_open, mock_exists): mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" with patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0 + 400): # 400 seconds later - consumer = TestConsumer() + consumer = MockConsumer() result = consumer.check_heartbeat(max_age_seconds=300) assert result is False @pytest.mark.asyncio +@patch("service_bus.service_bus_consumer.os.getpid", return_value=12345) @patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0) @patch("builtins.open", create=True) -async def test_update_heartbeat(mock_open, mock_time): +async def test_update_heartbeat(mock_open, mock_time, mock_getpid): """Test updating heartbeat.""" - consumer = TestConsumer() - with patch("service_bus.service_bus_consumer.os.getpid", return_value=12345): - consumer.worker_id = 12345 # Set worker_id explicitly for test - consumer.update_heartbeat() + consumer = MockConsumer() + consumer.update_heartbeat() import tempfile expected_path = f"{tempfile.gettempdir()}/test_consumer_heartbeat_12345.txt" @@ -72,67 +71,117 @@ async def test_update_heartbeat(mock_open, mock_time): @pytest.mark.asyncio async def test_receive_messages_with_restart_check(): - """Test receive_messages_with_restart_check calls receive_messages.""" - consumer = TestConsumer() + """Test receive_messages_with_restart_check calls receive_messages and handles exceptions.""" + consumer = MockConsumer() - # Mock the receive_messages method to raise an exception after first call - original_receive_messages = consumer.receive_messages - call_count = 0 + # Track how many times receive_messages has been called + receive_messages_call_count = 0 + sleep_calls = [] async def mock_receive_messages(): - nonlocal call_count - call_count += 1 - if call_count == 1: - await original_receive_messages() - else: + nonlocal receive_messages_call_count + receive_messages_call_count += 1 + if receive_messages_call_count == 1: + # First call raises an exception raise Exception("Test exception") + elif receive_messages_call_count == 2: + # Second call succeeds, but we need to break the infinite loop + # Let's raise a special exception to break out + raise KeyboardInterrupt("Break out of loop for test") + else: + # Should not get here in this test + return - consumer.receive_messages = mock_receive_messages + async def mock_sleep(duration): + sleep_calls.append(duration) + # Just return immediately instead of sleeping + return - # Mock asyncio.sleep to avoid actual waiting - with patch("asyncio.sleep", new_callable=AsyncMock): - # Schedule the coroutine and run it for a short time - task = asyncio.create_task(consumer.receive_messages_with_restart_check()) - await asyncio.sleep(0.2) # Give it a chance to run - task.cancel() # Cancel to exit the infinite loop + # Override the method with our mock + consumer.receive_messages = mock_receive_messages + # Patch asyncio.sleep in the service_bus_consumer module + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep): try: - await task - except asyncio.CancelledError: + # Run the method, expecting it to call receive_messages twice and then break + await consumer.receive_messages_with_restart_check() + except KeyboardInterrupt: + # This is our expected way out of the infinite loop pass - assert consumer.receive_messages_called is True - assert call_count > 0 # Should have called at least once + # Verify that receive_messages was called twice and sleep was called once + assert receive_messages_call_count == 2, f"Expected exactly 2 calls to receive_messages, got {receive_messages_call_count}" + assert len(sleep_calls) == 1, f"Expected exactly 1 sleep call for restart delay, got {len(sleep_calls)}" + assert sleep_calls[0] == 5, f"Expected sleep(5) call for restart delay, got {sleep_calls}" @pytest.mark.asyncio async def test_supervisor_with_heartbeat_check(): """Test supervisor_with_heartbeat_check manages the receive_messages task.""" - consumer = TestConsumer() + consumer = MockConsumer() - # Mock check_heartbeat to control the test flow - # Save original for potential future use - # original_check_heartbeat = consumer.check_heartbeat + # Track method calls and task lifecycle heartbeat_calls = 0 + task_create_calls = 0 + task_cancel_calls = 0 + sleep_calls = [] + # Mock check_heartbeat to return False on second call (to trigger restart) def mock_check_heartbeat(max_age_seconds=300): nonlocal heartbeat_calls heartbeat_calls += 1 - # First call returns True, second call returns False to trigger restart - return heartbeat_calls != 2 + # Return True first, then False to trigger restart, then break the loop + if heartbeat_calls == 1: + return True # Heartbeat is fresh + elif heartbeat_calls == 2: + return False # Heartbeat is stale, should trigger restart + else: + # Break out of the infinite loop after testing restart logic + raise KeyboardInterrupt("Test complete") + + # Track sleep calls + async def mock_sleep(duration): + sleep_calls.append(duration) + # Don't actually sleep + return + + # Mock task to track creation and cancellation + class MockTask: + def __init__(self): + nonlocal task_create_calls + task_create_calls += 1 + + def cancel(self): + nonlocal task_cancel_calls + task_cancel_calls += 1 + + def done(self): + # Return False so task appears to be running + return False + + def __await__(self): + # Mock awaiting the task (for task cleanup) + async def _await(): + return None + return _await().__await__() + # Apply mocks consumer.check_heartbeat = mock_check_heartbeat - # Mock asyncio.sleep to avoid actual waiting - with patch("asyncio.sleep", new_callable=AsyncMock): - # Schedule the coroutine and run it for a short time - task = asyncio.create_task(consumer.supervisor_with_heartbeat_check()) - await asyncio.sleep(0.2) # Give it a chance to run - task.cancel() # Cancel to exit the infinite loop + # Mock asyncio functions in the service_bus_consumer module + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=lambda coro: MockTask()): try: - await task - except asyncio.CancelledError: + # Run the supervisor - it will break out when KeyboardInterrupt is raised + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + # Expected way to exit the infinite loop pass - assert heartbeat_calls > 0 # Should have checked heartbeat at least once + # Verify expected behavior occurred + assert heartbeat_calls >= 2, f"Expected at least 2 heartbeat checks, got {heartbeat_calls}" + assert task_create_calls >= 2, f"Expected at least 2 tasks created (initial + restart), got {task_create_calls}" + assert task_cancel_calls >= 1, f"Expected at least 1 task cancellation, got {task_cancel_calls}" + assert len(sleep_calls) >= 2, f"Expected at least 2 sleep calls, got {len(sleep_calls)}" + assert 60 in sleep_calls, f"Expected sleep(60) for heartbeat check interval, got {sleep_calls}" From 96e39b2ad0016b645cd1d99cc63bdb1c33deefaa Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Thu, 26 Jun 2025 11:42:18 +0000 Subject: [PATCH 08/20] Fix lint. --- api_app/service_bus/deployment_status_updater.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 6b14c743e4..b379093ce8 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -1,4 +1,3 @@ -import asyncio import json import uuid import time From 75d77ddcfdb2e51381ff955a8325d3cb952d8650 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Thu, 26 Jun 2025 11:57:40 +0000 Subject: [PATCH 09/20] remove duplicate tests. --- .../test_deployment_status_update.py | 82 ------------------- .../test_service_bus_consumer.py | 9 ++ 2 files changed, 9 insertions(+), 82 deletions(-) diff --git a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py index 55adf8f8e7..db80c5b1f7 100644 --- a/api_app/tests_ma/test_service_bus/test_deployment_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_deployment_status_update.py @@ -413,85 +413,3 @@ async def test_convert_outputs_to_dict(): 'list2': ['one', 'two'] } assert status_updater.convert_outputs_to_dict(deployment_status_update_message.outputs) == expected_result - - -@patch('service_bus.deployment_status_updater.asyncio.sleep') -@patch('services.logging.logger.exception') -@patch('services.logging.logger.info') -async def test_receive_messages_with_restart_check_restarts_on_exception(mock_logger_info, mock_logger_exception, mock_sleep): - """Test that receive_messages_with_restart_check properly restarts when receive_messages fails""" - status_updater = DeploymentStatusUpdater() - - # Mock receive_messages to fail once, then succeed (stopping the loop) - call_count = 0 - - async def mock_receive_messages(): - nonlocal call_count - call_count += 1 - if call_count == 1: - raise Exception("Test exception") - # Second call succeeds but we need to stop the loop somehow - # For testing purposes, we'll just raise a different exception to break the loop - raise KeyboardInterrupt("Test interrupt to stop loop") - - status_updater.receive_messages = mock_receive_messages - - # Test that the restart mechanism works - try: - await status_updater.receive_messages_with_restart_check() - except KeyboardInterrupt: - pass # Expected to stop the loop - - # Verify the restart mechanism worked - assert call_count == 2, "receive_messages should have been called twice (once failed, once succeeded)" - - # Verify logging calls - mock_logger_info.assert_called_with("Starting the receive_messages loop...") - mock_logger_exception.assert_called_once_with("receive_messages stopped unexpectedly. Restarting... - Test exception") - mock_sleep.assert_called_once_with(5) - - -@patch("service_bus.deployment_status_updater.time.time", return_value=1234567890.0 + 100) # 100 seconds later -@patch("service_bus.service_bus_consumer.os.path.exists", return_value=True) -@patch("builtins.open", create=True) -async def test_check_heartbeat_recent(mock_open, mock_exists, mock_time): - """Test checking a recent heartbeat.""" - mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" - - status_updater = DeploymentStatusUpdater() - result = status_updater.check_heartbeat(max_age_seconds=300) - assert result is True - - -@patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0 + 400) # 400 seconds later -@patch("service_bus.service_bus_consumer.os.path.exists", return_value=True) -@patch("builtins.open", create=True) -async def test_check_heartbeat_stale(mock_open, mock_exists, mock_time): - """Test checking a stale heartbeat.""" - mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" - - status_updater = DeploymentStatusUpdater() - result = status_updater.check_heartbeat(max_age_seconds=300) - assert result is False - - -@patch("service_bus.service_bus_consumer.os.path.exists", return_value=False) -async def test_check_heartbeat_no_file(mock_exists): - """Test checking heartbeat when file doesn't exist.""" - status_updater = DeploymentStatusUpdater() - result = status_updater.check_heartbeat() - assert result is False - - -@patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0) -@patch("builtins.open", create=True) -async def test_update_heartbeat(mock_open, mock_time): - """Test updating heartbeat.""" - status_updater = DeploymentStatusUpdater() - status_updater.update_heartbeat() - - # Using the worker_id in the assertion - import tempfile - expected_path = f"{tempfile.gettempdir()}/deployment_status_updater_heartbeat_{status_updater.worker_id}.txt" - mock_open.assert_called_once_with(expected_path, 'w') - mock_open.return_value.__enter__.return_value.write.assert_called_once_with("1234567890.0") diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py index 34cbaf830f..884cf51653 100644 --- a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py +++ b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py @@ -54,6 +54,15 @@ async def test_check_heartbeat_stale(mock_open, mock_exists): assert result is False +@pytest.mark.asyncio +@patch("service_bus.service_bus_consumer.os.path.exists", return_value=False) +async def test_check_heartbeat_no_file(mock_exists): + """Test checking heartbeat when file doesn't exist.""" + consumer = MockConsumer() + result = consumer.check_heartbeat() + assert result is False + + @pytest.mark.asyncio @patch("service_bus.service_bus_consumer.os.getpid", return_value=12345) @patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0) From 7b78e9911926561673ae87bb60373db0407c2918 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Fri, 7 Nov 2025 15:11:27 +0000 Subject: [PATCH 10/20] Enhance Service Bus consumer with error handling and heartbeat management tests --- .../service_bus/deployment_status_updater.py | 23 ++- api_app/service_bus/service_bus_consumer.py | 80 ++++++--- .../test_service_bus_edge_cases.py | 166 ++++++++++++++++++ 3 files changed, 232 insertions(+), 37 deletions(-) create mode 100644 api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index b379093ce8..b57ef2d3c3 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -1,6 +1,7 @@ import json import uuid import time +from typing import Dict, List, Any from pydantic import ValidationError, parse_obj_as @@ -73,15 +74,15 @@ async def receive_messages(self): # Timeout occurred whilst connecting to a session - this is expected and indicates no non-empty sessions are available logger.debug("No sessions for this process. Will look again...") - except ServiceBusConnectionError: + except ServiceBusConnectionError as e: # Occasionally there will be a transient / network-level error in connecting to SB. - logger.info("Unknown Service Bus connection error. Will retry...") + logger.warning(f"Service Bus connection error (will retry): {e}") except Exception as e: # Catch all other exceptions, log them via .exception to get the stack trace, and reconnect - logger.exception(f"Unknown exception. Will retry - {e}") + logger.exception(f"Unexpected error in message processing: {type(e).__name__}: {e}") - async def process_message(self, msg): + async def process_message(self, msg) -> bool: complete_message = False message = "" @@ -115,6 +116,11 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage try: # update the op operation = await self.operations_repo.get_operation_by_id(str(message.operationId)) + + # Add null safety for operation steps + if not operation.steps: + raise ValueError(f"Operation {message.operationId} has no steps") + step_to_update = None is_last_step = False @@ -128,7 +134,7 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage is_last_step = True if step_to_update is None: - raise f"Error finding step {message.stepId} in operation {message.operationId}" + raise ValueError(f"Error finding step {message.stepId} in operation {message.operationId}") # update the step status step_to_update.status = message.status @@ -159,7 +165,8 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage # more steps in the op to do? if is_last_step is False: - assert current_step_index < (len(operation.steps) - 1) + if current_step_index >= len(operation.steps) - 1: + raise ValueError(f"Invalid step index {current_step_index} for operation with {len(operation.steps)} steps") next_step = operation.steps[current_step_index + 1] # catch any errors in updating the resource - maybe Cosmos / schema invalid etc, and report them back to the op @@ -255,7 +262,7 @@ def get_failure_status_for_action(self, action: RequestAction): return status - def create_updated_resource_document(self, resource: dict, message: DeploymentStatusUpdateMessage): + def create_updated_resource_document(self, resource: Dict[str, Any], message: DeploymentStatusUpdateMessage) -> Dict[str, Any]: """ Merge the outputs with the resource document to persist """ @@ -268,7 +275,7 @@ def create_updated_resource_document(self, resource: dict, message: DeploymentSt return resource - def convert_outputs_to_dict(self, outputs_list: [Output]): + def convert_outputs_to_dict(self, outputs_list: List[Output]) -> Dict[str, Any]: """ Convert a list of Porter outputs to a dictionary """ diff --git a/api_app/service_bus/service_bus_consumer.py b/api_app/service_bus/service_bus_consumer.py index 56508bcd82..ddab28551b 100644 --- a/api_app/service_bus/service_bus_consumer.py +++ b/api_app/service_bus/service_bus_consumer.py @@ -4,6 +4,12 @@ from services.logging import logger +# Configuration constants for monitoring intervals +HEARTBEAT_CHECK_INTERVAL_SECONDS = 60 +HEARTBEAT_STALENESS_THRESHOLD_SECONDS = 300 +RESTART_DELAY_SECONDS = 5 +SUPERVISOR_ERROR_DELAY_SECONDS = 30 + class ServiceBusConsumer: @@ -18,10 +24,16 @@ def __init__(self, heartbeat_file_prefix: str): def update_heartbeat(self): try: + # Ensure directory exists + os.makedirs(os.path.dirname(self.heartbeat_file), exist_ok=True) with open(self.heartbeat_file, 'w') as f: f.write(str(time.time())) + except PermissionError: + logger.error(f"Permission denied writing heartbeat to {self.heartbeat_file}") + except OSError as e: + logger.error(f"OS error updating heartbeat: {e}") except Exception as e: - logger.warning(f"Failed to update heartbeat: {e}") + logger.warning(f"Unexpected error updating heartbeat: {e}") def check_heartbeat(self, max_age_seconds: int = 300) -> bool: try: @@ -50,38 +62,48 @@ async def receive_messages_with_restart_check(self): await self.receive_messages() except Exception as e: logger.exception(f"receive_messages stopped unexpectedly. Restarting... - {e}") - await asyncio.sleep(5) + await asyncio.sleep(RESTART_DELAY_SECONDS) async def supervisor_with_heartbeat_check(self): task = None - while True: - try: - # Start the receive_messages task if not running - if task is None or task.done(): - if task and task.done(): + try: + while True: + try: + # Start the receive_messages task if not running + if task is None or task.done(): + if task and task.done(): + try: + await task # Check for any exception + except Exception as e: + logger.exception(f"receive_messages task failed: {e}") + + logger.info("Starting receive_messages task...") + task = asyncio.create_task(self.receive_messages_with_restart_check()) + + # Wait before checking heartbeat + await asyncio.sleep(HEARTBEAT_CHECK_INTERVAL_SECONDS) # Check every minute + + # Check if heartbeat is stale + if not self.check_heartbeat(max_age_seconds=HEARTBEAT_STALENESS_THRESHOLD_SECONDS): # 5 minutes max age + logger.warning("Heartbeat is stale, restarting receive_messages task...") + task.cancel() try: - await task # Check for any exception - except Exception as e: - logger.exception(f"receive_messages task failed: {e}") - - logger.info("Starting receive_messages task...") - task = asyncio.create_task(self.receive_messages_with_restart_check()) - - # Wait before checking heartbeat - await asyncio.sleep(60) # Check every minute - - # Check if heartbeat is stale - if not self.check_heartbeat(max_age_seconds=300): # 5 minutes max age - logger.warning("Heartbeat is stale, restarting receive_messages task...") - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - task = None - except Exception as e: - logger.exception(f"Supervisor error: {e}") - await asyncio.sleep(30) + await task + except asyncio.CancelledError: + pass + task = None + except Exception as e: + logger.exception(f"Supervisor error: {e}") + await asyncio.sleep(SUPERVISOR_ERROR_DELAY_SECONDS) + finally: + # Ensure proper cleanup on shutdown + if task and not task.done(): + logger.info("Cleaning up supervisor task...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass async def receive_messages(self): raise NotImplementedError("Subclasses must implement receive_messages()") diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py new file mode 100644 index 0000000000..8a6c6e53e2 --- /dev/null +++ b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py @@ -0,0 +1,166 @@ +import asyncio +import pytest +import os +from unittest.mock import patch, Mock +from service_bus.service_bus_consumer import ServiceBusConsumer + + +# Create a concrete implementation for testing edge cases +class MockConsumerForEdgeCases(ServiceBusConsumer): + def __init__(self): + super().__init__("test_consumer_edge") + self.receive_messages_called = False + + async def receive_messages(self): + self.receive_messages_called = True + await asyncio.sleep(0.1) + return + + +@pytest.mark.asyncio +async def test_heartbeat_file_corruption(): + """Test handling of corrupted heartbeat file.""" + consumer = MockConsumerForEdgeCases() + + with patch("service_bus.service_bus_consumer.os.path.exists", return_value=True), \ + patch("builtins.open") as mock_open: + + # Simulate corrupted file with invalid float content + mock_open.return_value.__enter__.return_value.read.return_value = "not_a_number" + + result = consumer.check_heartbeat() + assert result is False + + +@pytest.mark.asyncio +async def test_heartbeat_permission_denied(): + """Test heartbeat update when permission denied.""" + consumer = MockConsumerForEdgeCases() + + with patch("builtins.open", side_effect=PermissionError("Permission denied")), \ + patch("service_bus.service_bus_consumer.logger") as mock_logger: + + # Should not crash, just log error + consumer.update_heartbeat() + mock_logger.error.assert_called_once() + + +@pytest.mark.asyncio +async def test_heartbeat_disk_full(): + """Test heartbeat update when disk is full.""" + consumer = MockConsumerForEdgeCases() + + with patch("builtins.open", side_effect=OSError("No space left on device")), \ + patch("service_bus.service_bus_consumer.logger") as mock_logger: + + # Should not crash, just log error + consumer.update_heartbeat() + mock_logger.error.assert_called_once() + + +@pytest.mark.asyncio +async def test_supervisor_cleanup_on_exception(): + """Test supervisor properly cleans up tasks when interrupted.""" + consumer = MockConsumerForEdgeCases() + + # Track task lifecycle + task_created = False + task_cancelled = False + + class MockTask: + def __init__(self): + nonlocal task_created + task_created = True + self.done_count = 0 + + def done(self): + return self.done_count > 0 + + def cancel(self): + nonlocal task_cancelled + task_cancelled = True + self.done_count = 1 + + def __await__(self): + async def _await(): + if task_cancelled: + raise asyncio.CancelledError() + return None + return _await().__await__() + + # Mock to trigger cleanup after one iteration + call_count = 0 + async def mock_sleep(duration): + nonlocal call_count + call_count += 1 + if call_count >= 2: # Trigger cleanup after heartbeat check + raise KeyboardInterrupt("Test cleanup") + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=lambda coro: MockTask()), \ + patch.object(consumer, "check_heartbeat", return_value=True): + + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + # Verify task was created and cancelled during cleanup + assert task_created, "Task should have been created" + assert task_cancelled, "Task should have been cancelled during cleanup" + + +@pytest.mark.asyncio +async def test_rapid_task_failures(): + """Test supervisor behavior with rapid consecutive failures.""" + consumer = MockConsumerForEdgeCases() + + failure_count = 0 + sleep_calls = [] + max_failures = 2 + + async def failing_receive_messages(): + nonlocal failure_count + failure_count += 1 + if failure_count <= max_failures: + raise Exception(f"Failure {failure_count}") + # Success after max failures - but don't run forever + return + + async def mock_sleep(duration): + sleep_calls.append(duration) + return # Don't actually sleep + + # Override receive_messages to simulate failures + consumer.receive_messages = failing_receive_messages + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep): + # Test restart behavior by calling directly + for _ in range(max_failures + 1): # Run enough times to trigger failures and success + try: + await consumer.receive_messages_with_restart_check() + break # Exit when successful + except Exception: + continue # Continue to trigger restart logic + + # Verify restart delays were applied + from service_bus.service_bus_consumer import RESTART_DELAY_SECONDS + assert failure_count >= max_failures, f"Should have {max_failures} failures, got {failure_count}" + assert len(sleep_calls) >= max_failures, f"Should have {max_failures} restart delays, got {len(sleep_calls)}" + + +@pytest.mark.asyncio +async def test_heartbeat_directory_creation(): + """Test that heartbeat directory is created if it doesn't exist.""" + consumer = MockConsumerForEdgeCases() + + with patch("service_bus.service_bus_consumer.os.makedirs") as mock_makedirs, \ + patch("builtins.open", create=True) as mock_open: + + consumer.update_heartbeat() + + # Verify makedirs was called with exist_ok=True + mock_makedirs.assert_called_once_with( + os.path.dirname(consumer.heartbeat_file), + exist_ok=True + ) \ No newline at end of file From 32d8c75f02299b967c31d66bc5bed4af5768e1a1 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Fri, 7 Nov 2025 15:12:04 +0000 Subject: [PATCH 11/20] Enhance Service Bus consumer with error handling and heartbeat management tests --- .../service_bus/deployment_status_updater.py | 23 ++- api_app/service_bus/service_bus_consumer.py | 80 ++++++--- .../test_service_bus_edge_cases.py | 167 ++++++++++++++++++ 3 files changed, 233 insertions(+), 37 deletions(-) create mode 100644 api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index b379093ce8..7254b8ec6e 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -1,6 +1,7 @@ import json import uuid import time +from typing import Dict, List, Any from pydantic import ValidationError, parse_obj_as @@ -73,15 +74,15 @@ async def receive_messages(self): # Timeout occurred whilst connecting to a session - this is expected and indicates no non-empty sessions are available logger.debug("No sessions for this process. Will look again...") - except ServiceBusConnectionError: + except ServiceBusConnectionError as e: # Occasionally there will be a transient / network-level error in connecting to SB. - logger.info("Unknown Service Bus connection error. Will retry...") + logger.warning(f"Service Bus connection error (will retry): {e}") except Exception as e: # Catch all other exceptions, log them via .exception to get the stack trace, and reconnect - logger.exception(f"Unknown exception. Will retry - {e}") + logger.exception(f"Unexpected error in message processing: {type(e).__name__}: {e}") - async def process_message(self, msg): + async def process_message(self, msg) -> bool: complete_message = False message = "" @@ -115,6 +116,11 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage try: # update the op operation = await self.operations_repo.get_operation_by_id(str(message.operationId)) + + # Add null safety for operation steps + if not operation.steps: + raise ValueError(f"Operation {message.operationId} has no steps") + step_to_update = None is_last_step = False @@ -128,7 +134,7 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage is_last_step = True if step_to_update is None: - raise f"Error finding step {message.stepId} in operation {message.operationId}" + raise ValueError(f"Error finding step {message.stepId} in operation {message.operationId}") # update the step status step_to_update.status = message.status @@ -159,7 +165,8 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage # more steps in the op to do? if is_last_step is False: - assert current_step_index < (len(operation.steps) - 1) + if current_step_index >= len(operation.steps) - 1: + raise ValueError(f"Invalid step index {current_step_index} for operation with {len(operation.steps)} steps") next_step = operation.steps[current_step_index + 1] # catch any errors in updating the resource - maybe Cosmos / schema invalid etc, and report them back to the op @@ -255,7 +262,7 @@ def get_failure_status_for_action(self, action: RequestAction): return status - def create_updated_resource_document(self, resource: dict, message: DeploymentStatusUpdateMessage): + def create_updated_resource_document(self, resource: Dict[str, Any], message: DeploymentStatusUpdateMessage) -> Dict[str, Any]: """ Merge the outputs with the resource document to persist """ @@ -268,7 +275,7 @@ def create_updated_resource_document(self, resource: dict, message: DeploymentSt return resource - def convert_outputs_to_dict(self, outputs_list: [Output]): + def convert_outputs_to_dict(self, outputs_list: List[Output]) -> Dict[str, Any]: """ Convert a list of Porter outputs to a dictionary """ diff --git a/api_app/service_bus/service_bus_consumer.py b/api_app/service_bus/service_bus_consumer.py index 56508bcd82..ddab28551b 100644 --- a/api_app/service_bus/service_bus_consumer.py +++ b/api_app/service_bus/service_bus_consumer.py @@ -4,6 +4,12 @@ from services.logging import logger +# Configuration constants for monitoring intervals +HEARTBEAT_CHECK_INTERVAL_SECONDS = 60 +HEARTBEAT_STALENESS_THRESHOLD_SECONDS = 300 +RESTART_DELAY_SECONDS = 5 +SUPERVISOR_ERROR_DELAY_SECONDS = 30 + class ServiceBusConsumer: @@ -18,10 +24,16 @@ def __init__(self, heartbeat_file_prefix: str): def update_heartbeat(self): try: + # Ensure directory exists + os.makedirs(os.path.dirname(self.heartbeat_file), exist_ok=True) with open(self.heartbeat_file, 'w') as f: f.write(str(time.time())) + except PermissionError: + logger.error(f"Permission denied writing heartbeat to {self.heartbeat_file}") + except OSError as e: + logger.error(f"OS error updating heartbeat: {e}") except Exception as e: - logger.warning(f"Failed to update heartbeat: {e}") + logger.warning(f"Unexpected error updating heartbeat: {e}") def check_heartbeat(self, max_age_seconds: int = 300) -> bool: try: @@ -50,38 +62,48 @@ async def receive_messages_with_restart_check(self): await self.receive_messages() except Exception as e: logger.exception(f"receive_messages stopped unexpectedly. Restarting... - {e}") - await asyncio.sleep(5) + await asyncio.sleep(RESTART_DELAY_SECONDS) async def supervisor_with_heartbeat_check(self): task = None - while True: - try: - # Start the receive_messages task if not running - if task is None or task.done(): - if task and task.done(): + try: + while True: + try: + # Start the receive_messages task if not running + if task is None or task.done(): + if task and task.done(): + try: + await task # Check for any exception + except Exception as e: + logger.exception(f"receive_messages task failed: {e}") + + logger.info("Starting receive_messages task...") + task = asyncio.create_task(self.receive_messages_with_restart_check()) + + # Wait before checking heartbeat + await asyncio.sleep(HEARTBEAT_CHECK_INTERVAL_SECONDS) # Check every minute + + # Check if heartbeat is stale + if not self.check_heartbeat(max_age_seconds=HEARTBEAT_STALENESS_THRESHOLD_SECONDS): # 5 minutes max age + logger.warning("Heartbeat is stale, restarting receive_messages task...") + task.cancel() try: - await task # Check for any exception - except Exception as e: - logger.exception(f"receive_messages task failed: {e}") - - logger.info("Starting receive_messages task...") - task = asyncio.create_task(self.receive_messages_with_restart_check()) - - # Wait before checking heartbeat - await asyncio.sleep(60) # Check every minute - - # Check if heartbeat is stale - if not self.check_heartbeat(max_age_seconds=300): # 5 minutes max age - logger.warning("Heartbeat is stale, restarting receive_messages task...") - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - task = None - except Exception as e: - logger.exception(f"Supervisor error: {e}") - await asyncio.sleep(30) + await task + except asyncio.CancelledError: + pass + task = None + except Exception as e: + logger.exception(f"Supervisor error: {e}") + await asyncio.sleep(SUPERVISOR_ERROR_DELAY_SECONDS) + finally: + # Ensure proper cleanup on shutdown + if task and not task.done(): + logger.info("Cleaning up supervisor task...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass async def receive_messages(self): raise NotImplementedError("Subclasses must implement receive_messages()") diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py new file mode 100644 index 0000000000..f3a156144f --- /dev/null +++ b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py @@ -0,0 +1,167 @@ +import asyncio +import pytest +import os +from unittest.mock import patch, Mock +from service_bus.service_bus_consumer import ServiceBusConsumer + + +# Create a concrete implementation for testing edge cases +class MockConsumerForEdgeCases(ServiceBusConsumer): + def __init__(self): + super().__init__("test_consumer_edge") + self.receive_messages_called = False + + async def receive_messages(self): + self.receive_messages_called = True + await asyncio.sleep(0.1) + return + + +@pytest.mark.asyncio +async def test_heartbeat_file_corruption(): + """Test handling of corrupted heartbeat file.""" + consumer = MockConsumerForEdgeCases() + + with patch("service_bus.service_bus_consumer.os.path.exists", return_value=True), \ + patch("builtins.open") as mock_open: + + # Simulate corrupted file with invalid float content + mock_open.return_value.__enter__.return_value.read.return_value = "not_a_number" + + result = consumer.check_heartbeat() + assert result is False + + +@pytest.mark.asyncio +async def test_heartbeat_permission_denied(): + """Test heartbeat update when permission denied.""" + consumer = MockConsumerForEdgeCases() + + with patch("builtins.open", side_effect=PermissionError("Permission denied")), \ + patch("service_bus.service_bus_consumer.logger") as mock_logger: + + # Should not crash, just log error + consumer.update_heartbeat() + mock_logger.error.assert_called_once() + + +@pytest.mark.asyncio +async def test_heartbeat_disk_full(): + """Test heartbeat update when disk is full.""" + consumer = MockConsumerForEdgeCases() + + with patch("builtins.open", side_effect=OSError("No space left on device")), \ + patch("service_bus.service_bus_consumer.logger") as mock_logger: + + # Should not crash, just log error + consumer.update_heartbeat() + mock_logger.error.assert_called_once() + + +@pytest.mark.asyncio +async def test_supervisor_cleanup_on_exception(): + """Test supervisor properly cleans up tasks when interrupted.""" + consumer = MockConsumerForEdgeCases() + + # Track task lifecycle + task_created = False + task_cancelled = False + + class MockTask: + def __init__(self): + nonlocal task_created + task_created = True + self.done_count = 0 + + def done(self): + return self.done_count > 0 + + def cancel(self): + nonlocal task_cancelled + task_cancelled = True + self.done_count = 1 + + def __await__(self): + async def _await(): + if task_cancelled: + raise asyncio.CancelledError() + return None + return _await().__await__() + + # Mock to trigger cleanup after one iteration + call_count = 0 + + async def mock_sleep(duration): + nonlocal call_count + call_count += 1 + if call_count >= 2: # Trigger cleanup after heartbeat check + raise KeyboardInterrupt("Test cleanup") + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=lambda coro: MockTask()), \ + patch.object(consumer, "check_heartbeat", return_value=True): + + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + # Verify task was created and cancelled during cleanup + assert task_created, "Task should have been created" + assert task_cancelled, "Task should have been cancelled during cleanup" + + +@pytest.mark.asyncio +async def test_rapid_task_failures(): + """Test supervisor behavior with rapid consecutive failures.""" + consumer = MockConsumerForEdgeCases() + + failure_count = 0 + sleep_calls = [] + max_failures = 2 + + async def failing_receive_messages(): + nonlocal failure_count + failure_count += 1 + if failure_count <= max_failures: + raise Exception(f"Failure {failure_count}") + # Success after max failures - but don't run forever + return + + async def mock_sleep(duration): + sleep_calls.append(duration) + return # Don't actually sleep + + # Override receive_messages to simulate failures + consumer.receive_messages = failing_receive_messages + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep): + # Test restart behavior by calling directly + for _ in range(max_failures + 1): # Run enough times to trigger failures and success + try: + await consumer.receive_messages_with_restart_check() + break # Exit when successful + except Exception: + continue # Continue to trigger restart logic + + # Verify restart delays were applied + from service_bus.service_bus_consumer import RESTART_DELAY_SECONDS + assert failure_count >= max_failures, f"Should have {max_failures} failures, got {failure_count}" + assert len(sleep_calls) >= max_failures, f"Should have {max_failures} restart delays, got {len(sleep_calls)}" + + +@pytest.mark.asyncio +async def test_heartbeat_directory_creation(): + """Test that heartbeat directory is created if it doesn't exist.""" + consumer = MockConsumerForEdgeCases() + + with patch("service_bus.service_bus_consumer.os.makedirs") as mock_makedirs, \ + patch("builtins.open", create=True) as mock_open: + + consumer.update_heartbeat() + + # Verify makedirs was called with exist_ok=True + mock_makedirs.assert_called_once_with( + os.path.dirname(consumer.heartbeat_file), + exist_ok=True + ) From b6f7e299ca5b662e8fd8bfa0c7dafe8049660e54 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Fri, 7 Nov 2025 15:51:07 +0000 Subject: [PATCH 12/20] Update tests --- .../test_service_bus_edge_cases.py | 74 ++++++++----------- 1 file changed, 29 insertions(+), 45 deletions(-) diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py index f3a156144f..99b441c925 100644 --- a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py +++ b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py @@ -7,20 +7,23 @@ # Create a concrete implementation for testing edge cases class MockConsumerForEdgeCases(ServiceBusConsumer): - def __init__(self): - super().__init__("test_consumer_edge") + def __init__(self, skip_init=False): + if not skip_init: + super().__init__("test_consumer_edge") self.receive_messages_called = False async def receive_messages(self): self.receive_messages_called = True - await asyncio.sleep(0.1) + await asyncio.sleep(0.01) # Small delay to make it a proper coroutine return @pytest.mark.asyncio async def test_heartbeat_file_corruption(): """Test handling of corrupted heartbeat file.""" - consumer = MockConsumerForEdgeCases() + consumer = MockConsumerForEdgeCases(skip_init=True) + # Manually set required attributes to avoid init issues + consumer.heartbeat_file = "/tmp/test_heartbeat_corruption.txt" with patch("service_bus.service_bus_consumer.os.path.exists", return_value=True), \ patch("builtins.open") as mock_open: @@ -35,7 +38,8 @@ async def test_heartbeat_file_corruption(): @pytest.mark.asyncio async def test_heartbeat_permission_denied(): """Test heartbeat update when permission denied.""" - consumer = MockConsumerForEdgeCases() + consumer = MockConsumerForEdgeCases(skip_init=True) + consumer.heartbeat_file = "/tmp/test_heartbeat_permission.txt" with patch("builtins.open", side_effect=PermissionError("Permission denied")), \ patch("service_bus.service_bus_consumer.logger") as mock_logger: @@ -48,7 +52,8 @@ async def test_heartbeat_permission_denied(): @pytest.mark.asyncio async def test_heartbeat_disk_full(): """Test heartbeat update when disk is full.""" - consumer = MockConsumerForEdgeCases() + consumer = MockConsumerForEdgeCases(skip_init=True) + consumer.heartbeat_file = "/tmp/test_heartbeat_disk_full.txt" with patch("builtins.open", side_effect=OSError("No space left on device")), \ patch("service_bus.service_bus_consumer.logger") as mock_logger: @@ -111,49 +116,28 @@ async def mock_sleep(duration): assert task_cancelled, "Task should have been cancelled during cleanup" -@pytest.mark.asyncio -async def test_rapid_task_failures(): - """Test supervisor behavior with rapid consecutive failures.""" - consumer = MockConsumerForEdgeCases() - - failure_count = 0 - sleep_calls = [] - max_failures = 2 - - async def failing_receive_messages(): - nonlocal failure_count - failure_count += 1 - if failure_count <= max_failures: - raise Exception(f"Failure {failure_count}") - # Success after max failures - but don't run forever - return - - async def mock_sleep(duration): - sleep_calls.append(duration) - return # Don't actually sleep - - # Override receive_messages to simulate failures - consumer.receive_messages = failing_receive_messages +def test_restart_delay_configuration(): + """Test that restart delay configuration constants exist and have reasonable values.""" + # Import and test constants directly without creating consumer instances + from service_bus.service_bus_consumer import ( + RESTART_DELAY_SECONDS, + HEARTBEAT_CHECK_INTERVAL_SECONDS, + HEARTBEAT_STALENESS_THRESHOLD_SECONDS, + SUPERVISOR_ERROR_DELAY_SECONDS + ) - with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep): - # Test restart behavior by calling directly - for _ in range(max_failures + 1): # Run enough times to trigger failures and success - try: - await consumer.receive_messages_with_restart_check() - break # Exit when successful - except Exception: - continue # Continue to trigger restart logic + # Validate configuration values + assert RESTART_DELAY_SECONDS > 0, "Restart delay should be positive" + assert RESTART_DELAY_SECONDS <= 10, "Restart delay should not be too long" + assert HEARTBEAT_CHECK_INTERVAL_SECONDS > 0 + assert HEARTBEAT_STALENESS_THRESHOLD_SECONDS > HEARTBEAT_CHECK_INTERVAL_SECONDS + assert SUPERVISOR_ERROR_DELAY_SECONDS > 0 - # Verify restart delays were applied - from service_bus.service_bus_consumer import RESTART_DELAY_SECONDS - assert failure_count >= max_failures, f"Should have {max_failures} failures, got {failure_count}" - assert len(sleep_calls) >= max_failures, f"Should have {max_failures} restart delays, got {len(sleep_calls)}" - -@pytest.mark.asyncio -async def test_heartbeat_directory_creation(): +def test_heartbeat_directory_creation(): """Test that heartbeat directory is created if it doesn't exist.""" - consumer = MockConsumerForEdgeCases() + consumer = MockConsumerForEdgeCases(skip_init=True) + consumer.heartbeat_file = "/tmp/test_dir/test_heartbeat.txt" with patch("service_bus.service_bus_consumer.os.makedirs") as mock_makedirs, \ patch("builtins.open", create=True) as mock_open: From ba8d1e92e3d74b8267391cd28af24e3fba987785 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Fri, 7 Nov 2025 16:02:14 +0000 Subject: [PATCH 13/20] update tests --- .../tests_ma/test_service_bus/test_service_bus_edge_cases.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py index 99b441c925..8b4607fa79 100644 --- a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py +++ b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py @@ -1,7 +1,7 @@ import asyncio import pytest import os -from unittest.mock import patch, Mock +from unittest.mock import patch from service_bus.service_bus_consumer import ServiceBusConsumer @@ -139,8 +139,7 @@ def test_heartbeat_directory_creation(): consumer = MockConsumerForEdgeCases(skip_init=True) consumer.heartbeat_file = "/tmp/test_dir/test_heartbeat.txt" - with patch("service_bus.service_bus_consumer.os.makedirs") as mock_makedirs, \ - patch("builtins.open", create=True) as mock_open: + with patch("service_bus.service_bus_consumer.os.makedirs") as mock_makedirs: consumer.update_heartbeat() From 49245fedbf648a1d706d17756a478e96f6aad55d Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Fri, 7 Nov 2025 16:10:28 +0000 Subject: [PATCH 14/20] Update api_app/service_bus/deployment_status_updater.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api_app/service_bus/deployment_status_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 7254b8ec6e..786cbce364 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -134,7 +134,7 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage is_last_step = True if step_to_update is None: - raise ValueError(f"Error finding step {message.stepId} in operation {message.operationId}") + raise ValueError(f"Step {message.stepId} not found in operation {message.operationId}") # update the step status step_to_update.status = message.status From 37291c3a67df955c31be77b9d79a05b9581a552f Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Fri, 7 Nov 2025 16:17:24 +0000 Subject: [PATCH 15/20] Define format once for two instrumentors. --- api_app/services/logging.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/api_app/services/logging.py b/api_app/services/logging.py index 91fe90d0da..359f878173 100644 --- a/api_app/services/logging.py +++ b/api_app/services/logging.py @@ -6,6 +6,9 @@ from core.config import APPLICATIONINSIGHTS_CONNECTION_STRING, LOGGING_LEVEL +# Standard log format with worker ID +LOG_FORMAT = '%(asctime)s - Worker %(worker_id)s - %(name)s - %(levelname)s - %(message)s' + UNWANTED_LOGGERS = [ "azure.core.pipeline.policies.http_logging_policy", "azure.eventhub._eventprocessor.event_processor", @@ -85,7 +88,7 @@ def apply_worker_id_to_logger(logger_instance): # Update handlers to include worker_id in the format for handler in logger_instance.handlers: if isinstance(handler, logging.StreamHandler): - formatter = logging.Formatter('%(asctime)s - Worker %(worker_id)s - %(name)s - %(levelname)s - %(message)s') + formatter = logging.Formatter(LOG_FORMAT) handler.setFormatter(formatter) @@ -116,14 +119,11 @@ def initialize_logging() -> logging.Logger: } ) - # Custom log format including worker_id - log_format = '%(asctime)s - Worker %(worker_id)s - %(name)s - %(levelname)s - %(message)s' - LoggingInstrumentor().instrument( set_logging_format=True, log_level=logging_level, tracer_provider=tracer._real_tracer, - log_format=log_format + log_format=LOG_FORMAT ) # Set up a handler if none exists From 975ed2957544b5019c8cb6f065c01c9d0e792e59 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Fri, 7 Nov 2025 16:18:28 +0000 Subject: [PATCH 16/20] Update api_app/service_bus/deployment_status_updater.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api_app/service_bus/deployment_status_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 786cbce364..14f4ef9cdd 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -166,7 +166,7 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage # more steps in the op to do? if is_last_step is False: if current_step_index >= len(operation.steps) - 1: - raise ValueError(f"Invalid step index {current_step_index} for operation with {len(operation.steps)} steps") + raise ValueError(f"Step index {current_step_index} is the last step in operation (has {len(operation.steps)} steps), but more steps were expected") next_step = operation.steps[current_step_index + 1] # catch any errors in updating the resource - maybe Cosmos / schema invalid etc, and report them back to the op From 7eac68bbebe7f3699d60717a22b739eb28f29933 Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Fri, 7 Nov 2025 16:18:57 +0000 Subject: [PATCH 17/20] Update api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py index 8b4607fa79..2851c60a8c 100644 --- a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py +++ b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py @@ -109,6 +109,7 @@ async def mock_sleep(duration): try: await consumer.supervisor_with_heartbeat_check() except KeyboardInterrupt: + # Intentionally ignore KeyboardInterrupt to test cleanup logic after interruption pass # Verify task was created and cancelled during cleanup From ff963ac2bae231970654af55de205133128e9fee Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 7 Nov 2025 16:23:05 +0000 Subject: [PATCH 18/20] Move tempfile import to top and add explanatory comment to except clause Co-authored-by: marrobi <17089773+marrobi@users.noreply.github.com> --- api_app/service_bus/service_bus_consumer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api_app/service_bus/service_bus_consumer.py b/api_app/service_bus/service_bus_consumer.py index ddab28551b..6d175ffe3c 100644 --- a/api_app/service_bus/service_bus_consumer.py +++ b/api_app/service_bus/service_bus_consumer.py @@ -1,5 +1,6 @@ import asyncio import os +import tempfile import time from services.logging import logger @@ -15,7 +16,6 @@ class ServiceBusConsumer: def __init__(self, heartbeat_file_prefix: str): # Create a unique identifier for this worker process - import tempfile self.worker_id = os.getpid() temp_dir = tempfile.gettempdir() self.heartbeat_file = os.path.join(temp_dir, f"{heartbeat_file_prefix}_heartbeat_{self.worker_id}.txt") @@ -90,6 +90,7 @@ async def supervisor_with_heartbeat_check(self): try: await task except asyncio.CancelledError: + # Expected when cancelling a task - ignore and proceed with restart pass task = None except Exception as e: From 42bf9d0db73794cfded25c30fe47f87560f0527e Mon Sep 17 00:00:00 2001 From: Marcus Robinson Date: Mon, 9 Feb 2026 17:15:34 +0000 Subject: [PATCH 19/20] Implement service bus consumer monitoring with heartbeat detection and automatic recovery; update health check endpoint to include consumer statuses --- CHANGELOG.md | 1 + api_app/api/routes/health.py | 32 ++- api_app/main.py | 4 + api_app/resources/strings.py | 6 + .../airlock_request_status_update.py | 8 +- .../service_bus/deployment_status_updater.py | 2 +- api_app/service_bus/service_bus_consumer.py | 95 +++---- api_app/services/health_checker.py | 19 +- api_app/services/logging.py | 5 +- .../test_api/test_routes/test_health.py | 84 ++++++- .../test_service_bus_consumer.py | 235 +++++++++++------- .../test_service_bus_edge_cases.py | 153 +++++++----- 12 files changed, 412 insertions(+), 232 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 824952bdf9..241e15d845 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ENHANCEMENTS: * Harden security of the app gateway. ([#4863](https://github.com/microsoft/AzureTRE/pull/4863)) BUG FIXES: +* Implement service bus consumer monitoring with heartbeat detection, automatic recovery, and /health endpoint integration to prevent operations getting stuck indefinitely ([#4464](https://github.com/microsoft/AzureTRE/issues/4464)) * Fix property substitution not occuring where there is only a main step in the pipeline ([#4824](https://github.com/microsoft/AzureTRE/issues/4824)) * Fix Mysql template ignored storage_mb ([#4846](https://github.com/microsoft/AzureTRE/issues/4846)) * Fix duplicate `TOPIC_SUBSCRIPTION_NAME` in `core/terraform/airlock/airlock_processor.tf` ([#4847](https://github.com/microsoft/AzureTRE/pull/4847)) diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index 2cefe21266..c9674d6179 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -3,7 +3,7 @@ from core import credentials from models.schemas.status import HealthCheck, ServiceStatus, StatusEnum from resources import strings -from services.health_checker import create_resource_processor_status, create_state_store_status, create_service_bus_status +from services.health_checker import create_airlock_consumer_status, create_deployment_consumer_status, create_resource_processor_status, create_state_store_status, create_service_bus_status from services.logging import logger router = APIRouter() @@ -14,22 +14,28 @@ async def health_check(request: Request) -> HealthCheck: # The health endpoint checks the status of key components of the system. # Note that Resource Processor checks incur Azure management calls, so # calling this endpoint frequently may result in API throttling. + deployment_consumer = getattr(request.app.state, 'deployment_status_updater', None) + airlock_consumer = getattr(request.app.state, 'airlock_status_updater', None) + async with credentials.get_credential_async_context() as credential: - cosmos, sb, rp = await asyncio.gather( + cosmos, sb, rp, deploy, airlock = await asyncio.gather( create_state_store_status(), create_service_bus_status(credential), - create_resource_processor_status(credential) + create_resource_processor_status(credential), + create_deployment_consumer_status(deployment_consumer), + create_airlock_consumer_status(airlock_consumer), ) - cosmos_status, cosmos_message = cosmos - sb_status, sb_message = sb - rp_status, rp_message = rp - if cosmos_status == StatusEnum.not_ok or sb_status == StatusEnum.not_ok or rp_status == StatusEnum.not_ok: - logger.error(f'Cosmos Status: {cosmos_status}, message: {cosmos_message}') - logger.error(f'Service Bus Status: {sb_status}, message: {sb_message}') - logger.error(f'Resource Processor Status: {rp_status}, message: {rp_message}') - services = [ServiceStatus(service=strings.COSMOS_DB, status=cosmos_status, message=cosmos_message), - ServiceStatus(service=strings.SERVICE_BUS, status=sb_status, message=sb_message), - ServiceStatus(service=strings.RESOURCE_PROCESSOR, status=rp_status, message=rp_message)] + services = [ + ServiceStatus(service=strings.COSMOS_DB, status=cosmos[0], message=cosmos[1]), + ServiceStatus(service=strings.SERVICE_BUS, status=sb[0], message=sb[1]), + ServiceStatus(service=strings.RESOURCE_PROCESSOR, status=rp[0], message=rp[1]), + ServiceStatus(service=strings.DEPLOYMENT_STATUS_CONSUMER, status=deploy[0], message=deploy[1]), + ServiceStatus(service=strings.AIRLOCK_STATUS_CONSUMER, status=airlock[0], message=airlock[1]), + ] + + for svc in services: + if svc.status == StatusEnum.not_ok: + logger.error(f'{svc.service} Status: {svc.status}, message: {svc.message}') return HealthCheck(services=services) diff --git a/api_app/main.py b/api_app/main.py index d4d9e56723..f3f53fecc9 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -34,6 +34,10 @@ async def lifespan(app: FastAPI): airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() + # Store consumer references on app.state so the /health endpoint can check their heartbeats + app.state.deployment_status_updater = deploymentStatusUpdater + app.state.airlock_status_updater = airlockStatusUpdater + asyncio.create_task(deploymentStatusUpdater.supervisor_with_heartbeat_check()) asyncio.create_task(airlockStatusUpdater.supervisor_with_heartbeat_check()) yield diff --git a/api_app/resources/strings.py b/api_app/resources/strings.py index c54a40ba52..07726b43ad 100644 --- a/api_app/resources/strings.py +++ b/api_app/resources/strings.py @@ -106,6 +106,12 @@ RESOURCE_PROCESSOR_GENERAL_ERROR_MESSAGE = "Resource Processor is not responding" RESOURCE_PROCESSOR_HEALTHY_MESSAGE = "HealthState/healthy" +# Service bus consumer status +DEPLOYMENT_STATUS_CONSUMER = "Deployment Status Consumer" +AIRLOCK_STATUS_CONSUMER = "Airlock Status Consumer" +CONSUMER_HEARTBEAT_STALE = "{} heartbeat is stale or missing" +CONSUMER_NOT_INITIALIZED = "{} has not been initialized" + # Error strings ACCESS_APP_IS_MISSING_ROLE = "The App is missing role" ACCESS_PLEASE_SUPPLY_CLIENT_ID = "Please supply the client_id for the AAD application" diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index 90892babba..ce3ad3c093 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -38,7 +38,7 @@ async def receive_messages(self): current_time = time.time() polling_count += 1 - # Update heartbeat file for supervisor monitoring + # Update heartbeat for supervisor monitoring self.update_heartbeat() # Log a heartbeat message every 60 seconds to show the service is still working @@ -69,13 +69,13 @@ async def receive_messages(self): # Timeout occurred whilst connecting to a session - this is expected and indicates no non-empty sessions are available logger.debug("No sessions for this process. Will look again...") - except ServiceBusConnectionError: + except ServiceBusConnectionError as e: # Occasionally there will be a transient / network-level error in connecting to SB. - logger.info("Unknown Service Bus connection error. Will retry...") + logger.warning(f"Service Bus connection error (will retry): {e}") except Exception as e: # Catch all other exceptions, log them via .exception to get the stack trace, and reconnect - logger.exception(f"Unknown exception. Will retry - {e}") + logger.exception(f"Unexpected error in message processing: {type(e).__name__}: {e}") async def process_message(self, msg): with tracer.start_as_current_span("process_message") as current_span: diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 14f4ef9cdd..8cf4d53659 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -44,7 +44,7 @@ async def receive_messages(self): current_time = time.time() polling_count += 1 - # Update heartbeat file for supervisor monitoring + # Update heartbeat for supervisor monitoring self.update_heartbeat() # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: diff --git a/api_app/service_bus/service_bus_consumer.py b/api_app/service_bus/service_bus_consumer.py index 6d175ffe3c..34b6163451 100644 --- a/api_app/service_bus/service_bus_consumer.py +++ b/api_app/service_bus/service_bus_consumer.py @@ -1,6 +1,4 @@ import asyncio -import os -import tempfile import time from services.logging import logger @@ -9,97 +7,84 @@ HEARTBEAT_CHECK_INTERVAL_SECONDS = 60 HEARTBEAT_STALENESS_THRESHOLD_SECONDS = 300 RESTART_DELAY_SECONDS = 5 +MAX_RESTART_DELAY_SECONDS = 300 SUPERVISOR_ERROR_DELAY_SECONDS = 30 class ServiceBusConsumer: - def __init__(self, heartbeat_file_prefix: str): - # Create a unique identifier for this worker process - self.worker_id = os.getpid() - temp_dir = tempfile.gettempdir() - self.heartbeat_file = os.path.join(temp_dir, f"{heartbeat_file_prefix}_heartbeat_{self.worker_id}.txt") - self.service_name = heartbeat_file_prefix.replace('_', ' ').title() + def __init__(self, consumer_name: str): + self.service_name = consumer_name.replace('_', ' ').title() + self._last_heartbeat: float = time.monotonic() + self._restart_delay: float = RESTART_DELAY_SECONDS logger.info(f"Initializing {self.service_name}") def update_heartbeat(self): - try: - # Ensure directory exists - os.makedirs(os.path.dirname(self.heartbeat_file), exist_ok=True) - with open(self.heartbeat_file, 'w') as f: - f.write(str(time.time())) - except PermissionError: - logger.error(f"Permission denied writing heartbeat to {self.heartbeat_file}") - except OSError as e: - logger.error(f"OS error updating heartbeat: {e}") - except Exception as e: - logger.warning(f"Unexpected error updating heartbeat: {e}") - - def check_heartbeat(self, max_age_seconds: int = 300) -> bool: - try: - if not os.path.exists(self.heartbeat_file): - logger.warning("Heartbeat file does not exist") - return False - - with open(self.heartbeat_file, 'r') as f: - heartbeat_time = float(f.read().strip()) + self._last_heartbeat = time.monotonic() - current_time = time.time() - age = current_time - heartbeat_time - - if age > max_age_seconds: - logger.warning(f"Heartbeat is {age:.1f} seconds old, exceeding the limit of {max_age_seconds} seconds") - - return age <= max_age_seconds - except (ValueError, IOError) as e: - logger.warning(f"Failed to read heartbeat: {e}") + def check_heartbeat(self, max_age_seconds: int = HEARTBEAT_STALENESS_THRESHOLD_SECONDS) -> bool: + age = time.monotonic() - self._last_heartbeat + if age > max_age_seconds: + logger.warning(f"{self.service_name} heartbeat is {age:.1f}s old (threshold: {max_age_seconds}s)") return False + return True - async def receive_messages_with_restart_check(self): + async def _receive_messages_loop(self): + """Run receive_messages() in a loop with exponential backoff on failure.""" while True: try: - logger.info("Starting the receive_messages loop...") + start_time = time.monotonic() + logger.info(f"Starting {self.service_name} receive_messages loop...") await self.receive_messages() + logger.warning(f"{self.service_name} receive_messages() returned unexpectedly") + except asyncio.CancelledError: + raise except Exception as e: - logger.exception(f"receive_messages stopped unexpectedly. Restarting... - {e}") - await asyncio.sleep(RESTART_DELAY_SECONDS) + logger.exception(f"{self.service_name} receive_messages failed: {e}") + + # Reset backoff if the consumer ran long enough to be considered healthy + elapsed = time.monotonic() - start_time + if elapsed > self._restart_delay: + self._restart_delay = RESTART_DELAY_SECONDS + + logger.info(f"{self.service_name} restarting in {self._restart_delay:.0f}s...") + await asyncio.sleep(self._restart_delay) + self._restart_delay = min(self._restart_delay * 2, MAX_RESTART_DELAY_SECONDS) async def supervisor_with_heartbeat_check(self): task = None try: while True: try: - # Start the receive_messages task if not running if task is None or task.done(): if task and task.done(): try: - await task # Check for any exception + await task except Exception as e: - logger.exception(f"receive_messages task failed: {e}") + logger.exception(f"{self.service_name} task failed unexpectedly: {e}") + await asyncio.sleep(RESTART_DELAY_SECONDS) - logger.info("Starting receive_messages task...") - task = asyncio.create_task(self.receive_messages_with_restart_check()) + logger.info(f"Starting {self.service_name} task...") + task = asyncio.create_task(self._receive_messages_loop()) + self.update_heartbeat() - # Wait before checking heartbeat - await asyncio.sleep(HEARTBEAT_CHECK_INTERVAL_SECONDS) # Check every minute + await asyncio.sleep(HEARTBEAT_CHECK_INTERVAL_SECONDS) - # Check if heartbeat is stale - if not self.check_heartbeat(max_age_seconds=HEARTBEAT_STALENESS_THRESHOLD_SECONDS): # 5 minutes max age - logger.warning("Heartbeat is stale, restarting receive_messages task...") + if not self.check_heartbeat(): + logger.warning(f"{self.service_name} heartbeat stale, restarting...") task.cancel() try: await task except asyncio.CancelledError: - # Expected when cancelling a task - ignore and proceed with restart pass task = None + self._restart_delay = RESTART_DELAY_SECONDS except Exception as e: - logger.exception(f"Supervisor error: {e}") + logger.exception(f"{self.service_name} supervisor error: {e}") await asyncio.sleep(SUPERVISOR_ERROR_DELAY_SECONDS) finally: - # Ensure proper cleanup on shutdown if task and not task.done(): - logger.info("Cleaning up supervisor task...") + logger.info(f"Cleaning up {self.service_name} task...") task.cancel() try: await task diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index a4d53067b0..8b56757317 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple from azure.core import exceptions from azure.servicebus.aio import ServiceBusClient from azure.mgmt.compute.aio import ComputeManagementClient @@ -11,6 +11,7 @@ from core import config from models.schemas.status import StatusEnum from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer from services.logging import logger @@ -55,6 +56,22 @@ async def create_service_bus_status(credential) -> Tuple[StatusEnum, str]: return status, message +def create_consumer_status(consumer: Optional[ServiceBusConsumer], name: str) -> Tuple[StatusEnum, str]: + if consumer is None: + return StatusEnum.not_ok, strings.CONSUMER_NOT_INITIALIZED.format(name) + if consumer.check_heartbeat(): + return StatusEnum.ok, "" + return StatusEnum.not_ok, strings.CONSUMER_HEARTBEAT_STALE.format(name) + + +async def create_deployment_consumer_status(consumer: Optional[ServiceBusConsumer]) -> Tuple[StatusEnum, str]: + return create_consumer_status(consumer, strings.DEPLOYMENT_STATUS_CONSUMER) + + +async def create_airlock_consumer_status(consumer: Optional[ServiceBusConsumer]) -> Tuple[StatusEnum, str]: + return create_consumer_status(consumer, strings.AIRLOCK_STATUS_CONSUMER) + + async def create_resource_processor_status(credential) -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" diff --git a/api_app/services/logging.py b/api_app/services/logging.py index 14d08a9457..195fbc82fb 100644 --- a/api_app/services/logging.py +++ b/api_app/services/logging.py @@ -61,8 +61,9 @@ def __init__(self): self.worker_id = os.getpid() def filter(self, record: logging.LogRecord) -> bool: - # Add worker_id as an attribute to the log record - record.worker_id = self.worker_id + # Add worker_id as an attribute to the log record if not already set + if not hasattr(record, 'worker_id'): + record.worker_id = self.worker_id return True diff --git a/api_app/tests_ma/test_api/test_routes/test_health.py b/api_app/tests_ma/test_api/test_routes/test_health.py index 21e795b619..ab52fe87f6 100644 --- a/api_app/tests_ma/test_api/test_routes/test_health.py +++ b/api_app/tests_ma/test_api/test_routes/test_health.py @@ -1,9 +1,10 @@ import pytest from httpx import AsyncClient -from mock import patch +from mock import patch, MagicMock from models.schemas.status import StatusEnum from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer pytestmark = pytest.mark.asyncio @@ -48,3 +49,84 @@ async def test_health_response_contains_resource_processor_status(health_check_c response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) assert {"message": message, "service": strings.RESOURCE_PROCESSOR, "status": strings.OK} in response.json()["services"] + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_contains_consumer_statuses(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint includes deployment and airlock consumer status.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Simulate consumers stored on app.state with healthy heartbeats + mock_deployment_consumer = MagicMock(spec=ServiceBusConsumer) + mock_deployment_consumer.check_heartbeat.return_value = True + mock_airlock_consumer = MagicMock(spec=ServiceBusConsumer) + mock_airlock_consumer.check_heartbeat.return_value = True + app.state.deployment_status_updater = mock_deployment_consumer + app.state.airlock_status_updater = mock_airlock_consumer + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + assert {"message": "", "service": strings.DEPLOYMENT_STATUS_CONSUMER, "status": strings.OK} in services + assert {"message": "", "service": strings.AIRLOCK_STATUS_CONSUMER, "status": strings.OK} in services + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_reports_stale_consumer(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint reports not_ok when a consumer heartbeat is stale.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Simulate deployment consumer with stale heartbeat + mock_deployment_consumer = MagicMock(spec=ServiceBusConsumer) + mock_deployment_consumer.check_heartbeat.return_value = False + mock_airlock_consumer = MagicMock(spec=ServiceBusConsumer) + mock_airlock_consumer.check_heartbeat.return_value = True + app.state.deployment_status_updater = mock_deployment_consumer + app.state.airlock_status_updater = mock_airlock_consumer + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + deploy_svc = next(s for s in services if s["service"] == strings.DEPLOYMENT_STATUS_CONSUMER) + assert deploy_svc["status"] == strings.NOT_OK + assert deploy_svc["message"] == strings.CONSUMER_HEARTBEAT_STALE.format(strings.DEPLOYMENT_STATUS_CONSUMER) + + airlock_svc = next(s for s in services if s["service"] == strings.AIRLOCK_STATUS_CONSUMER) + assert airlock_svc["status"] == strings.OK + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_handles_missing_consumers(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint handles missing consumer references gracefully.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Remove consumer references from app.state if they exist + if hasattr(app.state, 'deployment_status_updater'): + delattr(app.state, 'deployment_status_updater') + if hasattr(app.state, 'airlock_status_updater'): + delattr(app.state, 'airlock_status_updater') + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + deploy_svc = next(s for s in services if s["service"] == strings.DEPLOYMENT_STATUS_CONSUMER) + assert deploy_svc["status"] == strings.NOT_OK + assert deploy_svc["message"] == strings.CONSUMER_NOT_INITIALIZED.format(strings.DEPLOYMENT_STATUS_CONSUMER) diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py index 884cf51653..7d74a99af3 100644 --- a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py +++ b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py @@ -1,8 +1,14 @@ import asyncio +import time import pytest from unittest.mock import patch -from service_bus.service_bus_consumer import ServiceBusConsumer +from service_bus.service_bus_consumer import ( + ServiceBusConsumer, + HEARTBEAT_STALENESS_THRESHOLD_SECONDS, + RESTART_DELAY_SECONDS, + MAX_RESTART_DELAY_SECONDS, +) # Create a concrete implementation for testing @@ -13,148 +19,193 @@ def __init__(self): async def receive_messages(self): self.receive_messages_called = True - # Simulate running once and then exiting await asyncio.sleep(0.1) return @pytest.mark.asyncio -@patch("service_bus.service_bus_consumer.os.getpid", return_value=12345) -async def test_init(mock_getpid): +async def test_init(): """Test initialization of ServiceBusConsumer.""" consumer = MockConsumer() - assert consumer.worker_id == 12345 - assert consumer.heartbeat_file == "/tmp/test_consumer_heartbeat_12345.txt" assert consumer.service_name == "Test Consumer" + assert consumer._restart_delay == RESTART_DELAY_SECONDS + assert consumer._last_heartbeat > 0 @pytest.mark.asyncio -@patch("service_bus.service_bus_consumer.os.path.exists", return_value=True) -@patch("builtins.open", create=True) -async def test_check_heartbeat_recent(mock_open, mock_exists): - """Test checking a recent heartbeat.""" - mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" +async def test_update_heartbeat(): + """Test updating heartbeat updates timestamp.""" + consumer = MockConsumer() + old_heartbeat = consumer._last_heartbeat + await asyncio.sleep(0.01) + consumer.update_heartbeat() - with patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0 + 60): # 60 seconds later - consumer = MockConsumer() - result = consumer.check_heartbeat(max_age_seconds=300) - assert result is True + assert consumer._last_heartbeat > old_heartbeat @pytest.mark.asyncio -@patch("service_bus.service_bus_consumer.os.path.exists", return_value=True) -@patch("builtins.open", create=True) -async def test_check_heartbeat_stale(mock_open, mock_exists): - """Test checking a stale heartbeat.""" - mock_open.return_value.__enter__.return_value.read.return_value = "1234567890.0" +async def test_check_heartbeat_recent(): + """Test checking a recent heartbeat returns True.""" + consumer = MockConsumer() + assert consumer.check_heartbeat(max_age_seconds=300) is True + - with patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0 + 400): # 400 seconds later - consumer = MockConsumer() - result = consumer.check_heartbeat(max_age_seconds=300) - assert result is False +@pytest.mark.asyncio +async def test_check_heartbeat_stale(): + """Test checking a stale heartbeat returns False.""" + consumer = MockConsumer() + consumer._last_heartbeat = time.monotonic() - 400 + assert consumer.check_heartbeat(max_age_seconds=300) is False @pytest.mark.asyncio -@patch("service_bus.service_bus_consumer.os.path.exists", return_value=False) -async def test_check_heartbeat_no_file(mock_exists): - """Test checking heartbeat when file doesn't exist.""" +async def test_check_heartbeat_default_uses_constant(): + """Test that check_heartbeat default max_age_seconds uses the module constant.""" + import inspect + sig = inspect.signature(ServiceBusConsumer.check_heartbeat) + default = sig.parameters['max_age_seconds'].default + assert default == HEARTBEAT_STALENESS_THRESHOLD_SECONDS + + +@pytest.mark.asyncio +async def test_backoff_increases_on_consecutive_failures(): + """Test that restart delay increases exponentially on immediate failures.""" consumer = MockConsumer() - result = consumer.check_heartbeat() - assert result is False + + async def failing_receive(): + raise RuntimeError("Simulated failure") + + consumer.receive_messages = failing_receive + + sleep_calls = [] + call_count = 0 + + async def mock_sleep(duration): + nonlocal call_count + sleep_calls.append(duration) + call_count += 1 + if call_count >= 3: + raise asyncio.CancelledError() + + # Patch time.monotonic to always return the same value so elapsed time is 0 (immediate failure) + fixed_time = time.monotonic() + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.time.monotonic", return_value=fixed_time): + try: + await consumer._receive_messages_loop() + except asyncio.CancelledError: + pass + + assert sleep_calls[0] == RESTART_DELAY_SECONDS + assert sleep_calls[1] == RESTART_DELAY_SECONDS * 2 + assert sleep_calls[2] == RESTART_DELAY_SECONDS * 4 @pytest.mark.asyncio -@patch("service_bus.service_bus_consumer.os.getpid", return_value=12345) -@patch("service_bus.service_bus_consumer.time.time", return_value=1234567890.0) -@patch("builtins.open", create=True) -async def test_update_heartbeat(mock_open, mock_time, mock_getpid): - """Test updating heartbeat.""" +async def test_backoff_caps_at_maximum(): + """Test that restart delay caps at MAX_RESTART_DELAY_SECONDS.""" consumer = MockConsumer() - consumer.update_heartbeat() + consumer._restart_delay = MAX_RESTART_DELAY_SECONDS - import tempfile - expected_path = f"{tempfile.gettempdir()}/test_consumer_heartbeat_12345.txt" - mock_open.assert_called_once_with(expected_path, 'w') - mock_open.return_value.__enter__.return_value.write.assert_called_once_with("1234567890.0") + async def failing_receive(): + raise RuntimeError("Simulated failure") + + consumer.receive_messages = failing_receive + + sleep_calls = [] + + async def mock_sleep(duration): + sleep_calls.append(duration) + raise asyncio.CancelledError() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep): + try: + await consumer._receive_messages_loop() + except asyncio.CancelledError: + pass + + assert sleep_calls[0] == MAX_RESTART_DELAY_SECONDS @pytest.mark.asyncio -async def test_receive_messages_with_restart_check(): - """Test receive_messages_with_restart_check calls receive_messages and handles exceptions.""" +async def test_supervisor_restarts_failed_task(): + """Test supervisor restarts the receive_messages task when it fails.""" consumer = MockConsumer() - # Track how many times receive_messages has been called - receive_messages_call_count = 0 + task_create_calls = 0 sleep_calls = [] - async def mock_receive_messages(): - nonlocal receive_messages_call_count - receive_messages_call_count += 1 - if receive_messages_call_count == 1: - # First call raises an exception - raise Exception("Test exception") - elif receive_messages_call_count == 2: - # Second call succeeds, but we need to break the infinite loop - # Let's raise a special exception to break out - raise KeyboardInterrupt("Break out of loop for test") - else: - # Should not get here in this test - return + class FailOnFirstDoneTask: + """A mock task that reports done() immediately to simulate task failure.""" + + def __init__(self): + nonlocal task_create_calls + task_create_calls += 1 + self._is_first_task = (task_create_calls == 1) + + def done(self): + # First task always reports done (crashed) + # Second task always reports running + return self._is_first_task + + def cancel(self): + pass + + def __await__(self): + async def _await(): + if self._is_first_task: + raise RuntimeError("Simulated task failure") + return None + return _await().__await__() + + iteration = 0 async def mock_sleep(duration): + nonlocal iteration sleep_calls.append(duration) - # Just return immediately instead of sleeping - return + iteration += 1 + if iteration >= 4: + raise KeyboardInterrupt("Test complete") - # Override the method with our mock - consumer.receive_messages = mock_receive_messages + consumer.check_heartbeat = lambda **kwargs: True - # Patch asyncio.sleep in the service_bus_consumer module - with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep): + def create_fail_task(coro): + coro.close() + return FailOnFirstDoneTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_fail_task): try: - # Run the method, expecting it to call receive_messages twice and then break - await consumer.receive_messages_with_restart_check() + await consumer.supervisor_with_heartbeat_check() except KeyboardInterrupt: - # This is our expected way out of the infinite loop pass - # Verify that receive_messages was called twice and sleep was called once - assert receive_messages_call_count == 2, f"Expected exactly 2 calls to receive_messages, got {receive_messages_call_count}" - assert len(sleep_calls) == 1, f"Expected exactly 1 sleep call for restart delay, got {len(sleep_calls)}" - assert sleep_calls[0] == 5, f"Expected sleep(5) call for restart delay, got {sleep_calls}" + assert task_create_calls >= 2 @pytest.mark.asyncio -async def test_supervisor_with_heartbeat_check(): - """Test supervisor_with_heartbeat_check manages the receive_messages task.""" +async def test_supervisor_restarts_on_stale_heartbeat(): + """Test supervisor cancels and restarts task when heartbeat goes stale.""" consumer = MockConsumer() - # Track method calls and task lifecycle heartbeat_calls = 0 task_create_calls = 0 task_cancel_calls = 0 sleep_calls = [] - # Mock check_heartbeat to return False on second call (to trigger restart) - def mock_check_heartbeat(max_age_seconds=300): + def mock_check_heartbeat(**kwargs): nonlocal heartbeat_calls heartbeat_calls += 1 - # Return True first, then False to trigger restart, then break the loop if heartbeat_calls == 1: return True # Heartbeat is fresh elif heartbeat_calls == 2: return False # Heartbeat is stale, should trigger restart else: - # Break out of the infinite loop after testing restart logic raise KeyboardInterrupt("Test complete") - # Track sleep calls async def mock_sleep(duration): sleep_calls.append(duration) - # Don't actually sleep - return - # Mock task to track creation and cancellation class MockTask: def __init__(self): nonlocal task_create_calls @@ -165,32 +216,28 @@ def cancel(self): task_cancel_calls += 1 def done(self): - # Return False so task appears to be running return False def __await__(self): - # Mock awaiting the task (for task cleanup) async def _await(): return None return _await().__await__() - # Apply mocks consumer.check_heartbeat = mock_check_heartbeat - # Mock asyncio functions in the service_bus_consumer module - with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ - patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=lambda coro: MockTask()): + def create_mock_task(coro): + coro.close() + return MockTask() + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task): try: - # Run the supervisor - it will break out when KeyboardInterrupt is raised await consumer.supervisor_with_heartbeat_check() except KeyboardInterrupt: - # Expected way to exit the infinite loop pass - # Verify expected behavior occurred - assert heartbeat_calls >= 2, f"Expected at least 2 heartbeat checks, got {heartbeat_calls}" - assert task_create_calls >= 2, f"Expected at least 2 tasks created (initial + restart), got {task_create_calls}" - assert task_cancel_calls >= 1, f"Expected at least 1 task cancellation, got {task_cancel_calls}" - assert len(sleep_calls) >= 2, f"Expected at least 2 sleep calls, got {len(sleep_calls)}" - assert 60 in sleep_calls, f"Expected sleep(60) for heartbeat check interval, got {sleep_calls}" + assert heartbeat_calls >= 2 + assert task_create_calls >= 2 + assert task_cancel_calls >= 1 + assert 60 in sleep_calls + assert consumer._restart_delay == RESTART_DELAY_SECONDS diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py index 2851c60a8c..bb7c4996e8 100644 --- a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py +++ b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py @@ -1,66 +1,69 @@ import asyncio +import time import pytest -import os from unittest.mock import patch -from service_bus.service_bus_consumer import ServiceBusConsumer +from service_bus.service_bus_consumer import ( + ServiceBusConsumer, + RESTART_DELAY_SECONDS, + HEARTBEAT_CHECK_INTERVAL_SECONDS, + HEARTBEAT_STALENESS_THRESHOLD_SECONDS, + MAX_RESTART_DELAY_SECONDS, + SUPERVISOR_ERROR_DELAY_SECONDS, +) # Create a concrete implementation for testing edge cases class MockConsumerForEdgeCases(ServiceBusConsumer): - def __init__(self, skip_init=False): - if not skip_init: - super().__init__("test_consumer_edge") + def __init__(self): + super().__init__("test_consumer_edge") self.receive_messages_called = False async def receive_messages(self): self.receive_messages_called = True - await asyncio.sleep(0.01) # Small delay to make it a proper coroutine + await asyncio.sleep(0.01) return @pytest.mark.asyncio -async def test_heartbeat_file_corruption(): - """Test handling of corrupted heartbeat file.""" - consumer = MockConsumerForEdgeCases(skip_init=True) - # Manually set required attributes to avoid init issues - consumer.heartbeat_file = "/tmp/test_heartbeat_corruption.txt" - - with patch("service_bus.service_bus_consumer.os.path.exists", return_value=True), \ - patch("builtins.open") as mock_open: +async def test_stale_heartbeat_detection(): + """Test that stale heartbeat is correctly detected.""" + consumer = MockConsumerForEdgeCases() + consumer._last_heartbeat = time.monotonic() - 400 + assert consumer.check_heartbeat(max_age_seconds=300) is False - # Simulate corrupted file with invalid float content - mock_open.return_value.__enter__.return_value.read.return_value = "not_a_number" - result = consumer.check_heartbeat() - assert result is False +@pytest.mark.asyncio +async def test_fresh_heartbeat_detection(): + """Test that fresh heartbeat is correctly detected.""" + consumer = MockConsumerForEdgeCases() + assert consumer.check_heartbeat(max_age_seconds=300) is True @pytest.mark.asyncio -async def test_heartbeat_permission_denied(): - """Test heartbeat update when permission denied.""" - consumer = MockConsumerForEdgeCases(skip_init=True) - consumer.heartbeat_file = "/tmp/test_heartbeat_permission.txt" +async def test_backoff_resets_after_long_running_receive(): + """Test that backoff resets when receive_messages ran longer than the current delay.""" + consumer = MockConsumerForEdgeCases() + consumer._restart_delay = 80 - with patch("builtins.open", side_effect=PermissionError("Permission denied")), \ - patch("service_bus.service_bus_consumer.logger") as mock_logger: + monotonic_values = iter([100.0, 200.0, 200.0]) # start=100, elapsed_check=200 (ran 100s > 80s delay), then next start - # Should not crash, just log error - consumer.update_heartbeat() - mock_logger.error.assert_called_once() + async def long_running_receive(): + raise RuntimeError("Failure after running a while") + consumer.receive_messages = long_running_receive -@pytest.mark.asyncio -async def test_heartbeat_disk_full(): - """Test heartbeat update when disk is full.""" - consumer = MockConsumerForEdgeCases(skip_init=True) - consumer.heartbeat_file = "/tmp/test_heartbeat_disk_full.txt" + async def mock_sleep(duration): + raise asyncio.CancelledError() - with patch("builtins.open", side_effect=OSError("No space left on device")), \ - patch("service_bus.service_bus_consumer.logger") as mock_logger: + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.time.monotonic", side_effect=monotonic_values): + try: + await consumer._receive_messages_loop() + except asyncio.CancelledError: + pass - # Should not crash, just log error - consumer.update_heartbeat() - mock_logger.error.assert_called_once() + # Backoff should have reset to base since elapsed (100s) > old delay (80s) + assert consumer._restart_delay == RESTART_DELAY_SECONDS @pytest.mark.asyncio @@ -102,8 +105,13 @@ async def mock_sleep(duration): if call_count >= 2: # Trigger cleanup after heartbeat check raise KeyboardInterrupt("Test cleanup") + def create_mock_task(coro): + # Close the coroutine to avoid "coroutine never awaited" warning + coro.close() + return MockTask() + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ - patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=lambda coro: MockTask()), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task), \ patch.object(consumer, "check_heartbeat", return_value=True): try: @@ -118,34 +126,57 @@ async def mock_sleep(duration): def test_restart_delay_configuration(): - """Test that restart delay configuration constants exist and have reasonable values.""" - # Import and test constants directly without creating consumer instances - from service_bus.service_bus_consumer import ( - RESTART_DELAY_SECONDS, - HEARTBEAT_CHECK_INTERVAL_SECONDS, - HEARTBEAT_STALENESS_THRESHOLD_SECONDS, - SUPERVISOR_ERROR_DELAY_SECONDS - ) - - # Validate configuration values - assert RESTART_DELAY_SECONDS > 0, "Restart delay should be positive" - assert RESTART_DELAY_SECONDS <= 10, "Restart delay should not be too long" + """Test that configuration constants exist and have reasonable values.""" + assert RESTART_DELAY_SECONDS > 0 + assert RESTART_DELAY_SECONDS <= 10 + assert MAX_RESTART_DELAY_SECONDS >= RESTART_DELAY_SECONDS + assert MAX_RESTART_DELAY_SECONDS <= 600 assert HEARTBEAT_CHECK_INTERVAL_SECONDS > 0 assert HEARTBEAT_STALENESS_THRESHOLD_SECONDS > HEARTBEAT_CHECK_INTERVAL_SECONDS assert SUPERVISOR_ERROR_DELAY_SECONDS > 0 -def test_heartbeat_directory_creation(): - """Test that heartbeat directory is created if it doesn't exist.""" - consumer = MockConsumerForEdgeCases(skip_init=True) - consumer.heartbeat_file = "/tmp/test_dir/test_heartbeat.txt" +@pytest.mark.asyncio +async def test_supervisor_resets_backoff_on_stale_heartbeat_restart(): + """Test that supervisor resets backoff when restarting due to stale heartbeat.""" + consumer = MockConsumerForEdgeCases() + consumer._restart_delay = 160 + + heartbeat_calls = 0 + + def mock_check_heartbeat(**kwargs): + nonlocal heartbeat_calls + heartbeat_calls += 1 + if heartbeat_calls == 1: + return False + raise KeyboardInterrupt("Test complete") + + async def mock_sleep(duration): + pass + + class MockTask: + def done(self): + return False - with patch("service_bus.service_bus_consumer.os.makedirs") as mock_makedirs: + def cancel(self): + pass - consumer.update_heartbeat() + def __await__(self): + async def _await(): + return None + return _await().__await__() + + consumer.check_heartbeat = mock_check_heartbeat + + def create_mock_task(coro): + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass - # Verify makedirs was called with exist_ok=True - mock_makedirs.assert_called_once_with( - os.path.dirname(consumer.heartbeat_file), - exist_ok=True - ) + assert consumer._restart_delay == RESTART_DELAY_SECONDS From d8ce5cf20bf18f20c9343201fc934a646b8da3af Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Feb 2026 17:32:53 +0000 Subject: [PATCH 20/20] Increment API version to 0.26.1 for bug fix Co-authored-by: marrobi <17089773+marrobi@users.noreply.github.com> --- api_app/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_app/_version.py b/api_app/_version.py index 7c4a9591e1..025f4c5d0b 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.26.0" +__version__ = "0.26.1"