diff --git a/workers/api-deployment/tasks.py b/workers/api-deployment/tasks.py index 5345d41d5e..f5551e14cb 100644 --- a/workers/api-deployment/tasks.py +++ b/workers/api-deployment/tasks.py @@ -168,10 +168,11 @@ def _unified_api_execution( Returns: Execution result dictionary """ + api_client = None try: # Set up execution context using shared utilities organization_id = schema_name - config, api_client = WorkerExecutionContext.setup_execution_context( + _, api_client = WorkerExecutionContext.setup_execution_context( organization_id, execution_id, workflow_id ) @@ -233,22 +234,13 @@ def _unified_api_execution( f"files_processed={len(converted_files)}", ) - # CRITICAL: Clean up StateStore to prevent data leaks between tasks - try: - from shared.infrastructure.context import StateStore - - StateStore.clear_all() - logger.debug("๐Ÿงน Cleaned up StateStore context to prevent data leaks") - except Exception as cleanup_error: - logger.warning(f"Failed to cleanup StateStore context: {cleanup_error}") - return result except Exception as e: logger.error(f"API execution failed: {e}") # Handle execution error with standardized pattern - if "api_client" in locals(): + if api_client is not None: WorkerExecutionContext.handle_execution_error( api_client, execution_id, e, logger, f"api_execution_{task_type}" ) @@ -261,19 +253,6 @@ def _unified_api_execution( f"error={str(e)}", ) - # CRITICAL: Clean up StateStore to prevent data leaks between tasks (error path) - try: - from shared.infrastructure.context import StateStore - - StateStore.clear_all() - logger.debug( - "๐Ÿงน Cleaned up StateStore context to prevent data leaks (error path)" - ) - except Exception as cleanup_error: - logger.warning( - f"Failed to cleanup StateStore context on error: {cleanup_error}" - ) - return { "execution_id": execution_id, "status": "ERROR", @@ -281,6 +260,22 @@ def _unified_api_execution( "files_processed": 0, } + finally: + # Clean up API client session to prevent socket FD leaks + if api_client is not None: + try: + api_client.close() + except Exception as e: + logger.debug("api_client.close() failed during cleanup: %s", e) + + # Clean up StateStore to prevent data leaks between tasks + try: + from shared.infrastructure.context import StateStore + + StateStore.clear_all() + except Exception as cleanup_error: + logger.warning(f"Failed to cleanup StateStore context: {cleanup_error}") + @app.task( bind=True, diff --git a/workers/callback/tasks.py b/workers/callback/tasks.py index 1e83559bbf..90ce75b029 100644 --- a/workers/callback/tasks.py +++ b/workers/callback/tasks.py @@ -1386,123 +1386,130 @@ def _process_batch_callback_core( f"organization_id={context.organization_id}, api_client={context.api_client is not None}" ) raise RuntimeError(f"Invalid context for execution {context.execution_id}") - with log_context( - task_id=context.task_id, - execution_id=context.execution_id, - workflow_id=context.workflow_id, - organization_id=context.organization_id, - pipeline_id=context.pipeline_id, - ): - logger.info( - f"Starting batch callback processing for execution {context.execution_id}" - ) + try: + with log_context( + task_id=context.task_id, + execution_id=context.execution_id, + workflow_id=context.workflow_id, + organization_id=context.organization_id, + pipeline_id=context.pipeline_id, + ): + logger.info( + f"Starting batch callback processing for execution {context.execution_id}" + ) - try: - # Use unified status determination with timeout detection (shared with API callback) - aggregated_results, execution_status, expected_files = ( - _determine_execution_status_unified( - file_batch_results=results, + try: + # Use unified status determination with timeout detection (shared with API callback) + aggregated_results, execution_status, expected_files = ( + _determine_execution_status_unified( + file_batch_results=results, + api_client=context.api_client, + execution_id=context.execution_id, + organization_id=context.organization_id, + ) + ) + # Update workflow execution status using unified function + execution_update_result = _update_execution_status_unified( api_client=context.api_client, execution_id=context.execution_id, + final_status=execution_status, + aggregated_results=aggregated_results, organization_id=context.organization_id, + error_message=None, + ) + # Handle pipeline updates using unified function (non-API deployment) + pipeline_result = _handle_pipeline_updates_unified( + context, execution_status, is_api_deployment=False ) - ) - # Update workflow execution status using unified function - execution_update_result = _update_execution_status_unified( - api_client=context.api_client, - execution_id=context.execution_id, - final_status=execution_status, - aggregated_results=aggregated_results, - organization_id=context.organization_id, - error_message=None, - ) - # Handle pipeline updates using unified function (non-API deployment) - pipeline_result = _handle_pipeline_updates_unified( - context, execution_status, is_api_deployment=False - ) - - # Track subscription usage if plugin is present - subscription_tracking_result = _track_subscription_usage_if_available( - context=context, - execution_status=execution_status, - ) - - # Add missing UI logs for cost and final workflow status (matching backend behavior) - _publish_final_workflow_ui_logs( - context=context, - aggregated_results=aggregated_results, - execution_status=execution_status, - ) - # Handle resource cleanup using existing function - cleanup_result = _cleanup_execution_resources(context) - callback_result = { - "status": "completed", - "execution_id": context.execution_id, - "workflow_id": context.workflow_id, - "task_id": context.task_id, - "aggregated_results": aggregated_results, - "execution_status": execution_status, - "expected_files": expected_files, - "execution_update_result": execution_update_result, - "pipeline_result": pipeline_result, - "subscription_tracking_result": subscription_tracking_result, - "cleanup_result": cleanup_result, - "pipeline_id": context.pipeline_id, - "unified_callback": True, - "shared_timeout_detection": True, - } + # Track subscription usage if plugin is present + subscription_tracking_result = _track_subscription_usage_if_available( + context=context, + execution_status=execution_status, + ) - logger.info( - f"Completed unified callback processing for execution {context.execution_id} " - f"with status {execution_status}" - ) - # Handle notifications using unified function (non-critical) - try: - notification_result = _handle_notifications_unified( - api_client=context.api_client, - status=execution_status, - organization_id=context.organization_id, - execution_id=context.execution_id, - pipeline_id=context.pipeline_id, - workflow_id=context.workflow_id, - pipeline_name=context.pipeline_name, - pipeline_type=context.pipeline_type, - error_message=None, + # Add missing UI logs for cost and final workflow status (matching backend behavior) + _publish_final_workflow_ui_logs( + context=context, + aggregated_results=aggregated_results, + execution_status=execution_status, ) - callback_result["notification_result"] = notification_result - except Exception as notif_error: - logger.warning(f"Failed to handle notifications: {notif_error}") - callback_result["notification_result"] = { - "status": "failed", - "error": str(notif_error), + + # Handle resource cleanup using existing function + cleanup_result = _cleanup_execution_resources(context) + callback_result = { + "status": "completed", + "execution_id": context.execution_id, + "workflow_id": context.workflow_id, + "task_id": context.task_id, + "aggregated_results": aggregated_results, + "execution_status": execution_status, + "expected_files": expected_files, + "execution_update_result": execution_update_result, + "pipeline_result": pipeline_result, + "subscription_tracking_result": subscription_tracking_result, + "cleanup_result": cleanup_result, + "pipeline_id": context.pipeline_id, + "unified_callback": True, + "shared_timeout_detection": True, } - return callback_result + logger.info( + f"Completed unified callback processing for execution {context.execution_id} " + f"with status {execution_status}" + ) + # Handle notifications using unified function (non-critical) + try: + notification_result = _handle_notifications_unified( + api_client=context.api_client, + status=execution_status, + organization_id=context.organization_id, + execution_id=context.execution_id, + pipeline_id=context.pipeline_id, + workflow_id=context.workflow_id, + pipeline_name=context.pipeline_name, + pipeline_type=context.pipeline_type, + error_message=None, + ) + callback_result["notification_result"] = notification_result + except Exception as notif_error: + logger.warning(f"Failed to handle notifications: {notif_error}") + callback_result["notification_result"] = { + "status": "failed", + "error": str(notif_error), + } - except Exception as e: - logger.error( - f"Unified batch callback processing failed for execution {context.execution_id}: {e}" - ) + return callback_result - # Try to mark execution as failed using unified function - try: - _update_execution_status_unified( - context.api_client, - context.execution_id, - ExecutionStatus.ERROR.value, - {"error": str(e)[:500]}, - context.organization_id, - error_message=str(e)[:500], - ) - logger.info( - f"Marked execution {context.execution_id} as failed using unified function" + except Exception as e: + logger.error( + f"Unified batch callback processing failed for execution {context.execution_id}: {e}" ) - except Exception as cleanup_error: - logger.error(f"Failed to mark execution as failed: {cleanup_error}") - # Re-raise for Celery retry mechanism - raise + # Try to mark execution as failed using unified function + try: + _update_execution_status_unified( + context.api_client, + context.execution_id, + ExecutionStatus.ERROR.value, + {"error": str(e)[:500]}, + context.organization_id, + error_message=str(e)[:500], + ) + logger.info( + f"Marked execution {context.execution_id} as failed using unified function" + ) + except Exception as cleanup_error: + logger.error(f"Failed to mark execution as failed: {cleanup_error}") + + # Re-raise for Celery retry mechanism + raise + finally: + if context.api_client is not None: + try: + context.api_client.close() + except Exception as e: + logger.debug("api_client.close() failed during cleanup: %s", e) @app.task( @@ -1581,194 +1588,200 @@ def process_batch_callback_api( api_client = create_api_client(organization_id) logger.info(f"Created organization-scoped API client: {organization_id}") - execution_response = api_client.get_workflow_execution( - execution_id, file_execution=False - ) - if not execution_response.success: - raise Exception(f"Failed to get execution context: {execution_response.error}") - execution_context = execution_response.data - workflow_execution = execution_context.get("execution", {}) - workflow = execution_context.get("workflow", {}) + try: + execution_response = api_client.get_workflow_execution( + execution_id, file_execution=False + ) + if not execution_response.success: + raise RuntimeError( + f"Failed to get execution context: {execution_response.error}" + ) + execution_context = execution_response.data + workflow_execution = execution_context.get("execution", {}) + workflow = execution_context.get("workflow", {}) - # Extract schema_name and workflow_id from context - schema_name = organization_id # For API callbacks, schema_name = organization_id - workflow_id = workflow_execution.get("workflow_id") or workflow.get("id") + # Extract schema_name and workflow_id from context + schema_name = organization_id # For API callbacks, schema_name = organization_id + workflow_id = workflow_execution.get("workflow_id") or workflow.get("id") - logger.info( - f"Extracted context: schema_name={schema_name}, workflow_id={workflow_id}, pipeline_id={pipeline_id}" - ) - - with log_context( - task_id=task_id, - execution_id=execution_id, - workflow_id=workflow_id, - organization_id=organization_id, - pipeline_id=pipeline_id, - ): logger.info( - f"Processing API callback for execution {execution_id} with {len(file_batch_results)} batch results" + f"Extracted context: schema_name={schema_name}, workflow_id={workflow_id}, pipeline_id={pipeline_id}" ) - try: - # Create organization-scoped API client using factory pattern - api_client = create_api_client(schema_name) - - # Get pipeline name and type (simplified approach) - if not pipeline_id: - logger.info( - f"No pipeline_id provided for API callback (testing mode). " - f"execution_id={execution_id}, workflow_id={workflow_id}. " - f"APIDeployment-specific updates and notifications will be skipped." - ) - pipeline_name = None - pipeline_type = None - else: - # Use simplified pipeline data fetching - pipeline_name, pipeline_type = _fetch_pipeline_data_simplified( - pipeline_id, schema_name, api_client, is_api_deployment=True - ) + with log_context( + task_id=task_id, + execution_id=execution_id, + workflow_id=workflow_id, + organization_id=organization_id, + pipeline_id=pipeline_id, + ): + logger.info( + f"Processing API callback for execution {execution_id} with {len(file_batch_results)} batch results" + ) - if pipeline_name: + try: + # Get pipeline name and type (simplified approach) + if not pipeline_id: logger.info( - f"โœ… Found pipeline: name='{pipeline_name}', type='{pipeline_type}'" + f"No pipeline_id provided for API callback (testing mode). " + f"execution_id={execution_id}, workflow_id={workflow_id}. " + f"APIDeployment-specific updates and notifications will be skipped." ) + pipeline_name = None + pipeline_type = None else: - logger.warning(f"Could not fetch pipeline data for {pipeline_id}") - pipeline_name = "Unknown API" - pipeline_type = PipelineType.API.value - - # Use unified status determination with timeout detection - aggregated_results, execution_status, expected_files = ( - _determine_execution_status_unified( - file_batch_results=file_batch_results, + # Use simplified pipeline data fetching + pipeline_name, pipeline_type = _fetch_pipeline_data_simplified( + pipeline_id, schema_name, api_client, is_api_deployment=True + ) + + if pipeline_name: + logger.info( + f"โœ… Found pipeline: name='{pipeline_name}', type='{pipeline_type}'" + ) + else: + logger.warning(f"Could not fetch pipeline data for {pipeline_id}") + pipeline_name = "Unknown API" + pipeline_type = PipelineType.API.value + + # Use unified status determination with timeout detection + aggregated_results, execution_status, expected_files = ( + _determine_execution_status_unified( + file_batch_results=file_batch_results, + api_client=api_client, + execution_id=execution_id, + organization_id=organization_id, + ) + ) + + # Update workflow execution status using unified function + execution_update_result = _update_execution_status_unified( api_client=api_client, execution_id=execution_id, + final_status=execution_status, + aggregated_results=aggregated_results, organization_id=organization_id, ) - ) - # Update workflow execution status using unified function - execution_update_result = _update_execution_status_unified( - api_client=api_client, - execution_id=execution_id, - final_status=execution_status, - aggregated_results=aggregated_results, - organization_id=organization_id, - ) - - # Create minimal context for unified pipeline handling - context = CallbackContext() - context.pipeline_id = pipeline_id - context.execution_id = execution_id - context.organization_id = organization_id - context.workflow_id = workflow_id - context.pipeline_name = pipeline_name - context.pipeline_type = pipeline_type - context.api_client = api_client - context.file_executions = execution_context.get("file_executions", []) - context.pipeline_data = { - "is_api": True - } # Mark as API execution for cleanup logic - - # Add missing UI logs for cost and final workflow status (matching backend behavior) - _publish_final_workflow_ui_logs_api( - context=context, - aggregated_results=aggregated_results, - execution_status=execution_status, - ) - - # Handle resource cleanup (matching ETL workflow behavior) - cleanup_result = _cleanup_execution_resources(context) - - # Handle pipeline updates (only if pipeline_id exists) - if context.pipeline_id: - pipeline_result = _handle_pipeline_updates_unified( - context=context, final_status=execution_status, is_api_deployment=True + # Create minimal context for unified pipeline handling + context = CallbackContext() + context.pipeline_id = pipeline_id + context.execution_id = execution_id + context.organization_id = organization_id + context.workflow_id = workflow_id + context.pipeline_name = pipeline_name + context.pipeline_type = pipeline_type + context.api_client = api_client + context.file_executions = execution_context.get("file_executions", []) + context.pipeline_data = { + "is_api": True + } # Mark as API execution for cleanup logic + + # Add missing UI logs for cost and final workflow status (matching backend behavior) + _publish_final_workflow_ui_logs_api( + context=context, + aggregated_results=aggregated_results, + execution_status=execution_status, ) - else: - logger.info( - f"Skipping pipeline status update for execution {execution_id} " - f"since no APIDeployment record" - ) - pipeline_result = { - "skipped": True, - "reason": "execute without APIDeployment record", - } - # Track subscription usage if plugin is present - subscription_tracking_result = _track_subscription_usage_if_available( - context=context, - execution_status=execution_status, - ) + # Handle resource cleanup (matching ETL workflow behavior) + cleanup_result = _cleanup_execution_resources(context) - # Handle notifications using unified function - notification_result = _handle_notifications_unified( - api_client=api_client, - status=execution_status, - organization_id=organization_id, - execution_id=execution_id, - pipeline_id=pipeline_id, - workflow_id=workflow_id, - pipeline_name=pipeline_name, - pipeline_type=pipeline_type, - error_message=None, - ) + # Handle pipeline updates (only if pipeline_id exists) + if context.pipeline_id: + pipeline_result = _handle_pipeline_updates_unified( + context=context, + final_status=execution_status, + is_api_deployment=True, + ) + else: + logger.info( + f"Skipping pipeline status update for execution {execution_id} " + f"since no APIDeployment record" + ) + pipeline_result = { + "skipped": True, + "reason": "execute without APIDeployment record", + } + + # Track subscription usage if plugin is present + subscription_tracking_result = _track_subscription_usage_if_available( + context=context, + execution_status=execution_status, + ) - callback_result = { - "execution_id": execution_id, - "workflow_id": workflow_id, - "pipeline_id": pipeline_id, - "status": "completed", - "total_files_processed": aggregated_results.get( - "total_files_processed", 0 - ), - "total_execution_time": aggregated_results.get("total_execution_time", 0), - "batches_processed": len(file_batch_results), - "task_id": task_id, - "expected_files": expected_files, # Include expected files for debugging - "execution_update": execution_update_result, - "pipeline_update": pipeline_result, - "notifications": notification_result, - "subscription_tracking_result": subscription_tracking_result, - "cleanup_result": cleanup_result, - "optimization": { - "method": "unified_callback_functions", - "eliminated_code_duplication": True, - "shared_timeout_detection": True, - }, - } + # Handle notifications using unified function + notification_result = _handle_notifications_unified( + api_client=api_client, + status=execution_status, + organization_id=organization_id, + execution_id=execution_id, + pipeline_id=pipeline_id, + workflow_id=workflow_id, + pipeline_name=pipeline_name, + pipeline_type=pipeline_type, + error_message=None, + ) - logger.info( - f"Successfully completed API callback for execution {execution_id}" - ) - return callback_result + callback_result = { + "execution_id": execution_id, + "workflow_id": workflow_id, + "pipeline_id": pipeline_id, + "status": "completed", + "total_files_processed": aggregated_results.get( + "total_files_processed", 0 + ), + "total_execution_time": aggregated_results.get( + "total_execution_time", 0 + ), + "batches_processed": len(file_batch_results), + "task_id": task_id, + "expected_files": expected_files, # Include expected files for debugging + "execution_update": execution_update_result, + "pipeline_update": pipeline_result, + "notifications": notification_result, + "subscription_tracking_result": subscription_tracking_result, + "cleanup_result": cleanup_result, + "optimization": { + "method": "unified_callback_functions", + "eliminated_code_duplication": True, + "shared_timeout_detection": True, + }, + } - except Exception as e: - logger.error( - f"API callback processing failed for execution {execution_id}: {e}" - ) + logger.info( + f"Successfully completed API callback for execution {execution_id}" + ) + return callback_result - # Try to update execution status to failed - try: - # Create organization-scoped API client for error handling - api_client = create_api_client(schema_name) - # Update execution status to error - api_client.update_workflow_execution_status( - execution_id=execution_id, - status=ExecutionStatus.ERROR.value, - error_message=str(e)[:500], # Limit error message length - organization_id=schema_name, + except Exception as e: + logger.error( + f"API callback processing failed for execution {execution_id}: {e}" ) - # OPTIMIZATION: Skip pipeline status update for API deployments on error - if pipeline_id: - logger.info( - f"OPTIMIZATION: Skipping pipeline status update for API deployment {pipeline_id} on error (no Pipeline record exists)" + # Try to update execution status to failed (reuse existing api_client) + try: + api_client.update_workflow_execution_status( + execution_id=execution_id, + status=ExecutionStatus.ERROR.value, + error_message=str(e)[:500], # Limit error message length + organization_id=schema_name, ) - except Exception as update_error: - logger.error(f"Failed to update execution status: {update_error}") - raise + # OPTIMIZATION: Skip pipeline status update for API deployments on error + if pipeline_id: + logger.info( + f"OPTIMIZATION: Skipping pipeline status update for API deployment {pipeline_id} on error (no Pipeline record exists)" + ) + except Exception as update_error: + logger.error(f"Failed to update execution status: {update_error}") + + raise + finally: + try: + api_client.close() + except Exception as e: + logger.debug("api_client.close() failed during cleanup: %s", e) def _publish_final_workflow_ui_logs( diff --git a/workers/conftest.py b/workers/conftest.py new file mode 100644 index 0000000000..137e8304e1 --- /dev/null +++ b/workers/conftest.py @@ -0,0 +1,18 @@ +"""Root conftest for workers test suite. + +Sets required environment variables before any workers modules are imported. +""" + +import os + +# These must be set before any workers module import because +# shared/constants/api_endpoints.py evaluates INTERNAL_API_BASE_URL at class definition time. +os.environ.setdefault("INTERNAL_API_BASE_URL", "http://test-backend:8000/internal") +os.environ.setdefault("INTERNAL_SERVICE_API_KEY", "test-key-123") +os.environ.setdefault("CELERY_BROKER_BASE_URL", "amqp://localhost:5672//") +os.environ.setdefault("CELERY_BROKER_USER", "guest") +os.environ.setdefault("CELERY_BROKER_PASS", "guest") +os.environ.setdefault("DB_HOST", "localhost") +os.environ.setdefault("DB_USER", "test") +os.environ.setdefault("DB_PASSWORD", "test") +os.environ.setdefault("DB_NAME", "testdb") diff --git a/workers/sample.env b/workers/sample.env index 516fc242e1..5e054ca854 100644 --- a/workers/sample.env +++ b/workers/sample.env @@ -102,8 +102,14 @@ ENABLE_API_CLIENT_SINGLETON=false DEBUG_API_CLIENT_INIT=false WORKER_INFRASTRUCTURE_HEALTH_CHECK=true -# API Client Configuration -API_CLIENT_POOL_SIZE=3 +# API Client Connection Pool Size +# Controls pool_connections (number of host connection pools) and pool_maxsize (max connections per pool). +# pool_connections = API_CLIENT_POOL_SIZE, pool_maxsize = API_CLIENT_POOL_SIZE * 2 +API_CLIENT_POOL_SIZE=10 + +# Session lifecycle safety valve - reset singleton after N tasks (0=disabled) +# Only effective when ENABLE_API_CLIENT_SINGLETON=true +WORKER_SINGLETON_RESET_THRESHOLD=1000 # Config Caching ENABLE_CONFIG_CACHE=true diff --git a/workers/shared/api/internal_client.py b/workers/shared/api/internal_client.py index 971ffb8457..f1ba5f6c58 100644 --- a/workers/shared/api/internal_client.py +++ b/workers/shared/api/internal_client.py @@ -10,6 +10,7 @@ """ import logging +import threading import uuid from typing import Any from uuid import UUID @@ -85,6 +86,11 @@ class InternalAPIClient(CachedAPIClientMixin): _shared_session = None _initialization_count = 0 + # Task counter for periodic singleton reset (FR-3.2) + _task_counter: int = 0 + _task_counter_lock: threading.Lock = threading.Lock() + _last_reset_time: float | None = None + def __init__(self, config: WorkerConfig | None = None): """Initialize the facade with all specialized clients and caching. @@ -126,6 +132,8 @@ def _initialize_core_clients_optimized(self) -> None: logger.info("Creating shared HTTP session (GIL-safe singleton pattern)") # Create the first base client to establish the session self.base_client = BaseAPIClient(self.config) + # The first client owns the session; reset_singleton() handles final cleanup + self.base_client._owns_session = False # Share the session for reuse (atomic assignment) InternalAPIClient._shared_session = self.base_client.session InternalAPIClient._shared_base_client = self.base_client @@ -140,35 +148,35 @@ def _initialize_core_clients_optimized(self) -> None: self.base_client.session = ( InternalAPIClient._shared_session ) # Use shared session + self.base_client._owns_session = False + + # Helper to share session on a client + def _share_session(client: BaseAPIClient) -> None: + client.session.close() # Close the freshly-created session + client.session = InternalAPIClient._shared_session + client._owns_session = False # Create specialized clients with shared session (outside lock for performance) self.execution_client = ExecutionAPIClient(self.config) - self.execution_client.session.close() - self.execution_client.session = InternalAPIClient._shared_session + _share_session(self.execution_client) self.workflow_client = WorkflowAPIClient() - self.workflow_client.session.close() - self.workflow_client.session = InternalAPIClient._shared_session + _share_session(self.workflow_client) self.file_client = FileAPIClient(self.config) - self.file_client.session.close() - self.file_client.session = InternalAPIClient._shared_session + _share_session(self.file_client) self.webhook_client = WebhookAPIClient(self.config) - self.webhook_client.session.close() - self.webhook_client.session = InternalAPIClient._shared_session + _share_session(self.webhook_client) self.organization_client = OrganizationAPIClient(self.config) - self.organization_client.session.close() - self.organization_client.session = InternalAPIClient._shared_session + _share_session(self.organization_client) self.tool_client = ToolAPIClient(self.config) - self.tool_client.session.close() - self.tool_client.session = InternalAPIClient._shared_session + _share_session(self.tool_client) self.usage_client = UsageAPIClient(self.config) - self.usage_client.session.close() - self.usage_client.session = InternalAPIClient._shared_session + _share_session(self.usage_client) def _initialize_core_clients_traditional(self) -> None: """Initialize clients the traditional way (for backward compatibility).""" @@ -247,6 +255,60 @@ def close(self): self.usage_client.close() logger.debug("Closed all InternalAPIClient sessions (traditional mode)") + @classmethod + def reset_singleton(cls): + """Reset shared singleton session state. Safe to call anytime. + + Note: In thread/gevent/eventlet pools, this may close the session while + other threads have in-flight requests. The BaseAPIClient retry logic + handles transient connection errors from this race. For prefork pools, + this is a non-issue (one task per process). + """ + if cls._shared_session is not None: + try: + cls._shared_session.close() + except Exception as e: + logger.warning("Failed to close shared session during reset: %s", e) + cls._shared_session = None + cls._shared_base_client = None + cls._initialization_count = 0 + cls._task_counter = 0 + logger.info("Reset InternalAPIClient singleton state") + + @classmethod + def increment_task_counter(cls) -> None: + """Increment task counter and reset singleton if threshold reached. + + Called via task_postrun signal after each task completes. + Uses a lock for thread safety in case threads/gevent/eventlet pools are used. + When singleton disabled (default), reset_singleton() is a no-op. + """ + from shared.infrastructure.config.worker_config import WorkerConfig + + threshold = WorkerConfig().singleton_reset_task_threshold + with cls._task_counter_lock: + cls._task_counter += 1 + if threshold > 0 and cls._task_counter >= threshold: + import time + + logger.info( + "Task counter reached threshold (%d/%d), resetting singleton session", + cls._task_counter, + threshold, + ) + cls.reset_singleton() + cls._last_reset_time = time.time() + + @classmethod + def get_task_counter_info(cls) -> dict: + """Get task counter state for observability.""" + return { + "task_counter": cls._task_counter, + "last_reset_time": cls._last_reset_time, + "shared_session_active": cls._shared_session is not None, + "initialization_count": cls._initialization_count, + } + def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit with safe cleanup.""" self.close() diff --git a/workers/shared/clients/base_client.py b/workers/shared/clients/base_client.py index 45d37d8166..07e7067055 100644 --- a/workers/shared/clients/base_client.py +++ b/workers/shared/clients/base_client.py @@ -103,8 +103,12 @@ def __init__(self, config: WorkerConfig | None = None): # It comes from task context, not from configuration self.organization_id = None + # Track whether this client owns its session (for singleton-aware close) + self._owns_session = True + # Initialize requests session with retry strategy self.session = requests.Session() + self._closed = False # Track session state for idempotent close self._setup_session() # Always log initialization @@ -149,10 +153,13 @@ def _setup_session(self): ) # HTTP adapter with connection pooling + # pool_connections: number of connection pools to cache (one per host:port) + # pool_maxsize: max connections per pool (concurrent requests per host) + pool_size = self.config.api_client_pool_size adapter = HTTPAdapter( max_retries=retry_strategy, - pool_connections=10, # Number of connection pools - pool_maxsize=20, # Maximum number of connections per pool + pool_connections=pool_size, + pool_maxsize=pool_size * 2, pool_block=False, # Don't block when pool is full ) self.session.mount("http://", adapter) @@ -540,8 +547,19 @@ def health_check(self) -> APIResponse: # Session management def close(self): - """Close the HTTP session.""" + """Close the HTTP session (idempotent). + + When _owns_session is False (singleton mode), this is a no-op to avoid + closing the shared session that other clients depend on. + """ + if self._closed: + return + if not self._owns_session: + logger.debug("Skipping session close (shared singleton session)") + self._closed = True + return self.session.close() + self._closed = True logger.debug("Closed API client session") def __enter__(self): @@ -551,3 +569,17 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.close() + + def __del__(self): + """Safety net: close session on garbage collection. + + Only closes if this client owns its session (not shared singleton). + """ + try: + if hasattr(self, "_closed") and not self._closed: + if hasattr(self, "_owns_session") and not self._owns_session: + return # Don't close shared singleton session + if hasattr(self, "session") and self.session is not None: + self.session.close() + except Exception: + pass # Never raise in __del__ diff --git a/workers/shared/infrastructure/config/worker_config.py b/workers/shared/infrastructure/config/worker_config.py index eb1c0b8e79..e1e11997c2 100644 --- a/workers/shared/infrastructure/config/worker_config.py +++ b/workers/shared/infrastructure/config/worker_config.py @@ -257,7 +257,14 @@ class WorkerConfig: == "true" ) api_client_pool_size: int = field( - default_factory=lambda: int(os.getenv("API_CLIENT_POOL_SIZE", "3")) + default_factory=lambda: int(os.getenv("API_CLIENT_POOL_SIZE", "10")) + ) + + # Session Lifecycle Safety Valve (FR-3.2) + # Reset singleton session after N task completions to prevent FD accumulation. + # Only effective when enable_api_client_singleton=true. Set to 0 to disable. + singleton_reset_task_threshold: int = field( + default_factory=lambda: int(os.getenv("WORKER_SINGLETON_RESET_THRESHOLD", "1000")) ) # Configuration Caching diff --git a/workers/shared/tests/__init__.py b/workers/shared/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/workers/shared/tests/test_session_lifecycle.py b/workers/shared/tests/test_session_lifecycle.py new file mode 100644 index 0000000000..46678c8bb9 --- /dev/null +++ b/workers/shared/tests/test_session_lifecycle.py @@ -0,0 +1,547 @@ +"""Tests for HTTP Session Lifecycle Management (UNS-205). + +Tests cover: +- FR-1: __del__ destructor safety net +- FR-2: Explicit cleanup in callback tasks (try/finally) +- FR-3: Singleton lifecycle management (reset_singleton, task counter) +- Gap #5: _owns_session flag for singleton-safe close() +- Gap #4: API_CLIENT_POOL_SIZE wired into HTTPAdapter +""" + +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_singleton_state(): + """Reset InternalAPIClient class-level singleton state between tests.""" + from shared.api.internal_client import InternalAPIClient + + InternalAPIClient._shared_session = None + InternalAPIClient._shared_base_client = None + InternalAPIClient._initialization_count = 0 + InternalAPIClient._task_counter = 0 + InternalAPIClient._last_reset_time = None + yield + InternalAPIClient._shared_session = None + InternalAPIClient._shared_base_client = None + InternalAPIClient._initialization_count = 0 + InternalAPIClient._task_counter = 0 + InternalAPIClient._last_reset_time = None + + +@pytest.fixture +def mock_config(): + """Create a WorkerConfig with defaults suitable for testing.""" + with patch.dict( + "os.environ", + { + "INTERNAL_API_BASE_URL": "http://test-backend:8000/internal", + "INTERNAL_SERVICE_API_KEY": "test-key-123", + "CELERY_BROKER_BASE_URL": "amqp://localhost:5672//", + "CELERY_BROKER_USER": "guest", + "CELERY_BROKER_PASS": "guest", + "DB_HOST": "localhost", + "DB_USER": "test", + "DB_PASSWORD": "test", + "DB_NAME": "testdb", + "API_CLIENT_POOL_SIZE": "5", + "ENABLE_API_CLIENT_SINGLETON": "false", + }, + ): + from shared.infrastructure.config.worker_config import WorkerConfig + + yield WorkerConfig() + + +@pytest.fixture +def mock_config_singleton(): + """Create a WorkerConfig with singleton enabled.""" + with patch.dict( + "os.environ", + { + "INTERNAL_API_BASE_URL": "http://test-backend:8000/internal", + "INTERNAL_SERVICE_API_KEY": "test-key-123", + "CELERY_BROKER_BASE_URL": "amqp://localhost:5672//", + "CELERY_BROKER_USER": "guest", + "CELERY_BROKER_PASS": "guest", + "DB_HOST": "localhost", + "DB_USER": "test", + "DB_PASSWORD": "test", + "DB_NAME": "testdb", + "API_CLIENT_POOL_SIZE": "5", + "ENABLE_API_CLIENT_SINGLETON": "true", + "WORKER_SINGLETON_RESET_THRESHOLD": "3", + }, + ): + from shared.infrastructure.config.worker_config import WorkerConfig + + yield WorkerConfig() + + +# =========================================================================== +# FR-1: __del__ destructor tests +# =========================================================================== + + +class TestBaseAPIClientDestructor: + """Tests for BaseAPIClient.__del__ safety net.""" + + def test_del_closes_unclosed_session(self, mock_config): + """__del__ should close the session if close() was never called.""" + from shared.clients.base_client import BaseAPIClient + + client = BaseAPIClient(mock_config) + mock_session = MagicMock() + client.session = mock_session + client._closed = False + client._owns_session = True + + client.__del__() + + mock_session.close.assert_called_once() + + def test_del_skips_already_closed_session(self, mock_config): + """__del__ should be a no-op if session is already closed.""" + from shared.clients.base_client import BaseAPIClient + + client = BaseAPIClient(mock_config) + mock_session = MagicMock() + client.session = mock_session + client._closed = True + + client.__del__() + + mock_session.close.assert_not_called() + + def test_del_skips_shared_session(self, mock_config): + """__del__ should NOT close a shared singleton session.""" + from shared.clients.base_client import BaseAPIClient + + client = BaseAPIClient(mock_config) + mock_session = MagicMock() + client.session = mock_session + client._closed = False + client._owns_session = False + + client.__del__() + + mock_session.close.assert_not_called() + + def test_del_handles_missing_attributes(self): + """__del__ should not raise even if init failed partially.""" + from shared.clients.base_client import BaseAPIClient + + client = object.__new__(BaseAPIClient) + # No attributes set at all - should not raise + client.__del__() + + def test_del_swallows_exceptions(self, mock_config): + """__del__ should never propagate exceptions.""" + from shared.clients.base_client import BaseAPIClient + + client = BaseAPIClient(mock_config) + client._closed = False + client._owns_session = True + client.session = MagicMock() + client.session.close.side_effect = RuntimeError("connection broken") + + # Should not raise + client.__del__() + + +# =========================================================================== +# FR-1 + Gap #5: close() with _owns_session +# =========================================================================== + + +class TestBaseAPIClientClose: + """Tests for BaseAPIClient.close() behavior.""" + + def test_close_is_idempotent(self, mock_config): + """Calling close() multiple times should only close session once.""" + from shared.clients.base_client import BaseAPIClient + + client = BaseAPIClient(mock_config) + mock_session = MagicMock() + client.session = mock_session + + client.close() + client.close() + client.close() + + mock_session.close.assert_called_once() + + def test_close_skips_shared_session(self, mock_config): + """close() should NOT close the session when _owns_session=False.""" + from shared.clients.base_client import BaseAPIClient + + client = BaseAPIClient(mock_config) + mock_session = MagicMock() + client.session = mock_session + client._owns_session = False + + client.close() + + mock_session.close.assert_not_called() + assert client._closed is True # Flag still set to prevent redundant calls + + def test_close_closes_owned_session(self, mock_config): + """close() should close the session when _owns_session=True (default).""" + from shared.clients.base_client import BaseAPIClient + + client = BaseAPIClient(mock_config) + mock_session = MagicMock() + client.session = mock_session + assert client._owns_session is True # Default + + client.close() + + mock_session.close.assert_called_once() + assert client._closed is True + + def test_context_manager_calls_close(self, mock_config): + """Using 'with' should call close() on exit.""" + from shared.clients.base_client import BaseAPIClient + + with BaseAPIClient(mock_config) as client: + mock_session = MagicMock() + client.session = mock_session + + mock_session.close.assert_called_once() + + +# =========================================================================== +# Gap #4: API_CLIENT_POOL_SIZE wired into HTTPAdapter +# =========================================================================== + + +class TestPoolSizeConfiguration: + """Tests for API_CLIENT_POOL_SIZE being wired into HTTPAdapter.""" + + def test_pool_size_from_config(self, mock_config): + """HTTPAdapter should use api_client_pool_size from config.""" + from shared.clients.base_client import BaseAPIClient + + assert mock_config.api_client_pool_size == 5 + + client = BaseAPIClient(mock_config) + + # Inspect the mounted adapter's internal pool settings + adapter = client.session.get_adapter("http://") + assert adapter._pool_connections == 5 + assert adapter._pool_maxsize == 10 # pool_size * 2 + + client.close() + + def test_default_pool_size(self): + """Default pool size should be 10 when not configured.""" + with patch.dict( + "os.environ", + { + "INTERNAL_API_BASE_URL": "http://test:8000/internal", + "INTERNAL_SERVICE_API_KEY": "test-key", + "CELERY_BROKER_BASE_URL": "amqp://localhost:5672//", + "CELERY_BROKER_USER": "guest", + "CELERY_BROKER_PASS": "guest", + "DB_HOST": "localhost", + "DB_USER": "test", + "DB_PASSWORD": "test", + "DB_NAME": "testdb", + }, + clear=False, + ): + from shared.infrastructure.config.worker_config import WorkerConfig + + config = WorkerConfig() + assert config.api_client_pool_size == 10 + + +# =========================================================================== +# FR-3: Singleton lifecycle management +# =========================================================================== + + +class TestResetSingleton: + """Tests for InternalAPIClient.reset_singleton().""" + + def test_reset_when_no_shared_session(self): + """reset_singleton() should be a no-op when there's no shared session.""" + from shared.api.internal_client import InternalAPIClient + + # Should not raise + InternalAPIClient.reset_singleton() + assert InternalAPIClient._shared_session is None + + def test_reset_closes_shared_session(self): + """reset_singleton() should close the shared session and clear state.""" + from shared.api.internal_client import InternalAPIClient + + mock_session = MagicMock() + InternalAPIClient._shared_session = mock_session + InternalAPIClient._shared_base_client = MagicMock() + InternalAPIClient._initialization_count = 5 + InternalAPIClient._task_counter = 42 + + InternalAPIClient.reset_singleton() + + mock_session.close.assert_called_once() + assert InternalAPIClient._shared_session is None + assert InternalAPIClient._shared_base_client is None + assert InternalAPIClient._initialization_count == 0 + assert InternalAPIClient._task_counter == 0 + + def test_reset_handles_close_exception(self): + """reset_singleton() should handle session.close() failure gracefully.""" + from shared.api.internal_client import InternalAPIClient + + mock_session = MagicMock() + mock_session.close.side_effect = RuntimeError("broken pipe") + InternalAPIClient._shared_session = mock_session + + # Should not raise + InternalAPIClient.reset_singleton() + + assert InternalAPIClient._shared_session is None + + +# =========================================================================== +# FR-3: Task counter +# =========================================================================== + + +class TestTaskCounter: + """Tests for InternalAPIClient.increment_task_counter().""" + + def test_increment_counter(self, mock_config_singleton): + """Counter should increment on each call.""" + from shared.api.internal_client import InternalAPIClient + + InternalAPIClient._task_counter = 0 + + with patch.object( + InternalAPIClient, "reset_singleton" + ) as mock_reset: + InternalAPIClient.increment_task_counter() + assert InternalAPIClient._task_counter == 1 + mock_reset.assert_not_called() + + def test_threshold_triggers_reset(self, mock_config_singleton): + """Counter reaching threshold should trigger reset_singleton().""" + from shared.api.internal_client import InternalAPIClient + + # Threshold is 3 from mock_config_singleton + InternalAPIClient._task_counter = 2 # One away from threshold + + with patch.object( + InternalAPIClient, "reset_singleton" + ) as mock_reset: + InternalAPIClient.increment_task_counter() + mock_reset.assert_called_once() + # Counter zeroing is done inside reset_singleton() itself, + # which is mocked here โ€” verified in TestResetSingleton instead + + def test_get_task_counter_info(self): + """get_task_counter_info() should return correct state.""" + from shared.api.internal_client import InternalAPIClient + + InternalAPIClient._task_counter = 42 + InternalAPIClient._shared_session = MagicMock() + + info = InternalAPIClient.get_task_counter_info() + + assert info["task_counter"] == 42 + assert info["shared_session_active"] is True + + +# =========================================================================== +# FR-3: on_task_postrun guard +# =========================================================================== + + +class TestOnTaskPostrunGuard: + """Tests for the singleton guard in on_task_postrun.""" + + def test_postrun_skips_when_singleton_disabled(self): + """on_task_postrun should skip increment when singleton=false.""" + import worker + + with patch.object(worker, "config") as mock_cfg: + mock_cfg.enable_api_client_singleton = False + with patch( + "shared.api.internal_client.InternalAPIClient" + ".increment_task_counter" + ) as mock_increment: + worker.on_task_postrun(sender=None, task_id=None) + mock_increment.assert_not_called() + + def test_postrun_calls_increment_when_singleton_enabled(self): + """on_task_postrun should call increment when singleton=true.""" + import worker + + with patch.object(worker, "config") as mock_cfg: + mock_cfg.enable_api_client_singleton = True + with patch( + "shared.api.internal_client.InternalAPIClient" + ".increment_task_counter" + ) as mock_increment: + worker.on_task_postrun(sender=None, task_id=None) + mock_increment.assert_called_once() + + +# =========================================================================== +# Gap #5: Singleton-aware close() in InternalAPIClient +# =========================================================================== + + +class TestInternalAPIClientSingletonClose: + """Tests for InternalAPIClient.close() respecting singleton mode.""" + + def test_close_traditional_mode_closes_all(self, mock_config): + """In traditional mode, close() should close all client sessions.""" + from shared.api.internal_client import InternalAPIClient + + with patch("shared.api.internal_client.get_client_plugin", return_value=None): + client = InternalAPIClient(mock_config) + + # Replace sessions with mocks + mock_sessions = {} + for attr in [ + "base_client", + "execution_client", + "file_client", + "webhook_client", + "organization_client", + "tool_client", + "workflow_client", + "usage_client", + ]: + sub_client = getattr(client, attr) + mock_session = MagicMock() + sub_client.session = mock_session + sub_client._closed = False + sub_client._owns_session = True + mock_sessions[attr] = mock_session + + client.close() + + for attr, mock_session in mock_sessions.items(): + mock_session.close.assert_called_once(), ( + f"{attr} session was not closed" + ) + + def test_close_singleton_mode_preserves_shared_session( + self, mock_config_singleton + ): + """In singleton mode, close() should NOT close the shared session.""" + from shared.api.internal_client import InternalAPIClient + + with patch("shared.api.internal_client.get_client_plugin", return_value=None): + client = InternalAPIClient(mock_config_singleton) + + shared_session = InternalAPIClient._shared_session + assert shared_session is not None + + # close() in singleton mode should preserve the session + client.close() + + # The shared session should still be alive + assert InternalAPIClient._shared_session is shared_session + + def test_singleton_clients_have_owns_session_false( + self, mock_config_singleton + ): + """All clients in singleton mode should have _owns_session=False.""" + from shared.api.internal_client import InternalAPIClient + + with patch("shared.api.internal_client.get_client_plugin", return_value=None): + client = InternalAPIClient(mock_config_singleton) + + for attr in [ + "base_client", + "execution_client", + "file_client", + "webhook_client", + "organization_client", + "tool_client", + "workflow_client", + "usage_client", + ]: + sub_client = getattr(client, attr) + assert sub_client._owns_session is False, ( + f"{attr} should have _owns_session=False in singleton mode" + ) + + def test_traditional_clients_have_owns_session_true(self, mock_config): + """All clients in traditional mode should have _owns_session=True.""" + from shared.api.internal_client import InternalAPIClient + + with patch("shared.api.internal_client.get_client_plugin", return_value=None): + client = InternalAPIClient(mock_config) + + for attr in [ + "base_client", + "execution_client", + "file_client", + "webhook_client", + "organization_client", + "tool_client", + "workflow_client", + "usage_client", + ]: + sub_client = getattr(client, attr) + assert sub_client._owns_session is True, ( + f"{attr} should have _owns_session=True in traditional mode" + ) + + +# =========================================================================== +# FR-2: WorkerExecutionContext managed_execution_context cleanup +# =========================================================================== + + +class TestManagedExecutionContextCleanup: + """Tests for context manager cleanup in WorkerExecutionContext.""" + + def test_managed_context_closes_client_on_success(self, mock_config): + """managed_execution_context should close client after successful use.""" + from shared.workflow.execution.context import WorkerExecutionContext + + with patch.object( + WorkerExecutionContext, + "setup_execution_context", + ) as mock_setup: + mock_client = MagicMock() + mock_setup.return_value = (mock_config, mock_client) + + with WorkerExecutionContext.managed_execution_context( + "org-1", "exec-1", "wf-1" + ) as (cfg, client): + pass # Simulate successful execution + + mock_client.close.assert_called_once() + + def test_managed_context_closes_client_on_exception(self, mock_config): + """managed_execution_context should close client even when exception occurs.""" + from shared.workflow.execution.context import WorkerExecutionContext + + with patch.object( + WorkerExecutionContext, + "setup_execution_context", + ) as mock_setup: + mock_client = MagicMock() + mock_setup.return_value = (mock_config, mock_client) + + with pytest.raises(ValueError): + with WorkerExecutionContext.managed_execution_context( + "org-1", "exec-1", "wf-1" + ) as (cfg, client): + raise ValueError("test error") + + mock_client.close.assert_called_once() diff --git a/workers/worker.py b/workers/worker.py index aa3d6b6007..93de02baf2 100755 --- a/workers/worker.py +++ b/workers/worker.py @@ -366,6 +366,50 @@ def on_worker_process_init(**kwargs): # ============= END OF WORKER PROCESS INIT HOOK ============= +# ============= WORKER PROCESS SHUTDOWN HOOK ============= +@signals.worker_process_shutdown.connect +def on_worker_process_shutdown(**kwargs): + """Clean up HTTP sessions during worker shutdown.""" + logger.info("Cleaning up API client resources (PID: %s)", os.getpid()) + try: + from shared.api.internal_client import InternalAPIClient + + InternalAPIClient.reset_singleton() + except Exception as e: + logger.warning(f"Failed to reset InternalAPIClient singleton: {e}") + try: + from shared.patterns.factory.client_factory import ClientFactory + + ClientFactory.reset_shared_state() + except Exception as e: + logger.warning(f"Failed to reset ClientFactory: {e}") + + +# ============= END OF WORKER PROCESS SHUTDOWN HOOK ============= + + +# ============= TASK COMPLETION HOOK (SESSION LIFECYCLE SAFETY VALVE) ============= +@signals.task_postrun.connect +def on_task_postrun(sender=None, task_id=None, **kwargs): + """Increment task counter for periodic singleton session reset (FR-3.2). + + After WORKER_SINGLETON_RESET_THRESHOLD tasks (default: 1000), the singleton + session is closed and recreated. Skipped when singleton is disabled (default) + since the counter and reset would be no-ops. + """ + if not config.enable_api_client_singleton: + return + try: + from shared.api.internal_client import InternalAPIClient + + InternalAPIClient.increment_task_counter() + except Exception as e: + logger.warning(f"Failed to increment task counter: {e}") + + +# ============= END OF TASK COMPLETION HOOK ============= + + # ============= REGISTER HEARTBEATKEEPER HERE ============= # Check if HeartbeatKeeper should be enabled (default: enabled) # Uses hierarchical config system: CLI args > worker-specific > global > default