diff --git a/.coveragerc b/.coveragerc index 381b644b..a26ed3c6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -11,6 +11,7 @@ omit = */env/* */.pytest_cache/* */node_modules/* + src/backend/v4/api/router.py [paths] source = diff --git a/src/backend/app.py b/src/backend/app.py index 2cf7d6a6..38384fbe 100644 --- a/src/backend/app.py +++ b/src/backend/app.py @@ -17,6 +17,8 @@ from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + # Local imports from middleware.health_check import HealthCheckMiddleware from v4.api.router import app_v4 @@ -51,20 +53,6 @@ async def lifespan(app: FastAPI): logger.info("👋 MACAE application shutdown complete") -# Check if the Application Insights Instrumentation Key is set in the environment variables -connection_string = config.APPLICATIONINSIGHTS_CONNECTION_STRING -if connection_string: - # Configure Application Insights if the Instrumentation Key is found - configure_azure_monitor(connection_string=connection_string) - logging.info( - "Application Insights configured with the provided Instrumentation Key" - ) -else: - # Log a warning if the Instrumentation Key is not found - logging.warning( - "No Application Insights Instrumentation Key found. Skipping configuration" - ) - # Configure logging levels from environment variables # logging.basicConfig(level=getattr(logging, config.AZURE_BASIC_LOGGING_LEVEL.upper(), logging.INFO)) @@ -80,10 +68,32 @@ async def lifespan(app: FastAPI): logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING) +# Suppress noisy Azure Monitor exporter "Transmission succeeded" logs +logging.getLogger("azure.monitor.opentelemetry.exporter.export._base").setLevel(logging.WARNING) + # Initialize the FastAPI app app = FastAPI(lifespan=lifespan) frontend_url = config.FRONTEND_SITE_NAME +# Configure Azure Monitor and instrument FastAPI for OpenTelemetry +# This enables automatic request tracing, dependency tracking, and proper operation_id +if config.APPLICATIONINSIGHTS_CONNECTION_STRING: + # Configure Application Insights telemetry with live metrics + configure_azure_monitor( + connection_string=config.APPLICATIONINSIGHTS_CONNECTION_STRING, + enable_live_metrics=True + ) + + # Instrument FastAPI app — exclude WebSocket URLs to reduce telemetry noise + FastAPIInstrumentor.instrument_app( + app, + excluded_urls="socket,ws" + ) + logging.info("Application Insights configured with live metrics and WebSocket filtering") +else: + logging.warning( + "No Application Insights connection string found. Telemetry disabled." + ) # Add this near the top of your app.py, after initializing the app app.add_middleware( diff --git a/src/backend/common/config/app_config.py b/src/backend/common/config/app_config.py index 594a528d..e4801ca2 100644 --- a/src/backend/common/config/app_config.py +++ b/src/backend/common/config/app_config.py @@ -6,6 +6,10 @@ from azure.ai.projects.aio import AIProjectClient from azure.cosmos import CosmosClient from azure.identity import DefaultAzureCredential, ManagedIdentityCredential +from azure.identity.aio import ( + DefaultAzureCredential as DefaultAzureCredentialAsync, + ManagedIdentityCredential as ManagedIdentityCredentialAsync, +) from dotenv import load_dotenv @@ -113,7 +117,8 @@ def get_azure_credential(self, client_id=None): """ Returns an Azure credential based on the application environment. - If the environment is 'dev', it uses DefaultAzureCredential. + If the environment is 'dev', it uses DefaultAzureCredential with exclude_environment_credential=True + to avoid EnvironmentCredential exceptions in Application Insights traces. Otherwise, it uses ManagedIdentityCredential. Args: @@ -123,10 +128,29 @@ def get_azure_credential(self, client_id=None): Credential object: Either DefaultAzureCredential or ManagedIdentityCredential. """ if self.APP_ENV == "dev": - return DefaultAzureCredential() # CodeQL [SM05139]: DefaultAzureCredential is safe here + return DefaultAzureCredential(exclude_environment_credential=True) # CodeQL [SM05139]: DefaultAzureCredential is safe here else: return ManagedIdentityCredential(client_id=client_id) + def get_azure_credential_async(self, client_id=None): + """ + Returns an async Azure credential based on the application environment. + + If the environment is 'dev', it uses DefaultAzureCredential (async) with exclude_environment_credential=True + to avoid EnvironmentCredential exceptions in Application Insights traces. + Otherwise, it uses ManagedIdentityCredential (async). + + Args: + client_id (str, optional): The client ID for the Managed Identity Credential. + + Returns: + Async Credential object: Either DefaultAzureCredentialAsync or ManagedIdentityCredentialAsync. + """ + if self.APP_ENV == "dev": + return DefaultAzureCredentialAsync(exclude_environment_credential=True) + else: + return ManagedIdentityCredentialAsync(client_id=client_id) + def get_azure_credentials(self): """Retrieve Azure credentials, either from environment variables or managed identity.""" if self._azure_credentials is None: diff --git a/src/backend/v4/api/router.py b/src/backend/v4/api/router.py index d9a8e7c1..2a3d5fd9 100644 --- a/src/backend/v4/api/router.py +++ b/src/backend/v4/api/router.py @@ -4,6 +4,8 @@ import uuid from typing import Optional +from opentelemetry import trace + import v4.models.messages as messages from v4.models.messages import WebsocketMessageType from auth.auth_utils import get_authenticated_user_details @@ -60,42 +62,62 @@ async def start_comms( user_id = user_id or "00000000-0000-0000-0000-000000000000" - # Add to the connection manager for backend updates - connection_config.add_connection( - process_id=process_id, connection=websocket, user_id=user_id - ) - track_event_if_configured( - "WebSocketConnectionAccepted", {"process_id": process_id, "user_id": user_id} - ) + # Manually create a span for WebSocket since excluded_urls suppresses auto-instrumentation. + # Without this, all track_event_if_configured calls inside WebSocket would get operation_Id = 0. + tracer = trace.get_tracer(__name__) + with tracer.start_as_current_span( + "WebSocket_Connection", + attributes={"process_id": process_id, "user_id": user_id}, + ) as ws_span: + # Resolve session_id from plan for telemetry + session_id = None + try: + memory_store = await DatabaseFactory.get_database(user_id=user_id) + plan = await memory_store.get_plan_by_plan_id(plan_id=process_id) + if plan: + session_id = getattr(plan, 'session_id', None) + if session_id: + ws_span.set_attribute("session_id", session_id) + except Exception as e: + logging.warning(f"[websocket] Failed to resolve session_id: {e}") + + # Add to the connection manager for backend updates + connection_config.add_connection( + process_id=process_id, connection=websocket, user_id=user_id + ) + ws_props = {"process_id": process_id, "user_id": user_id} + if session_id: + ws_props["session_id"] = session_id + track_event_if_configured("WebSocket_Connected", ws_props) - # Keep the connection open - FastAPI will close the connection if this returns - try: # Keep the connection open - FastAPI will close the connection if this returns - while True: - # no expectation that we will receive anything from the client but this keeps - # the connection open and does not take cpu cycle - try: - message = await websocket.receive_text() - logging.debug(f"Received WebSocket message from {user_id}: {message}") - except asyncio.TimeoutError: - # Ignore timeouts to keep the WebSocket connection open, but avoid a tight loop. - logging.debug( - f"WebSocket receive timeout for user {user_id}, process {process_id}" - ) - await asyncio.sleep(0.1) - except WebSocketDisconnect: - track_event_if_configured( - "WebSocketDisconnect", - {"process_id": process_id, "user_id": user_id}, - ) - logging.info(f"Client disconnected from batch {process_id}") - break - except Exception as e: - # Fixed logging syntax - removed the error= parameter - logging.error(f"Error in WebSocket connection: {str(e)}") - finally: - # Always clean up the connection - await connection_config.close_connection(process_id=process_id) + try: + # Keep the connection open - FastAPI will close the connection if this returns + while True: + # no expectation that we will receive anything from the client but this keeps + # the connection open and does not take cpu cycle + try: + message = await websocket.receive_text() + logging.debug(f"Received WebSocket message from {user_id}: {message}") + except asyncio.TimeoutError: + # Ignore timeouts to keep the WebSocket connection open, but avoid a tight loop. + logging.debug( + f"WebSocket receive timeout for user {user_id}, process {process_id}" + ) + await asyncio.sleep(0.1) + except WebSocketDisconnect: + dc_props = {"process_id": process_id, "user_id": user_id} + if session_id: + dc_props["session_id"] = session_id + track_event_if_configured("WebSocket_Disconnected", dc_props) + logging.info(f"Client disconnected from batch {process_id}") + break + except Exception as e: + # Fixed logging syntax - removed the error= parameter + logging.error(f"Error in WebSocket connection: {str(e)}") + finally: + # Always clean up the connection + await connection_config.close_connection(process_id=process_id) @app_v4.get("/init_team") @@ -115,7 +137,7 @@ async def init_team( user_id = authenticated_user["user_principal_id"] if not user_id: track_event_if_configured( - "UserIdNotFound", {"status_code": 400, "detail": "no user"} + "Error_User_Not_Found", {"status_code": 400, "detail": "no user"} ) raise HTTPException(status_code=400, detail="no user") @@ -186,7 +208,7 @@ async def init_team( except Exception as e: track_event_if_configured( - "InitTeamFailed", + "Error_Init_Team_Failed", { "error": str(e), }, @@ -251,9 +273,10 @@ async def process_request( authenticated_user = get_authenticated_user_details(request_headers=request.headers) user_id = authenticated_user["user_principal_id"] if not user_id: - track_event_if_configured( - "UserIdNotFound", {"status_code": 400, "detail": "no user"} - ) + event_props = {"status_code": 400, "detail": "no user"} + if input_task and hasattr(input_task, 'session_id') and input_task.session_id: + event_props["session_id"] = input_task.session_id + track_event_if_configured("Error_User_Not_Found", event_props) raise HTTPException(status_code=400, detail="no user found") try: memory_store = await DatabaseFactory.get_database(user_id=user_id) @@ -275,7 +298,7 @@ async def process_request( if not await rai_success(input_task.description, team, memory_store): track_event_if_configured( - "RAI failed", + "Error_RAI_Check_Failed", { "status": "Plan not created - RAI check failed", "description": input_task.description, @@ -289,6 +312,12 @@ async def process_request( if not input_task.session_id: input_task.session_id = str(uuid.uuid4()) + + # Attach session_id to current span for Application Insights + span = trace.get_current_span() + if span: + span.set_attribute("session_id", input_task.session_id) + try: plan_id = str(uuid.uuid4()) # Initialize memory store and service @@ -315,7 +344,7 @@ async def process_request( ) track_event_if_configured( - "PlanCreated", + "Plan_Created", { "status": "success", "plan_id": plan.plan_id, @@ -328,7 +357,7 @@ async def process_request( except Exception as e: print(f"Error creating plan: {e}") track_event_if_configured( - "PlanCreationFailed", + "Error_Plan_Creation_Failed", { "status": "error", "description": input_task.description, @@ -354,7 +383,7 @@ async def run_orchestration_task(): except Exception as e: track_event_if_configured( - "RequestStartFailed", + "Error_Request_Start_Failed", { "session_id": input_task.session_id, "description": input_task.description, @@ -424,6 +453,21 @@ async def plan_approval( raise HTTPException( status_code=401, detail="Missing or invalid user information" ) + + # Attach session_id to span if plan_id is available and capture for events + session_id = None + if human_feedback.plan_id: + try: + memory_store = await DatabaseFactory.get_database(user_id=user_id) + plan = await memory_store.get_plan_by_plan_id(plan_id=human_feedback.plan_id) + if plan and plan.session_id: + session_id = plan.session_id + span = trace.get_current_span() + if span: + span.set_attribute("session_id", session_id) + except Exception: + pass # Don't fail request if span attribute fails + # Set the approval in the orchestration config try: if user_id and human_feedback.m_plan_id: @@ -472,16 +516,19 @@ async def plan_approval( message_type=WebsocketMessageType.ERROR_MESSAGE, ) - track_event_if_configured( - "PlanApprovalReceived", - { - "plan_id": human_feedback.plan_id, - "m_plan_id": human_feedback.m_plan_id, - "approved": human_feedback.approved, - "user_id": user_id, - "feedback": human_feedback.feedback, - }, - ) + # Use dynamic event name based on approval status + approval_status = "Approved" if human_feedback.approved else "Rejected" + event_name = f"Plan_{approval_status}" + event_props = { + "plan_id": human_feedback.plan_id, + "m_plan_id": human_feedback.m_plan_id, + "approved": human_feedback.approved, + "user_id": user_id, + "feedback": human_feedback.feedback, + } + if session_id: + event_props["session_id"] = session_id + track_event_if_configured(event_name, event_props) return {"status": "approval recorded"} else: @@ -570,8 +617,22 @@ async def user_clarification( raise HTTPException( status_code=401, detail="Missing or invalid user information" ) + + # Attach session_id to span if plan_id is available and capture for events + session_id = None + try: memory_store = await DatabaseFactory.get_database(user_id=user_id) + if human_feedback.plan_id: + try: + plan = await memory_store.get_plan_by_plan_id(plan_id=human_feedback.plan_id) + if plan and plan.session_id: + session_id = plan.session_id + span = trace.get_current_span() + if span: + span.set_attribute("session_id", session_id) + except Exception: + pass # Don't fail request if span attribute fails user_current_team = await memory_store.get_current_team(user_id=user_id) team_id = None if user_current_team: @@ -590,16 +651,16 @@ async def user_clarification( # Set the approval in the orchestration config if user_id and human_feedback.request_id: # validate rai - if human_feedback.answer is not None or human_feedback.answer != "": + if human_feedback.answer is not None and str(human_feedback.answer).strip() != "": if not await rai_success(human_feedback.answer, team, memory_store): - track_event_if_configured( - "RAI failed", - { - "status": "Plan Clarification ", - "description": human_feedback.answer, - "request_id": human_feedback.request_id, - }, - ) + event_props = { + "status": "Plan Clarification ", + "description": human_feedback.answer, + "request_id": human_feedback.request_id, + } + if session_id: + event_props["session_id"] = session_id + track_event_if_configured("Error_RAI_Check_Failed", event_props) raise HTTPException( status_code=400, detail={ @@ -633,14 +694,14 @@ async def user_clarification( print(f"ValueError processing human clarification: {ve}") except Exception as e: print(f"Error processing human clarification: {e}") - track_event_if_configured( - "HumanClarificationReceived", - { - "request_id": human_feedback.request_id, - "answer": human_feedback.answer, - "user_id": user_id, - }, - ) + event_props = { + "request_id": human_feedback.request_id, + "answer": human_feedback.answer, + "user_id": user_id, + } + if session_id: + event_props["session_id"] = session_id + track_event_if_configured("Human_Clarification_Received", event_props) return { "status": "clarification recorded", } @@ -712,6 +773,21 @@ async def agent_message_user( raise HTTPException( status_code=401, detail="Missing or invalid user information" ) + + # Attach session_id to span if plan_id is available and capture for events + session_id = None + if agent_message.plan_id: + try: + memory_store = await DatabaseFactory.get_database(user_id=user_id) + plan = await memory_store.get_plan_by_plan_id(plan_id=agent_message.plan_id) + if plan and plan.session_id: + session_id = plan.session_id + span = trace.get_current_span() + if span: + span.set_attribute("session_id", session_id) + except Exception: + pass # Don't fail request if span attribute fails + # Set the approval in the orchestration config try: @@ -723,14 +799,16 @@ async def agent_message_user( except Exception as e: print(f"Error processing agent message: {e}") - track_event_if_configured( - "AgentMessageReceived", - { - "agent": agent_message.agent, - "content": agent_message.content, - "user_id": user_id, - }, - ) + # Use dynamic event name with agent identifier + event_name = f"Agent_Message_From_{agent_message.agent.replace(' ', '_')}" + event_props = { + "agent": agent_message.agent, + "content": agent_message.content, + "user_id": user_id, + } + if session_id: + event_props["session_id"] = session_id + track_event_if_configured(event_name, event_props) return { "status": "message recorded", } @@ -774,7 +852,7 @@ async def upload_team_config( user_id = authenticated_user["user_principal_id"] if not user_id: track_event_if_configured( - "UserIdNotFound", {"status_code": 400, "detail": "no user"} + "Error_User_Not_Found", {"status_code": 400, "detail": "no user"} ) raise HTTPException(status_code=400, detail="no user found") try: @@ -807,7 +885,7 @@ async def upload_team_config( rai_valid, rai_error = await rai_validate_team_config(json_data, memory_store) if not rai_valid: track_event_if_configured( - "Team configuration RAI validation failed", + "Error_Config_RAI_Validation_Failed", { "status": "failed", "user_id": user_id, @@ -818,7 +896,7 @@ async def upload_team_config( raise HTTPException(status_code=400, detail=rai_error) track_event_if_configured( - "Team configuration RAI validation passed", + "Config_RAI_Validation_Passed", {"status": "passed", "user_id": user_id, "filename": file.filename}, ) team_service = TeamService(memory_store) @@ -833,7 +911,7 @@ async def upload_team_config( f"Please deploy these models in Azure AI Foundry before uploading this team configuration." ) track_event_if_configured( - "Team configuration model validation failed", + "Error_Config_Model_Validation_Failed", { "status": "failed", "user_id": user_id, @@ -844,7 +922,7 @@ async def upload_team_config( raise HTTPException(status_code=400, detail=error_message) track_event_if_configured( - "Team configuration model validation passed", + "Config_Model_Validation_Passed", {"status": "passed", "user_id": user_id, "filename": file.filename}, ) @@ -860,7 +938,7 @@ async def upload_team_config( f"Please ensure all referenced search indexes exist in your Azure AI Search service." ) track_event_if_configured( - "Team configuration search validation failed", + "Error_Config_Search_Validation_Failed", { "status": "failed", "user_id": user_id, @@ -872,7 +950,7 @@ async def upload_team_config( logger.info(f"✅ Search validation passed for user: {user_id}") track_event_if_configured( - "Team configuration search validation passed", + "Config_Search_Validation_Passed", {"status": "passed", "user_id": user_id, "filename": file.filename}, ) @@ -897,7 +975,7 @@ async def upload_team_config( ) from e track_event_if_configured( - "Team configuration uploaded", + "Config_Team_Uploaded", { "status": "success", "team_id": team_id, @@ -1137,7 +1215,7 @@ async def delete_team_config(team_id: str, request: Request): # Track the event track_event_if_configured( - "Team configuration deleted", + "Config_Team_Deleted", {"status": "success", "team_id": team_id, "user_id": user_id}, ) @@ -1190,7 +1268,7 @@ async def select_team(selection: TeamSelectionRequest, request: Request): ) if not set_team: track_event_if_configured( - "Team selected", + "Error_Config_Team_Selection_Failed", { "status": "failed", "team_id": selection.team_id, @@ -1210,7 +1288,7 @@ async def select_team(selection: TeamSelectionRequest, request: Request): # Track the team selection event track_event_if_configured( - "Team selected", + "Config_Team_Selected", { "status": "success", "team_id": selection.team_id, @@ -1234,7 +1312,7 @@ async def select_team(selection: TeamSelectionRequest, request: Request): except Exception as e: logging.error(f"Error selecting team: {str(e)}") track_event_if_configured( - "Team selection error", + "Error_Config_Team_Selection", { "status": "error", "team_id": selection.team_id, @@ -1310,7 +1388,7 @@ async def get_plans(request: Request): user_id = authenticated_user["user_principal_id"] if not user_id: track_event_if_configured( - "UserIdNotFound", {"status_code": 400, "detail": "no user"} + "Error_User_Not_Found", {"status_code": 400, "detail": "no user"} ) raise HTTPException(status_code=400, detail="no user") @@ -1398,7 +1476,7 @@ async def get_plan_by_id( user_id = authenticated_user["user_principal_id"] if not user_id: track_event_if_configured( - "UserIdNotFound", {"status_code": 400, "detail": "no user"} + "Error_User_Not_Found", {"status_code": 400, "detail": "no user"} ) raise HTTPException(status_code=400, detail="no user") @@ -1410,12 +1488,17 @@ async def get_plan_by_id( if plan_id: plan = await memory_store.get_plan_by_plan_id(plan_id=plan_id) if not plan: - track_event_if_configured( - "GetPlanBySessionNotFound", - {"status_code": 400, "detail": "Plan not found"}, - ) + event_props = {"status_code": 400, "detail": "Plan not found"} + # No session_id available since plan not found + track_event_if_configured("Error_Plan_Not_Found", event_props) raise HTTPException(status_code=404, detail="Plan not found") + # Attach session_id to span + if plan.session_id: + span = trace.get_current_span() + if span: + span.set_attribute("session_id", plan.session_id) + # Use get_steps_by_plan to match the original implementation team = await memory_store.get_team_by_id(team_id=plan.team_id) diff --git a/src/backend/v4/common/services/plan_service.py b/src/backend/v4/common/services/plan_service.py index 6c1e24b6..045cf291 100644 --- a/src/backend/v4/common/services/plan_service.py +++ b/src/backend/v4/common/services/plan_service.py @@ -10,7 +10,6 @@ AgentType, PlanStatus, ) -from common.utils.event_utils import track_event_if_configured from v4.config.settings import orchestration_config logger = logging.getLogger(__name__) @@ -154,26 +153,10 @@ async def handle_plan_approval( plan.overall_status = PlanStatus.approved plan.m_plan = mplan.model_dump() await memory_store.update_plan(plan) - track_event_if_configured( - "PlanApproved", - { - "m_plan_id": human_feedback.m_plan_id, - "plan_id": human_feedback.plan_id, - "user_id": user_id, - }, - ) else: print("Plan not found in memory store.") return False else: # reject plan - track_event_if_configured( - "PlanRejected", - { - "m_plan_id": human_feedback.m_plan_id, - "plan_id": human_feedback.plan_id, - "user_id": user_id, - }, - ) await memory_store.delete_plan_by_plan_id(human_feedback.plan_id) except Exception as e: diff --git a/src/backend/v4/magentic_agents/common/lifecycle.py b/src/backend/v4/magentic_agents/common/lifecycle.py index b38e31ee..5bd02ff5 100644 --- a/src/backend/v4/magentic_agents/common/lifecycle.py +++ b/src/backend/v4/magentic_agents/common/lifecycle.py @@ -13,7 +13,7 @@ # from agent_framework.azure import AzureAIClient from agent_framework_azure_ai import AzureAIClient from azure.ai.agents.aio import AgentsClient -from azure.identity.aio import DefaultAzureCredential +from common.config.app_config import config from common.database.database_base import DatabaseBase from common.models.messages_af import TeamConfiguration from common.utils.utils_agents import ( @@ -52,7 +52,7 @@ def __init__( self.team_config: TeamConfiguration | None = team_config self.client: Optional[AgentsClient] = None self.project_endpoint = project_endpoint - self.creds: Optional[DefaultAzureCredential] = None + self.creds = None self.memory_store: Optional[DatabaseBase] = memory_store self.agent_name: str | None = agent_name self.agent_description: str | None = agent_description @@ -66,8 +66,8 @@ async def open(self) -> "MCPEnabledBase": return self self._stack = AsyncExitStack() - # Acquire credential - self.creds = DefaultAzureCredential() + # Acquire credential using centralized config method + self.creds = config.get_azure_credential_async(config.AZURE_CLIENT_ID) if self._stack: await self._stack.enter_async_context(self.creds) # Create AgentsClient diff --git a/src/tests/backend/common/config/test_app_config.py b/src/tests/backend/common/config/test_app_config.py index 2652d453..dbe445d1 100644 --- a/src/tests/backend/common/config/test_app_config.py +++ b/src/tests/backend/common/config/test_app_config.py @@ -251,7 +251,7 @@ def _get_minimal_env(self): @patch('backend.common.config.app_config.DefaultAzureCredential') def test_get_azure_credential_dev_environment(self, mock_default_credential): - """Test get_azure_credential method in dev environment.""" + """Test get_azure_credential method in dev environment with exclude_environment_credential.""" mock_credential = MagicMock() mock_default_credential.return_value = mock_credential @@ -259,7 +259,8 @@ def test_get_azure_credential_dev_environment(self, mock_default_credential): config = AppConfig() result = config.get_azure_credential() - mock_default_credential.assert_called_once() + # Verify it's called with exclude_environment_credential=True in dev + mock_default_credential.assert_called_once_with(exclude_environment_credential=True) assert result == mock_credential @patch('backend.common.config.app_config.ManagedIdentityCredential') @@ -333,6 +334,55 @@ def test_get_access_token_failure(self, mock_default_credential): with pytest.raises(Exception, match="Token retrieval failed"): credential.get_token(config.AZURE_COGNITIVE_SERVICES) + @patch('backend.common.config.app_config.DefaultAzureCredentialAsync') + def test_get_azure_credential_async_dev_environment(self, mock_default_credential_async): + """Test get_azure_credential_async method in dev environment with exclude_environment_credential.""" + mock_credential = MagicMock() + mock_default_credential_async.return_value = mock_credential + + with patch.dict(os.environ, self._get_minimal_env()): + config = AppConfig() + result = config.get_azure_credential_async() + + # Verify it's called with exclude_environment_credential=True in dev + mock_default_credential_async.assert_called_once_with(exclude_environment_credential=True) + assert result == mock_credential + + @patch('backend.common.config.app_config.ManagedIdentityCredentialAsync') + def test_get_azure_credential_async_prod_environment(self, mock_managed_credential_async): + """Test get_azure_credential_async method in production environment.""" + mock_credential = MagicMock() + mock_managed_credential_async.return_value = mock_credential + + env = self._get_minimal_env() + env["APP_ENV"] = "prod" + env["AZURE_CLIENT_ID"] = "test-client-id" + + with patch.dict(os.environ, env): + config = AppConfig() + result = config.get_azure_credential_async("test-client-id") + + mock_managed_credential_async.assert_called_once_with(client_id="test-client-id") + assert result == mock_credential + + @patch('backend.common.config.app_config.ManagedIdentityCredentialAsync') + def test_get_azure_credential_async_prod_uppercase(self, mock_managed_credential_async): + """Test get_azure_credential_async handles uppercase Prod environment value.""" + mock_credential = MagicMock() + mock_managed_credential_async.return_value = mock_credential + + env = self._get_minimal_env() + env["APP_ENV"] = "Prod" # Bicep sets it as "Prod" with capital P + env["AZURE_CLIENT_ID"] = "test-client-id" + + with patch.dict(os.environ, env): + config = AppConfig() + result = config.get_azure_credential_async("test-client-id") + + # Should use ManagedIdentityCredential even with capital "Prod" + mock_managed_credential_async.assert_called_once_with(client_id="test-client-id") + assert result == mock_credential + class TestAppConfigClientMethods: """Test cases for client creation methods in AppConfig class.""" diff --git a/src/tests/backend/v4/magentic_agents/common/test_lifecycle.py b/src/tests/backend/v4/magentic_agents/common/test_lifecycle.py index 25a33dfc..129b7213 100644 --- a/src/tests/backend/v4/magentic_agents/common/test_lifecycle.py +++ b/src/tests/backend/v4/magentic_agents/common/test_lifecycle.py @@ -171,7 +171,9 @@ async def test_open_method_success(self): mock_mcp_tool = AsyncMock() with patch('backend.v4.magentic_agents.common.lifecycle.AsyncExitStack', return_value=mock_stack): - with patch('backend.v4.magentic_agents.common.lifecycle.DefaultAzureCredential', return_value=mock_creds): + with patch('backend.v4.magentic_agents.common.lifecycle.config') as mock_config: + mock_config.get_azure_credential_async.return_value = mock_creds + mock_config.AZURE_CLIENT_ID = "test-client-id" with patch('backend.v4.magentic_agents.common.lifecycle.AgentsClient', return_value=mock_client): with patch('backend.v4.magentic_agents.common.lifecycle.MCPStreamableHTTPTool', return_value=mock_mcp_tool): with patch.object(base, '_after_open', new_callable=AsyncMock) as mock_after_open: @@ -182,6 +184,7 @@ async def test_open_method_success(self): assert base._stack is mock_stack assert base.creds is mock_creds assert base.client is mock_client + mock_config.get_azure_credential_async.assert_called_once_with("test-client-id") mock_after_open.assert_called_once() mock_agent_registry.register_agent.assert_called_once_with(base) @@ -207,7 +210,9 @@ async def test_open_method_registration_failure(self): mock_client = AsyncMock() with patch('backend.v4.magentic_agents.common.lifecycle.AsyncExitStack', return_value=mock_stack): - with patch('backend.v4.magentic_agents.common.lifecycle.DefaultAzureCredential', return_value=mock_creds): + with patch('backend.v4.magentic_agents.common.lifecycle.config') as mock_config: + mock_config.get_azure_credential_async.return_value = mock_creds + mock_config.AZURE_CLIENT_ID = "test-client-id" with patch('backend.v4.magentic_agents.common.lifecycle.AgentsClient', return_value=mock_client): with patch.object(base, '_after_open', new_callable=AsyncMock): mock_agent_registry.register_agent.side_effect = Exception("Registration failed") @@ -216,6 +221,7 @@ async def test_open_method_registration_failure(self): result = await base.open() assert result is base + mock_config.get_azure_credential_async.assert_called_once_with("test-client-id") mock_agent_registry.register_agent.assert_called_once_with(base) @pytest.mark.asyncio