diff --git a/infra/main.bicep b/infra/main.bicep index a029af2c..31852aa0 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -291,8 +291,25 @@ module applicationInsights 'br/public:avm/res/insights/component:0.7.0' = if (en retentionInDays: 365 kind: 'web' disableIpMasking: false - disableLocalAuth: true flowType: 'Bluefield' + // WAF aligned configuration for Private Networking - block public ingestion/query + publicNetworkAccessForIngestion: enablePrivateNetworking ? 'Disabled' : 'Enabled' + publicNetworkAccessForQuery: enablePrivateNetworking ? 'Disabled' : 'Enabled' + } +} + +// ========== Data Collection Endpoint (DCE) ========== // +// Required for Azure Monitor Private Link - provides private ingestion and configuration endpoints +// Per: https://learn.microsoft.com/en-us/azure/azure-monitor/fundamentals/private-link-configure +module dataCollectionEndpoint 'br/public:avm/res/insights/data-collection-endpoint:0.5.0' = if (enablePrivateNetworking && enableMonitoring) { + name: take('avm.res.insights.data-collection-endpoint.${solutionSuffix}', 64) + params: { + name: 'dce-${solutionSuffix}' + location: location + kind: 'Windows' + publicNetworkAccess: 'Disabled' + tags: allTags + enableTelemetry: enableTelemetry } } @@ -320,6 +337,10 @@ var privateDnsZones = [ 'privatelink.vaultcore.azure.net' 'privatelink.blob.${environment().suffixes.storage}' 'privatelink.file.${environment().suffixes.storage}' + 'privatelink.monitor.azure.com' // Azure Monitor global endpoints (App Insights, DCE) + 'privatelink.oms.opinsights.azure.com' // Log Analytics OMS endpoints + 'privatelink.ods.opinsights.azure.com' // Log Analytics ODS ingestion endpoints + 'privatelink.agentsvc.azure-automation.net' // Agent service automation endpoints ] // DNS Zone Index Constants @@ -331,6 +352,10 @@ var dnsZoneIndex = { keyVault: 4 storageBlob: 5 storageFile: 6 + monitor: 7 + oms: 8 + ods: 9 + agentSvc: 10 } // =================================================== @@ -356,6 +381,76 @@ module avmPrivateDnsZones 'br/public:avm/res/network/private-dns-zone:0.8.0' = [ } ] +// ========== Azure Monitor Private Link Scope (AMPLS) ========== // +// Step 1: Create AMPLS +// Step 2: Connect Azure Monitor resources (LAW, Application Insights, DCE) to the AMPLS +// Step 3: Connect AMPLS to a private endpoint with required DNS zones +// Per: https://learn.microsoft.com/en-us/azure/azure-monitor/fundamentals/private-link-configure +module azureMonitorPrivateLinkScope 'br/public:avm/res/insights/private-link-scope:0.6.0' = if (enablePrivateNetworking) { + name: take('avm.res.insights.private-link-scope.${solutionSuffix}', 64) + #disable-next-line no-unnecessary-dependson + dependsOn: [logAnalyticsWorkspace, applicationInsights, dataCollectionEndpoint, virtualNetwork] + params: { + name: 'ampls-${solutionSuffix}' + location: 'global' + // Access mode: PrivateOnly ensures all ingestion and queries go through private link + accessModeSettings: { + ingestionAccessMode: 'PrivateOnly' + queryAccessMode: 'PrivateOnly' + } + // Step 2: Connect Azure Monitor resources to the AMPLS as scoped resources + scopedResources: concat([ + { + name: 'scoped-law' + linkedResourceId: logAnalyticsWorkspaceResourceId + } + ], enableMonitoring ? [ + { + name: 'scoped-appi' + linkedResourceId: applicationInsights!.outputs.resourceId + } + { + name: 'scoped-dce' + linkedResourceId: dataCollectionEndpoint!.outputs.resourceId + } + ] : []) + // Step 3: Connect AMPLS to a private endpoint + // The private endpoint requires 5 DNS zones per documentation: + // - privatelink.monitor.azure.com (App Insights + DCE global endpoints) + // - privatelink.oms.opinsights.azure.com (Log Analytics OMS) + // - privatelink.ods.opinsights.azure.com (Log Analytics ODS ingestion) + // - privatelink.agentsvc.azure-automation.net (Agent service automation) + // - privatelink.blob.core.windows.net (Agent solution packs storage) + privateEndpoints: [ + { + name: 'pep-ampls-${solutionSuffix}' + subnetResourceId: virtualNetwork!.outputs.pepsSubnetResourceId + privateDnsZoneGroup: { + privateDnsZoneGroupConfigs: [ + { + privateDnsZoneResourceId: avmPrivateDnsZones[dnsZoneIndex.monitor]!.outputs.resourceId + } + { + privateDnsZoneResourceId: avmPrivateDnsZones[dnsZoneIndex.oms]!.outputs.resourceId + } + { + privateDnsZoneResourceId: avmPrivateDnsZones[dnsZoneIndex.ods]!.outputs.resourceId + } + { + privateDnsZoneResourceId: avmPrivateDnsZones[dnsZoneIndex.agentSvc]!.outputs.resourceId + } + { + privateDnsZoneResourceId: avmPrivateDnsZones[dnsZoneIndex.storageBlob]!.outputs.resourceId + } + ] + } + } + ] + tags: allTags + enableTelemetry: enableTelemetry + } +} + // Azure Bastion Host var bastionHostName = 'bas-${solutionSuffix}' module bastionHost 'br/public:avm/res/network/bastion-host:0.8.0' = if (enablePrivateNetworking) { @@ -437,6 +532,7 @@ module windowsVmDataCollectionRules 'br/public:avm/res/insights/data-collection- location: dataCollectionRulesLocation dataCollectionRuleProperties: { kind: 'Windows' + dataCollectionEndpointResourceId: dataCollectionEndpoint!.outputs.resourceId dataSources: { performanceCounters: [ { @@ -495,26 +591,6 @@ module windowsVmDataCollectionRules 'br/public:avm/res/insights/data-collection- name: 'perfCounterDataSource60' } ] - windowsEventLogs: [ - { - name: 'SecurityAuditEvents' - streams: [ - 'Microsoft-WindowsEvent' - ] - eventLogName: 'Security' - eventTypes: [ - { - eventType: 'Audit Success' - } - { - eventType: 'Audit Failure' - } - ] - xPathQueries: [ - 'Security!*[System[(EventID=4624 or EventID=4625)]]' - ] - } - ] } destinations: { logAnalytics: [ @@ -532,8 +608,6 @@ module windowsVmDataCollectionRules 'br/public:avm/res/insights/data-collection- destinations: [ 'la-${dataCollectionRulesResourceName}' ] - transformKql: 'source' - outputStream: 'Microsoft-Perf' } ] } diff --git a/src/backend/api/api_routes.py b/src/backend/api/api_routes.py index e3c6a372..bff54d75 100644 --- a/src/backend/api/api_routes.py +++ b/src/backend/api/api_routes.py @@ -3,8 +3,6 @@ # Standard library import asyncio import io -import logging -import os import zipfile from typing import Optional @@ -14,9 +12,6 @@ from api.status_updates import app_connection_manager, close_connection # Third-party -# Azure Monitor OpenTelemetry integration is currently causing issues with OpenAI calls in process_batch_async, needs further investigation, commenting out for now -# from azure.monitor.opentelemetry import configure_azure_monitor - from common.logger.app_logger import AppLogger from common.services.batch_service import BatchService @@ -40,21 +35,6 @@ router = APIRouter() logger = AppLogger("APIRoutes") -# Check if the Application Insights Instrumentation Key is set in the environment variables -instrumentation_key = os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING") -if instrumentation_key: - # Configure Application Insights if the Instrumentation Key is found - # commenting below line as configure_azure_monitor is causing issues with OpenAI calls in process_batch_async, needs further investigation - # configure_azure_monitor(connection_string=instrumentation_key) - logging.info( - "Application Insights configured with the provided Instrumentation Key" - ) -else: - # Log a warning if the Instrumentation Key is not found - logging.warning( - "No Application Insights Instrumentation Key found. Skipping configuration" - ) - def record_exception_to_trace(e): """Record exception to the current OpenTelemetry trace span.""" diff --git a/src/backend/api/event_utils.py b/src/backend/api/event_utils.py index 97c0a196..11905a8a 100644 --- a/src/backend/api/event_utils.py +++ b/src/backend/api/event_utils.py @@ -3,17 +3,69 @@ import os # Third-party -from azure.monitor.events.extension import track_event +from applicationinsights import TelemetryClient +from applicationinsights.channel import SynchronousQueue, SynchronousSender, TelemetryChannel from dotenv import load_dotenv load_dotenv() +# Global telemetry client (initialized once) +_telemetry_client = None + + +def _get_telemetry_client(): + """Get or create the Application Insights telemetry client.""" + global _telemetry_client + + if _telemetry_client is None: + connection_string = os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING") + if connection_string: + try: + # Extract instrumentation key from connection string + # Format: InstrumentationKey=xxx;IngestionEndpoint=https://... + parts = dict(part.split('=', 1) for part in connection_string.split(';') if '=' in part) + instrumentation_key = parts.get('InstrumentationKey') + + if instrumentation_key: + # Create a synchronous channel for immediate sending + sender = SynchronousSender() + queue = SynchronousQueue(sender) + channel = TelemetryChannel(None, queue) + + _telemetry_client = TelemetryClient(instrumentation_key, channel) + logging.info("Application Insights TelemetryClient initialized successfully") + else: + logging.error("Could not extract InstrumentationKey from connection string") + except Exception as e: + logging.error(f"Failed to initialize TelemetryClient: {e}") + + return _telemetry_client + def track_event_if_configured(event_name: str, event_data: dict): + """Track a custom event to Application Insights customEvents table. + + This uses the Application Insights SDK TelemetryClient which properly + sends custom events to the customEvents table in Application Insights. + """ instrumentation_key = os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING") if instrumentation_key: - track_event(event_name, event_data) + try: + client = _get_telemetry_client() + if client: + # Convert all values to strings to ensure compatibility + properties = {k: str(v) for k, v in event_data.items()} + + # Track the custom event + client.track_event(event_name, properties=properties) + client.flush() # Ensure immediate sending + + logging.debug(f"Tracked custom event: {event_name} with data: {event_data}") + else: + logging.warning("TelemetryClient not available, custom event not tracked") + except Exception as e: + logging.error(f"Failed to track event {event_name}: {e}") else: logging.warning( f"Skipping track_event for {event_name} as Application Insights is not configured" diff --git a/src/backend/app.py b/src/backend/app.py index 0f877fc7..58ee1a29 100644 --- a/src/backend/app.py +++ b/src/backend/app.py @@ -5,6 +5,8 @@ from api.api_routes import router as backend_router +from azure.monitor.opentelemetry.exporter import AzureMonitorLogExporter, AzureMonitorTraceExporter + from common.config.config import app_config from common.logger.app_logger import AppLogger @@ -15,6 +17,14 @@ from helper.azure_credential_utils import get_azure_credential +from opentelemetry import trace +from opentelemetry._logs import set_logger_provider +from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor +from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent # pylint: disable=E0611 from sql_agents.agent_manager import clear_sql_agents, set_sql_agents @@ -46,6 +56,11 @@ for logger_name in AZURE_LOGGING_PACKAGES: logging.getLogger(logger_name).setLevel(getattr(logging, AZURE_PACKAGE_LOGGING_LEVEL, logging.WARNING)) +# Suppress noisy OpenTelemetry and Azure Monitor logs +# logging.getLogger("opentelemetry.sdk").setLevel(logging.ERROR) +# logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING) +# logging.getLogger("azure.monitor.opentelemetry.exporter.export._base").setLevel(logging.WARNING) + logger = AppLogger("app") # Global variables for agents @@ -119,6 +134,59 @@ def create_app() -> FastAPI: allow_headers=["*"], ) + # Configure Azure Monitor and instrument FastAPI for OpenTelemetry + # This must happen AFTER app creation but BEFORE route registration + instrumentation_key = os.getenv("APPLICATIONINSIGHTS_CONNECTION_STRING") + if instrumentation_key: + # SOLUTION: Use manual telemetry setup instead of configure_azure_monitor + # This gives us precise control over what gets instrumented, avoiding interference + # with Semantic Kernel's async generators while still tracking Azure SDK calls + + # Set up Azure Monitor exporter for traces + azure_trace_exporter = AzureMonitorTraceExporter(connection_string=instrumentation_key) + + # Create a tracer provider and add the Azure Monitor exporter + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(BatchSpanProcessor(azure_trace_exporter)) + + # Set the global tracer provider + trace.set_tracer_provider(tracer_provider) + + # Set up Azure Monitor exporter for logs (appears in traces table) + azure_log_exporter = AzureMonitorLogExporter(connection_string=instrumentation_key) + + # Create a logger provider and add the Azure Monitor exporter + logger_provider = LoggerProvider() + logger_provider.add_log_record_processor(BatchLogRecordProcessor(azure_log_exporter)) + set_logger_provider(logger_provider) + + # Attach OpenTelemetry handler to Python's root logger + handler = LoggingHandler(logger_provider=logger_provider) + logging.getLogger().addHandler(handler) + + # Instrument ONLY FastAPI for HTTP request/response tracing + # This is safe because it only wraps HTTP handlers, not internal async operations + FastAPIInstrumentor.instrument_app( + app, + excluded_urls="socket,ws", # Exclude WebSocket URLs to reduce noise + tracer_provider=tracer_provider + ) + + # Optional: Add manual spans in your code for Azure SDK operations using: + # from opentelemetry import trace + # tracer = trace.get_tracer(__name__) + # with tracer.start_as_current_span("operation_name"): + # # your Azure SDK call here + + logger.logger.info("Application Insights configured with selective instrumentation") + logger.logger.info("✓ FastAPI HTTP tracing enabled") + logger.logger.info("✓ Python logging export to Application Insights enabled") + logger.logger.info("✓ Manual span support enabled for Azure SDK operations") + logger.logger.info("✓ Custom events via OpenTelemetry enabled") + logger.logger.info("✓ Semantic Kernel async generators unaffected") + else: + logger.logger.warning("No Application Insights connection string found. Telemetry disabled.") + # Include routers with /api prefix app.include_router(backend_router, prefix="/api", tags=["backend"]) # app.include_router(agents_router, prefix="/api/agents", tags=["agents"]) diff --git a/src/backend/common/telemetry/__init__.py b/src/backend/common/telemetry/__init__.py new file mode 100644 index 00000000..1e725b70 --- /dev/null +++ b/src/backend/common/telemetry/__init__.py @@ -0,0 +1,17 @@ +"""Telemetry utilities for Application Insights integration.""" + +from common.telemetry.telemetry_helper import ( + add_span_attributes, + get_tracer, + trace_context, + trace_operation, + trace_sync_context, +) + +__all__ = [ + "trace_operation", + "trace_context", + "trace_sync_context", + "get_tracer", + "add_span_attributes", +] diff --git a/src/backend/common/telemetry/telemetry_helper.py b/src/backend/common/telemetry/telemetry_helper.py new file mode 100644 index 00000000..e1f87900 --- /dev/null +++ b/src/backend/common/telemetry/telemetry_helper.py @@ -0,0 +1,160 @@ +"""Helper utilities for adding telemetry spans to Azure SDK operations. + +This module provides decorators and context managers for adding OpenTelemetry +spans to Azure SDK calls (CosmosDB, Blob Storage, etc.) without interfering +with Semantic Kernel's async generators. + +Example usage: + from common.telemetry.telemetry_helper import trace_operation + + @trace_operation("cosmosdb_query") + async def query_items(self, query: str): + # Your CosmosDB query here + pass +""" + +import asyncio +import functools +from contextlib import asynccontextmanager, contextmanager +from typing import Optional + +from opentelemetry import trace +from opentelemetry.trace import Status, StatusCode + + +def get_tracer(name: str = __name__): + """Get a tracer instance for the given name.""" + return trace.get_tracer(name) + + +def trace_operation(operation_name: str, attributes: Optional[dict] = None): + """Decorator to add telemetry span to a function or method. + + Args: + operation_name: Name of the operation for the span + attributes: Optional dictionary of attributes to add to the span + + Example: + @trace_operation("batch_processing", {"service": "sql_agents"}) + async def process_batch(batch_id: str): + # Your code here + pass + """ + def decorator(func): + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + tracer = get_tracer(func.__module__) + with tracer.start_as_current_span(operation_name) as span: + # Add custom attributes if provided + if attributes: + for key, value in attributes.items(): + span.set_attribute(key, str(value)) + + # Add function arguments as attributes (optional, for debugging) + span.set_attribute("function", func.__name__) + + try: + result = await func(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + return result + except Exception as e: + span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, str(e))) + raise + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + tracer = get_tracer(func.__module__) + with tracer.start_as_current_span(operation_name) as span: + if attributes: + for key, value in attributes.items(): + span.set_attribute(key, str(value)) + + span.set_attribute("function", func.__name__) + + try: + result = func(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + return result + except Exception as e: + span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, str(e))) + raise + + # Return appropriate wrapper based on function type + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + +@asynccontextmanager +async def trace_context(operation_name: str, attributes: Optional[dict] = None): + """Async context manager for adding telemetry span to a code block. + + Args: + operation_name: Name of the operation for the span + attributes: Optional dictionary of attributes to add to the span + + Example: + async with trace_context("cosmosdb_batch_query", {"batch_id": batch_id}): + results = await database.query_items(query) + # Your code here + """ + tracer = get_tracer() + with tracer.start_as_current_span(operation_name) as span: + if attributes: + for key, value in attributes.items(): + span.set_attribute(key, str(value)) + + try: + yield span + span.set_status(Status(StatusCode.OK)) + except Exception as e: + span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, str(e))) + raise + + +@contextmanager +def trace_sync_context(operation_name: str, attributes: Optional[dict] = None): + """Sync context manager for adding telemetry span to a code block. + + Args: + operation_name: Name of the operation for the span + attributes: Optional dictionary of attributes to add to the span + + Example: + with trace_sync_context("blob_upload", {"file_name": file_name}): + blob_client.upload_blob(data) + """ + tracer = get_tracer() + with tracer.start_as_current_span(operation_name) as span: + if attributes: + for key, value in attributes.items(): + span.set_attribute(key, str(value)) + + try: + yield span + span.set_status(Status(StatusCode.OK)) + except Exception as e: + span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, str(e))) + raise + + +def add_span_attributes(attributes: dict): + """Add attributes to the current span. + + Args: + attributes: Dictionary of attributes to add + + Example: + add_span_attributes({"user_id": user_id, "batch_id": batch_id}) + """ + span = trace.get_current_span() + if span and span.is_recording(): + for key, value in attributes.items(): + span.set_attribute(key, str(value)) diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index a8ccfcfd..84b96d5c 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -44,10 +44,11 @@ starlette aiortc opentelemetry-exporter-otlp-proto-grpc opentelemetry-exporter-otlp-proto-http -azure-monitor-events-extension +applicationinsights opentelemetry-sdk==1.39.0 opentelemetry-api==1.39.0 opentelemetry-semantic-conventions==0.60b0 opentelemetry-instrumentation==0.60b0 +opentelemetry-instrumentation-fastapi==0.60b0 azure-monitor-opentelemetry==1.8.6 azure-ai-projects==1.0.0 diff --git a/src/backend/sql_agents/process_batch.py b/src/backend/sql_agents/process_batch.py index 192fa022..1da3467f 100644 --- a/src/backend/sql_agents/process_batch.py +++ b/src/backend/sql_agents/process_batch.py @@ -17,6 +17,7 @@ ) from common.services.batch_service import BatchService from common.storage.blob_factory import BlobStorageFactory +from common.telemetry import trace_context from fastapi import HTTPException @@ -37,114 +38,120 @@ async def process_batch_async( batch_id: str, convert_from: str = "informix", convert_to: str = "tsql" ): """Central batch processing function to process each file in the batch""" - logger.info("Processing batch: %s", batch_id) - storage = await BlobStorageFactory.get_storage() - batch_service = BatchService() - await batch_service.initialize_database() - - try: - batch_files = await batch_service.database.get_batch_files(batch_id) - if not batch_files: - raise HTTPException(status_code=404, detail="Batch not found") - # Retrieve list of file paths - await batch_service.update_batch(batch_id, ProcessStatus.IN_PROGRESS) - except Exception as exc: - logger.error("Error updating batch status. %s", exc) - - # Get the global SQL agents instance - sql_agents = get_sql_agents() - if not sql_agents: - logger.error("SQL agents not initialized. Application may not have started properly.") - await batch_service.update_batch(batch_id, ProcessStatus.FAILED) - return - - # Update agent configuration for this batch's conversion requirements - await update_agent_config(convert_from, convert_to) - - # Walk through each file name and retrieve it from blob storage - # Send file to the agents for processing - # Send status update to the client of type in progress, completed, or failed - for file in batch_files: - # Get the file from blob storage + # Add telemetry span for the entire batch processing operation + async with trace_context("process_batch", { + "batch_id": batch_id, + "convert_from": convert_from, + "convert_to": convert_to + }): + logger.info("Processing batch: %s", batch_id) + storage = await BlobStorageFactory.get_storage() + batch_service = BatchService() + await batch_service.initialize_database() + try: - file_record = FileRecord.fromdb(file) - # Update the file status + batch_files = await batch_service.database.get_batch_files(batch_id) + if not batch_files: + raise HTTPException(status_code=404, detail="Batch not found") + # Retrieve list of file paths + await batch_service.update_batch(batch_id, ProcessStatus.IN_PROGRESS) + except Exception as exc: + logger.error("Error updating batch status. %s", exc) + + # Get the global SQL agents instance + sql_agents = get_sql_agents() + if not sql_agents: + logger.error("SQL agents not initialized. Application may not have started properly.") + await batch_service.update_batch(batch_id, ProcessStatus.FAILED) + return + + # Update agent configuration for this batch's conversion requirements + await update_agent_config(convert_from, convert_to) + + # Walk through each file name and retrieve it from blob storage + # Send file to the agents for processing + # Send status update to the client of type in progress, completed, or failed + for file in batch_files: + # Get the file from blob storage try: - file_record.status = ProcessStatus.IN_PROGRESS - await batch_service.update_file_record(file_record) + file_record = FileRecord.fromdb(file) + # Update the file status + try: + file_record.status = ProcessStatus.IN_PROGRESS + await batch_service.update_file_record(file_record) + except Exception as exc: + logger.error("Error updating file status. %s", exc) + + sql_in_file = await storage.get_file(file_record.blob_path) + + # split into base validation routine + # Check if the file is a valid text file <-- + if not is_text(sql_in_file): + logger.error("File is not a valid text file. Skipping.") + # insert data base write to file record stating invalid file + await batch_service.create_file_log( + str(file_record.file_id), + "File is not a valid text file. Skipping.", + "", + LogType.ERROR, + AgentType.ALL, + AuthorRole.ASSISTANT, + ) + # send status update to the client of type failed + send_status_update( + status=FileProcessUpdate( + file_record.batch_id, + file_record.file_id, + ProcessStatus.COMPLETED, + file_result=FileResult.ERROR, + ), + ) + file_record.file_result = FileResult.ERROR + file_record.status = ProcessStatus.COMPLETED + file_record.error_count = 1 + await batch_service.update_file_record(file_record) + continue + else: + logger.info("sql_in_file: %s", sql_in_file) + + # Convert the file + converted_query = await convert_script( + sql_in_file, + file_record, + batch_service, + sql_agents, + ) + if converted_query: + # Add RAI disclaimer to the converted query + converted_query = add_rai_disclaimer(converted_query) + await batch_service.create_candidate( + file["file_id"], converted_query + ) + else: + await batch_service.update_file_counts(file["file_id"]) + except UnicodeDecodeError as ucde: + logger.error("Error decoding file: %s", file) + logger.error("Error decoding file. %s", ucde) + await process_error(ucde, file_record, batch_service) + except ServiceResponseException as sre: + logger.error(file) + logger.error("Error processing file. %s", sre) + # insert data base write to file record stating invalid file + await process_error(sre, file_record, batch_service) except Exception as exc: - logger.error("Error updating file status. %s", exc) - - sql_in_file = await storage.get_file(file_record.blob_path) - - # split into base validation routine - # Check if the file is a valid text file <-- - if not is_text(sql_in_file): - logger.error("File is not a valid text file. Skipping.") + logger.error(file) + logger.error("Error processing file. %s", exc) # insert data base write to file record stating invalid file - await batch_service.create_file_log( - str(file_record.file_id), - "File is not a valid text file. Skipping.", - "", - LogType.ERROR, - AgentType.ALL, - AuthorRole.ASSISTANT, - ) - # send status update to the client of type failed - send_status_update( - status=FileProcessUpdate( - file_record.batch_id, - file_record.file_id, - ProcessStatus.COMPLETED, - file_result=FileResult.ERROR, - ), - ) - file_record.file_result = FileResult.ERROR - file_record.status = ProcessStatus.COMPLETED - file_record.error_count = 1 - await batch_service.update_file_record(file_record) - continue - else: - logger.info("sql_in_file: %s", sql_in_file) - - # Convert the file - converted_query = await convert_script( - sql_in_file, - file_record, - batch_service, - sql_agents, - ) - if converted_query: - # Add RAI disclaimer to the converted query - converted_query = add_rai_disclaimer(converted_query) - await batch_service.create_candidate( - file["file_id"], converted_query - ) - else: - await batch_service.update_file_counts(file["file_id"]) - except UnicodeDecodeError as ucde: - logger.error("Error decoding file: %s", file) - logger.error("Error decoding file. %s", ucde) - await process_error(ucde, file_record, batch_service) - except ServiceResponseException as sre: - logger.error(file) - logger.error("Error processing file. %s", sre) - # insert data base write to file record stating invalid file - await process_error(sre, file_record, batch_service) + await process_error(exc, file_record, batch_service) + + # Update batch status to completed or failed + try: + await batch_service.batch_files_final_update(batch_id) + await batch_service.update_batch(batch_id, ProcessStatus.COMPLETED) except Exception as exc: - logger.error(file) - logger.error("Error processing file. %s", exc) - # insert data base write to file record stating invalid file - await process_error(exc, file_record, batch_service) - - # Update batch status to completed or failed - try: - await batch_service.batch_files_final_update(batch_id) - await batch_service.update_batch(batch_id, ProcessStatus.COMPLETED) - except Exception as exc: - await batch_service.update_batch(batch_id, ProcessStatus.FAILED) - logger.error("Error updating batch status. %s", exc) - logger.info("Batch processing complete.") + await batch_service.update_batch(batch_id, ProcessStatus.FAILED) + logger.error("Error updating batch status. %s", exc) + logger.info("Batch processing complete.") async def process_error( diff --git a/src/tests/backend/api/event_utils_test.py b/src/tests/backend/api/event_utils_test.py index d4495aff..9e6e2b0f 100644 --- a/src/tests/backend/api/event_utils_test.py +++ b/src/tests/backend/api/event_utils_test.py @@ -1,7 +1,7 @@ """Tests for event_utils module.""" import os -from unittest.mock import patch +from unittest.mock import MagicMock, patch from backend.api.event_utils import track_event_if_configured @@ -11,36 +11,50 @@ class TestTrackEventIfConfigured: def test_track_event_with_instrumentation_key(self): """Test tracking event when instrumentation key is set.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "test-key"}): - with patch("backend.api.event_utils.track_event") as mock_track: + connection_string = "InstrumentationKey=test-key;IngestionEndpoint=https://test.com" + with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": connection_string}): + with patch("backend.api.event_utils._get_telemetry_client") as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + track_event_if_configured("TestEvent", {"key": "value"}) - mock_track.assert_called_once_with("TestEvent", {"key": "value"}) + mock_client.track_event.assert_called_once_with("TestEvent", properties={"key": "value"}) + mock_client.flush.assert_called_once() def test_track_event_without_instrumentation_key(self): """Test tracking event when instrumentation key is not set.""" with patch.dict(os.environ, {}, clear=True): # Remove the key if it exists os.environ.pop("APPLICATIONINSIGHTS_CONNECTION_STRING", None) - with patch("backend.api.event_utils.track_event") as mock_track: + with patch("backend.api.event_utils._get_telemetry_client") as mock_get_client: with patch("backend.api.event_utils.logging.warning") as mock_warning: track_event_if_configured("TestEvent", {"key": "value"}) - mock_track.assert_not_called() + mock_get_client.assert_not_called() mock_warning.assert_called_once() def test_track_event_with_empty_data(self): """Test tracking event with empty data.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "test-key"}): - with patch("backend.api.event_utils.track_event") as mock_track: + connection_string = "InstrumentationKey=test-key;IngestionEndpoint=https://test.com" + with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": connection_string}): + with patch("backend.api.event_utils._get_telemetry_client") as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + track_event_if_configured("TestEvent", {}) - mock_track.assert_called_once_with("TestEvent", {}) + mock_client.track_event.assert_called_once_with("TestEvent", properties={}) + mock_client.flush.assert_called_once() def test_track_event_with_complex_data(self): """Test tracking event with complex data.""" - with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "test-key"}): - with patch("backend.api.event_utils.track_event") as mock_track: + connection_string = "InstrumentationKey=test-key;IngestionEndpoint=https://test.com" + with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": connection_string}): + with patch("backend.api.event_utils._get_telemetry_client") as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client + complex_data = { "batch_id": "test-batch", "file_count": 10, @@ -50,4 +64,38 @@ def test_track_event_with_complex_data(self): track_event_if_configured("ComplexEvent", complex_data) - mock_track.assert_called_once_with("ComplexEvent", complex_data) + # Values are converted to strings in the actual implementation + expected_properties = { + "batch_id": "test-batch", + "file_count": "10", + "status": "completed", + "nested": "{'key': 'value'}", + } + + mock_client.track_event.assert_called_once_with("ComplexEvent", properties=expected_properties) + mock_client.flush.assert_called_once() + + def test_track_event_client_returns_none(self): + """Test tracking event when client initialization fails.""" + connection_string = "InstrumentationKey=test-key;IngestionEndpoint=https://test.com" + with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": connection_string}): + with patch("backend.api.event_utils._get_telemetry_client") as mock_get_client: + mock_get_client.return_value = None + with patch("backend.api.event_utils.logging.warning") as mock_warning: + track_event_if_configured("TestEvent", {"key": "value"}) + + mock_warning.assert_called_once() + + def test_track_event_with_exception(self): + """Test tracking event when an exception occurs.""" + connection_string = "InstrumentationKey=test-key;IngestionEndpoint=https://test.com" + with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": connection_string}): + with patch("backend.api.event_utils._get_telemetry_client") as mock_get_client: + mock_client = MagicMock() + mock_client.track_event.side_effect = Exception("Test error") + mock_get_client.return_value = mock_client + + with patch("backend.api.event_utils.logging.error") as mock_error: + track_event_if_configured("TestEvent", {"key": "value"}) + + mock_error.assert_called_once()