From 743ed373d0740c309c5fdf8f1ca798a829b3856f Mon Sep 17 00:00:00 2001 From: gayathrivijayakumar Date: Mon, 5 Jan 2026 20:24:30 +0530 Subject: [PATCH 01/22] Lookups firstcut --- backend/lookup/__init__.py | 0 backend/lookup/apps.py | 11 + backend/lookup/constants.py | 28 + backend/lookup/exceptions.py | 60 ++ backend/lookup/integrations/__init__.py | 8 + .../integrations/file_storage_client.py | 86 ++ backend/lookup/integrations/llm_provider.py | 311 ++++++ .../integrations/llmwhisperer_client.py | 334 +++++++ backend/lookup/integrations/redis_cache.py | 384 ++++++++ backend/lookup/integrations/storage_client.py | 281 ++++++ .../integrations/unstract_llm_client.py | 113 +++ backend/lookup/migrations/0001_initial.py | 886 +++++++++++++++++ backend/lookup/migrations/__init__.py | 1 + backend/lookup/models/__init__.py | 24 + backend/lookup/models/lookup_data_source.py | 192 ++++ .../lookup/models/lookup_execution_audit.py | 134 +++ backend/lookup/models/lookup_index_manager.py | 188 ++++ .../lookup/models/lookup_profile_manager.py | 157 +++ backend/lookup/models/lookup_project.py | 177 ++++ .../lookup/models/lookup_prompt_template.py | 178 ++++ .../models/prompt_studio_lookup_link.py | 128 +++ backend/lookup/serializers.py | 333 +++++++ backend/lookup/services/__init__.py | 29 + backend/lookup/services/audit_logger.py | 310 ++++++ .../services/document_indexing_service.py | 141 +++ backend/lookup/services/enrichment_merger.py | 174 ++++ backend/lookup/services/indexing_service.py | 471 +++++++++ backend/lookup/services/llm_cache.py | 186 ++++ backend/lookup/services/lookup_executor.py | 285 ++++++ .../lookup/services/lookup_index_helper.py | 190 ++++ .../lookup/services/lookup_orchestrator.py | 306 ++++++ backend/lookup/services/mock_clients.py | 157 +++ .../lookup/services/reference_data_loader.py | 267 +++++ backend/lookup/services/variable_resolver.py | 147 +++ backend/lookup/tests/__init__.py | 1 + backend/lookup/tests/test_api/__init__.py | 3 + .../tests/test_api/test_execution_api.py | 309 ++++++ .../lookup/tests/test_api/test_linking_api.py | 260 +++++ .../test_api/test_profile_manager_api.py | 397 ++++++++ .../lookup/tests/test_api/test_project_api.py | 215 ++++ .../tests/test_api/test_template_api.py | 175 ++++ .../tests/test_integrations/__init__.py | 3 + .../test_integrations/test_llm_integration.py | 239 +++++ .../test_llmwhisperer_integration.py | 297 ++++++ .../test_redis_cache_integration.py | 272 ++++++ .../test_storage_integration.py | 247 +++++ backend/lookup/tests/test_migrations.py | 286 ++++++ .../tests/test_services/test_audit_logger.py | 455 +++++++++ .../test_services/test_enrichment_merger.py | 547 +++++++++++ .../tests/test_services/test_llm_cache.py | 322 ++++++ .../test_services/test_lookup_executor.py | 418 ++++++++ .../test_services/test_lookup_orchestrator.py | 521 ++++++++++ .../test_reference_data_loader.py | 565 +++++++++++ .../lookup/tests/test_variable_resolver.py | 289 ++++++ backend/lookup/urls.py | 34 + backend/lookup/views.py | 924 ++++++++++++++++++ frontend/src/assets/lookups.svg | 6 + .../combined-output/CombinedOutput.css | 43 +- .../combined-output/CombinedOutput.jsx | 96 +- .../custom-tools/combined-output/JsonView.jsx | 80 +- .../combined-output/JsonViewBody.jsx | 38 + .../CreateProjectModal.jsx | 92 ++ .../components/lookups/debug-tab/DebugTab.css | 8 + .../components/lookups/debug-tab/DebugTab.jsx | 249 +++++ .../ExecutionHistoryTab.css | 17 + .../ExecutionHistoryTab.jsx | 283 ++++++ .../linked-projects-tab/LinkedProjectsTab.css | 17 + .../linked-projects-tab/LinkedProjectsTab.jsx | 376 +++++++ .../ProfileFormModal.css | 17 + .../ProfileFormModal.jsx | 433 ++++++++ .../ProfileManagementTab.css | 42 + .../ProfileManagementTab.jsx | 293 ++++++ .../project-detail/LookUpProjectDetail.css | 39 + .../project-detail/LookUpProjectDetail.jsx | 321 ++++++ .../project-list/LookUpProjectList.css | 15 + .../project-list/LookUpProjectList.jsx | 278 ++++++ .../reference-data-tab/ReferenceDataTab.css | 17 + .../reference-data-tab/ReferenceDataTab.jsx | 342 +++++++ .../lookups/template-tab/TemplateTab.css | 8 + .../lookups/template-tab/TemplateTab.jsx | 391 ++++++++ .../navigations/side-nav-bar/SideNavBar.jsx | 9 + .../src/layouts/menu-layout/MenuLayout.css | 9 +- frontend/src/pages/LookUpsPage.jsx | 23 + frontend/src/routes/useLookUpsRoutes.js | 13 + frontend/src/routes/useMainAppRoutes.js | 39 +- lookup/services/__init__.py | 0 lookup/tests/test_services/__init__.py | 0 87 files changed, 17010 insertions(+), 40 deletions(-) create mode 100644 backend/lookup/__init__.py create mode 100644 backend/lookup/apps.py create mode 100644 backend/lookup/constants.py create mode 100644 backend/lookup/exceptions.py create mode 100644 backend/lookup/integrations/__init__.py create mode 100644 backend/lookup/integrations/file_storage_client.py create mode 100644 backend/lookup/integrations/llm_provider.py create mode 100644 backend/lookup/integrations/llmwhisperer_client.py create mode 100644 backend/lookup/integrations/redis_cache.py create mode 100644 backend/lookup/integrations/storage_client.py create mode 100644 backend/lookup/integrations/unstract_llm_client.py create mode 100644 backend/lookup/migrations/0001_initial.py create mode 100644 backend/lookup/migrations/__init__.py create mode 100644 backend/lookup/models/__init__.py create mode 100644 backend/lookup/models/lookup_data_source.py create mode 100644 backend/lookup/models/lookup_execution_audit.py create mode 100644 backend/lookup/models/lookup_index_manager.py create mode 100644 backend/lookup/models/lookup_profile_manager.py create mode 100644 backend/lookup/models/lookup_project.py create mode 100644 backend/lookup/models/lookup_prompt_template.py create mode 100644 backend/lookup/models/prompt_studio_lookup_link.py create mode 100644 backend/lookup/serializers.py create mode 100644 backend/lookup/services/__init__.py create mode 100644 backend/lookup/services/audit_logger.py create mode 100644 backend/lookup/services/document_indexing_service.py create mode 100644 backend/lookup/services/enrichment_merger.py create mode 100644 backend/lookup/services/indexing_service.py create mode 100644 backend/lookup/services/llm_cache.py create mode 100644 backend/lookup/services/lookup_executor.py create mode 100644 backend/lookup/services/lookup_index_helper.py create mode 100644 backend/lookup/services/lookup_orchestrator.py create mode 100644 backend/lookup/services/mock_clients.py create mode 100644 backend/lookup/services/reference_data_loader.py create mode 100644 backend/lookup/services/variable_resolver.py create mode 100644 backend/lookup/tests/__init__.py create mode 100644 backend/lookup/tests/test_api/__init__.py create mode 100644 backend/lookup/tests/test_api/test_execution_api.py create mode 100644 backend/lookup/tests/test_api/test_linking_api.py create mode 100644 backend/lookup/tests/test_api/test_profile_manager_api.py create mode 100644 backend/lookup/tests/test_api/test_project_api.py create mode 100644 backend/lookup/tests/test_api/test_template_api.py create mode 100644 backend/lookup/tests/test_integrations/__init__.py create mode 100644 backend/lookup/tests/test_integrations/test_llm_integration.py create mode 100644 backend/lookup/tests/test_integrations/test_llmwhisperer_integration.py create mode 100644 backend/lookup/tests/test_integrations/test_redis_cache_integration.py create mode 100644 backend/lookup/tests/test_integrations/test_storage_integration.py create mode 100644 backend/lookup/tests/test_migrations.py create mode 100644 backend/lookup/tests/test_services/test_audit_logger.py create mode 100644 backend/lookup/tests/test_services/test_enrichment_merger.py create mode 100644 backend/lookup/tests/test_services/test_llm_cache.py create mode 100644 backend/lookup/tests/test_services/test_lookup_executor.py create mode 100644 backend/lookup/tests/test_services/test_lookup_orchestrator.py create mode 100644 backend/lookup/tests/test_services/test_reference_data_loader.py create mode 100644 backend/lookup/tests/test_variable_resolver.py create mode 100644 backend/lookup/urls.py create mode 100644 backend/lookup/views.py create mode 100644 frontend/src/assets/lookups.svg create mode 100644 frontend/src/components/lookups/create-project-modal/CreateProjectModal.jsx create mode 100644 frontend/src/components/lookups/debug-tab/DebugTab.css create mode 100644 frontend/src/components/lookups/debug-tab/DebugTab.jsx create mode 100644 frontend/src/components/lookups/execution-history-tab/ExecutionHistoryTab.css create mode 100644 frontend/src/components/lookups/execution-history-tab/ExecutionHistoryTab.jsx create mode 100644 frontend/src/components/lookups/linked-projects-tab/LinkedProjectsTab.css create mode 100644 frontend/src/components/lookups/linked-projects-tab/LinkedProjectsTab.jsx create mode 100644 frontend/src/components/lookups/profile-management-tab/ProfileFormModal.css create mode 100644 frontend/src/components/lookups/profile-management-tab/ProfileFormModal.jsx create mode 100644 frontend/src/components/lookups/profile-management-tab/ProfileManagementTab.css create mode 100644 frontend/src/components/lookups/profile-management-tab/ProfileManagementTab.jsx create mode 100644 frontend/src/components/lookups/project-detail/LookUpProjectDetail.css create mode 100644 frontend/src/components/lookups/project-detail/LookUpProjectDetail.jsx create mode 100644 frontend/src/components/lookups/project-list/LookUpProjectList.css create mode 100644 frontend/src/components/lookups/project-list/LookUpProjectList.jsx create mode 100644 frontend/src/components/lookups/reference-data-tab/ReferenceDataTab.css create mode 100644 frontend/src/components/lookups/reference-data-tab/ReferenceDataTab.jsx create mode 100644 frontend/src/components/lookups/template-tab/TemplateTab.css create mode 100644 frontend/src/components/lookups/template-tab/TemplateTab.jsx create mode 100644 frontend/src/pages/LookUpsPage.jsx create mode 100644 frontend/src/routes/useLookUpsRoutes.js create mode 100644 lookup/services/__init__.py create mode 100644 lookup/tests/test_services/__init__.py diff --git a/backend/lookup/__init__.py b/backend/lookup/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/lookup/apps.py b/backend/lookup/apps.py new file mode 100644 index 0000000000..ddaf2e194a --- /dev/null +++ b/backend/lookup/apps.py @@ -0,0 +1,11 @@ +"""Lookup app configuration.""" + +from django.apps import AppConfig + + +class LookupConfig(AppConfig): + """Configuration for the Lookup application.""" + + default_auto_field = "django.db.models.BigAutoField" + name = "lookup" + verbose_name = "Look-Up System" diff --git a/backend/lookup/constants.py b/backend/lookup/constants.py new file mode 100644 index 0000000000..f41cc62bfc --- /dev/null +++ b/backend/lookup/constants.py @@ -0,0 +1,28 @@ +"""Constants for the Look-Up system.""" + + +class LookupProfileManagerKeys: + """Keys used in LookupProfileManager serialization.""" + + CREATED_BY = "created_by" + MODIFIED_BY = "modified_by" + LOOKUP_PROJECT = "lookup_project" + PROFILE_NAME = "profile_name" + LLM = "llm" + VECTOR_STORE = "vector_store" + EMBEDDING_MODEL = "embedding_model" + X2TEXT = "x2text" + CHUNK_SIZE = "chunk_size" + CHUNK_OVERLAP = "chunk_overlap" + SIMILARITY_TOP_K = "similarity_top_k" + IS_DEFAULT = "is_default" + REINDEX = "reindex" + + +class LookupProfileManagerErrors: + """Error messages for LookupProfileManager operations.""" + + SERIALIZATION_FAILED = "Data serialization failed." + PROFILE_NAME_EXISTS = "A profile with this name already exists for this project." + DUPLICATE_API = "It appears that a duplicate call may have been made." + NO_DEFAULT_PROFILE = "No default profile found for this project." diff --git a/backend/lookup/exceptions.py b/backend/lookup/exceptions.py new file mode 100644 index 0000000000..70d09a4bb3 --- /dev/null +++ b/backend/lookup/exceptions.py @@ -0,0 +1,60 @@ +"""Custom exceptions for the Look-Up system. + +This module defines custom exceptions specific to the Look-Up functionality. +""" + + +class LookupError(Exception): + """Base exception for Look-Up system errors.""" + + pass + + +class ExtractionNotCompleteError(LookupError): + """Raised when attempting to use reference data before extraction is complete. + + This exception is raised when trying to load reference data for a project + where one or more data sources have not completed extraction processing. + """ + + def __init__(self, failed_files=None): + """Initialize the exception. + + Args: + failed_files: List of file names that failed or are pending extraction + """ + self.failed_files = failed_files or [] + message = "Reference data extraction not complete" + if failed_files: + message += f" for files: {', '.join(failed_files)}" + super().__init__(message) + + +class TemplateNotFoundError(LookupError): + """Raised when a Look-Up project has no associated template. + + This exception is raised when attempting to execute a Look-Up + that doesn't have a prompt template configured. + """ + + pass + + +class ParseError(LookupError): + """Raised when LLM response cannot be parsed. + + This exception is raised when the LLM returns a response that + cannot be parsed as valid JSON or doesn't match expected format. + """ + + pass + + +class DefaultProfileError(LookupError): + """Raised when default profile is not found for a Look-Up project. + + This exception is raised when attempting to get the default profile + for a Look-Up project that doesn't have one configured. + """ + + pass diff --git a/backend/lookup/integrations/__init__.py b/backend/lookup/integrations/__init__.py new file mode 100644 index 0000000000..53383fb10c --- /dev/null +++ b/backend/lookup/integrations/__init__.py @@ -0,0 +1,8 @@ +"""Integration modules for Look-Up functionality. + +This package contains integrations with external services: +- Object Storage (S3-compatible) +- LLM Providers (OpenAI, Anthropic, etc.) +- LLMWhisperer (Document extraction) +- Redis Cache +""" diff --git a/backend/lookup/integrations/file_storage_client.py b/backend/lookup/integrations/file_storage_client.py new file mode 100644 index 0000000000..28870c19e8 --- /dev/null +++ b/backend/lookup/integrations/file_storage_client.py @@ -0,0 +1,86 @@ +"""File Storage Client for Look-Up reference data. + +This module provides integration with Unstract's file storage +for loading extracted reference data content. +""" + +import logging + +from utils.file_storage.constants import FileStorageKeys + +from unstract.sdk1.file_storage.constants import StorageType +from unstract.sdk1.file_storage.env_helper import EnvHelper + +logger = logging.getLogger(__name__) + + +class FileStorageClient: + """Storage client implementation using Unstract's file storage. + + This client uses the actual platform file storage to read + extracted reference data content. + """ + + def __init__(self): + """Initialize the file storage client.""" + self.fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + def get(self, path: str) -> str: + """Retrieve file content from storage. + + Args: + path: Storage path to the file + + Returns: + File content as string + + Raises: + FileNotFoundError: If file doesn't exist + Exception: If reading fails + """ + try: + if not self.fs_instance.exists(path): + logger.error(f"File not found: {path}") + raise FileNotFoundError(f"File not found: {path}") + + # Use read() method with text mode + content = self.fs_instance.read(path, mode="r", encoding="utf-8") + logger.debug(f"Read {len(content)} chars from {path}") + return content + + except FileNotFoundError: + raise + except Exception as e: + logger.error(f"Failed to read file {path}: {e}") + raise Exception(f"Failed to read file: {str(e)}") + + def exists(self, path: str) -> bool: + """Check if path exists in storage. + + Args: + path: Storage path + + Returns: + True if exists + """ + return self.fs_instance.exists(path) + + def get_text_content(self, path: str) -> str | None: + """Get text content from storage (alias for get). + + Args: + path: Storage path + + Returns: + Text content or None if not found + """ + try: + return self.get(path) + except FileNotFoundError: + return None + except Exception as e: + logger.warning(f"Error reading {path}: {e}") + return None diff --git a/backend/lookup/integrations/llm_provider.py b/backend/lookup/integrations/llm_provider.py new file mode 100644 index 0000000000..ca16d9b235 --- /dev/null +++ b/backend/lookup/integrations/llm_provider.py @@ -0,0 +1,311 @@ +"""LLM Provider integration for Look-Up enrichment. + +This module provides integration with various LLM providers +(OpenAI, Anthropic, etc.) for generating enrichment data. +""" + +import json +import logging +import os +import time +from typing import Any + +from ..protocols import LLMClient + +logger = logging.getLogger(__name__) + + +class UnstractLLMClient(LLMClient): + """Implementation of LLMClient using Unstract's LLM abstraction. + + This client integrates with the Unstract platform's LLM providers + to generate enrichment data for Look-Ups. + """ + + def __init__(self, provider: str | None = None, model: str | None = None): + """Initialize the LLM client. + + Args: + provider: LLM provider name (e.g., 'openai', 'anthropic') + model: Model name (e.g., 'gpt-4', 'claude-2') + """ + self.default_provider = provider or os.getenv( + "LOOKUP_DEFAULT_LLM_PROVIDER", "openai" + ) + self.default_model = model or os.getenv("LOOKUP_DEFAULT_LLM_MODEL", "gpt-4") + + # Import Unstract LLM utilities + try: + from unstract.llmbox import LLMBox + from unstract.llmbox.llm import LLM + + self.LLMBox = LLMBox + self.LLM = LLM + self.llm_available = True + except ImportError: + logger.warning("Unstract LLMBox not available, using fallback implementation") + self.llm_available = False + + def _get_llm_instance(self, config: dict[str, Any]): + """Get an LLM instance based on configuration. + + Args: + config: LLM configuration + + Returns: + LLM instance + """ + if not self.llm_available: + raise RuntimeError("LLM integration not available") + + provider = config.get("provider", self.default_provider) + model = config.get("model", self.default_model) + + # Create LLM instance using Unstract's LLMBox + # This would integrate with the actual Unstract LLM abstraction + # For now, we'll use a simplified approach + + # Map provider to Unstract's LLM types + provider_map = { + "openai": "OpenAI", + "anthropic": "Anthropic", + "azure_openai": "AzureOpenAI", + "gemini": "Gemini", + "vertex_ai": "VertexAI", + } + + llm_provider = provider_map.get(provider, "OpenAI") + + # Create configuration for the provider + llm_config = { + "provider": llm_provider, + "model": model, + "temperature": config.get("temperature", 0.7), + "max_tokens": config.get("max_tokens", 1000), + } + + # Add API keys based on provider + if provider == "openai": + llm_config["api_key"] = os.getenv("OPENAI_API_KEY") + elif provider == "anthropic": + llm_config["api_key"] = os.getenv("ANTHROPIC_API_KEY") + elif provider == "azure_openai": + llm_config["api_key"] = os.getenv("AZURE_OPENAI_API_KEY") + llm_config["endpoint"] = os.getenv("AZURE_OPENAI_ENDPOINT") + + # For now, return a mock implementation + # In production, this would create actual LLM instance + return llm_config + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate LLM response for Look-Up enrichment. + + Args: + prompt: The prompt text with resolved variables + config: LLM configuration (provider, model, temperature, etc.) + timeout: Request timeout in seconds + + Returns: + JSON-formatted string with enrichment data + + Raises: + RuntimeError: If LLM call fails + """ + start_time = time.time() + + try: + if self.llm_available: + # Use actual Unstract LLM integration + llm_config = self._get_llm_instance(config) + + # In production, this would make actual LLM call + # For now, simulate the call + response = self._simulate_llm_call(prompt, llm_config) + else: + # Fallback implementation + response = self._fallback_generate(prompt, config) + + # Validate response is JSON + try: + json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from response + response = self._extract_json(response) + + elapsed_time = time.time() - start_time + logger.info(f"LLM generation completed in {elapsed_time:.2f}s") + + return response + + except Exception as e: + logger.error(f"LLM generation failed: {e}") + raise RuntimeError(f"LLM generation failed: {str(e)}") + + def _simulate_llm_call(self, prompt: str, config: dict[str, Any]) -> str: + """Simulate LLM call for development/testing. + + In production, this would be replaced with actual LLM API calls. + """ + # Simulate some processing + time.sleep(0.5) + + # Generate mock response based on prompt content + if "vendor" in prompt.lower(): + return json.dumps( + { + "canonical_vendor": "Sample Vendor Inc.", + "vendor_category": "SaaS", + "vendor_type": "Software", + "confidence": 0.95, + } + ) + elif "product" in prompt.lower(): + return json.dumps( + { + "product_name": "Sample Product", + "product_category": "Enterprise Software", + "product_type": "Cloud Service", + "confidence": 0.92, + } + ) + else: + return json.dumps({"enriched_data": "Sample enrichment", "confidence": 0.88}) + + def _fallback_generate(self, prompt: str, config: dict[str, Any]) -> str: + """Fallback generation when LLM integration is not available. + + This is primarily for testing and development. + """ + logger.warning("Using fallback LLM generation") + + # Simple pattern matching for testing + response_data = { + "status": "fallback", + "message": "LLM integration not available", + "confidence": 0.5, + } + + # Add some basic enrichment based on prompt + if "vendor" in prompt.lower(): + response_data["canonical_vendor"] = "Unknown Vendor" + response_data["vendor_category"] = "Unknown" + elif "product" in prompt.lower(): + response_data["product_name"] = "Unknown Product" + response_data["product_category"] = "Unknown" + + return json.dumps(response_data) + + def _extract_json(self, response: str) -> str: + """Extract JSON from LLM response if wrapped in text. + + Args: + response: Raw LLM response + + Returns: + Extracted JSON string + """ + # Try to find JSON in the response + import re + + # Look for JSON object pattern + json_pattern = r"\{[^{}]*\}" + matches = re.findall(json_pattern, response) + + if matches: + # Try to parse each match + for match in matches: + try: + json.loads(match) + return match + except json.JSONDecodeError: + continue + + # If no valid JSON found, create a basic response + return json.dumps( + { + "raw_response": response[:500], # Truncate if too long + "confidence": 0.3, + "warning": "Could not extract structured data", + } + ) + + def validate_response(self, response: str) -> bool: + """Validate that the LLM response is properly formatted. + + Args: + response: LLM response string + + Returns: + True if valid JSON with required fields + """ + try: + data = json.loads(response) + + # Check for confidence score + if "confidence" not in data: + logger.warning("Response missing confidence score") + return False + + # Check confidence is valid + confidence = data.get("confidence", 0) + if not (0 <= confidence <= 1): + logger.warning(f"Invalid confidence score: {confidence}") + return False + + return True + + except json.JSONDecodeError: + logger.error("Response is not valid JSON") + return False + + def get_token_count(self, text: str, model: str = None) -> int: + """Estimate token count for the given text. + + Args: + text: Input text + model: Model name for accurate counting + + Returns: + Estimated token count + """ + # Simple estimation: ~4 characters per token + # In production, use tiktoken or model-specific tokenizer + return len(text) // 4 + + +class OpenAILLMClient(UnstractLLMClient): + """OpenAI-specific LLM client implementation.""" + + def __init__(self): + """Initialize OpenAI client.""" + super().__init__(provider="openai", model="gpt-4") + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate using OpenAI API.""" + # Override config with OpenAI defaults + config = { + **config, + "provider": "openai", + "model": config.get("model", "gpt-4"), + "temperature": config.get("temperature", 0.7), + } + return super().generate(prompt, config, timeout) + + +class AnthropicLLMClient(UnstractLLMClient): + """Anthropic-specific LLM client implementation.""" + + def __init__(self): + """Initialize Anthropic client.""" + super().__init__(provider="anthropic", model="claude-2") + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate using Anthropic API.""" + # Override config with Anthropic defaults + config = { + **config, + "provider": "anthropic", + "model": config.get("model", "claude-2"), + "temperature": config.get("temperature", 0.7), + } + return super().generate(prompt, config, timeout) diff --git a/backend/lookup/integrations/llmwhisperer_client.py b/backend/lookup/integrations/llmwhisperer_client.py new file mode 100644 index 0000000000..5445196e97 --- /dev/null +++ b/backend/lookup/integrations/llmwhisperer_client.py @@ -0,0 +1,334 @@ +"""LLMWhisperer integration for document text extraction. + +This module provides integration with LLMWhisperer service +for extracting text from various document formats. +""" + +import logging +import time +from enum import Enum +from typing import Any + +import requests +from django.conf import settings + +logger = logging.getLogger(__name__) + + +class ExtractionStatus(Enum): + """Status of document extraction.""" + + PENDING = "pending" + PROCESSING = "processing" + COMPLETE = "complete" + FAILED = "failed" + NOT_REQUIRED = "not_required" + + +class LLMWhispererClient: + """Client for integrating with LLMWhisperer document extraction service. + + LLMWhisperer extracts text from PDFs, images, and other document formats + for use as reference data in Look-Ups. + """ + + def __init__(self): + """Initialize LLMWhisperer client with configuration.""" + self.base_url = getattr( + settings, "LLMWHISPERER_BASE_URL", "https://api.llmwhisperer.com" + ) + self.api_key = getattr(settings, "LLMWHISPERER_API_KEY", "") + + if not self.api_key: + logger.warning("LLMWhisperer API key not configured") + + self.session = requests.Session() + self.session.headers.update( + { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + ) + + def extract_text( + self, + file_content: bytes, + file_name: str, + extraction_config: dict[str, Any] | None = None, + ) -> tuple[str, str]: + """Extract text from a document using LLMWhisperer. + + Args: + file_content: File content as bytes + file_name: Original file name + extraction_config: Optional extraction configuration + + Returns: + Tuple of (extraction_id, status) + """ + if not self.api_key: + logger.error("LLMWhisperer API key not configured") + return "", ExtractionStatus.FAILED.value + + try: + # Prepare extraction request + config = extraction_config or self._get_default_config() + + # Create extraction job + url = f"{self.base_url}/v1/extract" + + # Prepare multipart form data + files = {"file": (file_name, file_content)} + + # Add configuration as form data + data = { + "processing_mode": config.get("processing_mode", "ocr"), + "output_format": config.get("output_format", "text"), + "page_separator": config.get("page_separator", "\n---\n"), + "force_text_processing": str( + config.get("force_text_processing", True) + ).lower(), + "line_splitter": config.get("line_splitter", "line"), + "horizontal_stretch": str(config.get("horizontal_stretch", 1.0)), + "vertical_stretch": str(config.get("vertical_stretch", 1.0)), + } + + # Remove JSON content type for multipart + headers = {"Authorization": f"Bearer {self.api_key}"} + + response = requests.post( + url, files=files, data=data, headers=headers, timeout=30 + ) + + if response.status_code == 200: + result = response.json() + extraction_id = result.get("extraction_id", "") + logger.info(f"Started extraction job: {extraction_id}") + return extraction_id, ExtractionStatus.PROCESSING.value + else: + logger.error( + f"Extraction failed with status {response.status_code}: {response.text}" + ) + return "", ExtractionStatus.FAILED.value + + except requests.exceptions.RequestException as e: + logger.error(f"LLMWhisperer extraction request failed: {e}") + return "", ExtractionStatus.FAILED.value + except Exception as e: + logger.error(f"Unexpected error during extraction: {e}") + return "", ExtractionStatus.FAILED.value + + def check_extraction_status(self, extraction_id: str) -> tuple[str, str | None]: + """Check the status of an extraction job. + + Args: + extraction_id: Extraction job ID + + Returns: + Tuple of (status, extracted_text) + """ + if not extraction_id: + return ExtractionStatus.FAILED.value, None + + try: + url = f"{self.base_url}/v1/status/{extraction_id}" + + response = self.session.get(url, timeout=10) + + if response.status_code == 200: + result = response.json() + status = result.get("status", "unknown") + + if status == "completed": + # Get extracted text + text = self._get_extracted_text(extraction_id) + return ExtractionStatus.COMPLETE.value, text + elif status == "processing": + return ExtractionStatus.PROCESSING.value, None + elif status == "failed": + error_msg = result.get("error", "Unknown error") + logger.error(f"Extraction {extraction_id} failed: {error_msg}") + return ExtractionStatus.FAILED.value, None + else: + logger.warning(f"Unknown extraction status: {status}") + return ExtractionStatus.PENDING.value, None + else: + logger.error(f"Status check failed with code {response.status_code}") + return ExtractionStatus.FAILED.value, None + + except Exception as e: + logger.error(f"Error checking extraction status: {e}") + return ExtractionStatus.FAILED.value, None + + def _get_extracted_text(self, extraction_id: str) -> str | None: + """Retrieve extracted text for a completed job. + + Args: + extraction_id: Extraction job ID + + Returns: + Extracted text or None + """ + try: + url = f"{self.base_url}/v1/result/{extraction_id}" + + response = self.session.get(url, timeout=30) + + if response.status_code == 200: + # Response is the extracted text + return response.text + else: + logger.error(f"Failed to get extraction result: {response.status_code}") + return None + + except Exception as e: + logger.error(f"Error retrieving extracted text: {e}") + return None + + def wait_for_extraction( + self, extraction_id: str, max_wait_seconds: int = 300, poll_interval: int = 5 + ) -> tuple[str, str | None]: + """Wait for extraction to complete with polling. + + Args: + extraction_id: Extraction job ID + max_wait_seconds: Maximum time to wait + poll_interval: Seconds between status checks + + Returns: + Tuple of (final_status, extracted_text) + """ + start_time = time.time() + + while time.time() - start_time < max_wait_seconds: + status, text = self.check_extraction_status(extraction_id) + + if status == ExtractionStatus.COMPLETE.value: + logger.info(f"Extraction {extraction_id} completed successfully") + return status, text + elif status == ExtractionStatus.FAILED.value: + logger.error(f"Extraction {extraction_id} failed") + return status, None + elif status == ExtractionStatus.PROCESSING.value: + logger.debug(f"Extraction {extraction_id} still processing...") + time.sleep(poll_interval) + else: + logger.warning(f"Unexpected status for {extraction_id}: {status}") + time.sleep(poll_interval) + + logger.error( + f"Extraction {extraction_id} timed out after {max_wait_seconds} seconds" + ) + return ExtractionStatus.FAILED.value, None + + def extract_and_wait( + self, + file_content: bytes, + file_name: str, + extraction_config: dict[str, Any] | None = None, + max_wait_seconds: int = 300, + ) -> tuple[bool, str | None]: + """Extract text and wait for completion. + + Args: + file_content: File content + file_name: File name + extraction_config: Extraction configuration + max_wait_seconds: Maximum wait time + + Returns: + Tuple of (success, extracted_text) + """ + # Start extraction + extraction_id, status = self.extract_text( + file_content, file_name, extraction_config + ) + + if status == ExtractionStatus.FAILED.value: + return False, None + + # Wait for completion + final_status, text = self.wait_for_extraction(extraction_id, max_wait_seconds) + + return final_status == ExtractionStatus.COMPLETE.value, text + + def _get_default_config(self) -> dict[str, Any]: + """Get default extraction configuration. + + Returns: + Default configuration dictionary + """ + return { + "processing_mode": "ocr", # 'ocr' or 'text' + "output_format": "text", # 'text' or 'markdown' + "page_separator": "\n---\n", + "force_text_processing": True, + "line_splitter": "line", # 'line' or 'paragraph' + "horizontal_stretch": 1.0, + "vertical_stretch": 1.0, + "pages": "", # Empty for all pages + "timeout": 300, + "store_metadata": False, + } + + def is_extraction_needed(self, file_name: str) -> bool: + """Check if extraction is needed for the file type. + + Args: + file_name: File name with extension + + Returns: + True if extraction is needed + """ + # File types that need extraction + extractable_extensions = { + ".pdf", + ".png", + ".jpg", + ".jpeg", + ".tiff", + ".bmp", + ".docx", + ".doc", + ".pptx", + ".ppt", + ".xlsx", + ".xls", + } + + # Check file extension + import os + + _, ext = os.path.splitext(file_name.lower()) + return ext in extractable_extensions + + def get_extraction_config_for_file(self, file_name: str) -> dict[str, Any]: + """Get optimal extraction configuration based on file type. + + Args: + file_name: File name with extension + + Returns: + Extraction configuration + """ + import os + + _, ext = os.path.splitext(file_name.lower()) + + config = self._get_default_config() + + # Optimize based on file type + if ext in [".pdf"]: + config["processing_mode"] = "ocr" + config["force_text_processing"] = True + elif ext in [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]: + config["processing_mode"] = "ocr" + config["force_text_processing"] = False + elif ext in [".docx", ".doc"]: + config["processing_mode"] = "text" + config["output_format"] = "markdown" + elif ext in [".xlsx", ".xls"]: + config["processing_mode"] = "text" + config["line_splitter"] = "paragraph" + + return config diff --git a/backend/lookup/integrations/redis_cache.py b/backend/lookup/integrations/redis_cache.py new file mode 100644 index 0000000000..976164a6ac --- /dev/null +++ b/backend/lookup/integrations/redis_cache.py @@ -0,0 +1,384 @@ +"""Redis cache implementation for Look-Up LLM responses. + +This module provides a Redis-based cache for storing and retrieving +LLM responses to improve performance and reduce API costs. +""" + +import hashlib +import logging +import time +from typing import Any + +try: + import redis + from redis.exceptions import RedisError + + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + +from django.conf import settings + +logger = logging.getLogger(__name__) + + +class RedisLLMCache: + """Redis-based cache for LLM responses. + + Provides persistent caching with TTL support and + automatic failover to in-memory cache if Redis is unavailable. + """ + + def __init__( + self, + ttl_hours: int = 24, + key_prefix: str = "lookup:llm:", + fallback_to_memory: bool = True, + ): + """Initialize Redis cache. + + Args: + ttl_hours: Time-to-live for cache entries in hours + key_prefix: Prefix for all cache keys + fallback_to_memory: Use in-memory cache if Redis unavailable + """ + self.ttl_seconds = ttl_hours * 3600 + self.key_prefix = key_prefix + self.fallback_to_memory = fallback_to_memory + + # Initialize Redis connection + self.redis_client = self._init_redis() + + # Fallback in-memory cache + if self.fallback_to_memory: + from ..services.llm_cache import LLMResponseCache + + self.memory_cache = LLMResponseCache(ttl_hours) + else: + self.memory_cache = None + + def _init_redis(self) -> Any | None: + """Initialize Redis connection. + + Returns: + Redis client or None if unavailable + """ + if not REDIS_AVAILABLE: + logger.warning("Redis package not installed, using fallback cache") + return None + + try: + # Get Redis configuration from settings + redis_host = getattr(settings, "REDIS_HOST", "localhost") + redis_port = getattr(settings, "REDIS_PORT", 6379) + redis_db = getattr(settings, "REDIS_CACHE_DB", 1) + redis_password = getattr(settings, "REDIS_PASSWORD", None) + + # Create Redis client + client = redis.Redis( + host=redis_host, + port=redis_port, + db=redis_db, + password=redis_password, + decode_responses=True, + socket_connect_timeout=5, + socket_timeout=5, + ) + + # Test connection + client.ping() + logger.info(f"Connected to Redis at {redis_host}:{redis_port}/{redis_db}") + return client + + except Exception as e: + logger.warning(f"Failed to connect to Redis: {e}") + return None + + def generate_cache_key(self, prompt: str, reference_data: str) -> str: + """Generate a cache key from prompt and reference data. + + Args: + prompt: The resolved prompt + reference_data: The reference data content + + Returns: + SHA256 hash as cache key + """ + combined = f"{prompt}{reference_data}" + hash_obj = hashlib.sha256(combined.encode("utf-8")) + return f"{self.key_prefix}{hash_obj.hexdigest()}" + + def get(self, key: str) -> str | None: + """Get cached response. + + Args: + key: Cache key + + Returns: + Cached response or None if not found/expired + """ + # Try Redis first + if self.redis_client: + try: + value = self.redis_client.get(key) + if value: + logger.debug(f"Redis cache hit for key: {key[:20]}...") + self._update_stats("hits") + return value + else: + logger.debug(f"Redis cache miss for key: {key[:20]}...") + self._update_stats("misses") + except RedisError as e: + logger.error(f"Redis get error: {e}") + # Fall through to memory cache + + # Fallback to memory cache + if self.memory_cache: + # Remove prefix for memory cache + memory_key = key.replace(self.key_prefix, "") + value = self.memory_cache.get(memory_key) + if value: + logger.debug(f"Memory cache hit for key: {key[:20]}...") + return value + + return None + + def set(self, key: str, value: str, ttl: int | None = None) -> bool: + """Set cache value. + + Args: + key: Cache key + value: Response to cache + ttl: Optional TTL override in seconds + + Returns: + True if successful + """ + ttl = ttl or self.ttl_seconds + + # Try Redis first + if self.redis_client: + try: + self.redis_client.setex(name=key, time=ttl, value=value) + logger.debug(f"Cached to Redis with key: {key[:20]}... (TTL: {ttl}s)") + self._update_stats("sets") + return True + except RedisError as e: + logger.error(f"Redis set error: {e}") + # Fall through to memory cache + + # Fallback to memory cache + if self.memory_cache: + # Remove prefix for memory cache + memory_key = key.replace(self.key_prefix, "") + self.memory_cache.set(memory_key, value) + logger.debug(f"Cached to memory with key: {key[:20]}...") + return True + + return False + + def delete(self, key: str) -> bool: + """Delete a cache entry. + + Args: + key: Cache key + + Returns: + True if deleted + """ + deleted = False + + # Try Redis + if self.redis_client: + try: + result = self.redis_client.delete(key) + if result > 0: + deleted = True + logger.debug(f"Deleted from Redis: {key[:20]}...") + except RedisError as e: + logger.error(f"Redis delete error: {e}") + + # Also delete from memory cache + if self.memory_cache: + memory_key = key.replace(self.key_prefix, "") + if memory_key in self.memory_cache.cache: + del self.memory_cache.cache[memory_key] + deleted = True + logger.debug(f"Deleted from memory: {key[:20]}...") + + return deleted + + def clear_pattern(self, pattern: str) -> int: + """Clear all cache entries matching a pattern. + + Args: + pattern: Pattern to match (e.g., "lookup:llm:project_*") + + Returns: + Number of entries cleared + """ + count = 0 + + # Clear from Redis + if self.redis_client: + try: + # Use SCAN to find matching keys + cursor = 0 + while True: + cursor, keys = self.redis_client.scan( + cursor=cursor, match=pattern, count=100 + ) + if keys: + count += self.redis_client.delete(*keys) + if cursor == 0: + break + + logger.info(f"Cleared {count} entries from Redis matching: {pattern}") + except RedisError as e: + logger.error(f"Redis clear pattern error: {e}") + + # Clear from memory cache + if self.memory_cache: + # Remove prefix from pattern + memory_pattern = pattern.replace(self.key_prefix, "") + keys_to_delete = [ + k + for k in self.memory_cache.cache.keys() + if self._match_pattern(k, memory_pattern) + ] + for key in keys_to_delete: + del self.memory_cache.cache[key] + count += 1 + + logger.info(f"Cleared {len(keys_to_delete)} entries from memory") + + return count + + def _match_pattern(self, key: str, pattern: str) -> bool: + """Simple pattern matching for memory cache. + + Args: + key: Key to check + pattern: Pattern with * wildcards + + Returns: + True if matches + """ + import fnmatch + + return fnmatch.fnmatch(key, pattern) + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary of cache statistics + """ + stats = { + "backend": "redis" if self.redis_client else "memory", + "ttl_hours": self.ttl_seconds // 3600, + "key_prefix": self.key_prefix, + } + + # Get Redis stats + if self.redis_client: + try: + info = self.redis_client.info("stats") + stats["redis"] = { + "total_connections": info.get("total_connections_received", 0), + "keyspace_hits": info.get("keyspace_hits", 0), + "keyspace_misses": info.get("keyspace_misses", 0), + "hit_rate": self._calculate_hit_rate( + info.get("keyspace_hits", 0), info.get("keyspace_misses", 0) + ), + } + + # Get key count + dbinfo = self.redis_client.info("keyspace") + db_key = f"db{self.redis_client.connection_pool.connection_kwargs.get('db', 0)}" + if db_key in dbinfo: + stats["redis"]["total_keys"] = dbinfo[db_key].get("keys", 0) + + except RedisError as e: + logger.error(f"Failed to get Redis stats: {e}") + stats["redis"] = {"error": str(e)} + + # Get memory cache stats + if self.memory_cache: + stats["memory"] = { + "entries": len(self.memory_cache.cache), + "size_estimate": sum( + len(k) + len(v[0]) for k, v in self.memory_cache.cache.items() + ), + } + + return stats + + def _calculate_hit_rate(self, hits: int, misses: int) -> float: + """Calculate cache hit rate.""" + total = hits + misses + return (hits / total) if total > 0 else 0.0 + + def _update_stats(self, stat_type: str) -> None: + """Update internal statistics. + + Args: + stat_type: Type of stat to update ('hits', 'misses', 'sets') + """ + # This could be extended to track more detailed statistics + pass + + def cleanup_expired(self) -> int: + """Clean up expired entries. + + For Redis, this happens automatically with TTL. + For memory cache, we clean up lazily. + + Returns: + Number of entries cleaned up + """ + count = 0 + + # Redis handles expiration automatically + if self.redis_client: + logger.debug("Redis handles expiration automatically") + + # Clean memory cache + if self.memory_cache: + current_time = time.time() + keys_to_delete = [] + + for key, (value, expiry) in list(self.memory_cache.cache.items()): + if current_time >= expiry: + keys_to_delete.append(key) + + for key in keys_to_delete: + del self.memory_cache.cache[key] + count += 1 + + if count > 0: + logger.info(f"Cleaned up {count} expired entries from memory cache") + + return count + + def warmup(self, project_id: str, preload_data: dict[str, str]) -> int: + """Warm up cache with preloaded data. + + Args: + project_id: Project ID for namespacing + preload_data: Dictionary of prompt->response to preload + + Returns: + Number of entries preloaded + """ + count = 0 + + for prompt, response in preload_data.items(): + # Generate key with project namespace + key = f"{self.key_prefix}project:{project_id}:{hashlib.md5(prompt.encode()).hexdigest()}" + + if self.set(key, response): + count += 1 + + logger.info(f"Warmed up cache with {count} entries for project {project_id}") + return count diff --git a/backend/lookup/integrations/storage_client.py b/backend/lookup/integrations/storage_client.py new file mode 100644 index 0000000000..816cf72f9c --- /dev/null +++ b/backend/lookup/integrations/storage_client.py @@ -0,0 +1,281 @@ +"""Object Storage integration for Look-Up reference data. + +This module provides integration with PERMANENT_REMOTE_STORAGE +for storing and retrieving reference data files. +""" + +import logging +from pathlib import Path +from typing import Any +from uuid import UUID + +from utils.file_storage.constants import FileStorageKeys + +from unstract.sdk1.file_storage.constants import StorageType +from unstract.sdk1.file_storage.env_helper import EnvHelper + +from ..protocols import StorageClient + +logger = logging.getLogger(__name__) + + +class RemoteStorageClient(StorageClient): + """Implementation of StorageClient using PERMANENT_REMOTE_STORAGE. + + This client integrates with the Unstract file storage system + to store and retrieve Look-Up reference data. + """ + + def __init__(self, base_path: str = "lookup/reference_data"): + """Initialize the remote storage client. + + Args: + base_path: Base path for Look-Up data in storage + """ + self.base_path = base_path + self.fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + def _get_project_path(self, project_id: UUID) -> str: + """Get the storage path for a project. + + Args: + project_id: Project UUID + + Returns: + Storage path string + """ + return str(Path(self.base_path) / str(project_id)) + + def _get_file_path(self, project_id: UUID, filename: str) -> str: + """Get the full storage path for a file. + + Args: + project_id: Project UUID + filename: File name + + Returns: + Full storage path + """ + return str(Path(self._get_project_path(project_id)) / filename) + + def upload(self, path: str, content: bytes) -> bool: + """Upload content to remote storage. + + Args: + path: Storage path + content: File content as bytes + + Returns: + True if successful, False otherwise + """ + try: + # Ensure parent directory exists + parent_dir = str(Path(path).parent) + self.fs_instance.mkdir(parent_dir, create_parents=True) + + # Write content + self.fs_instance.write(path=path, mode="wb", data=content) + logger.info(f"Successfully uploaded file to {path}") + return True + + except Exception as e: + logger.error(f"Failed to upload file to {path}: {e}") + return False + + def download(self, path: str) -> bytes | None: + """Download content from remote storage. + + Args: + path: Storage path + + Returns: + File content as bytes or None if not found + """ + try: + if not self.exists(path): + logger.warning(f"File not found at path: {path}") + return None + + # Read content + content = self.fs_instance.read(path=path, mode="rb") + logger.info(f"Successfully downloaded file from {path}") + return content + + except Exception as e: + logger.error(f"Failed to download file from {path}: {e}") + return None + + def delete(self, path: str) -> bool: + """Delete content from remote storage. + + Args: + path: Storage path + + Returns: + True if deleted, False otherwise + """ + try: + if not self.exists(path): + logger.warning(f"File not found for deletion: {path}") + return False + + self.fs_instance.delete(path) + logger.info(f"Successfully deleted file at {path}") + return True + + except Exception as e: + logger.error(f"Failed to delete file at {path}: {e}") + return False + + def exists(self, path: str) -> bool: + """Check if path exists in storage. + + Args: + path: Storage path + + Returns: + True if exists, False otherwise + """ + try: + return self.fs_instance.exists(path) + except Exception as e: + logger.error(f"Failed to check existence of {path}: {e}") + return False + + def list_files(self, prefix: str) -> list[str]: + """List files with given prefix. + + Args: + prefix: Path prefix + + Returns: + List of matching paths + """ + try: + # List all files in the directory + files = self.fs_instance.listdir(prefix) + return [str(Path(prefix) / f) for f in files if not f.startswith(".")] + + except Exception as e: + logger.error(f"Failed to list files with prefix {prefix}: {e}") + return [] + + def get_text_content(self, path: str) -> str | None: + """Get text content from storage. + + Args: + path: Storage path + + Returns: + Text content or None if not found + """ + content = self.download(path) + if content: + try: + return content.decode("utf-8") + except UnicodeDecodeError: + logger.error(f"Failed to decode text content from {path}") + return None + + def save_text_content(self, path: str, text: str) -> bool: + """Save text content to storage. + + Args: + path: Storage path + text: Text content + + Returns: + True if successful + """ + return self.upload(path, text.encode("utf-8")) + + def upload_reference_data( + self, + project_id: UUID, + filename: str, + content: bytes, + metadata: dict[str, Any] | None = None, + ) -> str: + """Upload reference data for a Look-Up project. + + Args: + project_id: Project UUID + filename: Original filename + content: File content + metadata: Optional metadata + + Returns: + Storage path of uploaded file + """ + # Generate storage path + storage_path = self._get_file_path(project_id, filename) + + # Upload file + if self.upload(storage_path, content): + # Store metadata if provided + if metadata: + meta_path = f"{storage_path}.meta.json" + import json + + meta_content = json.dumps(metadata, indent=2) + self.save_text_content(meta_path, meta_content) + + return storage_path + else: + raise Exception(f"Failed to upload reference data to {storage_path}") + + def get_reference_data(self, project_id: UUID, filename: str) -> str | None: + """Get reference data text content. + + Args: + project_id: Project UUID + filename: File name + + Returns: + Text content or None + """ + storage_path = self._get_file_path(project_id, filename) + return self.get_text_content(storage_path) + + def list_project_files(self, project_id: UUID) -> list[str]: + """List all files for a project. + + Args: + project_id: Project UUID + + Returns: + List of file paths + """ + project_path = self._get_project_path(project_id) + return self.list_files(project_path) + + def delete_project_data(self, project_id: UUID) -> bool: + """Delete all data for a project. + + Args: + project_id: Project UUID + + Returns: + True if successful + """ + try: + project_path = self._get_project_path(project_id) + files = self.list_project_files(project_id) + + # Delete all files + for file_path in files: + self.delete(file_path) + + # Delete directory + if self.exists(project_path): + self.fs_instance.rmdir(project_path) + + logger.info(f"Deleted all data for project {project_id}") + return True + + except Exception as e: + logger.error(f"Failed to delete project data: {e}") + return False diff --git a/backend/lookup/integrations/unstract_llm_client.py b/backend/lookup/integrations/unstract_llm_client.py new file mode 100644 index 0000000000..48baf8484b --- /dev/null +++ b/backend/lookup/integrations/unstract_llm_client.py @@ -0,0 +1,113 @@ +"""Unstract LLM Client for Look-Up enrichment. + +This module provides integration with Unstract's LLM abstraction +using the SDK's LLM class for generating enrichment data. +""" + +import json +import logging +from typing import Any, Protocol + +from adapter_processor_v2.models import AdapterInstance + +from unstract.sdk1.llm import LLM + +logger = logging.getLogger(__name__) + + +class LLMClient(Protocol): + """Protocol for LLM client abstraction.""" + + def generate(self, prompt: str, config: dict[str, Any]) -> str: + """Generate LLM response for the prompt.""" + ... + + +class UnstractLLMClient(LLMClient): + """LLM client implementation using Unstract's SDK LLM class. + + This client uses the actual LLM adapters configured in the platform + to generate enrichment data for Look-Ups. + """ + + def __init__(self, adapter_instance: AdapterInstance): + """Initialize the LLM client with an adapter instance. + + Args: + adapter_instance: The AdapterInstance model object containing + the LLM configuration + """ + self.adapter_instance = adapter_instance + self.adapter_id = adapter_instance.adapter_id + self.adapter_metadata = adapter_instance.metadata # Decrypted metadata + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate LLM response for Look-Up enrichment. + + Args: + prompt: The prompt text with resolved variables and reference data + config: Additional LLM configuration (temperature, etc.) + timeout: Request timeout in seconds + + Returns: + JSON-formatted string with enrichment data + + Raises: + RuntimeError: If LLM call fails + """ + try: + # Create LLM instance using SDK + llm = LLM(adapter_id=self.adapter_id, adapter_metadata=self.adapter_metadata) + + # Call the LLM + logger.debug(f"Calling LLM with prompt length: {len(prompt)}") + response = llm.complete(prompt) + + # Extract the response text + response_text = response["response"].text + + logger.debug(f"LLM response: {response_text[:500]}...") + + # Validate it's valid JSON + try: + json.loads(response_text) + return response_text + except json.JSONDecodeError: + # Try to extract JSON from response + return self._extract_json(response_text) + + except Exception as e: + logger.error(f"LLM generation failed: {e}") + raise RuntimeError(f"LLM generation failed: {str(e)}") + + def _extract_json(self, response: str) -> str: + """Extract JSON from LLM response if wrapped in text. + + Args: + response: Raw LLM response + + Returns: + Extracted JSON string + """ + # Look for JSON object pattern (handles nested objects) + # Try to find content between first { and last } + start_idx = response.find("{") + end_idx = response.rfind("}") + + if start_idx != -1 and end_idx != -1 and end_idx > start_idx: + potential_json = response[start_idx : end_idx + 1] + try: + json.loads(potential_json) + return potential_json + except json.JSONDecodeError: + pass + + # If no valid JSON found, create a basic response + logger.warning(f"Could not extract JSON from response: {response[:200]}") + return json.dumps( + { + "raw_response": response[:500], + "confidence": 0.3, + "warning": "Could not extract structured data from LLM response", + } + ) diff --git a/backend/lookup/migrations/0001_initial.py b/backend/lookup/migrations/0001_initial.py new file mode 100644 index 0000000000..da583747de --- /dev/null +++ b/backend/lookup/migrations/0001_initial.py @@ -0,0 +1,886 @@ +# Generated by Django - Squashed migration for lookup app + +import uuid + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + """Squashed initial migration for the lookup app. + + This migration creates all lookup-related models: + - LookupProject: Main project configuration + - LookupPromptTemplate: Prompt templates with variable detection + - LookupDataSource: Reference data file management with versioning + - LookupProfileManager: Adapter profiles for indexing/querying + - LookupIndexManager: Index tracking for vector DB + - PromptStudioLookupLink: Links between PS projects and Lookups + - LookupExecutionAudit: Execution history and metrics + """ + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("account_v2", "0001_initial"), + ("adapter_processor_v2", "0001_initial"), + ] + + operations = [ + # 1. Create LookupProject model + migrations.CreateModel( + name="LookupProject", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "name", + models.CharField( + help_text="Name of the Look-Up project", max_length=255 + ), + ), + ( + "description", + models.TextField( + blank=True, + help_text="Description of the Look-Up project's purpose", + null=True, + ), + ), + ( + "lookup_type", + models.CharField( + choices=[("static_data", "Static Data")], + default="static_data", + help_text="Type of Look-Up (only static_data for POC)", + max_length=50, + ), + ), + ( + "reference_data_type", + models.CharField( + choices=[ + ("vendor_catalog", "Vendor Catalog"), + ("product_catalog", "Product Catalog"), + ("customer_database", "Customer Database"), + ("pricing_data", "Pricing Data"), + ("compliance_rules", "Compliance Rules"), + ("custom", "Custom"), + ], + help_text="Category of reference data being stored", + max_length=50, + ), + ), + ( + "is_active", + models.BooleanField( + default=True, help_text="Whether this project is active" + ), + ), + ( + "metadata", + models.JSONField( + blank=True, + default=dict, + help_text="Additional metadata for the project", + ), + ), + ( + "llm_provider", + models.CharField( + blank=True, + choices=[ + ("openai", "OpenAI"), + ("anthropic", "Anthropic"), + ("azure", "Azure OpenAI"), + ("custom", "Custom Provider"), + ], + help_text="LLM provider to use for matching", + max_length=50, + null=True, + ), + ), + ( + "llm_model", + models.CharField( + blank=True, + help_text="Specific model name (e.g., gpt-4-turbo, claude-3-opus)", + max_length=100, + null=True, + ), + ), + ( + "llm_config", + models.JSONField( + blank=True, + default=dict, + help_text="Additional LLM configuration (temperature, max_tokens, etc.)", + ), + ), + ( + "created_by", + models.ForeignKey( + help_text="User who created this project", + on_delete=django.db.models.deletion.RESTRICT, + related_name="created_lookup_projects", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "organization", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="lookup_projects", + to="account_v2.organization", + ), + ), + ], + options={ + "verbose_name": "Look-Up Project", + "verbose_name_plural": "Look-Up Projects", + "db_table": "lookup_projects", + "ordering": ["-created_at"], + }, + ), + migrations.AddIndex( + model_name="lookupproject", + index=models.Index(fields=["organization"], name="lookup_proj_organiz_idx"), + ), + migrations.AddIndex( + model_name="lookupproject", + index=models.Index(fields=["created_by"], name="lookup_proj_created_idx"), + ), + migrations.AddIndex( + model_name="lookupproject", + index=models.Index(fields=["modified_at"], name="lookup_proj_modifie_idx"), + ), + # 2. Create LookupPromptTemplate model + migrations.CreateModel( + name="LookupPromptTemplate", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "name", + models.CharField( + help_text="Template name for identification", max_length=255 + ), + ), + ( + "template_text", + models.TextField(help_text="Template with {{variable}} placeholders"), + ), + ( + "llm_config", + models.JSONField( + blank=True, + default=dict, + help_text="LLM configuration including adapter_id", + ), + ), + ( + "is_active", + models.BooleanField( + default=True, help_text="Whether this template is active" + ), + ), + ( + "variable_mappings", + models.JSONField( + blank=True, + default=dict, + help_text="Optional documentation of variable mappings", + ), + ), + ( + "created_by", + models.ForeignKey( + help_text="User who created this template", + on_delete=django.db.models.deletion.RESTRICT, + related_name="created_lookup_templates", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "project", + models.OneToOneField( + help_text="Parent Look-Up project", + on_delete=django.db.models.deletion.CASCADE, + related_name="prompt_template_link", + to="lookup.lookupproject", + ), + ), + ], + options={ + "verbose_name": "Look-Up Prompt Template", + "verbose_name_plural": "Look-Up Prompt Templates", + "db_table": "lookup_prompt_templates", + "ordering": ["-modified_at"], + }, + ), + # 3. Add template FK to LookupProject (after template is created) + migrations.AddField( + model_name="lookupproject", + name="template", + field=models.ForeignKey( + blank=True, + help_text="Prompt template for this project", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="projects", + to="lookup.lookupprompttemplate", + ), + ), + # 4. Create LookupDataSource model + migrations.CreateModel( + name="LookupDataSource", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "file_name", + models.CharField(help_text="Original filename", max_length=255), + ), + ("file_path", models.TextField(help_text="Path in object storage")), + ("file_size", models.BigIntegerField(help_text="File size in bytes")), + ( + "file_type", + models.CharField( + choices=[ + ("pdf", "PDF"), + ("xlsx", "Excel"), + ("csv", "CSV"), + ("docx", "Word"), + ("txt", "Text"), + ("json", "JSON"), + ], + help_text="Type of file", + max_length=50, + ), + ), + ( + "extracted_content_path", + models.TextField( + blank=True, + help_text="Path to extracted text in object storage", + null=True, + ), + ), + ( + "extraction_status", + models.CharField( + choices=[ + ("pending", "Pending"), + ("processing", "Processing"), + ("completed", "Completed"), + ("failed", "Failed"), + ], + default="pending", + help_text="Status of text extraction", + max_length=20, + ), + ), + ( + "extraction_error", + models.TextField( + blank=True, + help_text="Error details if extraction failed", + null=True, + ), + ), + ( + "version_number", + models.IntegerField( + default=1, + help_text="Version number of this data source (auto-incremented)", + ), + ), + ( + "is_latest", + models.BooleanField( + default=True, help_text="Whether this is the latest version" + ), + ), + ( + "project", + models.ForeignKey( + help_text="Parent Look-Up project", + on_delete=django.db.models.deletion.CASCADE, + related_name="data_sources", + to="lookup.lookupproject", + ), + ), + ( + "uploaded_by", + models.ForeignKey( + help_text="User who uploaded this file", + on_delete=django.db.models.deletion.RESTRICT, + related_name="uploaded_lookup_data", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Look-Up Data Source", + "verbose_name_plural": "Look-Up Data Sources", + "db_table": "lookup_data_sources", + "ordering": ["-version_number"], + "unique_together": {("project", "version_number")}, + }, + ), + migrations.AddIndex( + model_name="lookupdatasource", + index=models.Index(fields=["project"], name="lookup_data_project_idx"), + ), + migrations.AddIndex( + model_name="lookupdatasource", + index=models.Index( + fields=["project", "is_latest"], name="lookup_data_proj_latest_idx" + ), + ), + migrations.AddIndex( + model_name="lookupdatasource", + index=models.Index(fields=["created_at"], name="lookup_data_created_idx"), + ), + migrations.AddIndex( + model_name="lookupdatasource", + index=models.Index( + fields=["extraction_status"], name="lookup_data_extract_idx" + ), + ), + # 5. Create LookupProfileManager model + migrations.CreateModel( + name="LookupProfileManager", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "profile_id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("profile_name", models.TextField(db_comment="Name of the profile")), + ( + "chunk_size", + models.IntegerField( + db_comment="Size of text chunks for indexing", default=1000 + ), + ), + ( + "chunk_overlap", + models.IntegerField( + db_comment="Overlap between consecutive chunks", default=200 + ), + ), + ( + "similarity_top_k", + models.IntegerField( + db_comment="Number of top similar chunks to retrieve", default=5 + ), + ), + ( + "is_default", + models.BooleanField( + db_comment="Whether this is the default profile for the project", + default=False, + ), + ), + ( + "reindex", + models.BooleanField( + db_comment="Flag to trigger re-indexing of reference data", + default=False, + ), + ), + ( + "lookup_project", + models.ForeignKey( + db_comment="Look-Up project this profile belongs to", + on_delete=django.db.models.deletion.CASCADE, + related_name="profiles", + to="lookup.lookupproject", + ), + ), + ( + "vector_store", + models.ForeignKey( + db_comment="Vector database adapter for storing embeddings", + on_delete=django.db.models.deletion.PROTECT, + related_name="lookup_profiles_vector_store", + to="adapter_processor_v2.adapterinstance", + ), + ), + ( + "embedding_model", + models.ForeignKey( + db_comment="Embedding model adapter for generating vectors", + on_delete=django.db.models.deletion.PROTECT, + related_name="lookup_profiles_embedding_model", + to="adapter_processor_v2.adapterinstance", + ), + ), + ( + "llm", + models.ForeignKey( + db_comment="LLM adapter for query processing and response generation", + on_delete=django.db.models.deletion.PROTECT, + related_name="lookup_profiles_llm", + to="adapter_processor_v2.adapterinstance", + ), + ), + ( + "x2text", + models.ForeignKey( + db_comment="X2Text adapter for extracting text from various file formats", + on_delete=django.db.models.deletion.PROTECT, + related_name="lookup_profiles_x2text", + to="adapter_processor_v2.adapterinstance", + ), + ), + ( + "created_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="lookup_profile_managers_created", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "modified_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="lookup_profile_managers_modified", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Lookup Profile Manager", + "verbose_name_plural": "Lookup Profile Managers", + "db_table": "lookup_profile_manager", + }, + ), + migrations.AddConstraint( + model_name="lookupprofilemanager", + constraint=models.UniqueConstraint( + fields=("lookup_project", "profile_name"), + name="unique_lookup_project_profile_name_index", + ), + ), + # 6. Create LookupIndexManager model + migrations.CreateModel( + name="LookupIndexManager", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "index_manager_id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "raw_index_id", + models.CharField( + blank=True, + db_comment="Raw index ID for vector DB", + editable=False, + max_length=255, + null=True, + ), + ), + ( + "index_ids_history", + models.JSONField( + blank=False, + db_comment="List of all index IDs created for this data source", + default=list, + null=False, + ), + ), + ( + "extraction_status", + models.JSONField( + blank=False, + db_comment='Extraction status per X2Text config: {x2text_config_hash: {"extracted": bool, "enable_highlight": bool, "error": str}}', + default=dict, + null=False, + ), + ), + ( + "status", + models.JSONField( + blank=False, + db_comment="Extraction and indexing status", + default=dict, + null=False, + ), + ), + ( + "data_source", + models.ForeignKey( + db_comment="Reference data source being indexed", + editable=False, + on_delete=django.db.models.deletion.CASCADE, + related_name="index_managers", + to="lookup.lookupdatasource", + ), + ), + ( + "profile_manager", + models.ForeignKey( + blank=True, + db_comment="Profile used for indexing this data source", + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="index_managers", + to="lookup.lookupprofilemanager", + ), + ), + ( + "created_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="lookup_index_managers_created", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "modified_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="lookup_index_managers_modified", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Lookup Index Manager", + "verbose_name_plural": "Lookup Index Managers", + "db_table": "lookup_index_manager", + }, + ), + migrations.AddConstraint( + model_name="lookupindexmanager", + constraint=models.UniqueConstraint( + fields=("data_source", "profile_manager"), + name="unique_data_source_profile_manager_index", + ), + ), + # 7. Create PromptStudioLookupLink model + migrations.CreateModel( + name="PromptStudioLookupLink", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "prompt_studio_project_id", + models.UUIDField(help_text="UUID of the Prompt Studio project"), + ), + ( + "execution_order", + models.PositiveIntegerField( + default=0, + help_text="Order in which this Look-Up executes (lower numbers first)", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ( + "lookup_project", + models.ForeignKey( + help_text="Linked Look-Up project", + on_delete=django.db.models.deletion.CASCADE, + related_name="ps_links", + to="lookup.lookupproject", + ), + ), + ], + options={ + "verbose_name": "Prompt Studio Look-Up Link", + "verbose_name_plural": "Prompt Studio Look-Up Links", + "db_table": "prompt_studio_lookup_links", + "ordering": ["execution_order", "created_at"], + "unique_together": {("prompt_studio_project_id", "lookup_project")}, + }, + ), + migrations.AddIndex( + model_name="promptstudiolookuplink", + index=models.Index( + fields=["prompt_studio_project_id"], name="ps_lookup_link_ps_proj_idx" + ), + ), + migrations.AddIndex( + model_name="promptstudiolookuplink", + index=models.Index( + fields=["lookup_project"], name="ps_lookup_link_lookup_idx" + ), + ), + migrations.AddIndex( + model_name="promptstudiolookuplink", + index=models.Index( + fields=["prompt_studio_project_id", "execution_order"], + name="ps_lookup_link_order_idx", + ), + ), + # 8. Create LookupExecutionAudit model + migrations.CreateModel( + name="LookupExecutionAudit", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "prompt_studio_project_id", + models.UUIDField( + blank=True, + help_text="Associated Prompt Studio project if applicable", + null=True, + ), + ), + ( + "execution_id", + models.UUIDField( + help_text="Groups all Look-Ups in a single execution batch" + ), + ), + ( + "input_data", + models.JSONField( + help_text="Input data from Prompt Studio extraction" + ), + ), + ( + "reference_data_version", + models.IntegerField(help_text="Version of reference data used"), + ), + ( + "enriched_output", + models.JSONField( + blank=True, help_text="Enrichment data produced", null=True + ), + ), + ( + "llm_provider", + models.CharField(help_text="LLM provider used", max_length=50), + ), + ( + "llm_model", + models.CharField(help_text="LLM model used", max_length=100), + ), + ("llm_prompt", models.TextField(help_text="Full prompt sent to LLM")), + ( + "llm_response", + models.TextField(blank=True, help_text="Raw LLM response", null=True), + ), + ( + "llm_response_cached", + models.BooleanField( + default=False, help_text="Whether response was from cache" + ), + ), + ( + "execution_time_ms", + models.IntegerField( + blank=True, + help_text="Total execution time in milliseconds", + null=True, + ), + ), + ( + "llm_call_time_ms", + models.IntegerField( + blank=True, + help_text="LLM API call time in milliseconds", + null=True, + ), + ), + ( + "status", + models.CharField( + choices=[ + ("success", "Success"), + ("partial", "Partial Success"), + ("failed", "Failed"), + ], + help_text="Execution status", + max_length=20, + ), + ), + ( + "error_message", + models.TextField( + blank=True, help_text="Error details if failed", null=True + ), + ), + ( + "confidence_score", + models.DecimalField( + blank=True, + decimal_places=2, + help_text="Confidence score from LLM (0.00 to 1.00)", + max_digits=3, + null=True, + ), + ), + ( + "executed_at", + models.DateTimeField( + auto_now_add=True, help_text="When the execution occurred" + ), + ), + ( + "lookup_project", + models.ForeignKey( + help_text="Look-Up project that was executed", + on_delete=django.db.models.deletion.CASCADE, + related_name="execution_audits", + to="lookup.lookupproject", + ), + ), + ], + options={ + "verbose_name": "Look-Up Execution Audit", + "verbose_name_plural": "Look-Up Execution Audits", + "db_table": "lookup_execution_audit", + "ordering": ["-executed_at"], + }, + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index(fields=["lookup_project"], name="lookup_audit_proj_idx"), + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index(fields=["execution_id"], name="lookup_audit_exec_idx"), + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index(fields=["executed_at"], name="lookup_audit_time_idx"), + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index(fields=["status"], name="lookup_audit_status_idx"), + ), + ] diff --git a/backend/lookup/migrations/__init__.py b/backend/lookup/migrations/__init__.py new file mode 100644 index 0000000000..b98f22cc20 --- /dev/null +++ b/backend/lookup/migrations/__init__.py @@ -0,0 +1 @@ +# Lookup migrations package diff --git a/backend/lookup/models/__init__.py b/backend/lookup/models/__init__.py new file mode 100644 index 0000000000..4c196dc030 --- /dev/null +++ b/backend/lookup/models/__init__.py @@ -0,0 +1,24 @@ +"""Look-Up system models.""" + +from .lookup_data_source import LookupDataSource, LookupDataSourceManager +from .lookup_execution_audit import LookupExecutionAudit +from .lookup_index_manager import LookupIndexManager +from .lookup_profile_manager import LookupProfileManager +from .lookup_project import LookupProject +from .lookup_prompt_template import LookupPromptTemplate +from .prompt_studio_lookup_link import ( + PromptStudioLookupLink, + PromptStudioLookupLinkManager, +) + +__all__ = [ + "LookupProject", + "LookupDataSource", + "LookupDataSourceManager", + "LookupPromptTemplate", + "LookupProfileManager", + "LookupIndexManager", + "PromptStudioLookupLink", + "PromptStudioLookupLinkManager", + "LookupExecutionAudit", +] diff --git a/backend/lookup/models/lookup_data_source.py b/backend/lookup/models/lookup_data_source.py new file mode 100644 index 0000000000..49c1ea0011 --- /dev/null +++ b/backend/lookup/models/lookup_data_source.py @@ -0,0 +1,192 @@ +"""LookupDataSource model for managing reference data versions.""" + +import uuid + +from django.contrib.auth import get_user_model +from django.db import models +from django.db.models.signals import pre_save +from django.dispatch import receiver +from utils.models.base_model import BaseModel + +User = get_user_model() + + +class LookupDataSourceManager(models.Manager): + """Custom manager for LookupDataSource.""" + + def get_latest_for_project(self, project_id: uuid.UUID): + """Get the latest data sources for a project. + + Args: + project_id: UUID of the lookup project + + Returns: + QuerySet of latest data sources + """ + return self.filter(project_id=project_id, is_latest=True) + + def get_ready_for_project(self, project_id: uuid.UUID): + """Get all completed latest data sources for a project. + + Args: + project_id: UUID of the lookup project + + Returns: + QuerySet of completed latest data sources + """ + return self.get_latest_for_project(project_id).filter( + extraction_status="completed" + ) + + +class LookupDataSource(BaseModel): + """Represents a reference data source with version management. + + Each upload creates a new version, with automatic version numbering + and latest flag management. + """ + + EXTRACTION_STATUS_CHOICES = [ + ("pending", "Pending"), + ("processing", "Processing"), + ("completed", "Completed"), + ("failed", "Failed"), + ] + + FILE_TYPE_CHOICES = [ + ("pdf", "PDF"), + ("xlsx", "Excel"), + ("csv", "CSV"), + ("docx", "Word"), + ("txt", "Text"), + ("json", "JSON"), + ] + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + project = models.ForeignKey( + "lookup.LookupProject", + on_delete=models.CASCADE, + related_name="data_sources", + help_text="Parent Look-Up project", + ) + + # File Information + file_name = models.CharField(max_length=255, help_text="Original filename") + file_path = models.TextField(help_text="Path in object storage") + file_size = models.BigIntegerField(help_text="File size in bytes") + file_type = models.CharField( + max_length=50, choices=FILE_TYPE_CHOICES, help_text="Type of file" + ) + + # Extracted Content + extracted_content_path = models.TextField( + blank=True, null=True, help_text="Path to extracted text in object storage" + ) + extraction_status = models.CharField( + max_length=20, + choices=EXTRACTION_STATUS_CHOICES, + default="pending", + help_text="Status of text extraction", + ) + extraction_error = models.TextField( + blank=True, null=True, help_text="Error details if extraction failed" + ) + + # Version Management + version_number = models.IntegerField( + default=1, help_text="Version number of this data source (auto-incremented)" + ) + is_latest = models.BooleanField( + default=True, help_text="Whether this is the latest version" + ) + + # Upload Information + uploaded_by = models.ForeignKey( + User, + on_delete=models.RESTRICT, + related_name="uploaded_lookup_data", + help_text="User who uploaded this file", + ) + + objects = LookupDataSourceManager() + + class Meta: + """Model metadata.""" + + db_table = "lookup_data_sources" + ordering = ["-version_number"] + unique_together = [["project", "version_number"]] + verbose_name = "Look-Up Data Source" + verbose_name_plural = "Look-Up Data Sources" + indexes = [ + models.Index(fields=["project"]), + models.Index(fields=["project", "is_latest"]), + models.Index(fields=["created_at"]), + models.Index(fields=["extraction_status"]), + ] + + def __str__(self) -> str: + """String representation.""" + return f"{self.project.name} - v{self.version_number} - {self.file_name}" + + def get_file_size_display(self) -> str: + """Get human-readable file size. + + Returns: + Formatted file size string (e.g., "51.2 KB") + """ + size = self.file_size + for unit in ["B", "KB", "MB", "GB"]: + if size < 1024.0: + return f"{size:.1f} {unit}" + size /= 1024.0 + return f"{size:.1f} TB" + + @property + def is_extraction_complete(self) -> bool: + """Check if extraction is successfully completed.""" + return self.extraction_status == "completed" + + def get_extracted_content(self) -> str | None: + """Load extracted content from object storage. + + Returns: + Extracted text content or None if not available. + + Note: + This is a placeholder - actual implementation will + load from object storage using extracted_content_path. + """ + if not self.extracted_content_path or not self.is_extraction_complete: + return None + + # TODO: Implement actual object storage retrieval + # For now, return a placeholder + return f"[Content from {self.extracted_content_path}]" + + +@receiver(pre_save, sender=LookupDataSource) +def auto_increment_version_and_update_latest(sender, instance, **kwargs): + """Signal to auto-increment version number and manage is_latest flag. + + This signal: + 1. Auto-increments version_number if not set + 2. Marks all previous versions as not latest + """ + # If this is a new instance (not updating existing) + if not instance.pk: + # Get the highest version number for this project + max_version = LookupDataSource.objects.filter(project=instance.project).aggregate( + max_version=models.Max("version_number") + )["max_version"] + + # Always auto-increment version for new instances + instance.version_number = (max_version or 0) + 1 + + # Mark all previous versions as not latest + LookupDataSource.objects.filter(project=instance.project, is_latest=True).update( + is_latest=False + ) + + # Ensure new version is marked as latest + instance.is_latest = True diff --git a/backend/lookup/models/lookup_execution_audit.py b/backend/lookup/models/lookup_execution_audit.py new file mode 100644 index 0000000000..6bda095155 --- /dev/null +++ b/backend/lookup/models/lookup_execution_audit.py @@ -0,0 +1,134 @@ +"""LookupExecutionAudit model for tracking Look-Up execution history.""" + +import uuid +from decimal import Decimal + +from django.db import models + + +class LookupExecutionAudit(models.Model): + """Audit log for Look-Up executions. + + Tracks all execution attempts with detailed metadata for debugging + and performance monitoring. + """ + + STATUS_CHOICES = [ + ("success", "Success"), + ("partial", "Partial Success"), + ("failed", "Failed"), + ] + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + + # Execution Context + lookup_project = models.ForeignKey( + "lookup.LookupProject", + on_delete=models.CASCADE, + related_name="execution_audits", + help_text="Look-Up project that was executed", + ) + prompt_studio_project_id = models.UUIDField( + null=True, blank=True, help_text="Associated Prompt Studio project if applicable" + ) + execution_id = models.UUIDField( + help_text="Groups all Look-Ups in a single execution batch" + ) + + # Input/Output + input_data = models.JSONField(help_text="Input data from Prompt Studio extraction") + reference_data_version = models.IntegerField( + help_text="Version of reference data used" + ) + enriched_output = models.JSONField( + null=True, blank=True, help_text="Enrichment data produced" + ) + + # LLM Details + llm_provider = models.CharField(max_length=50, help_text="LLM provider used") + llm_model = models.CharField(max_length=100, help_text="LLM model used") + llm_prompt = models.TextField(help_text="Full prompt sent to LLM") + llm_response = models.TextField(null=True, blank=True, help_text="Raw LLM response") + llm_response_cached = models.BooleanField( + default=False, help_text="Whether response was from cache" + ) + + # Performance Metrics + execution_time_ms = models.IntegerField( + null=True, blank=True, help_text="Total execution time in milliseconds" + ) + llm_call_time_ms = models.IntegerField( + null=True, blank=True, help_text="LLM API call time in milliseconds" + ) + + # Status & Errors + status = models.CharField( + max_length=20, choices=STATUS_CHOICES, help_text="Execution status" + ) + error_message = models.TextField( + null=True, blank=True, help_text="Error details if failed" + ) + confidence_score = models.DecimalField( + max_digits=3, + decimal_places=2, + null=True, + blank=True, + validators=[], + help_text="Confidence score from LLM (0.00 to 1.00)", + ) + + # Timestamps + executed_at = models.DateTimeField( + auto_now_add=True, help_text="When the execution occurred" + ) + + class Meta: + """Model metadata.""" + + db_table = "lookup_execution_audit" + ordering = ["-executed_at"] + verbose_name = "Look-Up Execution Audit" + verbose_name_plural = "Look-Up Execution Audits" + indexes = [ + models.Index(fields=["lookup_project"]), + models.Index(fields=["execution_id"]), + models.Index(fields=["executed_at"]), + models.Index(fields=["status"]), + ] + + def __str__(self) -> str: + """String representation.""" + return f"{self.lookup_project.name} - {self.status} - {self.executed_at}" + + @property + def was_successful(self) -> bool: + """Check if execution was successful.""" + return self.status in ["success", "partial"] + + @property + def execution_duration_seconds(self) -> float: + """Get execution duration in seconds.""" + if self.execution_time_ms: + return self.execution_time_ms / 1000.0 + return 0.0 + + def clean(self): + """Validate the audit record.""" + super().clean() + from django.core.exceptions import ValidationError + + # Validate confidence score range + if self.confidence_score is not None: + if not (Decimal("0.00") <= self.confidence_score <= Decimal("1.00")): + raise ValidationError( + f"Confidence score must be between 0.00 and 1.00, " + f"got {self.confidence_score}" + ) + + # Ensure error_message is provided for failed status + if self.status == "failed" and not self.error_message: + raise ValidationError("Error message is required for failed executions") + + # Ensure enriched_output is provided for success status + if self.status == "success" and not self.enriched_output: + raise ValidationError("Enriched output is required for successful executions") diff --git a/backend/lookup/models/lookup_index_manager.py b/backend/lookup/models/lookup_index_manager.py new file mode 100644 index 0000000000..3d483a7a3a --- /dev/null +++ b/backend/lookup/models/lookup_index_manager.py @@ -0,0 +1,188 @@ +"""LookupIndexManager model for tracking indexed reference data.""" + +import logging +import uuid + +from account_v2.models import User +from django.db import models +from django.db.models.signals import pre_delete +from django.dispatch import receiver +from utils.models.base_model import BaseModel +from utils.user_context import UserContext + +from unstract.sdk1.constants import LogLevel +from unstract.sdk1.vector_db import VectorDB + +logger = logging.getLogger(__name__) + + +class LookupIndexManager(BaseModel): + """Model to store indexing details for Look-Up reference data. + + Tracks which data sources have been indexed with which profile, + stores vector DB index IDs, and manages extraction status. + + Follows the same pattern as Prompt Studio's IndexManager. + """ + + index_manager_id = models.UUIDField( + primary_key=True, default=uuid.uuid4, editable=False + ) + + # Reference to the data source being indexed + data_source = models.ForeignKey( + "LookupDataSource", + on_delete=models.CASCADE, + related_name="index_managers", + editable=False, + null=False, + blank=False, + db_comment="Reference data source being indexed", + ) + + # Reference to the profile used for indexing + profile_manager = models.ForeignKey( + "LookupProfileManager", + on_delete=models.SET_NULL, + related_name="index_managers", + editable=False, + null=True, + blank=True, + db_comment="Profile used for indexing this data source", + ) + + # Vector DB index ID for this data source (raw index) + raw_index_id = models.CharField( + max_length=255, + db_comment="Raw index ID for vector DB", + editable=False, + null=True, + blank=True, + ) + + # History of all index IDs (for cleanup on deletion) + index_ids_history = models.JSONField( + db_comment="List of all index IDs created for this data source", + default=list, + null=False, + blank=False, + ) + + # Extraction status per X2Text configuration + # Format: {x2text_config_hash: {"extracted": bool, "enable_highlight": bool, "error": str|null}} + extraction_status = models.JSONField( + db_comment='Extraction status per X2Text config: {x2text_config_hash: {"extracted": bool, "enable_highlight": bool, "error": str}}', + default=dict, + null=False, + blank=False, + ) + + # Overall extraction and indexing status (legacy field, kept for compatibility) + # Format: {"extracted": bool, "indexed": bool, "error": str|null} + status = models.JSONField( + db_comment="Extraction and indexing status", null=False, blank=False, default=dict + ) + + # Audit fields + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="lookup_index_managers_created", + null=True, + blank=True, + editable=False, + ) + + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="lookup_index_managers_modified", + null=True, + blank=True, + editable=False, + ) + + class Meta: + verbose_name = "Lookup Index Manager" + verbose_name_plural = "Lookup Index Managers" + db_table = "lookup_index_manager" + constraints = [ + models.UniqueConstraint( + fields=["data_source", "profile_manager"], + name="unique_data_source_profile_manager_index", + ), + ] + + def __str__(self): + return f"Index for {self.data_source.file_name} with {self.profile_manager.profile_name if self.profile_manager else 'No Profile'}" + + +def delete_from_vector_db(index_ids_history, vector_db_instance_id): + """Delete index IDs from vector database. + + Args: + index_ids_history: List of index IDs to delete + vector_db_instance_id: UUID of the vector DB adapter instance + """ + try: + from prompt_studio.prompt_studio_core_v2.prompt_ide_base_tool import ( + PromptIdeBaseTool, + ) + + organization_identifier = UserContext.get_organization_identifier() + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=organization_identifier) + + vector_db = VectorDB( + tool=util, + adapter_instance_id=vector_db_instance_id, + ) + + for index_id in index_ids_history: + logger.debug(f"Deleting from VectorDB - index id: {index_id}") + try: + vector_db.delete(ref_doc_id=index_id) + except Exception as e: + # Log error and continue with the next index id + logger.error(f"Error deleting index: {index_id} - {e}") + + except Exception as e: + logger.error(f"Error in delete_from_vector_db: {e}", exc_info=True) + + +# Signal to perform vector DB cleanup on deletion +@receiver(pre_delete, sender=LookupIndexManager) +def perform_vector_db_cleanup(sender, instance, **kwargs): + """Signal handler to clean up vector DB entries when index is deleted. + + This ensures that when a LookupIndexManager is deleted (e.g., when + a data source is deleted or re-indexed), the corresponding vectors + are removed from the vector database. + """ + logger.debug( + f"Performing vector DB cleanup for data source: " + f"{instance.data_source.file_name}" + ) + + try: + # Get the index_ids_history to clean up from the vector db + index_ids_history = instance.index_ids_history + + if not index_ids_history: + logger.debug("No index IDs to clean up") + return + + if instance.profile_manager and instance.profile_manager.vector_store: + vector_db_instance_id = str(instance.profile_manager.vector_store.id) + delete_from_vector_db(index_ids_history, vector_db_instance_id) + else: + logger.warning( + f"Cannot cleanup vector DB: missing profile or vector store " + f"for data source {instance.data_source.file_name}" + ) + + except Exception as e: + logger.warning( + f"Error during vector DB cleanup for data source " + f"{instance.data_source.file_name}: {e}", + exc_info=True, + ) diff --git a/backend/lookup/models/lookup_profile_manager.py b/backend/lookup/models/lookup_profile_manager.py new file mode 100644 index 0000000000..6d1472d23c --- /dev/null +++ b/backend/lookup/models/lookup_profile_manager.py @@ -0,0 +1,157 @@ +"""LookupProfileManager model for managing adapter profiles in Look-Up projects.""" + +import uuid + +from account_v2.models import User +from adapter_processor_v2.models import AdapterInstance +from django.db import models +from utils.models.base_model import BaseModel + +from lookup.exceptions import DefaultProfileError + + +class LookupProfileManager(BaseModel): + """Model to store adapter configuration profiles for Look-Up projects. + + Each profile defines the set of adapters (X2Text, Embedding, VectorDB, LLM) + to use for text extraction, indexing, and lookup operations. + + Follows the same pattern as Prompt Studio's ProfileManager. + """ + + profile_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + + profile_name = models.TextField( + blank=False, null=False, db_comment="Name of the profile" + ) + + # Foreign key to LookupProject + lookup_project = models.ForeignKey( + "LookupProject", + on_delete=models.CASCADE, + related_name="profiles", + db_comment="Look-Up project this profile belongs to", + ) + + # Required Adapters - All must be configured + vector_store = models.ForeignKey( + AdapterInstance, + db_comment="Vector database adapter for storing embeddings", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="lookup_profiles_vector_store", + ) + + embedding_model = models.ForeignKey( + AdapterInstance, + db_comment="Embedding model adapter for generating vectors", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="lookup_profiles_embedding_model", + ) + + llm = models.ForeignKey( + AdapterInstance, + db_comment="LLM adapter for query processing and response generation", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="lookup_profiles_llm", + ) + + x2text = models.ForeignKey( + AdapterInstance, + db_comment="X2Text adapter for extracting text from various file formats", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="lookup_profiles_x2text", + ) + + # Configuration fields + chunk_size = models.IntegerField( + default=1000, + null=False, + blank=False, + db_comment="Size of text chunks for indexing", + ) + + chunk_overlap = models.IntegerField( + default=200, + null=False, + blank=False, + db_comment="Overlap between consecutive chunks", + ) + + similarity_top_k = models.IntegerField( + default=5, + null=False, + blank=False, + db_comment="Number of top similar chunks to retrieve", + ) + + # Flags + is_default = models.BooleanField( + default=False, db_comment="Whether this is the default profile for the project" + ) + + reindex = models.BooleanField( + default=False, db_comment="Flag to trigger re-indexing of reference data" + ) + + # Audit fields + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="lookup_profile_managers_created", + null=True, + blank=True, + editable=False, + ) + + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="lookup_profile_managers_modified", + null=True, + blank=True, + editable=False, + ) + + class Meta: + verbose_name = "Lookup Profile Manager" + verbose_name_plural = "Lookup Profile Managers" + db_table = "lookup_profile_manager" + constraints = [ + models.UniqueConstraint( + fields=["lookup_project", "profile_name"], + name="unique_lookup_project_profile_name_index", + ), + ] + + def __str__(self): + return f"{self.profile_name} ({self.lookup_project.name})" + + @staticmethod + def get_default_profile(project): + """Get the default profile for a Look-Up project. + + Args: + project: LookupProject instance + + Returns: + LookupProfileManager: The default profile + + Raises: + DefaultProfileError: If no default profile exists + """ + try: + return LookupProfileManager.objects.get( + lookup_project=project, is_default=True + ) + except LookupProfileManager.DoesNotExist: + raise DefaultProfileError( + f"No default profile found for project {project.name}" + ) diff --git a/backend/lookup/models/lookup_project.py b/backend/lookup/models/lookup_project.py new file mode 100644 index 0000000000..14b2d4d671 --- /dev/null +++ b/backend/lookup/models/lookup_project.py @@ -0,0 +1,177 @@ +"""LookupProject model for Static Data-based Look-Ups.""" + +import uuid + +from django.contrib.auth import get_user_model +from django.db import models +from django.urls import reverse +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationMixin, +) + +User = get_user_model() + + +class LookupProject(DefaultOrganizationMixin, BaseModel): + """Represents a Look-Up project for static data-based enrichment. + + This model stores the configuration for a Look-Up project including + LLM settings and organization association. + """ + + LOOKUP_TYPE_CHOICES = [ + ("static_data", "Static Data"), + ] + + REFERENCE_DATA_TYPE_CHOICES = [ + ("vendor_catalog", "Vendor Catalog"), + ("product_catalog", "Product Catalog"), + ("customer_database", "Customer Database"), + ("pricing_data", "Pricing Data"), + ("compliance_rules", "Compliance Rules"), + ("custom", "Custom"), + ] + + LLM_PROVIDER_CHOICES = [ + ("openai", "OpenAI"), + ("anthropic", "Anthropic"), + ("azure", "Azure OpenAI"), + ("custom", "Custom Provider"), + ] + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + name = models.CharField(max_length=255, help_text="Name of the Look-Up project") + description = models.TextField( + blank=True, null=True, help_text="Description of the Look-Up project's purpose" + ) + lookup_type = models.CharField( + max_length=50, + choices=LOOKUP_TYPE_CHOICES, + default="static_data", + help_text="Type of Look-Up (only static_data for POC)", + ) + reference_data_type = models.CharField( + max_length=50, + choices=REFERENCE_DATA_TYPE_CHOICES, + help_text="Category of reference data being stored", + ) + + # Template and status + template = models.ForeignKey( + "LookupPromptTemplate", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="projects", + help_text="Prompt template for this project", + ) + is_active = models.BooleanField( + default=True, help_text="Whether this project is active" + ) + metadata = models.JSONField( + default=dict, blank=True, help_text="Additional metadata for the project" + ) + + # LLM Configuration + llm_provider = models.CharField( + max_length=50, + choices=LLM_PROVIDER_CHOICES, + null=True, + blank=True, + help_text="LLM provider to use for matching", + ) + llm_model = models.CharField( + max_length=100, + null=True, + blank=True, + help_text="Specific model name (e.g., gpt-4-turbo, claude-3-opus)", + ) + llm_config = models.JSONField( + default=dict, + blank=True, + help_text="Additional LLM configuration (temperature, max_tokens, etc.)", + ) + + # Ownership + created_by = models.ForeignKey( + User, + on_delete=models.RESTRICT, + related_name="created_lookup_projects", + help_text="User who created this project", + ) + + # Note: created_at and modified_at are inherited from BaseModel + # Note: organization ForeignKey is inherited from DefaultOrganizationMixin + + class Meta: + """Model metadata.""" + + db_table = "lookup_projects" + ordering = ["-created_at"] + verbose_name = "Look-Up Project" + verbose_name_plural = "Look-Up Projects" + indexes = [ + models.Index(fields=["organization"]), + models.Index(fields=["created_by"]), + models.Index(fields=["modified_at"]), + ] + + def __str__(self) -> str: + """String representation of the project.""" + return self.name + + def get_absolute_url(self) -> str: + """Get the URL for this project's detail view.""" + return reverse("lookup:project-detail", kwargs={"pk": self.pk}) + + @property + def is_ready(self) -> bool: + """Check if the project has reference data ready for use. + + Returns: + True if all reference data is extracted and ready, False otherwise. + """ + if not hasattr(self, "data_sources"): + return False + + latest_sources = self.data_sources.filter(is_latest=True) + if not latest_sources.exists(): + return False + + return all(source.extraction_status == "completed" for source in latest_sources) + + def get_latest_reference_version(self) -> int | None: + """Get the latest version number of reference data. + + Returns: + Latest version number or None if no data sources exist. + """ + if not hasattr(self, "data_sources"): + return None + + latest = self.data_sources.filter(is_latest=True).first() + return latest.version_number if latest else None + + def clean(self): + """Validate model fields.""" + super().clean() + from django.core.exceptions import ValidationError + + # Validate LLM provider + valid_providers = [choice[0] for choice in self.LLM_PROVIDER_CHOICES] + if self.llm_provider not in valid_providers: + raise ValidationError( + f"Invalid LLM provider: {self.llm_provider}. " + f"Must be one of: {', '.join(valid_providers)}" + ) + + # Validate LLM config structure + if self.llm_config and not isinstance(self.llm_config, dict): + raise ValidationError("LLM config must be a dictionary") + + # Validate lookup_type (only static_data for POC) + if self.lookup_type != "static_data": + raise ValidationError( + "Only 'static_data' lookup type is supported in this POC" + ) diff --git a/backend/lookup/models/lookup_prompt_template.py b/backend/lookup/models/lookup_prompt_template.py new file mode 100644 index 0000000000..62bd7aae5d --- /dev/null +++ b/backend/lookup/models/lookup_prompt_template.py @@ -0,0 +1,178 @@ +"""LookupPromptTemplate model for managing prompt templates.""" + +import re +import uuid + +from account_v2.models import User +from django.core.exceptions import ValidationError +from django.db import models +from utils.models.base_model import BaseModel + + +class LookupPromptTemplate(BaseModel): + """Represents a prompt template with variable detection and validation. + + Each Look-Up project has one template that defines how to construct + the LLM prompt with {{variable}} placeholders. + """ + + VARIABLE_PATTERN = r"\{\{([^}]+)\}\}" + RESERVED_PREFIXES = ["_", "_lookup_"] + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + project = models.OneToOneField( + "lookup.LookupProject", + on_delete=models.CASCADE, + related_name="prompt_template_link", + help_text="Parent Look-Up project", + ) + + # Template Configuration + name = models.CharField(max_length=255, help_text="Template name for identification") + template_text = models.TextField(help_text="Template with {{variable}} placeholders") + llm_config = models.JSONField( + default=dict, blank=True, help_text="LLM configuration including adapter_id" + ) + is_active = models.BooleanField( + default=True, help_text="Whether this template is active" + ) + created_by = models.ForeignKey( + User, + on_delete=models.RESTRICT, + related_name="created_lookup_templates", + help_text="User who created this template", + ) + variable_mappings = models.JSONField( + default=dict, blank=True, help_text="Optional documentation of variable mappings" + ) + + class Meta: + """Model metadata.""" + + db_table = "lookup_prompt_templates" + ordering = ["-modified_at"] + verbose_name = "Look-Up Prompt Template" + verbose_name_plural = "Look-Up Prompt Templates" + + def __str__(self) -> str: + """String representation.""" + return f"Template for {self.project.name}" + + def detect_variables(self) -> list[str]: + """Detect all {{variable}} references in the template. + + Returns: + List of unique variable paths found in the template. + + Example: + Template: "Match {{input_data.vendor}} from {{reference_data}}" + Returns: ["input_data.vendor", "reference_data"] + """ + if not self.template_text: + return [] + + matches = re.findall(self.VARIABLE_PATTERN, self.template_text) + # Strip whitespace and deduplicate + unique_vars = list({m.strip() for m in matches}) + return sorted(unique_vars) + + def validate_syntax(self) -> bool: + """Validate template syntax for matching braces. + + Returns: + True if syntax is valid, False otherwise. + + Raises: + ValidationError: If syntax is invalid. + """ + # Check for unmatched opening braces + open_count = self.template_text.count("{{") + close_count = self.template_text.count("}}") + + if open_count != close_count: + raise ValidationError( + f"Mismatched braces in template: {open_count} opening, " + f"{close_count} closing" + ) + + # Check for nested braces (not allowed) + if re.search(r"\{\{[^}]*\{\{", self.template_text): + raise ValidationError("Nested variable placeholders are not allowed") + + return True + + def validate_reserved_keywords(self) -> bool: + """Check that template doesn't use reserved keywords. + + Returns: + True if no reserved keywords are used. + + Raises: + ValidationError: If reserved keywords are found. + """ + variables = self.detect_variables() + + for var in variables: + # Check if variable starts with reserved prefixes + for prefix in self.RESERVED_PREFIXES: + if var.startswith(prefix): + raise ValidationError( + f"Variable '{var}' uses reserved prefix '{prefix}'. " + f"Reserved prefixes: {', '.join(self.RESERVED_PREFIXES)}" + ) + + # Check if trying to write to reserved fields + if "=" in var or var.endswith("_metadata"): + raise ValidationError( + f"Variable '{var}' appears to be trying to set a value. " + f"Variables should only reference existing data." + ) + + return True + + def get_variable_info(self) -> dict: + """Get detailed information about detected variables. + + Returns: + Dictionary with variable paths and their types/documentation. + """ + variables = self.detect_variables() + info = {} + + for var in variables: + parts = var.split(".") + if len(parts) > 0: + root = parts[0] + path = ".".join(parts[1:]) if len(parts) > 1 else "" + + info[var] = { + "root": root, + "path": path, + "depth": len(parts), + "description": self.variable_mappings.get(var, "No description"), + } + + return info + + def clean(self): + """Validate the template on save.""" + super().clean() + + if not self.template_text: + raise ValidationError("Template text cannot be empty") + + try: + self.validate_syntax() + self.validate_reserved_keywords() + except ValidationError as e: + raise e + + # Warn if no variables detected (might be intentional) + if not self.detect_variables(): + # This is just a warning, not an error + pass + + @property + def variable_count(self) -> int: + """Get the count of unique variables in the template.""" + return len(self.detect_variables()) diff --git a/backend/lookup/models/prompt_studio_lookup_link.py b/backend/lookup/models/prompt_studio_lookup_link.py new file mode 100644 index 0000000000..0a03cd7f63 --- /dev/null +++ b/backend/lookup/models/prompt_studio_lookup_link.py @@ -0,0 +1,128 @@ +"""PromptStudioLookupLink model for linking PS projects with Look-Up projects.""" + +import uuid + +from django.db import models + + +class PromptStudioLookupLinkManager(models.Manager): + """Custom manager for PromptStudioLookupLink.""" + + def get_links_for_ps_project(self, ps_project_id: uuid.UUID): + """Get all Look-Up links for a Prompt Studio project, ordered by execution order. + + Args: + ps_project_id: UUID of the Prompt Studio project + + Returns: + QuerySet of links ordered by execution_order + """ + return self.filter(prompt_studio_project_id=ps_project_id).order_by( + "execution_order", "created_at" + ) + + def get_lookup_projects_for_ps(self, ps_project_id: uuid.UUID): + """Get all Look-Up projects linked to a Prompt Studio project. + + Args: + ps_project_id: UUID of the Prompt Studio project + + Returns: + QuerySet of LookupProject instances + """ + from .lookup_project import LookupProject + + link_ids = self.filter(prompt_studio_project_id=ps_project_id).values_list( + "lookup_project_id", flat=True + ) + + return LookupProject.objects.filter(id__in=link_ids) + + +class PromptStudioLookupLink(models.Model): + """Many-to-many relationship between Prompt Studio projects and Look-Up projects. + + Manages the linking and execution order of Look-Up projects within + a Prompt Studio extraction pipeline. + """ + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + + # Note: We're using UUIDs for now since PS project model isn't defined yet + # In production, this should be a ForeignKey to the actual PS project model + prompt_studio_project_id = models.UUIDField( + help_text="UUID of the Prompt Studio project" + ) + + lookup_project = models.ForeignKey( + "lookup.LookupProject", + on_delete=models.CASCADE, + related_name="ps_links", + help_text="Linked Look-Up project", + ) + + execution_order = models.PositiveIntegerField( + default=0, help_text="Order in which this Look-Up executes (lower numbers first)" + ) + + # Timestamps + created_at = models.DateTimeField(auto_now_add=True) + + objects = PromptStudioLookupLinkManager() + + class Meta: + """Model metadata.""" + + db_table = "prompt_studio_lookup_links" + ordering = ["execution_order", "created_at"] + unique_together = [["prompt_studio_project_id", "lookup_project"]] + verbose_name = "Prompt Studio Look-Up Link" + verbose_name_plural = "Prompt Studio Look-Up Links" + indexes = [ + models.Index(fields=["prompt_studio_project_id"]), + models.Index(fields=["lookup_project"]), + models.Index(fields=["prompt_studio_project_id", "execution_order"]), + ] + + def __str__(self) -> str: + """String representation.""" + return f"PS Project {self.prompt_studio_project_id} → {self.lookup_project.name}" + + def save(self, *args, **kwargs): + """Override save to auto-assign execution_order if not set. + + Auto-assigns the next available execution order number for the + Prompt Studio project if not explicitly provided. + """ + if self.execution_order == 0 and not self.pk: + # Get the maximum execution order for this PS project + max_order = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.prompt_studio_project_id + ).aggregate(max_order=models.Max("execution_order"))["max_order"] + + # Set to max + 1, or 1 if no existing links + self.execution_order = (max_order or 0) + 1 + + super().save(*args, **kwargs) + + def clean(self): + """Validate the link.""" + super().clean() + + # In production, we would validate that both projects + # belong to the same organization here + # For now, we'll just ensure the lookup project is ready + + if self.lookup_project and not self.lookup_project.is_ready: + # This is a warning, not an error - allow linking but warn + # In production, this might log a warning + pass + + @property + def is_enabled(self) -> bool: + """Check if this link is enabled and ready for execution. + + Returns: + True if the linked Look-Up project is ready for use. + """ + return self.lookup_project.is_ready if self.lookup_project else False diff --git a/backend/lookup/serializers.py b/backend/lookup/serializers.py new file mode 100644 index 0000000000..be0aff118f --- /dev/null +++ b/backend/lookup/serializers.py @@ -0,0 +1,333 @@ +"""Django REST Framework serializers for Look-Up API. + +This module provides serializers for all Look-Up models +to support RESTful API operations. +""" + +import logging +from typing import Any + +from adapter_processor_v2.adapter_processor import AdapterProcessor +from rest_framework import serializers + +from backend.serializers import AuditSerializer + +from .constants import LookupProfileManagerKeys +from .models import ( + LookupDataSource, + LookupExecutionAudit, + LookupProfileManager, + LookupProject, + LookupPromptTemplate, + PromptStudioLookupLink, +) + +logger = logging.getLogger(__name__) + + +class LookupPromptTemplateSerializer(serializers.ModelSerializer): + """Serializer for LookupPromptTemplate model.""" + + class Meta: + model = LookupPromptTemplate + fields = [ + "id", + "project", + "name", + "template_text", + "llm_config", + "is_active", + "created_by", + "created_at", + "modified_at", + ] + read_only_fields = ["id", "created_by", "created_at", "modified_at"] + + def validate_template_text(self, value: str) -> str: + """Validate that template text contains required placeholders.""" + if "{{reference_data}}" not in value: + raise serializers.ValidationError( + "Template must contain {{reference_data}} placeholder" + ) + return value + + def validate_llm_config(self, value: dict[str, Any]) -> dict[str, Any]: + """Validate LLM configuration structure.""" + # Accept either new format (adapter_id) or legacy format (provider + model) + has_adapter_id = "adapter_id" in value + has_legacy = "provider" in value and "model" in value + + if not has_adapter_id and not has_legacy: + raise serializers.ValidationError( + "llm_config must contain either 'adapter_id' (recommended) " + "or both 'provider' and 'model' fields" + ) + return value + + +class LookupDataSourceSerializer(serializers.ModelSerializer): + """Serializer for LookupDataSource model.""" + + extraction_status_display = serializers.CharField( + source="get_extraction_status_display", read_only=True + ) + + class Meta: + model = LookupDataSource + fields = [ + "id", + "project", + "file_name", + "file_path", + "file_size", + "file_type", + "extracted_content_path", + "extraction_status", + "extraction_status_display", + "extraction_error", + "version_number", + "is_latest", + "uploaded_by", + "created_at", + "modified_at", + ] + read_only_fields = [ + "id", + "version_number", + "is_latest", + "uploaded_by", + "created_at", + "modified_at", + "extraction_status_display", + ] + + +class LookupProjectSerializer(serializers.ModelSerializer): + """Serializer for LookupProject model.""" + + template = LookupPromptTemplateSerializer(read_only=True) + template_id = serializers.PrimaryKeyRelatedField( + queryset=LookupPromptTemplate.objects.all(), + source="template", + write_only=True, + allow_null=True, + required=False, + ) + data_source_count = serializers.SerializerMethodField() + latest_version = serializers.SerializerMethodField() + + class Meta: + model = LookupProject + fields = [ + "id", + "name", + "description", + "reference_data_type", + "template", + "template_id", + "is_active", + "data_source_count", + "latest_version", + "metadata", + "created_by", + "created_at", + "modified_at", + ] + read_only_fields = [ + "id", + "data_source_count", + "latest_version", + "created_by", + "created_at", + "modified_at", + ] + + def get_data_source_count(self, obj) -> int: + """Get count of data sources for this project.""" + return obj.data_sources.filter(is_latest=True).count() + + def get_latest_version(self, obj) -> int: + """Get latest version number of data sources.""" + latest = obj.data_sources.filter(is_latest=True).first() + return latest.version_number if latest else 0 + + +class PromptStudioLookupLinkSerializer(serializers.ModelSerializer): + """Serializer for linking Look-Ups to Prompt Studio projects.""" + + lookup_project_name = serializers.CharField( + source="lookup_project.name", read_only=True + ) + + class Meta: + model = PromptStudioLookupLink + fields = [ + "id", + "prompt_studio_project_id", + "lookup_project", + "lookup_project_name", + "created_at", + ] + read_only_fields = ["id", "lookup_project_name", "created_at"] + + def validate(self, attrs): + """Validate that the link doesn't already exist.""" + ps_project_id = attrs.get("prompt_studio_project_id") + lookup_project = attrs.get("lookup_project") + + # Check if this combination already exists + existing = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=ps_project_id, lookup_project=lookup_project + ).exists() + + if existing and not self.instance: # Only check for creation, not update + raise serializers.ValidationError( + "This Look-Up is already linked to the Prompt Studio project" + ) + + return attrs + + +class LookupExecutionAuditSerializer(serializers.ModelSerializer): + """Serializer for execution audit records.""" + + lookup_project_name = serializers.CharField( + source="lookup_project.name", read_only=True + ) + + class Meta: + model = LookupExecutionAudit + fields = [ + "id", + "lookup_project", + "lookup_project_name", + "prompt_studio_project_id", + "execution_id", + "input_data", + "enriched_output", + "reference_data_version", + "llm_provider", + "llm_model", + "llm_prompt", + "llm_response", + "llm_response_cached", + "execution_time_ms", + "llm_call_time_ms", + "status", + "error_message", + "confidence_score", + "executed_at", + ] + read_only_fields = fields # All fields are read-only for audit records + + +class LookupExecutionRequestSerializer(serializers.Serializer): + """Serializer for Look-Up execution requests.""" + + input_data = serializers.JSONField(help_text="Input data for variable resolution") + lookup_project_ids = serializers.ListField( + child=serializers.UUIDField(), + required=False, + help_text="List of Look-Up project IDs to execute", + ) + use_cache = serializers.BooleanField( + default=True, help_text="Whether to use cached LLM responses" + ) + timeout_seconds = serializers.IntegerField( + default=30, + min_value=1, + max_value=300, + help_text="Timeout for execution in seconds", + ) + + +class LookupExecutionResponseSerializer(serializers.Serializer): + """Serializer for Look-Up execution responses.""" + + lookup_enrichment = serializers.JSONField( + help_text="Merged enrichment data from all Look-Ups" + ) + _lookup_metadata = serializers.JSONField( + help_text="Execution metadata including timing and status" + ) + + +class ReferenceDataUploadSerializer(serializers.Serializer): + """Serializer for reference data upload requests.""" + + file = serializers.FileField(help_text="Reference data file to upload") + extract_text = serializers.BooleanField( + default=True, help_text="Whether to extract text from the file" + ) + metadata = serializers.JSONField( + required=False, default=dict, help_text="Additional metadata for the data source" + ) + + +class BulkLinkSerializer(serializers.Serializer): + """Serializer for bulk linking operations.""" + + prompt_studio_project_id = serializers.UUIDField(help_text="Prompt Studio project ID") + lookup_project_ids = serializers.ListField( + child=serializers.UUIDField(), help_text="List of Look-Up project IDs to link" + ) + unlink = serializers.BooleanField( + default=False, help_text="If true, unlink instead of link" + ) + + +class TemplateValidationSerializer(serializers.Serializer): + """Serializer for template validation requests.""" + + template_text = serializers.CharField(help_text="Template text to validate") + sample_data = serializers.JSONField( + required=False, help_text="Sample input data for testing variable resolution" + ) + sample_reference = serializers.CharField( + required=False, help_text="Sample reference data for testing" + ) + + +class LookupProfileManagerSerializer(AuditSerializer): + """Serializer for LookupProfileManager model. + + Follows the same pattern as Prompt Studio's ProfileManagerSerializer. + Expands adapter UUIDs to full adapter details in the response. + """ + + class Meta: + model = LookupProfileManager + fields = "__all__" + + def to_representation(self, instance): + """Expand adapter UUIDs to full adapter details. + + This converts the FK references to AdapterInstance objects + into full adapter details including adapter_name, adapter_type, etc. + """ + rep: dict[str, str] = super().to_representation(instance) + + # Expand each adapter FK to full details + llm = rep.get(LookupProfileManagerKeys.LLM) + embedding = rep.get(LookupProfileManagerKeys.EMBEDDING_MODEL) + vector_db = rep.get(LookupProfileManagerKeys.VECTOR_STORE) + x2text = rep.get(LookupProfileManagerKeys.X2TEXT) + + if llm: + rep[LookupProfileManagerKeys.LLM] = ( + AdapterProcessor.get_adapter_instance_by_id(llm) + ) + if embedding: + rep[LookupProfileManagerKeys.EMBEDDING_MODEL] = ( + AdapterProcessor.get_adapter_instance_by_id(embedding) + ) + if vector_db: + rep[LookupProfileManagerKeys.VECTOR_STORE] = ( + AdapterProcessor.get_adapter_instance_by_id(vector_db) + ) + if x2text: + rep[LookupProfileManagerKeys.X2TEXT] = ( + AdapterProcessor.get_adapter_instance_by_id(x2text) + ) + + return rep diff --git a/backend/lookup/services/__init__.py b/backend/lookup/services/__init__.py new file mode 100644 index 0000000000..1e7332ee1c --- /dev/null +++ b/backend/lookup/services/__init__.py @@ -0,0 +1,29 @@ +"""Service layer implementations for the Look-Up system. + +This package contains the core business logic and service classes +for the Static Data-based Look-Ups feature. +""" + +from .audit_logger import AuditLogger +from .document_indexing_service import LookupDocumentIndexingService +from .enrichment_merger import EnrichmentMerger +from .indexing_service import IndexingService +from .llm_cache import LLMResponseCache +from .lookup_executor import LookUpExecutor +from .lookup_index_helper import LookupIndexHelper +from .lookup_orchestrator import LookUpOrchestrator +from .reference_data_loader import ReferenceDataLoader +from .variable_resolver import VariableResolver + +__all__ = [ + "AuditLogger", + "LookupDocumentIndexingService", + "EnrichmentMerger", + "IndexingService", + "LookupIndexHelper", + "LLMResponseCache", + "LookUpExecutor", + "LookUpOrchestrator", + "ReferenceDataLoader", + "VariableResolver", +] diff --git a/backend/lookup/services/audit_logger.py b/backend/lookup/services/audit_logger.py new file mode 100644 index 0000000000..3259e4c1d6 --- /dev/null +++ b/backend/lookup/services/audit_logger.py @@ -0,0 +1,310 @@ +"""Audit Logger implementation for tracking Look-Up executions. + +This module provides functionality to log all Look-Up execution details +to the database for debugging, monitoring, and compliance purposes. +""" + +import logging +from decimal import Decimal +from typing import Any +from uuid import UUID + +from lookup.models import LookupExecutionAudit, LookupProject + +logger = logging.getLogger(__name__) + + +class AuditLogger: + """Logs Look-Up execution details to lookup_execution_audit table. + + This class provides methods to record all aspects of Look-Up executions + including inputs, outputs, performance metrics, and errors for audit + trail and debugging purposes. + """ + + def log_execution( + self, + execution_id: str, + lookup_project_id: UUID, + prompt_studio_project_id: UUID | None, + input_data: dict[str, Any], + reference_data_version: int, + llm_provider: str, + llm_model: str, + llm_prompt: str, + llm_response: str | None, + enriched_output: dict[str, Any] | None, + status: str, # 'success', 'partial', 'failed' + confidence_score: float | None = None, + execution_time_ms: int | None = None, + llm_call_time_ms: int | None = None, + llm_response_cached: bool = False, + error_message: str | None = None, + ) -> LookupExecutionAudit | None: + """Log execution to database. + + Records comprehensive details about a Look-Up execution including + the input data, LLM interaction, output, and performance metrics. + + Args: + execution_id: UUID of the orchestration execution + lookup_project_id: UUID of the Look-Up project + prompt_studio_project_id: Optional UUID of PS project + input_data: Input data used for variable resolution + reference_data_version: Version of reference data used + llm_provider: LLM provider name (e.g., 'openai') + llm_model: LLM model name (e.g., 'gpt-4') + llm_prompt: Final resolved prompt sent to LLM + llm_response: Raw response from LLM + enriched_output: Parsed enrichment data + status: Execution status ('success', 'partial', 'failed') + confidence_score: Optional confidence score (0.0-1.0) + execution_time_ms: Total execution time in milliseconds + llm_call_time_ms: Time spent calling LLM in milliseconds + llm_response_cached: Whether response was from cache + error_message: Error message if execution failed + + Returns: + Created LookupExecutionAudit instance or None if logging fails + + Example: + >>> logger = AuditLogger() + >>> audit = logger.log_execution( + ... execution_id='abc-123', + ... lookup_project_id=project_id, + ... status='success', + ... ... + ... ) + """ + try: + # Get the Look-Up project + try: + lookup_project = LookupProject.objects.get(id=lookup_project_id) + except LookupProject.DoesNotExist: + logger.error(f"Look-Up project {lookup_project_id} not found for audit") + return None + + # Convert confidence score to Decimal if provided + if confidence_score is not None: + confidence_score = Decimal(str(confidence_score)) + + # Create audit record + audit = LookupExecutionAudit.objects.create( + lookup_project=lookup_project, + prompt_studio_project_id=prompt_studio_project_id, + execution_id=execution_id, + input_data=input_data, + reference_data_version=reference_data_version, + enriched_output=enriched_output, + llm_provider=llm_provider, + llm_model=llm_model, + llm_prompt=llm_prompt, + llm_response=llm_response, + llm_response_cached=llm_response_cached, + execution_time_ms=execution_time_ms, + llm_call_time_ms=llm_call_time_ms, + status=status, + error_message=error_message, + confidence_score=confidence_score, + ) + + logger.debug( + f"Logged execution audit {audit.id} for Look-Up {lookup_project.name} " + f"(execution {execution_id})" + ) + + return audit + + except Exception as e: + # Log error but don't fail the execution + logger.exception( + f"Failed to log execution audit for {lookup_project_id}: {str(e)}" + ) + return None + + def log_success( + self, execution_id: str, project_id: UUID, **kwargs + ) -> LookupExecutionAudit | None: + """Convenience method for logging successful execution. + + Args: + execution_id: UUID of the orchestration execution + project_id: UUID of the Look-Up project + **kwargs: Additional parameters to pass to log_execution + + Returns: + Created audit record or None if logging fails + + Example: + >>> audit = logger.log_success( + ... execution_id="abc-123", + ... project_id=project_id, + ... input_data={"vendor": "Slack"}, + ... enriched_output={"canonical_vendor": "Slack"}, + ... confidence_score=0.92, + ... ) + """ + return self.log_execution( + execution_id=execution_id, + lookup_project_id=project_id, + status="success", + **kwargs, + ) + + def log_failure( + self, execution_id: str, project_id: UUID, error: str, **kwargs + ) -> LookupExecutionAudit | None: + """Convenience method for logging failed execution. + + Args: + execution_id: UUID of the orchestration execution + project_id: UUID of the Look-Up project + error: Error message describing the failure + **kwargs: Additional parameters to pass to log_execution + + Returns: + Created audit record or None if logging fails + + Example: + >>> audit = logger.log_failure( + ... execution_id="abc-123", + ... project_id=project_id, + ... error="LLM timeout after 30 seconds", + ... input_data={"vendor": "Slack"}, + ... ) + """ + return self.log_execution( + execution_id=execution_id, + lookup_project_id=project_id, + status="failed", + error_message=error, + **kwargs, + ) + + def log_partial( + self, execution_id: str, project_id: UUID, **kwargs + ) -> LookupExecutionAudit | None: + """Convenience method for logging partial success. + + Used when some enrichment was achieved but with issues + (e.g., low confidence, incomplete data). + + Args: + execution_id: UUID of the orchestration execution + project_id: UUID of the Look-Up project + **kwargs: Additional parameters to pass to log_execution + + Returns: + Created audit record or None if logging fails + + Example: + >>> audit = logger.log_partial( + ... execution_id="abc-123", + ... project_id=project_id, + ... input_data={"vendor": "Unknown Corp"}, + ... enriched_output={"canonical_vendor": "Unknown"}, + ... confidence_score=0.35, + ... error_message="Low confidence match", + ... ) + """ + return self.log_execution( + execution_id=execution_id, + lookup_project_id=project_id, + status="partial", + **kwargs, + ) + + def get_execution_history(self, execution_id: str, limit: int = 100) -> list: + """Retrieve audit records for a specific execution. + + Args: + execution_id: UUID of the orchestration execution + limit: Maximum number of records to return + + Returns: + List of LookupExecutionAudit instances + + Example: + >>> history = logger.get_execution_history("abc-123") + >>> for audit in history: + ... print(f"{audit.lookup_project.name}: {audit.status}") + """ + try: + return list( + LookupExecutionAudit.objects.filter(execution_id=execution_id) + .select_related("lookup_project") + .order_by("executed_at")[:limit] + ) + except Exception as e: + logger.error(f"Failed to retrieve execution history: {str(e)}") + return [] + + def get_project_stats(self, project_id: UUID, limit: int = 1000) -> dict[str, Any]: + """Get execution statistics for a Look-Up project. + + Args: + project_id: UUID of the Look-Up project + limit: Maximum number of records to analyze + + Returns: + Dictionary with statistics including success rate, + average execution time, cache hit rate, etc. + + Example: + >>> stats = logger.get_project_stats(project_id) + >>> print(f"Success rate: {stats['success_rate']:.1%}") + """ + try: + audits = LookupExecutionAudit.objects.filter( + lookup_project_id=project_id + ).order_by("-executed_at")[:limit] + + total = len(audits) + if total == 0: + return { + "total_executions": 0, + "success_rate": 0.0, + "avg_execution_time_ms": 0, + "cache_hit_rate": 0.0, + "avg_confidence": 0.0, + } + + successful = sum(1 for a in audits if a.status == "success") + cached = sum(1 for a in audits if a.llm_response_cached) + + exec_times = [ + a.execution_time_ms for a in audits if a.execution_time_ms is not None + ] + avg_exec_time = sum(exec_times) / len(exec_times) if exec_times else 0 + + confidence_scores = [ + float(a.confidence_score) + for a in audits + if a.confidence_score is not None + ] + avg_confidence = ( + sum(confidence_scores) / len(confidence_scores) + if confidence_scores + else 0.0 + ) + + return { + "total_executions": total, + "success_rate": successful / total if total > 0 else 0.0, + "avg_execution_time_ms": int(avg_exec_time), + "cache_hit_rate": cached / total if total > 0 else 0.0, + "avg_confidence": avg_confidence, + "successful": successful, + "failed": sum(1 for a in audits if a.status == "failed"), + "partial": sum(1 for a in audits if a.status == "partial"), + } + + except Exception as e: + logger.error(f"Failed to get project stats: {str(e)}") + return { + "total_executions": 0, + "success_rate": 0.0, + "avg_execution_time_ms": 0, + "cache_hit_rate": 0.0, + "avg_confidence": 0.0, + } diff --git a/backend/lookup/services/document_indexing_service.py b/backend/lookup/services/document_indexing_service.py new file mode 100644 index 0000000000..5da68e8542 --- /dev/null +++ b/backend/lookup/services/document_indexing_service.py @@ -0,0 +1,141 @@ +"""Document Indexing Service for Lookup projects. + +This service manages indexing status tracking using cache to prevent +duplicate indexing operations and track in-progress indexing. + +Based on Prompt Studio's DocumentIndexingService pattern. +""" + +import logging + +from django.core.cache import cache + +logger = logging.getLogger(__name__) + + +class LookupDocumentIndexingService: + """Cache-based service to track document indexing status. + + Prevents duplicate indexing and tracks in-progress operations. + """ + + # Cache key format: lookup_indexing:{org_id}:{user_id}:{doc_id_key} + CACHE_KEY_PREFIX = "lookup_indexing" + CACHE_TIMEOUT = 3600 # 1 hour + + # Status values + STATUS_INDEXING = "INDEXING" # Currently being indexed + + @staticmethod + def _get_cache_key(org_id: str, user_id: str, doc_id_key: str) -> str: + """Generate cache key for document indexing status. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key (hash of indexing parameters) + + Returns: + Cache key string + """ + return f"{LookupDocumentIndexingService.CACHE_KEY_PREFIX}:{org_id}:{user_id}:{doc_id_key}" + + @staticmethod + def is_document_indexing(org_id: str, user_id: str, doc_id_key: str) -> bool: + """Check if document is currently being indexed. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + + Returns: + True if document is being indexed, False otherwise + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + status = cache.get(cache_key) + + if status == LookupDocumentIndexingService.STATUS_INDEXING: + logger.debug(f"Document {doc_id_key} is currently being indexed") + return True + + return False + + @staticmethod + def set_document_indexing(org_id: str, user_id: str, doc_id_key: str) -> None: + """Mark document as being indexed. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + cache.set( + cache_key, + LookupDocumentIndexingService.STATUS_INDEXING, + LookupDocumentIndexingService.CACHE_TIMEOUT, + ) + logger.debug(f"Marked document {doc_id_key} as being indexed") + + @staticmethod + def mark_document_indexed( + org_id: str, user_id: str, doc_id_key: str, doc_id: str + ) -> None: + """Mark document as indexed with final doc_id. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + doc_id: Final document ID from indexing service + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + # Store the final doc_id instead of status + cache.set(cache_key, doc_id, LookupDocumentIndexingService.CACHE_TIMEOUT) + logger.debug(f"Marked document {doc_id_key} as indexed with ID {doc_id}") + + @staticmethod + def get_indexed_document_id(org_id: str, user_id: str, doc_id_key: str) -> str | None: + """Get indexed document ID if already indexed. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + + Returns: + Document ID if indexed, None otherwise + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + cached_value = cache.get(cache_key) + + # Return doc_id only if it's not the "INDEXING" status + if cached_value and cached_value != LookupDocumentIndexingService.STATUS_INDEXING: + logger.debug(f"Document {doc_id_key} already indexed with ID {cached_value}") + return cached_value + + return None + + @staticmethod + def clear_indexing_status(org_id: str, user_id: str, doc_id_key: str) -> None: + """Clear indexing status from cache (e.g., on error). + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + cache.delete(cache_key) + logger.debug(f"Cleared indexing status for document {doc_id_key}") diff --git a/backend/lookup/services/enrichment_merger.py b/backend/lookup/services/enrichment_merger.py new file mode 100644 index 0000000000..47da11b079 --- /dev/null +++ b/backend/lookup/services/enrichment_merger.py @@ -0,0 +1,174 @@ +"""Enrichment Merger implementation for combining multiple Look-Up results. + +This module provides functionality to merge enrichment data from multiple +Look-Up executions with confidence-based conflict resolution. +""" + +from typing import Any + + +class EnrichmentMerger: + """Merges enrichments from multiple Look-Ups with conflict resolution. + + When multiple Look-Ups run on the same input data, they may produce + overlapping enrichment fields. This class handles merging those results + and resolving conflicts based on confidence scores. + """ + + def merge(self, enrichments: list[dict[str, Any]]) -> dict[str, Any]: + """Merge enrichments with conflict resolution. + + Combines enrichment data from multiple Look-Up executions, + resolving conflicts based on confidence scores when the same + field appears in multiple enrichments. + + Args: + enrichments: List of dicts with structure: + { + 'project_id': UUID, + 'project_name': str, + 'data': Dict[str, Any], # Enrichment fields + 'confidence': Optional[float], # 0.0-1.0 + 'execution_time_ms': int, + 'cached': bool + } + + Returns: + Dictionary containing: + - data: Merged enrichment data + - conflicts_resolved: Number of conflicts resolved + - enrichment_details: Per-lookup metadata + + Example: + >>> merger = EnrichmentMerger() + >>> enrichments = [ + ... { + ... "project_id": uuid1, + ... "project_name": "Vendor Matcher", + ... "data": {"vendor": "Slack", "category": "SaaS"}, + ... "confidence": 0.95, + ... "execution_time_ms": 1234, + ... "cached": False, + ... }, + ... { + ... "project_id": uuid2, + ... "project_name": "Product Classifier", + ... "data": {"category": "Communication", "type": "Software"}, + ... "confidence": 0.80, + ... "execution_time_ms": 567, + ... "cached": True, + ... }, + ... ] + >>> result = merger.merge(enrichments) + >>> print(result["data"]) + {'vendor': 'Slack', 'category': 'SaaS', 'type': 'Software'} + >>> print(result["conflicts_resolved"]) + 1 + """ + merged_data = {} + field_sources = {} # Track which lookup contributed each field + conflicts_resolved = 0 + enrichment_details = [] + + # Process each enrichment + for enrichment in enrichments: + project_id = enrichment.get("project_id") + project_name = enrichment.get("project_name", "Unknown") + data = enrichment.get("data", {}) + confidence = enrichment.get("confidence") + execution_time_ms = enrichment.get("execution_time_ms", 0) + cached = enrichment.get("cached", False) + + fields_added = [] + + # Process each field in the enrichment + for field_name, field_value in data.items(): + field_entry = { + "lookup_id": project_id, + "lookup_name": project_name, + "confidence": confidence, + "value": field_value, + } + + if field_name not in field_sources: + # No conflict, add the field + field_sources[field_name] = field_entry + merged_data[field_name] = field_value + fields_added.append(field_name) + else: + # Conflict! Resolve it + existing = field_sources[field_name] + winner = self._resolve_conflict(field_name, existing, field_entry) + + # Check if resolution changed the winner + if winner["lookup_id"] != existing["lookup_id"]: + # New winner, update the merged data + field_sources[field_name] = winner + merged_data[field_name] = winner["value"] + fields_added.append(field_name) + conflicts_resolved += 1 + elif winner["lookup_id"] == field_entry["lookup_id"]: + # Current enrichment won but was not originally there + conflicts_resolved += 1 + + # Track enrichment details + enrichment_details.append( + { + "lookup_project_id": str(project_id) if project_id else None, + "lookup_project_name": project_name, + "confidence": confidence, + "cached": cached, + "execution_time_ms": execution_time_ms, + "fields_added": fields_added, + } + ) + + return { + "data": merged_data, + "conflicts_resolved": conflicts_resolved, + "enrichment_details": enrichment_details, + } + + def _resolve_conflict(self, field_name: str, existing: dict, new: dict) -> dict: + """Resolve conflict for a single field. + + Uses confidence scores to determine which value to keep. + When confidence scores are equal or absent, uses first-complete-wins + strategy (keeps the existing value). + + Args: + field_name: Name of the field with conflict (for context) + existing: Dict with {lookup_id, lookup_name, confidence, value} + new: Dict with {lookup_id, lookup_name, confidence, value} + + Returns: + Winner dict with same structure + + Resolution rules: + 1. Both have confidence: higher confidence wins + 2. Equal confidence: first-complete wins (existing) + 3. One has confidence: confidence one wins + 4. Neither has confidence: first-complete wins (existing) + """ + existing_confidence = existing.get("confidence") + new_confidence = new.get("confidence") + + # Case 1: Both have confidence scores + if existing_confidence is not None and new_confidence is not None: + if new_confidence > existing_confidence: + return new + else: + # Equal or existing is higher: keep existing (first-complete-wins) + return existing + + # Case 2: Only new has confidence + elif existing_confidence is None and new_confidence is not None: + return new + + # Case 3: Only existing has confidence + elif existing_confidence is not None and new_confidence is None: + return existing + + # Case 4: Neither has confidence - first-complete-wins + else: + return existing diff --git a/backend/lookup/services/indexing_service.py b/backend/lookup/services/indexing_service.py new file mode 100644 index 0000000000..6dd427f5c5 --- /dev/null +++ b/backend/lookup/services/indexing_service.py @@ -0,0 +1,471 @@ +"""Service for indexing reference data using configured profiles. + +This service implements the actual indexing workflow by calling external +extraction and indexing services via the PromptTool SDK, following the +same pattern as Prompt Studio's indexing implementation. +""" + +import json +import logging +import os +from typing import Any + +from django.conf import settings +from prompt_studio.prompt_studio_core_v2.prompt_ide_base_tool import ( + PromptIdeBaseTool, +) +from utils.file_storage.constants import FileStorageKeys +from utils.user_context import UserContext + +from lookup.models import LookupDataSource, LookupProfileManager +from unstract.sdk1.constants import LogLevel +from unstract.sdk1.exceptions import SdkError +from unstract.sdk1.file_storage.constants import StorageType +from unstract.sdk1.file_storage.env_helper import EnvHelper +from unstract.sdk1.prompt import PromptTool +from unstract.sdk1.utils.indexing import IndexingUtils +from unstract.sdk1.utils.tool import ToolUtils + +from .document_indexing_service import LookupDocumentIndexingService +from .lookup_index_helper import LookupIndexHelper + +logger = logging.getLogger(__name__) + + +class IndexingService: + """Service to orchestrate indexing of reference data. + + Uses PromptTool SDK to call external extraction and indexing services, + similar to Prompt Studio's implementation but adapted for Lookup projects. + """ + + def __init__(self, profile: LookupProfileManager): + """Initialize indexing service with profile configuration. + + Args: + profile: LookupProfileManager instance with adapter configuration + """ + self.profile = profile + self.chunk_size = profile.chunk_size + self.chunk_overlap = profile.chunk_overlap + self.similarity_top_k = profile.similarity_top_k + + # Adapters from profile + self.llm = profile.llm + self.embedding_model = profile.embedding_model + self.vector_store = profile.vector_store + self.x2text = profile.x2text + + @staticmethod + def extract_text( + data_source: LookupDataSource, + profile: LookupProfileManager, + org_id: str, + run_id: str = None, + ) -> str: + """Extract text from data source using X2Text adapter via external service. + + Args: + data_source: LookupDataSource instance to extract + profile: LookupProfileManager with X2Text adapter configuration + org_id: Organization ID + run_id: Optional run ID for tracking + + Returns: + Extracted text content + + Raises: + SdkError: If extraction service fails + """ + # Generate X2Text config hash for tracking + metadata = profile.x2text.metadata or {} + x2text_config_hash = ToolUtils.hash_str(json.dumps(metadata, sort_keys=True)) + + # Check if already extracted + is_extracted = LookupIndexHelper.check_extraction_status( + data_source_id=str(data_source.id), + profile_manager=profile, + x2text_config_hash=x2text_config_hash, + enable_highlight=False, # Lookup doesn't need highlighting + ) + + # Get file storage instance + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + logger.info(f"File storage instance type: {type(fs_instance)}") + logger.info( + f"File storage config: {fs_instance.fs if hasattr(fs_instance, 'fs') else 'N/A'}" + ) + + # Construct file paths + file_path = data_source.file_path + logger.info(f"Data source file_path from DB: {file_path}") + logger.info( + f"Storage type: {StorageType.PERMANENT}, env: {FileStorageKeys.PERMANENT_REMOTE_STORAGE}" + ) + + directory, filename = os.path.split(file_path) + extract_file_path = os.path.join( + directory, "extract", os.path.splitext(filename)[0] + ".txt" + ) + + logger.info(f"Constructed paths - directory: {directory}, filename: {filename}") + logger.info(f"Extract file path: {extract_file_path}") + + if is_extracted: + try: + extracted_text = fs_instance.read(path=extract_file_path, mode="r") + logger.info(f"Extracted text found for {filename}, reading from file") + return extracted_text + except FileNotFoundError as e: + logger.warning( + f"File not found for extraction: {extract_file_path}. {e}. " + "Continuing with extraction..." + ) + + # Call extraction service via PromptTool SDK + usage_kwargs = {"run_id": run_id, "file_name": filename} + payload = { + "x2text_instance_id": str(profile.x2text.id), + "file_path": file_path, + "enable_highlight": False, + "usage_kwargs": usage_kwargs.copy(), + "run_id": run_id, + "execution_source": "ide", + "output_file_path": extract_file_path, + } + + logger.info( + f"Extraction payload: x2text_id={payload['x2text_instance_id']}, " + f"file_path={payload['file_path']}, " + f"execution_source={payload['execution_source']}, " + f"output_file_path={payload['output_file_path']}" + ) + + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + + try: + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + request_id=None, + ) + extracted_text = responder.extract(payload=payload) + + # Mark extraction success in IndexManager + success = LookupIndexHelper.mark_extraction_status( + data_source_id=str(data_source.id), + profile_manager=profile, + x2text_config_hash=x2text_config_hash, + enable_highlight=False, + ) + + if not success: + logger.warning( + f"Failed to mark extraction success for data source {data_source.id}. " + "Extraction completed but status not saved." + ) + + # Update the data source extraction_status field for UI display + data_source.extraction_status = "completed" + data_source.save(update_fields=["extraction_status"]) + + logger.info(f"Successfully extracted text from {filename}") + return extracted_text + + except SdkError as e: + msg = str(e) + if e.actual_err and hasattr(e.actual_err, "response"): + msg = e.actual_err.response.json().get("error", str(e)) + + # Mark extraction failure in IndexManager + LookupIndexHelper.mark_extraction_status( + data_source_id=str(data_source.id), + profile_manager=profile, + x2text_config_hash=x2text_config_hash, + enable_highlight=False, + extracted=False, + error_message=msg, + ) + + # Update the data source extraction_status field for UI display + data_source.extraction_status = "failed" + data_source.extraction_error = msg + data_source.save(update_fields=["extraction_status", "extraction_error"]) + + raise Exception(f"Failed to extract '{filename}': {msg}") from e + + @staticmethod + def index_data_source( + data_source: LookupDataSource, + profile: LookupProfileManager, + org_id: str, + user_id: str, + extracted_text: str, + run_id: str = None, + reindex: bool = True, + ) -> str: + """Index extracted text using profile's adapters via external indexing service. + + Args: + data_source: LookupDataSource instance + profile: LookupProfileManager with adapter configuration + org_id: Organization ID + user_id: User ID + extracted_text: Pre-extracted text content + run_id: Optional run ID for tracking + reindex: Whether to reindex if already indexed + + Returns: + Document ID from indexing service + + Raises: + SdkError: If indexing service fails + """ + # Skip indexing if chunk_size is 0 (full context mode) + if profile.chunk_size == 0: + # Generate doc_id for tracking + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + + doc_id = IndexingUtils.generate_index_key( + vector_db=str(profile.vector_store.id), + embedding=str(profile.embedding_model.id), + x2text=str(profile.x2text.id), + chunk_size=str(profile.chunk_size), + chunk_overlap=str(profile.chunk_overlap), + file_path=data_source.file_path, + file_hash=None, + fs=fs_instance, + tool=util, + ) + + # Update index manager without actual indexing + LookupIndexHelper.handle_index_manager( + data_source_id=str(data_source.id), + profile_manager=profile, + doc_id=doc_id, + ) + + logger.info("Skipping vector DB indexing since chunk size is 0") + return doc_id + + # Get adapter IDs + embedding_model = str(profile.embedding_model.id) + vector_db = str(profile.vector_store.id) + x2text_adapter = str(profile.x2text.id) + + # Construct file paths + directory, filename = os.path.split(data_source.file_path) + file_path = os.path.join( + directory, "extract", os.path.splitext(filename)[0] + ".txt" + ) + + # Generate index key + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + + doc_id_key = IndexingUtils.generate_index_key( + vector_db=vector_db, + embedding=embedding_model, + x2text=x2text_adapter, + chunk_size=str(profile.chunk_size), + chunk_overlap=str(profile.chunk_overlap), + file_path=data_source.file_path, + file_hash=None, + fs=fs_instance, + tool=util, + ) + + try: + usage_kwargs = {"run_id": run_id, "file_name": filename} + + # Check if already indexed (unless reindexing) + if not reindex: + indexed_doc_id = LookupDocumentIndexingService.get_indexed_document_id( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + if indexed_doc_id: + logger.info(f"Document {filename} already indexed: {indexed_doc_id}") + return indexed_doc_id + + # Check if currently being indexed + if LookupDocumentIndexingService.is_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ): + raise Exception(f"Document {filename} is currently being indexed") + + # Mark as being indexed + LookupDocumentIndexingService.set_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + + logger.info(f"Invoking indexing service for: {doc_id_key}") + + # Build payload for indexing service + payload = { + "tool_id": str(data_source.project.id), # Use project ID as tool ID + "embedding_instance_id": embedding_model, + "vector_db_instance_id": vector_db, + "x2text_instance_id": x2text_adapter, + "file_path": file_path, + "file_hash": None, + "chunk_overlap": profile.chunk_overlap, + "chunk_size": profile.chunk_size, + "reindex": reindex, + "enable_highlight": False, + "usage_kwargs": usage_kwargs.copy(), + "extracted_text": extracted_text, + "run_id": run_id, + "execution_source": "ide", + } + + # Call indexing service via PromptTool SDK + try: + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + request_id=None, + ) + doc_id = responder.index(payload=payload) + + # Update index manager with doc_id + LookupIndexHelper.handle_index_manager( + data_source_id=str(data_source.id), + profile_manager=profile, + doc_id=doc_id, + ) + + # Mark as indexed in cache + LookupDocumentIndexingService.mark_document_indexed( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key, doc_id=doc_id + ) + + logger.info(f"Successfully indexed {filename} with doc_id: {doc_id}") + return doc_id + + except SdkError as e: + msg = str(e) + if e.actual_err and hasattr(e.actual_err, "response"): + msg = e.actual_err.response.json().get("error", str(e)) + raise Exception(f"Failed to index '{filename}': {msg}") from e + + except Exception as e: + logger.error(f"Error indexing {filename}: {e}", exc_info=True) + # Clear indexing status on error + LookupDocumentIndexingService.clear_indexing_status( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + raise + + @classmethod + def index_with_default_profile( + cls, project_id: str, org_id: str = None, user_id: str = None + ) -> dict[str, Any]: + """Index all completed data sources using the project's default profile. + + Args: + project_id: UUID of the lookup project + org_id: Organization ID (if None, gets from UserContext) + user_id: User ID (if None, gets from UserContext) + + Returns: + Dict with indexing results summary + + Raises: + DefaultProfileError: If no default profile exists for project + ValueError: If project not found + """ + from lookup.models import LookupProject + + # Get context if not provided + if org_id is None: + org_id = UserContext.get_organization_identifier() + if user_id is None: + user_id = UserContext.get_user_id() + + try: + project = LookupProject.objects.get(id=project_id) + except LookupProject.DoesNotExist: + raise ValueError(f"Project {project_id} not found") + + # Get default profile + default_profile = LookupProfileManager.get_default_profile(project) + + # Get all data sources (extraction will be done as part of indexing) + data_sources = LookupDataSource.objects.filter(project_id=project_id).order_by( + "-version_number" + ) + + logger.info(f"Found {data_sources.count()} data sources for project {project_id}") + + # Log each data source status + for ds in data_sources: + logger.info(f" - {ds.file_name}: extraction_status={ds.extraction_status}") + + results = { + "total": data_sources.count(), + "success": 0, + "failed": 0, + "errors": [], + } + + for data_source in data_sources: + try: + logger.info( + f"Indexing data source {data_source.id}: {data_source.file_name}" + ) + + # Extract text + extracted_text = cls.extract_text( + data_source=data_source, + profile=default_profile, + org_id=org_id, + run_id=None, + ) + + # Index the extracted text + doc_id = cls.index_data_source( + data_source=data_source, + profile=default_profile, + org_id=org_id, + user_id=user_id, + extracted_text=extracted_text, + run_id=None, + reindex=True, + ) + + results["success"] += 1 + logger.info( + f"Successfully indexed {data_source.file_name} with doc_id: {doc_id}" + ) + + except Exception as e: + results["failed"] += 1 + error_msg = str(e) + results["errors"].append( + { + "data_source_id": str(data_source.id), + "file_name": data_source.file_name, + "error": error_msg, + } + ) + logger.error(f"Failed to index {data_source.file_name}: {error_msg}") + + logger.info( + f"Indexing complete for project {project_id}: " + f"{results['success']} successful, {results['failed']} failed" + ) + + return results diff --git a/backend/lookup/services/llm_cache.py b/backend/lookup/services/llm_cache.py new file mode 100644 index 0000000000..e7e46bf0bd --- /dev/null +++ b/backend/lookup/services/llm_cache.py @@ -0,0 +1,186 @@ +"""LLM Response Cache implementation for caching LLM API responses. + +This module provides an in-memory cache with TTL (Time-To-Live) support +for storing and retrieving LLM responses based on prompt and reference data. +""" + +import hashlib +import time + + +class LLMResponseCache: + """In-memory cache for LLM responses with TTL expiration. + + This cache is used to store LLM API responses to avoid redundant + API calls for identical prompts and reference data combinations. + Uses SHA256 hashing for cache key generation and supports TTL-based + expiration for automatic cleanup. + """ + + def __init__(self, ttl_hours: int = 24): + """Initialize the LLM response cache. + + Args: + ttl_hours: Time-to-live in hours (default 24). + Cached entries expire after this duration. + """ + self.cache: dict[str, tuple[str, float]] = {} + self.ttl_seconds = ttl_hours * 3600 + + def get(self, key: str) -> str | None: + """Get cached response if not expired. + + Performs lazy cleanup by removing expired entries when accessed. + + Args: + key: Cache key (generated by generate_cache_key) + + Returns: + Cached response string if valid and not expired, None otherwise + + Example: + >>> cache = LLMResponseCache(ttl_hours=1) + >>> cache.set("key123", "response") + >>> cache.get("key123") + 'response' + """ + if key not in self.cache: + return None + + response, expiry = self.cache[key] + current_time = time.time() + + if current_time >= expiry: + # Entry has expired, remove it (lazy cleanup) + del self.cache[key] + return None + + return response + + def set(self, key: str, response: str) -> None: + """Cache response with TTL. + + Args: + key: Cache key (generated by generate_cache_key) + response: LLM response to cache + + Example: + >>> cache = LLMResponseCache() + >>> key = cache.generate_cache_key("prompt", "ref_data") + >>> cache.set(key, "LLM response") + """ + expiry = time.time() + self.ttl_seconds + self.cache[key] = (response, expiry) + + def invalidate(self, key: str) -> bool: + """Remove specific key from cache. + + Args: + key: Cache key to invalidate + + Returns: + True if key was removed, False if key didn't exist + + Example: + >>> cache = LLMResponseCache() + >>> cache.set("key123", "response") + >>> cache.invalidate("key123") + True + """ + if key in self.cache: + del self.cache[key] + return True + return False + + def invalidate_all(self) -> int: + """Clear entire cache. + + Returns: + Count of invalidated entries + + Example: + >>> cache = LLMResponseCache() + >>> cache.set("key1", "response1") + >>> cache.set("key2", "response2") + >>> cache.invalidate_all() + 2 + """ + count = len(self.cache) + self.cache.clear() + return count + + def generate_cache_key(self, prompt: str, reference_data: str) -> str: + r"""Generate SHA256 hash from prompt + reference data. + + Creates a deterministic cache key based on the prompt and + reference data combination. Same inputs always produce the + same key. + + Args: + prompt: The resolved prompt text + reference_data: The reference data text + + Returns: + 64-character hexadecimal SHA256 hash + + Example: + >>> cache = LLMResponseCache() + >>> key = cache.generate_cache_key("Match vendor", "Slack\nGoogle") + >>> len(key) + 64 + """ + combined = f"{prompt}{reference_data}" + hash_obj = hashlib.sha256(combined.encode("utf-8")) + return hash_obj.hexdigest() + + def get_stats(self) -> dict[str, int]: + """Return cache statistics. + + Analyzes the current cache state and returns counts of + total, expired, and valid entries. + + Returns: + Dictionary with keys: 'total', 'expired', 'valid' + + Example: + >>> cache = LLMResponseCache() + >>> cache.set("key1", "response1") + >>> stats = cache.get_stats() + >>> stats["total"] + 1 + """ + current_time = time.time() + total = len(self.cache) + expired = sum( + 1 for _, (_, expiry) in self.cache.items() if current_time >= expiry + ) + valid = total - expired + + return {"total": total, "expired": expired, "valid": valid} + + def cleanup_expired(self) -> int: + """Remove expired entries. + + Performs a full scan of the cache and removes all entries + that have exceeded their TTL. Can be called periodically + for cache maintenance. + + Returns: + Count of removed entries + + Example: + >>> cache = LLMResponseCache(ttl_hours=0) # Immediate expiry + >>> cache.set("key1", "response1") + >>> time.sleep(0.1) + >>> cache.cleanup_expired() + 1 + """ + current_time = time.time() + keys_to_remove = [ + key for key, (_, expiry) in self.cache.items() if current_time >= expiry + ] + + for key in keys_to_remove: + del self.cache[key] + + return len(keys_to_remove) diff --git a/backend/lookup/services/lookup_executor.py b/backend/lookup/services/lookup_executor.py new file mode 100644 index 0000000000..ebb70a9303 --- /dev/null +++ b/backend/lookup/services/lookup_executor.py @@ -0,0 +1,285 @@ +"""Look-Up Executor implementation for single Look-Up execution. + +This module provides functionality to execute a single Look-Up project +against input data, including variable resolution, LLM calling, and +response caching. +""" + +import json +import logging +import time +from typing import Any, Protocol + +from lookup.exceptions import ( + ExtractionNotCompleteError, + ParseError, + TemplateNotFoundError, +) +from lookup.models import LookupProject + +logger = logging.getLogger(__name__) + + +class LLMClient(Protocol): + """Protocol for LLM client abstraction.""" + + def generate(self, prompt: str, config: dict[str, Any]) -> str: + """Generate LLM response for the prompt.""" + ... + + +class LookUpExecutor: + """Executes a single Look-Up project against input data. + + This class handles the complete execution flow of a Look-Up: + loading reference data, resolving variables in the prompt template, + calling the LLM, caching responses, and parsing the results. + """ + + def __init__( + self, + variable_resolver, # Class, not instance + cache_manager, + reference_loader, + llm_client: LLMClient, + ): + """Initialize the Look-Up executor. + + Args: + variable_resolver: VariableResolver class (not instance) + cache_manager: LLMResponseCache instance + reference_loader: ReferenceDataLoader instance + llm_client: LLM provider client implementing LLMClient protocol + """ + self.variable_resolver_class = variable_resolver + self.cache = cache_manager + self.ref_loader = reference_loader + self.llm_client = llm_client + + def execute( + self, lookup_project: LookupProject, input_data: dict[str, Any] + ) -> dict[str, Any]: + """Execute single Look-Up. + + Performs the complete Look-Up execution including variable resolution, + LLM calling with caching, and response parsing. + + Args: + lookup_project: The Look-Up project to execute + input_data: Input data containing variables to resolve + + Returns: + Dictionary containing: + - status: 'success' or 'failed' + - project_id: UUID of the project + - project_name: Name of the project + - data: Enrichment data (if success) + - confidence: Confidence score 0.0-1.0 (if available) + - cached: Whether response was from cache + - execution_time_ms: Time taken in milliseconds + - error: Error message (if failed) + + Example: + >>> executor = LookUpExecutor(...) + >>> result = executor.execute(project, {"vendor": "Slack"}) + >>> if result["status"] == "success": + ... print(result["data"]) + {'canonical_vendor': 'Slack Technologies', 'confidence': 0.92} + """ + start_time = time.time() + + try: + # Step 1: Load reference data + try: + reference_data_dict = self.ref_loader.load_latest_for_project( + lookup_project.id + ) + reference_data = reference_data_dict["content"] + except ExtractionNotCompleteError as e: + return self._failed_result( + lookup_project, + f"Reference data extraction not complete: {str(e)}", + start_time, + ) + except Exception as e: + return self._failed_result( + lookup_project, f"Failed to load reference data: {str(e)}", start_time + ) + + # Step 2: Load prompt template + try: + template = lookup_project.template + if not template: + raise TemplateNotFoundError("No template configured") + template_text = template.template_text + except (AttributeError, TemplateNotFoundError) as e: + return self._failed_result( + lookup_project, f"Missing prompt template: {str(e)}", start_time + ) + + # Step 3: Resolve variables + logger.info(f"Input data received: {input_data}") + logger.info(f"Reference data length: {len(reference_data)} chars") + logger.info(f"Template text: {template_text[:200]}...") + resolver = self.variable_resolver_class(input_data, reference_data) + resolved_prompt = resolver.resolve(template_text) + logger.info(f"Resolved prompt: {resolved_prompt[:500]}...") + + # Step 4: Check cache (if caching is enabled) + cache_key = None + cached_response = None + if self.cache: + cache_key = self.cache.generate_cache_key(resolved_prompt, reference_data) + cached_response = self.cache.get(cache_key) + + if cached_response: + # Cache hit - parse and return + enrichment_data, confidence = self._parse_llm_response( + cached_response + ) + return self._success_result( + lookup_project, + enrichment_data, + confidence, + cached=True, + execution_time_ms=0, # No execution time for cached response + ) + + # Step 5: Call LLM (cache miss or caching disabled) + try: + llm_start = time.time() + llm_response = self.llm_client.generate( + resolved_prompt, lookup_project.llm_config or {} + ) + llm_time_ms = int((time.time() - llm_start) * 1000) + + # Store in cache (if caching is enabled) + if self.cache and cache_key: + self.cache.set(cache_key, llm_response) + + except TimeoutError as e: + return self._failed_result( + lookup_project, f"LLM request timed out: {str(e)}", start_time + ) + except Exception as e: + return self._failed_result( + lookup_project, f"LLM request failed: {str(e)}", start_time + ) + + # Step 6: Parse response + try: + enrichment_data, confidence = self._parse_llm_response(llm_response) + except ParseError as e: + return self._failed_result( + lookup_project, f"Failed to parse LLM response: {str(e)}", start_time + ) + + # Step 7: Return result + return self._success_result( + lookup_project, + enrichment_data, + confidence, + cached=False, + execution_time_ms=llm_time_ms, + ) + + except Exception as e: + # Catch-all for unexpected errors + logger.exception(f"Unexpected error executing Look-Up {lookup_project.id}") + return self._failed_result( + lookup_project, f"Unexpected error: {str(e)}", start_time + ) + + def _parse_llm_response(self, response_text: str) -> tuple[dict, float | None]: + """Parse LLM response to extract enrichment data. + + Attempts to parse the LLM response as JSON and extract + enrichment fields and optional confidence score. + + Args: + response_text: Raw text response from the LLM + + Returns: + Tuple of (enrichment_data, confidence) + - enrichment_data: Dictionary of extracted fields + - confidence: Optional confidence score (0.0-1.0) + + Raises: + ParseError: If response cannot be parsed as valid JSON + + Example: + >>> response = '{"vendor": "Slack", "confidence": 0.92}' + >>> data, conf = executor._parse_llm_response(response) + >>> print(data) + {'vendor': 'Slack'} + >>> print(conf) + 0.92 + """ + try: + # Try direct JSON parse + parsed = json.loads(response_text) + + if not isinstance(parsed, dict): + raise ParseError(f"Expected JSON object, got {type(parsed).__name__}") + + # Extract confidence if present + confidence = None + if "confidence" in parsed: + confidence = parsed.pop("confidence") + + # Validate confidence is a number between 0 and 1 + if isinstance(confidence, (int, float)): + confidence = float(confidence) + if not 0.0 <= confidence <= 1.0: + logger.warning( + f"Confidence {confidence} outside valid range [0.0, 1.0]" + ) + confidence = max(0.0, min(1.0, confidence)) # Clamp to range + else: + logger.warning(f"Invalid confidence type: {type(confidence)}") + confidence = None + + # Remaining fields are the enrichment data + enrichment_data = parsed + + return enrichment_data, confidence + + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse LLM response as JSON: {e}") + raise ParseError(f"Invalid JSON response: {str(e)}") + except Exception as e: + logger.warning(f"Unexpected error parsing LLM response: {e}") + raise ParseError(f"Parse error: {str(e)}") + + def _success_result( + self, + project: LookupProject, + data: dict[str, Any], + confidence: float | None, + cached: bool, + execution_time_ms: int, + ) -> dict[str, Any]: + """Build success result dictionary.""" + return { + "status": "success", + "project_id": str(project.id), # Convert UUID to string for JSON + "project_name": project.name, + "data": data, + "confidence": confidence, + "cached": cached, + "execution_time_ms": execution_time_ms, + } + + def _failed_result( + self, project: LookupProject, error: str, start_time: float + ) -> dict[str, Any]: + """Build failed result dictionary.""" + execution_time_ms = int((time.time() - start_time) * 1000) + return { + "status": "failed", + "project_id": str(project.id), # Convert UUID to string for JSON + "project_name": project.name, + "error": error, + "execution_time_ms": execution_time_ms, + "cached": False, + } diff --git a/backend/lookup/services/lookup_index_helper.py b/backend/lookup/services/lookup_index_helper.py new file mode 100644 index 0000000000..d22a7d86b3 --- /dev/null +++ b/backend/lookup/services/lookup_index_helper.py @@ -0,0 +1,190 @@ +"""Helper service for Lookup IndexManager operations. + +Based on Prompt Studio's PromptStudioIndexHelper pattern. +""" + +import logging + +from django.db import transaction + +from lookup.models import LookupIndexManager + +logger = logging.getLogger(__name__) + + +class LookupIndexHelper: + """Helper class for LookupIndexManager CRUD operations.""" + + @staticmethod + @transaction.atomic + def handle_index_manager( + data_source_id: str, + profile_manager, + doc_id: str, + ) -> LookupIndexManager: + """Create or update LookupIndexManager with doc_id. + + Args: + data_source_id: UUID of the LookupDataSource + profile_manager: LookupProfileManager instance + doc_id: Document ID returned from indexing service + + Returns: + LookupIndexManager instance + """ + from lookup.models import LookupDataSource + + try: + data_source = LookupDataSource.objects.get(pk=data_source_id) + + # Get or create index manager for this data source + profile combination + index_manager, created = LookupIndexManager.objects.get_or_create( + data_source=data_source, + profile_manager=profile_manager, + defaults={ + "raw_index_id": doc_id, + "index_ids_history": [doc_id], + "status": {"indexed": True, "error": None}, + }, + ) + + if not created: + # Update existing index manager + index_manager.raw_index_id = doc_id + + # Add to history if not already present + if doc_id not in index_manager.index_ids_history: + index_manager.index_ids_history.append(doc_id) + + # Update status + index_manager.status = {"indexed": True, "error": None} + + index_manager.save() + logger.debug(f"Updated index manager for data source {data_source_id}") + else: + logger.debug(f"Created index manager for data source {data_source_id}") + + return index_manager + + except LookupDataSource.DoesNotExist: + logger.error(f"Data source {data_source_id} not found") + raise + except Exception as e: + logger.error(f"Error handling index manager: {e}", exc_info=True) + raise + + @staticmethod + def check_extraction_status( + data_source_id: str, + profile_manager, + x2text_config_hash: str, + enable_highlight: bool = False, + ) -> bool: + """Check if extraction already completed for this configuration. + + Args: + data_source_id: UUID of the LookupDataSource + profile_manager: LookupProfileManager instance + x2text_config_hash: Hash of X2Text adapter configuration + enable_highlight: Whether highlighting is enabled + + Returns: + True if extraction completed with same settings, False otherwise + """ + try: + index_manager = LookupIndexManager.objects.get( + data_source_id=data_source_id, profile_manager=profile_manager + ) + + extraction_status = index_manager.extraction_status.get( + x2text_config_hash, {} + ) + + if not extraction_status.get("extracted", False): + return False + + # Check if highlight setting matches + stored_highlight = extraction_status.get("enable_highlight", False) + if stored_highlight != enable_highlight: + logger.debug( + f"Highlight setting mismatch: stored={stored_highlight}, " + f"requested={enable_highlight}" + ) + return False + + logger.debug(f"Extraction already completed for {x2text_config_hash}") + return True + + except LookupIndexManager.DoesNotExist: + logger.debug(f"No index manager found for data source {data_source_id}") + return False + except Exception as e: + logger.error(f"Error checking extraction status: {e}", exc_info=True) + return False + + @staticmethod + @transaction.atomic + def mark_extraction_status( + data_source_id: str, + profile_manager, + x2text_config_hash: str, + enable_highlight: bool = False, + extracted: bool = True, + error_message: str = None, + ) -> bool: + """Mark extraction success or failure in IndexManager. + + Args: + data_source_id: UUID of the LookupDataSource + profile_manager: LookupProfileManager instance + x2text_config_hash: Hash of X2Text adapter configuration + enable_highlight: Whether highlighting is enabled + extracted: Whether extraction succeeded + error_message: Error message if extraction failed + + Returns: + True if status marked successfully, False otherwise + """ + from lookup.models import LookupDataSource + + try: + data_source = LookupDataSource.objects.get(pk=data_source_id) + + # Get or create index manager + index_manager, created = LookupIndexManager.objects.get_or_create( + data_source=data_source, + profile_manager=profile_manager, + defaults={"extraction_status": {}, "status": {}}, + ) + + # Update extraction status for this configuration + index_manager.extraction_status[x2text_config_hash] = { + "extracted": extracted, + "enable_highlight": enable_highlight, + "error": error_message, + } + + # Also update overall status + if extracted: + index_manager.status["extracted"] = True + index_manager.status["error"] = None + else: + index_manager.status["extracted"] = False + index_manager.status["error"] = error_message + + index_manager.save() + + status_text = "success" if extracted else "failure" + logger.debug( + f"Marked extraction {status_text} for data source {data_source_id}, " + f"config {x2text_config_hash}" + ) + + return True + + except LookupDataSource.DoesNotExist: + logger.error(f"Data source {data_source_id} not found") + return False + except Exception as e: + logger.error(f"Error marking extraction status: {e}", exc_info=True) + return False diff --git a/backend/lookup/services/lookup_orchestrator.py b/backend/lookup/services/lookup_orchestrator.py new file mode 100644 index 0000000000..9ee00942d6 --- /dev/null +++ b/backend/lookup/services/lookup_orchestrator.py @@ -0,0 +1,306 @@ +"""Look-Up Orchestrator implementation for parallel execution. + +This module provides functionality to execute multiple Look-Up projects +in parallel and merge their results into a single enriched output. +""" + +import logging +import time +import uuid +from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed +from datetime import UTC, datetime +from typing import Any + +from lookup.models import LookupProject +from lookup.services.enrichment_merger import EnrichmentMerger +from lookup.services.lookup_executor import LookUpExecutor + +logger = logging.getLogger(__name__) + + +class LookUpOrchestrator: + """Orchestrates parallel execution of multiple Look-Up projects. + + This class manages the concurrent execution of multiple Look-Up projects, + handles timeouts, collects results, and merges them into a single + enriched output using the EnrichmentMerger. + """ + + def __init__( + self, + executor: LookUpExecutor, + merger: EnrichmentMerger, + config: dict[str, Any] = None, + ): + """Initialize the Look-Up orchestrator. + + Args: + executor: LookUpExecutor instance for single Look-Up execution + merger: EnrichmentMerger instance for combining results + config: Configuration dictionary with optional keys: + - max_concurrent_executions: Maximum parallel executions (default 10) + - queue_timeout_seconds: Overall queue timeout (default 120) + - execution_timeout_seconds: Per-execution timeout (default 30) + """ + self.executor = executor + self.merger = merger + + config = config or {} + self.max_concurrent = config.get("max_concurrent_executions", 10) + self.queue_timeout = config.get("queue_timeout_seconds", 120) + self.execution_timeout = config.get("execution_timeout_seconds", 30) + + logger.info( + f"Orchestrator initialized with max_concurrent={self.max_concurrent}, " + f"queue_timeout={self.queue_timeout}s, " + f"execution_timeout={self.execution_timeout}s" + ) + + def execute_lookups( + self, input_data: dict[str, Any], lookup_projects: list[LookupProject] + ) -> dict[str, Any]: + """Execute all Look-Ups in parallel and merge results. + + Submits all Look-Up projects for parallel execution, collects + the results, and merges successful enrichments into a single + output. Handles timeouts and failures gracefully. + + Args: + input_data: Input data to enrich + lookup_projects: List of Look-Up projects to execute + + Returns: + Dictionary containing: + - lookup_enrichment: Merged enrichment data + - _lookup_metadata: Execution metadata including: + - execution_id: Unique ID for this execution + - executed_at: ISO8601 timestamp + - total_execution_time_ms: Total time in milliseconds + - lookups_executed: Number of Look-Ups attempted + - lookups_successful: Number of successful executions + - lookups_failed: Number of failed executions + - conflicts_resolved: Number of field conflicts resolved + - enrichments: List of individual enrichment results + + Example: + >>> orchestrator = LookUpOrchestrator(executor, merger) + >>> projects = [vendor_lookup, product_lookup] + >>> result = orchestrator.execute_lookups({"vendor": "Slack"}, projects) + >>> print(result["lookup_enrichment"]) + {'canonical_vendor': 'Slack', 'product_type': 'SaaS'} + >>> print(result["_lookup_metadata"]["lookups_successful"]) + 2 + """ + execution_id = str(uuid.uuid4()) + start_time = time.time() + executed_at = datetime.now(UTC).isoformat() + + logger.info( + f"Starting orchestrated execution {execution_id} for " + f"{len(lookup_projects)} Look-Up projects" + ) + + if not lookup_projects: + # No Look-Ups to execute + return self._empty_result(execution_id, executed_at, start_time) + + successful_enrichments = [] + failed_lookups = [] + timeout_count = 0 + + # Execute Look-Ups in parallel + with ThreadPoolExecutor(max_workers=self.max_concurrent) as thread_executor: + # Submit all tasks + futures = { + thread_executor.submit( + self._execute_single, execution_id, input_data, lookup_project + ): lookup_project + for lookup_project in lookup_projects + } + + logger.debug(f"Submitted {len(futures)} Look-Up tasks for parallel execution") + + # Collect results with timeout + try: + for future in as_completed(futures, timeout=self.queue_timeout): + lookup_project = futures[future] + try: + result = future.result(timeout=self.execution_timeout) + + if result["status"] == "success": + successful_enrichments.append(result) + logger.debug( + f"Look-Up {lookup_project.name} completed successfully" + ) + else: + failed_lookups.append(result) + logger.warning( + f"Look-Up {lookup_project.name} failed: {result.get('error')}" + ) + + except TimeoutError: + # Individual execution timeout + timeout_count += 1 + logger.error( + f"Look-Up {lookup_project.name} timed out after " + f"{self.execution_timeout}s" + ) + failed_lookups.append( + { + "status": "failed", + "project_id": str(lookup_project.id), + "project_name": lookup_project.name, + "error": f"Execution timeout ({self.execution_timeout}s)", + "execution_time_ms": self.execution_timeout * 1000, + "cached": False, + } + ) + + except Exception as e: + # Unexpected error in future.result() + logger.exception( + f"Unexpected error getting result for {lookup_project.name}" + ) + failed_lookups.append( + { + "status": "failed", + "project_id": str(lookup_project.id), + "project_name": lookup_project.name, + "error": f"Execution error: {str(e)}", + "execution_time_ms": 0, + "cached": False, + } + ) + + except TimeoutError: + # Overall queue timeout + logger.error( + f"Queue timeout after {self.queue_timeout}s, " + f"some Look-Ups may not have completed" + ) + # Cancel remaining futures + for future in futures: + if not future.done(): + future.cancel() + lookup_project = futures[future] + failed_lookups.append( + { + "status": "failed", + "project_id": str(lookup_project.id), + "project_name": lookup_project.name, + "error": f"Queue timeout ({self.queue_timeout}s)", + "execution_time_ms": 0, + "cached": False, + } + ) + + # Merge successful enrichments + if successful_enrichments: + merge_result = self.merger.merge(successful_enrichments) + merged_data = merge_result["data"] + conflicts_resolved = merge_result["conflicts_resolved"] + # enrichment_details = merge_result["enrichment_details"] + else: + # No successful enrichments + merged_data = {} + conflicts_resolved = 0 + + # Calculate execution time + total_execution_time_ms = int((time.time() - start_time) * 1000) + + # Combine all enrichment results (successful and failed) + all_enrichments = successful_enrichments + failed_lookups + + logger.info( + f"Orchestration {execution_id} completed: " + f"{len(successful_enrichments)} successful, " + f"{len(failed_lookups)} failed, " + f"{timeout_count} timeouts, " + f"{conflicts_resolved} conflicts resolved, " + f"total time {total_execution_time_ms}ms" + ) + + return { + "lookup_enrichment": merged_data, + "_lookup_metadata": { + "execution_id": execution_id, + "executed_at": executed_at, + "total_execution_time_ms": total_execution_time_ms, + "lookups_executed": len(lookup_projects), + "lookups_successful": len(successful_enrichments), + "lookups_failed": len(failed_lookups), + "conflicts_resolved": conflicts_resolved, + "enrichments": all_enrichments, + }, + } + + def _execute_single( + self, execution_id: str, input_data: dict[str, Any], lookup_project: LookupProject + ) -> dict[str, Any]: + """Execute a single Look-Up project. + + Wrapper around the executor to add execution context and + handle any unexpected errors. + + Args: + execution_id: ID of the orchestration execution + input_data: Input data for enrichment + lookup_project: Look-Up project to execute + + Returns: + Enrichment result dictionary from the executor + """ + try: + logger.debug( + f"Executing Look-Up {lookup_project.name} for execution {execution_id}" + ) + + # Execute the Look-Up + result = self.executor.execute(lookup_project, input_data) + + # Add execution context + result["execution_id"] = execution_id + + return result + + except Exception as e: + # Catch any unexpected errors from the executor + logger.exception(f"Unexpected error executing Look-Up {lookup_project.name}") + return { + "status": "failed", + "project_id": str(lookup_project.id), + "project_name": lookup_project.name, + "error": f"Unexpected error: {str(e)}", + "execution_time_ms": 0, + "cached": False, + "execution_id": execution_id, + } + + def _empty_result( + self, execution_id: str, executed_at: str, start_time: float + ) -> dict[str, Any]: + """Build result for empty Look-Up list. + + Args: + execution_id: Execution ID + executed_at: Execution timestamp + start_time: Start time for calculating duration + + Returns: + Empty result dictionary with metadata + """ + total_execution_time_ms = int((time.time() - start_time) * 1000) + + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "execution_id": execution_id, + "executed_at": executed_at, + "total_execution_time_ms": total_execution_time_ms, + "lookups_executed": 0, + "lookups_successful": 0, + "lookups_failed": 0, + "conflicts_resolved": 0, + "enrichments": [], + }, + } diff --git a/backend/lookup/services/mock_clients.py b/backend/lookup/services/mock_clients.py new file mode 100644 index 0000000000..ee0838d153 --- /dev/null +++ b/backend/lookup/services/mock_clients.py @@ -0,0 +1,157 @@ +"""Mock implementations of LLM and Storage clients for testing. + +These are temporary implementations for testing the API layer +before integrating with real services in Phase 4. +""" + +import json +import random +import time +from typing import Any + + +class MockLLMClient: + """Mock LLM client for testing Look-Up execution. + + Generates synthetic responses for testing purposes. + """ + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate a mock LLM response. + + Returns JSON-formatted enrichment data with random confidence. + """ + # Simulate processing time + time.sleep(random.uniform(0.1, 0.5)) + + # Extract vendor name from prompt if available + vendor = "Unknown" + if "vendor" in prompt.lower(): + # Try to extract vendor name from prompt + lines = prompt.split("\n") + for line in lines: + if "vendor" in line.lower() and ":" in line: + vendor = line.split(":")[-1].strip() + break + + # Generate mock enrichment data + confidence = random.uniform(0.6, 0.98) + enrichment_data = { + "canonical_vendor": self._canonicalize(vendor), + "vendor_category": random.choice( + ["SaaS", "Infrastructure", "Security", "Analytics"] + ), + "vendor_type": random.choice(["Software", "Service", "Platform"]), + "confidence": round(confidence, 2), + } + + return json.dumps(enrichment_data) + + def _canonicalize(self, vendor: str) -> str: + """Mock canonicalization of vendor names.""" + # Simple mock canonicalization + mappings = { + "Slack Technologies": "Slack", + "Microsoft Corp": "Microsoft", + "Amazon Web Services": "AWS", + "Google Cloud Platform": "GCP", + "International Business Machines": "IBM", + } + return mappings.get(vendor, vendor) + + +class MockStorageClient: + """Mock storage client for testing reference data operations. + + Stores data in memory for testing purposes. + """ + + def __init__(self): + """Initialize in-memory storage.""" + self.storage = {} + + def upload(self, path: str, content: bytes) -> bool: + """Upload content to mock storage. + + Args: + path: Storage path + content: File content + + Returns: + True if successful + """ + self.storage[path] = content + return True + + def download(self, path: str) -> bytes | None: + """Download content from mock storage. + + Args: + path: Storage path + + Returns: + File content or None if not found + """ + return self.storage.get(path) + + def delete(self, path: str) -> bool: + """Delete content from mock storage. + + Args: + path: Storage path + + Returns: + True if deleted, False if not found + """ + if path in self.storage: + del self.storage[path] + return True + return False + + def exists(self, path: str) -> bool: + """Check if path exists in storage. + + Args: + path: Storage path + + Returns: + True if exists + """ + return path in self.storage + + def list_files(self, prefix: str) -> list: + """List files with given prefix. + + Args: + prefix: Path prefix + + Returns: + List of matching paths + """ + return [path for path in self.storage.keys() if path.startswith(prefix)] + + def get_text_content(self, path: str) -> str | None: + """Get text content from storage. + + Args: + path: Storage path + + Returns: + Text content or None if not found + """ + content = self.download(path) + if content: + return content.decode("utf-8") + return None + + def save_text_content(self, path: str, text: str) -> bool: + """Save text content to storage. + + Args: + path: Storage path + text: Text content + + Returns: + True if successful + """ + return self.upload(path, text.encode("utf-8")) diff --git a/backend/lookup/services/reference_data_loader.py b/backend/lookup/services/reference_data_loader.py new file mode 100644 index 0000000000..30799b4570 --- /dev/null +++ b/backend/lookup/services/reference_data_loader.py @@ -0,0 +1,267 @@ +"""Reference Data Loader implementation for loading and concatenating reference data. + +This module provides functionality to load reference data from object storage +and concatenate multiple sources into a single text for LLM processing. +""" + +from typing import Any, Protocol +from uuid import UUID + +from django.db.models import QuerySet + +from lookup.exceptions import ExtractionNotCompleteError +from lookup.models import LookupDataSource + + +class StorageClient(Protocol): + """Protocol for object storage client abstraction. + + Any storage client implementation must provide a get() method + that retrieves file content by path. + """ + + def get(self, path: str) -> str: + """Retrieve file content from storage.""" + ... + + +class ReferenceDataLoader: + """Loads and concatenates reference data from object storage. + + This class handles loading reference data files that have been + extracted from uploaded documents and stored in object storage. + It ensures all files are properly extracted before loading and + concatenates multiple sources in the order they were uploaded. + """ + + def __init__(self, storage_client: StorageClient): + """Initialize the reference data loader. + + Args: + storage_client: Object storage client (abstraction). + Must implement the StorageClient protocol. + """ + self.storage = storage_client + + def load_latest_for_project(self, project_id: UUID) -> dict[str, Any]: + r"""Load latest reference data for a project. + + Retrieves the most recent version of reference data for the specified + project. Ensures all data sources have completed extraction before + loading and concatenating their content. + + Args: + project_id: UUID of the Look-Up project + + Returns: + Dictionary containing: + - version: Version number of the reference data + - content: Concatenated text from all files + - files: List of metadata about source files + - total_size: Total size in bytes + + Raises: + ExtractionNotCompleteError: If any data source extraction is incomplete + LookupDataSource.DoesNotExist: If no data sources found + + Example: + >>> loader = ReferenceDataLoader(storage) + >>> data = loader.load_latest_for_project(project_id) + >>> print(data["version"]) + 3 + >>> print(data["content"][:50]) + '=== File: vendors.csv ===\n\nSlack\nMicrosoft...' + """ + # Get latest version data sources + data_sources = LookupDataSource.objects.filter( + project_id=project_id, is_latest=True + ).order_by("created_at") + + if not data_sources.exists(): + raise LookupDataSource.DoesNotExist( + f"No data sources found for project {project_id}" + ) + + # Validate all extractions are complete + is_complete, failed_files = self.validate_extraction_complete(data_sources) + if not is_complete: + raise ExtractionNotCompleteError(failed_files) + + # Get version number from first source (all should have same version) + version_number = data_sources.first().version_number + + # Concatenate content from all sources + content = self.concatenate_sources(data_sources) + + # Build file metadata + files = [] + total_size = 0 + for source in data_sources: + files.append( + { + "id": str(source.id), + "name": source.file_name, + "size": source.file_size, + "type": source.file_type, + "uploaded_at": source.created_at.isoformat(), + } + ) + total_size += source.file_size + + return { + "version": version_number, + "content": content, + "files": files, + "total_size": total_size, + } + + def load_specific_version(self, project_id: UUID, version: int) -> dict[str, Any]: + """Load specific version of reference data. + + Retrieves a specific version of reference data for the project, + regardless of whether it's the latest version. + + Args: + project_id: UUID of the Look-Up project + version: Version number to load + + Returns: + Dictionary with same structure as load_latest_for_project() + + Raises: + ExtractionNotCompleteError: If any data source extraction is incomplete + LookupDataSource.DoesNotExist: If version not found + + Example: + >>> loader = ReferenceDataLoader(storage) + >>> data = loader.load_specific_version(project_id, 2) + >>> print(data["version"]) + 2 + """ + # Get specific version data sources + data_sources = LookupDataSource.objects.filter( + project_id=project_id, version_number=version + ).order_by("created_at") + + if not data_sources.exists(): + raise LookupDataSource.DoesNotExist( + f"Version {version} not found for project {project_id}" + ) + + # Validate all extractions are complete + is_complete, failed_files = self.validate_extraction_complete(data_sources) + if not is_complete: + raise ExtractionNotCompleteError(failed_files) + + # Concatenate content from all sources + content = self.concatenate_sources(data_sources) + + # Build file metadata + files = [] + total_size = 0 + for source in data_sources: + files.append( + { + "id": str(source.id), + "name": source.file_name, + "size": source.file_size, + "type": source.file_type, + "uploaded_at": source.created_at.isoformat(), + } + ) + total_size += source.file_size + + return { + "version": version, + "content": content, + "files": files, + "total_size": total_size, + } + + def concatenate_sources(self, data_sources: QuerySet) -> str: + """Concatenate extracted content from multiple sources in upload order. + + Loads the extracted content for each data source from object storage + and concatenates them with file headers for clarity. + + Args: + data_sources: QuerySet of LookupDataSource objects, + should be ordered by created_at + + Returns: + Concatenated string with all file contents separated by headers + + Example: + >>> content = loader.concatenate_sources(sources) + >>> print(content) + === File: vendors.csv === + + Slack + Microsoft + Google + + === File: products.txt === + + Slack Workspace + Microsoft Teams + """ + contents = [] + + for source in data_sources: + # Add file header + header = f"=== File: {source.file_name} ===" + + # Load content from storage + # First try extracted_content_path, then fall back to original file_path + # for text-based files (CSV, TXT, JSON) + content_path = source.extracted_content_path + if not content_path: + # For text files, use the original file path + text_file_types = ["csv", "txt", "json"] + if source.file_type in text_file_types: + content_path = source.file_path + + if content_path: + try: + file_content = self.storage.get(content_path) + except Exception as e: + # If storage fails, include error in output + file_content = f"[Error loading file: {str(e)}]" + else: + file_content = "[No content path available]" + + # Combine header and content + contents.append(f"{header}\n\n{file_content}") + + # Join all contents with double newline separator + return "\n\n".join(contents) + + def validate_extraction_complete( + self, data_sources: QuerySet + ) -> tuple[bool, list[str]]: + """Check if all sources have completed extraction. + + Verifies that all data sources in the queryset have successfully + completed the extraction process. + + Args: + data_sources: QuerySet of LookupDataSource objects + + Returns: + Tuple of: + - all_complete: True if all extractions complete, False otherwise + - failed_files: List of filenames that are not complete + + Example: + >>> is_complete, failed = loader.validate_extraction_complete(sources) + >>> if not is_complete: + ... print(f"Waiting for: {', '.join(failed)}") + """ + failed_files = [] + + for source in data_sources: + if source.extraction_status != "completed": + failed_files.append(source.file_name) + + all_complete = len(failed_files) == 0 + return all_complete, failed_files diff --git a/backend/lookup/services/variable_resolver.py b/backend/lookup/services/variable_resolver.py new file mode 100644 index 0000000000..75861a1f14 --- /dev/null +++ b/backend/lookup/services/variable_resolver.py @@ -0,0 +1,147 @@ +"""Variable resolver for template variable replacement.""" + +import json +import re +from typing import Any + + +class VariableResolver: + """Resolves {{variable}} placeholders in prompt templates with actual values. + + Supports dot notation for nested field access and handles complex + data types (dicts, lists) by converting them to JSON. + """ + + VARIABLE_PATTERN = r"\{\{([^}]*)\}\}" + + def __init__(self, input_data: dict[str, Any], reference_data: str): + """Initialize the variable resolver with context data. + + Args: + input_data: Extracted data from Prompt Studio + reference_data: Concatenated text from all reference files + """ + self.context = {"input_data": input_data, "reference_data": reference_data} + + def resolve(self, template: str) -> str: + r"""Replace all {{variable}} references in template with actual values. + + Args: + template: Prompt template with {{variable}} placeholders + + Returns: + Resolved prompt with variables replaced + + Example: + >>> resolver = VariableResolver( + ... {"vendor": "Slack Inc"}, "Slack\nMicrosoft\nGoogle" + ... ) + >>> template = "Match {{input_data.vendor}} against: {{reference_data}}" + >>> resolver.resolve(template) + 'Match Slack Inc against: Slack\nMicrosoft\nGoogle' + """ + + def replacer(match): + variable_path = match.group(1).strip() + return str(self._get_nested_value(variable_path)) + + return re.sub(self.VARIABLE_PATTERN, replacer, template) + + def detect_variables(self, template: str) -> list[str]: + """Extract all {{variable}} references from template. + + Args: + template: Template text to analyze + + Returns: + List of unique variable paths found in template + + Example: + >>> resolver = VariableResolver({}, "") + >>> template = "Match {{input_data.vendor}} from {{reference_data}} and {{input_data.vendor}}" + >>> resolver.detect_variables(template) + ['input_data.vendor', 'reference_data'] + """ + if not template: + return [] + + matches = re.findall(self.VARIABLE_PATTERN, template) + # Strip whitespace and deduplicate + unique_vars = list({m.strip() for m in matches}) + return sorted(unique_vars) + + def _get_nested_value(self, path: str) -> Any: + """Get value from context using dot notation path. + + Args: + path: Dot-separated path (e.g., "input_data.vendor.name") + + Returns: + Value at path, or empty string if not found + + Examples: + >>> resolver = VariableResolver({"vendor": {"name": "Slack"}}, "") + >>> resolver._get_nested_value("input_data.vendor.name") + 'Slack' + >>> resolver._get_nested_value("input_data.missing") + '' + """ + if not path: + return "" + + keys = path.split(".") + value = self.context + + for key in keys: + if isinstance(value, dict): + value = value.get(key, "") + elif isinstance(value, list): + # Try to convert key to integer for list indexing + try: + index = int(key) + value = value[index] if 0 <= index < len(value) else "" + except (ValueError, IndexError): + return "" + else: + return "" + + # If value is complex object, return JSON representation + if isinstance(value, (dict, list)): + return json.dumps(value, indent=2, ensure_ascii=False) + + # Handle None values + if value is None: + return "" + + return value + + def validate_variables(self, template: str) -> dict[str, bool]: + """Validate that all variables in template can be resolved. + + Args: + template: Template to validate + + Returns: + Dictionary mapping variable paths to their availability status + """ + variables = self.detect_variables(template) + validation = {} + + for var in variables: + value = self._get_nested_value(var) + # Consider empty string as not available (could be missing) + validation[var] = value != "" + + return validation + + def get_missing_variables(self, template: str) -> list[str]: + """Get list of variables that cannot be resolved. + + Args: + template: Template to check + + Returns: + List of variable paths that resolve to empty/missing values + """ + validation = self.validate_variables(template) + return [var for var, available in validation.items() if not available] diff --git a/backend/lookup/tests/__init__.py b/backend/lookup/tests/__init__.py new file mode 100644 index 0000000000..291608f561 --- /dev/null +++ b/backend/lookup/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Look-Up system.""" diff --git a/backend/lookup/tests/test_api/__init__.py b/backend/lookup/tests/test_api/__init__.py new file mode 100644 index 0000000000..58150d0e27 --- /dev/null +++ b/backend/lookup/tests/test_api/__init__.py @@ -0,0 +1,3 @@ +""" +API tests for Look-Up functionality. +""" diff --git a/backend/lookup/tests/test_api/test_execution_api.py b/backend/lookup/tests/test_api/test_execution_api.py new file mode 100644 index 0000000000..042c35f057 --- /dev/null +++ b/backend/lookup/tests/test_api/test_execution_api.py @@ -0,0 +1,309 @@ +""" +Tests for Look-Up execution and debug API endpoints. +""" + +import uuid +from decimal import Decimal +from unittest.mock import patch, MagicMock + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import ( + LookupProject, + LookupPromptTemplate, + PromptStudioLookupLink, + LookupExecutionAudit +) + +User = get_user_model() + + +class LookupExecutionAPITest(TestCase): + """Test cases for Look-Up execution API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="Vendor: {{vendor_name}}\n{{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + # Create projects + self.lookup1 = LookupProject.objects.create( + name="Vendor Lookup", + description="Vendor enrichment", + reference_data_type="vendor_catalog", + template=self.template, + created_by=self.user + ) + self.lookup2 = LookupProject.objects.create( + name="Product Lookup", + description="Product enrichment", + reference_data_type="product_catalog", + template=self.template, + created_by=self.user + ) + + # Create PS project and links + self.ps_project_id = uuid.uuid4() + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup2 + ) + + @patch('lookup.views.LookUpOrchestrator') + @patch('lookup.views.LookUpExecutor') + def test_debug_with_ps_project(self, mock_executor_class, mock_orchestrator_class): + """Test debug execution with PS project context.""" + # Mock the execution + mock_orchestrator = MagicMock() + mock_orchestrator_class.return_value = mock_orchestrator + mock_orchestrator.execute_lookups.return_value = { + 'lookup_enrichment': { + 'canonical_vendor': 'Test Vendor', + 'vendor_category': 'SaaS', + 'product_type': 'Software' + }, + '_lookup_metadata': { + 'lookups_executed': 2, + 'successful_lookups': 2, + 'failed_lookups': 0, + 'execution_time_ms': 250, + 'conflicts_resolved': 0 + } + } + + url = reverse('lookup:lookupdebug-test-with-ps-project') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'input_data': { + 'vendor_name': 'Test Vendor Inc', + 'product_id': 'PROD-123' + } + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('lookup_enrichment', response.data) + self.assertIn('canonical_vendor', response.data['lookup_enrichment']) + self.assertEqual(response.data['_lookup_metadata']['lookups_executed'], 2) + + # Verify orchestrator was called with correct projects + mock_orchestrator.execute_lookups.assert_called_once() + call_args = mock_orchestrator.execute_lookups.call_args + self.assertEqual(len(call_args.kwargs['lookup_projects']), 2) + + def test_debug_without_ps_project_id(self): + """Test debug endpoint requires PS project ID.""" + url = reverse('lookup:lookupdebug-test-with-ps-project') + data = {'input_data': {}} + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('prompt_studio_project_id is required', response.data['error']) + + def test_debug_with_no_linked_lookups(self): + """Test debug with PS project that has no linked Look-Ups.""" + unlinked_ps_id = uuid.uuid4() + + url = reverse('lookup:lookupdebug-test-with-ps-project') + data = { + 'prompt_studio_project_id': str(unlinked_ps_id), + 'input_data': {} + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['message'], 'No Look-Ups linked to this Prompt Studio project') + self.assertEqual(response.data['lookup_enrichment'], {}) + self.assertEqual(response.data['_lookup_metadata']['lookups_executed'], 0) + + @patch('lookup.views.LookUpOrchestrator') + def test_debug_with_execution_error(self, mock_orchestrator_class): + """Test debug endpoint handles execution errors gracefully.""" + mock_orchestrator = MagicMock() + mock_orchestrator_class.return_value = mock_orchestrator + mock_orchestrator.execute_lookups.side_effect = Exception("Test error") + + url = reverse('lookup:lookupdebug-test-with-ps-project') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'input_data': {} + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + self.assertIn('Test error', response.data['error']) + + +class LookupAuditAPITest(TestCase): + """Test cases for execution audit API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template and project + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="{{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + self.lookup = LookupProject.objects.create( + name="Test Lookup", + description="Test", + reference_data_type="vendor_catalog", + template=self.template, + created_by=self.user + ) + + # Create audit records + self.execution_id = str(uuid.uuid4()) + + self.audit1 = LookupExecutionAudit.objects.create( + lookup_project=self.lookup, + prompt_studio_project_id=uuid.uuid4(), + execution_id=self.execution_id, + input_data={'vendor': 'Test1'}, + enriched_output={'canonical_vendor': 'Test'}, + reference_data_version=1, + llm_provider='openai', + llm_model='gpt-4', + llm_prompt='Test prompt', + llm_response='{"canonical_vendor": "Test"}', + llm_response_cached=False, + execution_time_ms=150, + llm_call_time_ms=100, + status='success', + confidence_score=Decimal('0.95') + ) + + self.audit2 = LookupExecutionAudit.objects.create( + lookup_project=self.lookup, + prompt_studio_project_id=uuid.uuid4(), + execution_id=str(uuid.uuid4()), + input_data={'vendor': 'Test2'}, + status='failure', + error_message='LLM timeout' + ) + + def test_list_audit_records(self): + """Test listing all audit records.""" + url = reverse('lookup:executionaudit-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + def test_filter_audits_by_lookup_project(self): + """Test filtering audits by Look-Up project.""" + url = reverse('lookup:executionaudit-list') + response = self.client.get(url, {'lookup_project_id': str(self.lookup.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + def test_filter_audits_by_execution_id(self): + """Test filtering audits by execution ID.""" + url = reverse('lookup:executionaudit-list') + response = self.client.get(url, {'execution_id': self.execution_id}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['execution_id'], self.execution_id) + + def test_filter_audits_by_status(self): + """Test filtering audits by status.""" + url = reverse('lookup:executionaudit-list') + + # Get successful executions + response = self.client.get(url, {'status': 'success'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['status'], 'success') + + # Get failed executions + response = self.client.get(url, {'status': 'failure'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['status'], 'failure') + + def test_retrieve_audit_record(self): + """Test retrieving a specific audit record.""" + url = reverse('lookup:executionaudit-detail', args=[self.audit1.id]) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['execution_id'], self.execution_id) + self.assertEqual(response.data['status'], 'success') + + def test_audit_records_are_readonly(self): + """Test that audit records cannot be modified.""" + url = reverse('lookup:executionaudit-detail', args=[self.audit1.id]) + + # Try to update + response = self.client.patch(url, {'status': 'modified'}, format='json') + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + # Try to delete + response = self.client.delete(url) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + @patch('lookup.views.AuditLogger') + def test_get_statistics(self, mock_logger_class): + """Test getting execution statistics.""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + mock_logger.get_project_stats.return_value = { + 'total_executions': 100, + 'success_rate': 0.95, + 'avg_execution_time_ms': 150.5, + 'cache_hit_rate': 0.30, + 'avg_confidence_score': 0.92 + } + + url = reverse('lookup:executionaudit-statistics') + response = self.client.get(url, {'lookup_project_id': str(self.lookup.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['total_executions'], 100) + self.assertEqual(response.data['success_rate'], 0.95) + + def test_statistics_requires_project_id(self): + """Test that statistics endpoint requires project ID.""" + url = reverse('lookup:executionaudit-statistics') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('lookup_project_id is required', response.data['error']) diff --git a/backend/lookup/tests/test_api/test_linking_api.py b/backend/lookup/tests/test_api/test_linking_api.py new file mode 100644 index 0000000000..69442b8985 --- /dev/null +++ b/backend/lookup/tests/test_api/test_linking_api.py @@ -0,0 +1,260 @@ +""" +Tests for Prompt Studio Look-Up linking API endpoints. +""" + +import uuid +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import ( + LookupProject, + LookupPromptTemplate, + PromptStudioLookupLink +) + +User = get_user_model() + + +class PromptStudioLinkingAPITest(TestCase): + """Test cases for PS Look-Up linking API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="{{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + # Create Look-Up projects + self.lookup1 = LookupProject.objects.create( + name="Lookup 1", + description="First lookup", + reference_data_type="vendor_catalog", + template=self.template, + created_by=self.user + ) + self.lookup2 = LookupProject.objects.create( + name="Lookup 2", + description="Second lookup", + reference_data_type="product_catalog", + template=self.template, + created_by=self.user + ) + + # Create PS project ID + self.ps_project_id = uuid.uuid4() + + def test_create_link(self): + """Test creating a link between PS project and Look-Up.""" + url = reverse('lookup:lookuplink-list') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project': str(self.lookup1.id) + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['lookup_project_name'], 'Lookup 1') + self.assertTrue( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ).exists() + ) + + def test_create_duplicate_link(self): + """Test that duplicate links are rejected.""" + # Create first link + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + + # Try to create duplicate + url = reverse('lookup:lookuplink-list') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project': str(self.lookup1.id) + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('already linked', str(response.data)) + + def test_list_links(self): + """Test listing all links.""" + # Create links + link1 = PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + link2 = PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup2 + ) + + url = reverse('lookup:lookuplink-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + def test_filter_links_by_ps_project(self): + """Test filtering links by PS project ID.""" + # Create links for different PS projects + ps_project_id_2 = uuid.uuid4() + + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=ps_project_id_2, + lookup_project=self.lookup2 + ) + + url = reverse('lookup:lookuplink-list') + response = self.client.get(url, {'prompt_studio_project_id': str(self.ps_project_id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['lookup_project_name'], 'Lookup 1') + + def test_filter_links_by_lookup_project(self): + """Test filtering links by Look-Up project ID.""" + # Create links + ps_project_id_2 = uuid.uuid4() + + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=ps_project_id_2, + lookup_project=self.lookup1 + ) + + url = reverse('lookup:lookuplink-list') + response = self.client.get(url, {'lookup_project_id': str(self.lookup1.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + def test_delete_link(self): + """Test deleting a link.""" + link = PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + + url = reverse('lookup:lookuplink-detail', args=[link.id]) + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertFalse( + PromptStudioLookupLink.objects.filter(id=link.id).exists() + ) + + def test_bulk_link(self): + """Test bulk linking multiple Look-Ups to a PS project.""" + url = reverse('lookup:lookuplink-bulk-link') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project_ids': [str(self.lookup1.id), str(self.lookup2.id)], + 'unlink': False + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['total_processed'], 2) + self.assertTrue(response.data['results'][0]['linked']) + self.assertTrue(response.data['results'][1]['linked']) + + # Verify links were created + self.assertEqual( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.ps_project_id + ).count(), + 2 + ) + + def test_bulk_unlink(self): + """Test bulk unlinking Look-Ups from a PS project.""" + # Create links first + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup2 + ) + + url = reverse('lookup:lookuplink-bulk-link') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project_ids': [str(self.lookup1.id), str(self.lookup2.id)], + 'unlink': True + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['total_processed'], 2) + self.assertTrue(response.data['results'][0]['unlinked']) + self.assertTrue(response.data['results'][1]['unlinked']) + + # Verify links were removed + self.assertEqual( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.ps_project_id + ).count(), + 0 + ) + + def test_bulk_link_idempotent(self): + """Test that bulk link is idempotent.""" + # Create one link first + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + + url = reverse('lookup:lookuplink-bulk-link') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project_ids': [str(self.lookup1.id), str(self.lookup2.id)], + 'unlink': False + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['total_processed'], 2) + self.assertFalse(response.data['results'][0]['linked']) # Already existed + self.assertTrue(response.data['results'][1]['linked']) # Newly created + + # Still only 2 links total + self.assertEqual( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.ps_project_id + ).count(), + 2 + ) diff --git a/backend/lookup/tests/test_api/test_profile_manager_api.py b/backend/lookup/tests/test_api/test_profile_manager_api.py new file mode 100644 index 0000000000..4884bb53f7 --- /dev/null +++ b/backend/lookup/tests/test_api/test_profile_manager_api.py @@ -0,0 +1,397 @@ +""" +Tests for LookupProfileManager API endpoints. +""" + +import uuid +from unittest.mock import patch, MagicMock + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import LookupProject, LookupProfileManager + +User = get_user_model() + + +class LookupProfileManagerAPITest(TestCase): + """Test cases for LookupProfileManager API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create lookup project + self.project = LookupProject.objects.create( + name="Test Lookup Project", + description="Test Description", + reference_data_type="vendor_catalog", + created_by=self.user + ) + + # Mock adapter instances (UUIDs) + self.mock_llm_id = str(uuid.uuid4()) + self.mock_embedding_id = str(uuid.uuid4()) + self.mock_vector_db_id = str(uuid.uuid4()) + self.mock_x2text_id = str(uuid.uuid4()) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + def test_create_profile(self, mock_get_adapter): + """Test creating a new profile.""" + # Mock adapter instances + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + url = reverse('lookup:lookupprofile-list') + data = { + 'profile_name': 'Default Profile', + 'lookup_project': str(self.project.id), + 'llm': self.mock_llm_id, + 'embedding_model': self.mock_embedding_id, + 'vector_store': self.mock_vector_db_id, + 'x2text': self.mock_x2text_id, + 'chunk_size': 1000, + 'chunk_overlap': 200, + 'similarity_top_k': 5, + 'is_default': True + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['profile_name'], 'Default Profile') + self.assertTrue(response.data['is_default']) + self.assertEqual(LookupProfileManager.objects.count(), 1) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + def test_duplicate_profile_name(self, mock_get_adapter): + """Test that duplicate profile names for same project are rejected.""" + # Mock adapter instances + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + # Create first profile + LookupProfileManager.objects.create( + profile_name='Test Profile', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + # Try to create duplicate + url = reverse('lookup:lookupprofile-list') + data = { + 'profile_name': 'Test Profile', # Same name + 'lookup_project': str(self.project.id), # Same project + 'llm': self.mock_llm_id, + 'embedding_model': self.mock_embedding_id, + 'vector_store': self.mock_vector_db_id, + 'x2text': self.mock_x2text_id, + } + + response = self.client.post(url, data, format='json') + + # Should fail due to unique constraint + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_list_profiles(self, mock_get_by_id, mock_get_adapter): + """Test listing all profiles.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create test profiles + LookupProfileManager.objects.create( + profile_name='Profile 1', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + LookupProfileManager.objects.create( + profile_name='Profile 2', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + url = reverse('lookup:lookupprofile-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_filter_by_project(self, mock_get_by_id, mock_get_adapter): + """Test filtering profiles by project.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create another project + project2 = LookupProject.objects.create( + name="Project 2", + description="Description 2", + reference_data_type="product_catalog", + created_by=self.user + ) + + # Create profiles for different projects + LookupProfileManager.objects.create( + profile_name='Profile 1', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + LookupProfileManager.objects.create( + profile_name='Profile 2', + lookup_project=project2, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + # Filter by project 1 + url = reverse('lookup:lookupprofile-list') + response = self.client.get(url, {'lookup_project': str(self.project.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['profile_name'], 'Profile 1') + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_get_default_profile(self, mock_get_by_id, mock_get_adapter): + """Test getting the default profile for a project.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create profiles + profile1 = LookupProfileManager.objects.create( + profile_name='Non-Default', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + is_default=False, + created_by=self.user + ) + + profile2 = LookupProfileManager.objects.create( + profile_name='Default Profile', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + is_default=True, + created_by=self.user + ) + + # Get default profile + url = reverse('lookup:lookupprofile-default') + response = self.client.get(url, {'lookup_project': str(self.project.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['profile_name'], 'Default Profile') + self.assertTrue(response.data['is_default']) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_set_default_profile(self, mock_get_by_id, mock_get_adapter): + """Test setting a profile as default.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create two profiles + profile1 = LookupProfileManager.objects.create( + profile_name='Profile 1', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + is_default=True, + created_by=self.user + ) + + profile2 = LookupProfileManager.objects.create( + profile_name='Profile 2', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + is_default=False, + created_by=self.user + ) + + # Set profile2 as default + url = reverse('lookup:lookupprofile-set-default', args=[profile2.profile_id]) + response = self.client.post(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue(response.data['is_default']) + + # Verify profile1 is no longer default + profile1.refresh_from_db() + self.assertFalse(profile1.is_default) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_update_profile(self, mock_get_by_id, mock_get_adapter): + """Test updating a profile.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create profile + profile = LookupProfileManager.objects.create( + profile_name='Original Name', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + chunk_size=1000, + created_by=self.user + ) + + # Update profile + url = reverse('lookup:lookupprofile-detail', args=[profile.profile_id]) + data = { + 'chunk_size': 2000, + 'chunk_overlap': 300, + 'similarity_top_k': 10 + } + + response = self.client.patch(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['chunk_size'], 2000) + self.assertEqual(response.data['chunk_overlap'], 300) + self.assertEqual(response.data['similarity_top_k'], 10) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + def test_delete_profile(self, mock_get_adapter): + """Test deleting a profile.""" + # Mock adapter instances + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + # Create profile + profile = LookupProfileManager.objects.create( + profile_name='To Delete', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + url = reverse('lookup:lookupprofile-detail', args=[profile.profile_id]) + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(LookupProfileManager.objects.count(), 0) + + def test_get_default_profile_no_default_exists(self): + """Test getting default profile when none exists.""" + url = reverse('lookup:lookupprofile-default') + response = self.client.get(url, {'lookup_project': str(self.project.id)}) + + # Should return 404 when no default profile exists + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_required_adapters(self): + """Test that all 4 adapters are required.""" + url = reverse('lookup:lookupprofile-list') + + # Missing x2text adapter + data = { + 'profile_name': 'Incomplete Profile', + 'lookup_project': str(self.project.id), + 'llm': self.mock_llm_id, + 'embedding_model': self.mock_embedding_id, + 'vector_store': self.mock_vector_db_id, + # Missing x2text + } + + response = self.client.post(url, data, format='json') + + # Should fail validation + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('x2text', str(response.data)) diff --git a/backend/lookup/tests/test_api/test_project_api.py b/backend/lookup/tests/test_api/test_project_api.py new file mode 100644 index 0000000000..f384d23f92 --- /dev/null +++ b/backend/lookup/tests/test_api/test_project_api.py @@ -0,0 +1,215 @@ +""" +Tests for Look-Up Project API endpoints. +""" + +import json +from unittest.mock import patch, MagicMock + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import LookupProject, LookupPromptTemplate, LookupDataSource + +User = get_user_model() + + +class LookupProjectAPITest(TestCase): + """Test cases for Look-Up Project API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="Vendor: {{vendor_name}}\nReference: {{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + # Create project + self.project = LookupProject.objects.create( + name="Test Project", + description="Test Description", + reference_data_type="vendor_catalog", + template=self.template, + created_by=self.user + ) + + def test_list_projects(self): + """Test listing all projects.""" + url = reverse('lookup:lookupproject-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Test Project') + + def test_create_project(self): + """Test creating a new project.""" + url = reverse('lookup:lookupproject-list') + data = { + 'name': 'New Project', + 'description': 'New Description', + 'reference_data_type': 'product_catalog', + 'template_id': str(self.template.id), + 'is_active': True + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['name'], 'New Project') + self.assertEqual(LookupProject.objects.count(), 2) + + def test_retrieve_project(self): + """Test retrieving a specific project.""" + url = reverse('lookup:lookupproject-detail', args=[self.project.id]) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['name'], 'Test Project') + self.assertIn('template', response.data) + + def test_update_project(self): + """Test updating a project.""" + url = reverse('lookup:lookupproject-detail', args=[self.project.id]) + data = { + 'name': 'Updated Project', + 'description': 'Updated Description', + 'reference_data_type': self.project.reference_data_type, + 'is_active': False + } + + response = self.client.patch(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['name'], 'Updated Project') + self.assertFalse(response.data['is_active']) + + def test_delete_project(self): + """Test deleting a project.""" + url = reverse('lookup:lookupproject-detail', args=[self.project.id]) + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(LookupProject.objects.count(), 0) + + @patch('lookup.views.LookUpOrchestrator') + def test_execute_project(self, mock_orchestrator_class): + """Test executing a Look-Up project.""" + # Mock the orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator_class.return_value = mock_orchestrator + mock_orchestrator.execute_lookups.return_value = { + 'lookup_enrichment': { + 'canonical_vendor': 'Test Vendor', + 'vendor_category': 'SaaS' + }, + '_lookup_metadata': { + 'lookups_executed': 1, + 'successful_lookups': 1, + 'execution_time_ms': 150 + } + } + + url = reverse('lookup:lookupproject-execute', args=[self.project.id]) + data = { + 'input_data': {'vendor_name': 'Test Vendor Inc'}, + 'use_cache': True, + 'timeout_seconds': 30 + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('lookup_enrichment', response.data) + self.assertIn('_lookup_metadata', response.data) + + def test_execute_project_without_auth(self): + """Test that unauthenticated requests are rejected.""" + self.client.force_authenticate(user=None) + url = reverse('lookup:lookupproject-execute', args=[self.project.id]) + data = {'input_data': {}} + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_filter_projects_by_active_status(self): + """Test filtering projects by active status.""" + # Create inactive project + LookupProject.objects.create( + name="Inactive Project", + description="Inactive", + reference_data_type="vendor_catalog", + is_active=False, + created_by=self.user + ) + + url = reverse('lookup:lookupproject-list') + response = self.client.get(url, {'is_active': 'true'}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Test Project') + + def test_upload_reference_data(self): + """Test uploading reference data.""" + url = reverse('lookup:lookupproject-upload-reference-data', args=[self.project.id]) + + # Create a mock file + from django.core.files.uploadedfile import SimpleUploadedFile + file_content = b"vendor1,category1\nvendor2,category2" + file = SimpleUploadedFile("vendors.csv", file_content, content_type="text/csv") + + data = { + 'file': file, + 'extract_text': True, + 'metadata': json.dumps({'source': 'manual_upload'}) + } + + response = self.client.post(url, data, format='multipart') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertIn('source_file_path', response.data) + self.assertEqual(response.data['extraction_status'], 'pending') + + def test_list_data_sources(self): + """Test listing data sources for a project.""" + # Create data sources + LookupDataSource.objects.create( + project=self.project, + source_file_path="test/file1.csv", + extraction_status='complete', + version=1, + is_latest=False + ) + LookupDataSource.objects.create( + project=self.project, + source_file_path="test/file2.csv", + extraction_status='complete', + version=2, + is_latest=True + ) + + url = reverse('lookup:lookupproject-data-sources', args=[self.project.id]) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 2) + + # Test filtering by is_latest + response = self.client.get(url, {'is_latest': 'true'}) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]['version'], 2) diff --git a/backend/lookup/tests/test_api/test_template_api.py b/backend/lookup/tests/test_api/test_template_api.py new file mode 100644 index 0000000000..0883049ffa --- /dev/null +++ b/backend/lookup/tests/test_api/test_template_api.py @@ -0,0 +1,175 @@ +""" +Tests for Look-Up Template API endpoints. +""" + +from unittest.mock import patch, MagicMock + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import LookupPromptTemplate + +User = get_user_model() + + +class LookupTemplateAPITest(TestCase): + """Test cases for Look-Up Template API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="Vendor: {{vendor_name}}\nReference: {{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + def test_list_templates(self): + """Test listing all templates.""" + url = reverse('lookup:lookuptemplate-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Test Template') + + def test_create_template(self): + """Test creating a new template.""" + url = reverse('lookup:lookuptemplate-list') + data = { + 'name': 'New Template', + 'template_text': 'Product: {{product_name}}\n{{reference_data}}', + 'llm_config': {'provider': 'anthropic', 'model': 'claude-2'}, + 'is_active': True + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['name'], 'New Template') + self.assertEqual(LookupPromptTemplate.objects.count(), 2) + + def test_create_template_without_reference_placeholder(self): + """Test that template without {{reference_data}} is rejected.""" + url = reverse('lookup:lookuptemplate-list') + data = { + 'name': 'Invalid Template', + 'template_text': 'Product: {{product_name}}', # Missing {{reference_data}} + 'llm_config': {'provider': 'openai', 'model': 'gpt-4'} + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('template_text', response.data) + + def test_create_template_invalid_llm_config(self): + """Test that template with invalid LLM config is rejected.""" + url = reverse('lookup:lookuptemplate-list') + data = { + 'name': 'Invalid Config Template', + 'template_text': '{{reference_data}}', + 'llm_config': {'model': 'gpt-4'} # Missing provider + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('llm_config', response.data) + + def test_update_template(self): + """Test updating a template.""" + url = reverse('lookup:lookuptemplate-detail', args=[self.template.id]) + data = { + 'name': 'Updated Template', + 'is_active': False + } + + response = self.client.patch(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['name'], 'Updated Template') + self.assertFalse(response.data['is_active']) + + def test_delete_template(self): + """Test deleting a template.""" + url = reverse('lookup:lookuptemplate-detail', args=[self.template.id]) + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(LookupPromptTemplate.objects.count(), 0) + + @patch('lookup.views.VariableResolver') + def test_validate_template(self, mock_resolver_class): + """Test template validation endpoint.""" + # Mock the variable resolver + mock_resolver = MagicMock() + mock_resolver_class.return_value = mock_resolver + mock_resolver.resolve.return_value = "Resolved template text" + mock_resolver.get_all_variables.return_value = {'vendor_name', 'product_id'} + + url = reverse('lookup:lookuptemplate-validate') + data = { + 'template_text': 'Vendor: {{vendor_name}}\nProduct: {{product_id}}\n{{reference_data}}', + 'sample_data': {'vendor_name': 'Test Vendor', 'product_id': '123'}, + 'sample_reference': 'Sample reference data' + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue(response.data['valid']) + self.assertIn('resolved_template', response.data) + self.assertIn('variables_found', response.data) + self.assertEqual(set(response.data['variables_found']), {'vendor_name', 'product_id'}) + + def test_validate_template_with_error(self): + """Test template validation with error.""" + url = reverse('lookup:lookuptemplate-validate') + data = { + 'template_text': 'Invalid: {{unclosed_variable', # Invalid template + 'sample_data': {} + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertFalse(response.data['valid']) + self.assertIn('error', response.data) + + def test_filter_templates_by_active_status(self): + """Test filtering templates by active status.""" + # Create inactive template + LookupPromptTemplate.objects.create( + name="Inactive Template", + template_text="{{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + is_active=False, + created_by=self.user + ) + + url = reverse('lookup:lookuptemplate-list') + + # Get only active templates + response = self.client.get(url, {'is_active': 'true'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Test Template') + + # Get only inactive templates + response = self.client.get(url, {'is_active': 'false'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Inactive Template') diff --git a/backend/lookup/tests/test_integrations/__init__.py b/backend/lookup/tests/test_integrations/__init__.py new file mode 100644 index 0000000000..c0565ab6fe --- /dev/null +++ b/backend/lookup/tests/test_integrations/__init__.py @@ -0,0 +1,3 @@ +""" +Integration tests for Look-Up external service integrations. +""" diff --git a/backend/lookup/tests/test_integrations/test_llm_integration.py b/backend/lookup/tests/test_integrations/test_llm_integration.py new file mode 100644 index 0000000000..620d9cd6a7 --- /dev/null +++ b/backend/lookup/tests/test_integrations/test_llm_integration.py @@ -0,0 +1,239 @@ +""" +Tests for LLM provider integration. +""" + +import json +from unittest.mock import patch + +from django.test import TestCase + +from ...integrations.llm_provider import ( + UnstractLLMClient, + OpenAILLMClient, + AnthropicLLMClient +) + + +class UnstractLLMClientTest(TestCase): + """Test cases for UnstractLLMClient.""" + + def setUp(self): + """Set up test fixtures.""" + # Patch environment variables + self.env_patcher = patch('lookup.integrations.llm_provider.os.getenv') + self.mock_getenv = self.env_patcher.start() + self.mock_getenv.side_effect = self._mock_getenv + + # Initialize client + self.client = UnstractLLMClient() + + def tearDown(self): + """Clean up patches.""" + self.env_patcher.stop() + + def _mock_getenv(self, key, default=None): + """Mock environment variables.""" + env_vars = { + 'LOOKUP_DEFAULT_LLM_PROVIDER': 'openai', + 'LOOKUP_DEFAULT_LLM_MODEL': 'gpt-4', + 'OPENAI_API_KEY': 'test-openai-key', + 'ANTHROPIC_API_KEY': 'test-anthropic-key', + 'AZURE_OPENAI_API_KEY': 'test-azure-key', + 'AZURE_OPENAI_ENDPOINT': 'https://test.azure.com' + } + return env_vars.get(key, default) + + def test_initialization(self): + """Test client initialization.""" + # Test default initialization + client = UnstractLLMClient() + self.assertEqual(client.default_provider, 'openai') + self.assertEqual(client.default_model, 'gpt-4') + + # Test custom initialization + client = UnstractLLMClient(provider='anthropic', model='claude-2') + self.assertEqual(client.default_provider, 'anthropic') + self.assertEqual(client.default_model, 'claude-2') + + def test_generate_with_valid_json(self): + """Test generation with valid JSON response.""" + prompt = "Extract vendor information" + config = { + 'provider': 'openai', + 'model': 'gpt-4', + 'temperature': 0.7 + } + + # Since we don't have actual LLM integration, test fallback + response = self.client.generate(prompt, config) + + # Verify response is valid JSON + data = json.loads(response) + self.assertIsInstance(data, dict) + + # Should have confidence score + if 'confidence' in data: + self.assertTrue(0 <= data['confidence'] <= 1) + + def test_generate_with_timeout(self): + """Test generation respects timeout.""" + import time + start = time.time() + + response = self.client.generate( + "Test prompt", + {'provider': 'openai'}, + timeout=1 + ) + + elapsed = time.time() - start + + # Should complete within reasonable time + self.assertLess(elapsed, 2) + self.assertIsNotNone(response) + + def test_extract_json_from_text(self): + """Test JSON extraction from mixed text.""" + # Test with embedded JSON + text = "Here is the result: {\"vendor\": \"Test Corp\", \"confidence\": 0.9} end of response" + result = self.client._extract_json(text) + + data = json.loads(result) + self.assertEqual(data['vendor'], 'Test Corp') + self.assertEqual(data['confidence'], 0.9) + + # Test with no valid JSON + text = "No JSON here" + result = self.client._extract_json(text) + + data = json.loads(result) + self.assertIn('raw_response', data) + self.assertIn('warning', data) + + def test_validate_response(self): + """Test response validation.""" + # Valid response + valid_response = json.dumps({ + 'vendor': 'Test', + 'confidence': 0.85 + }) + self.assertTrue(self.client.validate_response(valid_response)) + + # Missing confidence + no_confidence = json.dumps({'vendor': 'Test'}) + self.assertFalse(self.client.validate_response(no_confidence)) + + # Invalid confidence + bad_confidence = json.dumps({'confidence': 1.5}) + self.assertFalse(self.client.validate_response(bad_confidence)) + + # Invalid JSON + not_json = "not json" + self.assertFalse(self.client.validate_response(not_json)) + + def test_get_token_count(self): + """Test token counting estimation.""" + # Test basic estimation + text = "This is a test prompt with some content" + count = self.client.get_token_count(text) + + # Should be roughly len/4 + expected = len(text) // 4 + self.assertAlmostEqual(count, expected, delta=2) + + # Test empty text + self.assertEqual(self.client.get_token_count(""), 0) + + def test_fallback_generation(self): + """Test fallback generation when LLM unavailable.""" + # Force fallback mode + self.client.llm_available = False + + response = self.client.generate( + "vendor extraction prompt", + {'provider': 'openai'} + ) + + # Should return valid JSON + data = json.loads(response) + self.assertEqual(data['status'], 'fallback') + self.assertIn('canonical_vendor', data) + + def test_simulate_llm_call(self): + """Test simulated LLM call.""" + # Test vendor prompt + vendor_response = self.client._simulate_llm_call( + "Extract vendor information", + {'provider': 'openai'} + ) + vendor_data = json.loads(vendor_response) + self.assertIn('canonical_vendor', vendor_data) + self.assertIn('vendor_category', vendor_data) + + # Test product prompt + product_response = self.client._simulate_llm_call( + "Extract product details", + {'provider': 'openai'} + ) + product_data = json.loads(product_response) + self.assertIn('product_name', product_data) + self.assertIn('product_category', product_data) + + +class OpenAILLMClientTest(TestCase): + """Test cases for OpenAI-specific client.""" + + def test_openai_client_initialization(self): + """Test OpenAI client initialization.""" + client = OpenAILLMClient() + self.assertEqual(client.default_provider, 'openai') + self.assertEqual(client.default_model, 'gpt-4') + + @patch('lookup.integrations.llm_provider.UnstractLLMClient.generate') + def test_openai_generate(self, mock_generate): + """Test OpenAI-specific generation.""" + mock_generate.return_value = '{"result": "test"}' + + client = OpenAILLMClient() + response = client.generate( + "test prompt", + {'temperature': 0.5} + ) + + # Verify OpenAI config was used + mock_generate.assert_called_once() + call_args = mock_generate.call_args[0] + config = call_args[1] + + self.assertEqual(config['provider'], 'openai') + self.assertEqual(config['temperature'], 0.5) + + +class AnthropicLLMClientTest(TestCase): + """Test cases for Anthropic-specific client.""" + + def test_anthropic_client_initialization(self): + """Test Anthropic client initialization.""" + client = AnthropicLLMClient() + self.assertEqual(client.default_provider, 'anthropic') + self.assertEqual(client.default_model, 'claude-2') + + @patch('lookup.integrations.llm_provider.UnstractLLMClient.generate') + def test_anthropic_generate(self, mock_generate): + """Test Anthropic-specific generation.""" + mock_generate.return_value = '{"result": "test"}' + + client = AnthropicLLMClient() + response = client.generate( + "test prompt", + {'temperature': 0.8} + ) + + # Verify Anthropic config was used + mock_generate.assert_called_once() + call_args = mock_generate.call_args[0] + config = call_args[1] + + self.assertEqual(config['provider'], 'anthropic') + self.assertEqual(config['model'], 'claude-2') + self.assertEqual(config['temperature'], 0.8) diff --git a/backend/lookup/tests/test_integrations/test_llmwhisperer_integration.py b/backend/lookup/tests/test_integrations/test_llmwhisperer_integration.py new file mode 100644 index 0000000000..a03773159f --- /dev/null +++ b/backend/lookup/tests/test_integrations/test_llmwhisperer_integration.py @@ -0,0 +1,297 @@ +""" +Tests for LLMWhisperer document extraction integration. +""" + +from unittest.mock import patch, Mock + +from django.test import TestCase +from django.conf import settings + +from ...integrations.llmwhisperer_client import ( + LLMWhispererClient, + ExtractionStatus +) + + +class LLMWhispererClientTest(TestCase): + """Test cases for LLMWhispererClient.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock settings + self.settings_patcher = patch.multiple( + settings, + LLMWHISPERER_BASE_URL='https://test.llmwhisperer.com', + LLMWHISPERER_API_KEY='test-api-key' + ) + self.settings_patcher.start() + + # Mock requests + self.requests_patcher = patch('lookup.integrations.llmwhisperer_client.requests') + self.mock_requests = self.requests_patcher.start() + + # Initialize client + self.client = LLMWhispererClient() + + def tearDown(self): + """Clean up patches.""" + self.settings_patcher.stop() + self.requests_patcher.stop() + + def test_initialization(self): + """Test client initialization.""" + self.assertEqual(self.client.base_url, 'https://test.llmwhisperer.com') + self.assertEqual(self.client.api_key, 'test-api-key') + self.assertIsNotNone(self.client.session) + + def test_extract_text_success(self): + """Test successful text extraction.""" + # Mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'extraction_id': 'test-id-123', + 'status': 'processing' + } + self.mock_requests.post.return_value = mock_response + + # Test extraction + file_content = b"PDF content here" + file_name = "test.pdf" + + extraction_id, status = self.client.extract_text( + file_content, + file_name + ) + + # Verify + self.assertEqual(extraction_id, 'test-id-123') + self.assertEqual(status, ExtractionStatus.PROCESSING.value) + + # Check API call + self.mock_requests.post.assert_called_once() + call_args = self.mock_requests.post.call_args + + self.assertIn('/v1/extract', call_args[0][0]) + self.assertIn('files', call_args[1]) + self.assertIn('data', call_args[1]) + + def test_extract_text_failure(self): + """Test extraction failure.""" + # Mock failed response + mock_response = Mock() + mock_response.status_code = 400 + mock_response.text = "Bad request" + self.mock_requests.post.return_value = mock_response + + # Test extraction + extraction_id, status = self.client.extract_text( + b"content", + "test.pdf" + ) + + # Verify + self.assertEqual(extraction_id, "") + self.assertEqual(status, ExtractionStatus.FAILED.value) + + def test_check_extraction_status_complete(self): + """Test checking completed extraction status.""" + # Mock status response + mock_status_response = Mock() + mock_status_response.status_code = 200 + mock_status_response.json.return_value = { + 'status': 'completed', + 'extraction_id': 'test-id' + } + + # Mock result response + mock_result_response = Mock() + mock_result_response.status_code = 200 + mock_result_response.text = "Extracted text content" + + # Set up session mock + self.client.session.get = Mock(side_effect=[ + mock_status_response, + mock_result_response + ]) + + # Test status check + status, text = self.client.check_extraction_status('test-id') + + # Verify + self.assertEqual(status, ExtractionStatus.COMPLETE.value) + self.assertEqual(text, "Extracted text content") + + def test_check_extraction_status_processing(self): + """Test checking processing extraction status.""" + # Mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'status': 'processing' + } + self.client.session.get = Mock(return_value=mock_response) + + # Test status check + status, text = self.client.check_extraction_status('test-id') + + # Verify + self.assertEqual(status, ExtractionStatus.PROCESSING.value) + self.assertIsNone(text) + + def test_check_extraction_status_failed(self): + """Test checking failed extraction status.""" + # Mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'status': 'failed', + 'error': 'Extraction error' + } + self.client.session.get = Mock(return_value=mock_response) + + # Test status check + status, text = self.client.check_extraction_status('test-id') + + # Verify + self.assertEqual(status, ExtractionStatus.FAILED.value) + self.assertIsNone(text) + + @patch('time.sleep') + def test_wait_for_extraction_success(self, mock_sleep): + """Test waiting for extraction completion.""" + # Mock status checks + mock_responses = [ + (ExtractionStatus.PROCESSING.value, None), + (ExtractionStatus.PROCESSING.value, None), + (ExtractionStatus.COMPLETE.value, "Extracted text") + ] + + self.client.check_extraction_status = Mock( + side_effect=mock_responses + ) + + # Test wait + status, text = self.client.wait_for_extraction( + 'test-id', + max_wait_seconds=60, + poll_interval=5 + ) + + # Verify + self.assertEqual(status, ExtractionStatus.COMPLETE.value) + self.assertEqual(text, "Extracted text") + self.assertEqual(self.client.check_extraction_status.call_count, 3) + + @patch('time.sleep') + @patch('time.time') + def test_wait_for_extraction_timeout(self, mock_time, mock_sleep): + """Test extraction timeout.""" + # Mock time to simulate timeout + mock_time.side_effect = [0, 10, 20, 35, 40] # Exceeds 30 second limit + + self.client.check_extraction_status = Mock( + return_value=(ExtractionStatus.PROCESSING.value, None) + ) + + # Test wait with short timeout + status, text = self.client.wait_for_extraction( + 'test-id', + max_wait_seconds=30, + poll_interval=5 + ) + + # Verify + self.assertEqual(status, ExtractionStatus.FAILED.value) + self.assertIsNone(text) + + def test_extract_and_wait(self): + """Test combined extract and wait.""" + # Mock extraction + self.client.extract_text = Mock( + return_value=('test-id', ExtractionStatus.PROCESSING.value) + ) + + # Mock wait + self.client.wait_for_extraction = Mock( + return_value=(ExtractionStatus.COMPLETE.value, "Extracted text") + ) + + # Test + success, text = self.client.extract_and_wait( + b"content", + "test.pdf" + ) + + # Verify + self.assertTrue(success) + self.assertEqual(text, "Extracted text") + + def test_is_extraction_needed(self): + """Test checking if extraction is needed.""" + # Files that need extraction + extractable = [ + 'document.pdf', + 'image.png', + 'photo.jpg', + 'scan.tiff', + 'presentation.pptx', + 'spreadsheet.xlsx' + ] + + for filename in extractable: + self.assertTrue( + self.client.is_extraction_needed(filename), + f"{filename} should need extraction" + ) + + # Files that don't need extraction + non_extractable = [ + 'data.json', + 'script.py', + 'text.txt', + 'config.yml' + ] + + for filename in non_extractable: + self.assertFalse( + self.client.is_extraction_needed(filename), + f"{filename} should not need extraction" + ) + + def test_get_extraction_config_for_file(self): + """Test getting extraction config based on file type.""" + # Test PDF config + pdf_config = self.client.get_extraction_config_for_file('test.pdf') + self.assertEqual(pdf_config['processing_mode'], 'ocr') + self.assertTrue(pdf_config['force_text_processing']) + + # Test image config + img_config = self.client.get_extraction_config_for_file('test.jpg') + self.assertEqual(img_config['processing_mode'], 'ocr') + self.assertFalse(img_config['force_text_processing']) + + # Test Word config + doc_config = self.client.get_extraction_config_for_file('test.docx') + self.assertEqual(doc_config['processing_mode'], 'text') + self.assertEqual(doc_config['output_format'], 'markdown') + + # Test Excel config + xls_config = self.client.get_extraction_config_for_file('test.xlsx') + self.assertEqual(xls_config['processing_mode'], 'text') + self.assertEqual(xls_config['line_splitter'], 'paragraph') + + def test_default_config(self): + """Test default extraction configuration.""" + config = self.client._get_default_config() + + # Verify required fields + self.assertIn('processing_mode', config) + self.assertIn('output_format', config) + self.assertIn('page_separator', config) + self.assertIn('timeout', config) + + # Verify defaults + self.assertEqual(config['processing_mode'], 'ocr') + self.assertEqual(config['output_format'], 'text') + self.assertEqual(config['timeout'], 300) diff --git a/backend/lookup/tests/test_integrations/test_redis_cache_integration.py b/backend/lookup/tests/test_integrations/test_redis_cache_integration.py new file mode 100644 index 0000000000..de1d203a99 --- /dev/null +++ b/backend/lookup/tests/test_integrations/test_redis_cache_integration.py @@ -0,0 +1,272 @@ +""" +Tests for Redis cache integration. +""" + +from unittest.mock import MagicMock, patch + +from django.test import TestCase +from django.conf import settings + +from ...integrations.redis_cache import RedisLLMCache + + +class RedisLLMCacheTest(TestCase): + """Test cases for RedisLLMCache.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock Redis availability + self.redis_patcher = patch('lookup.integrations.redis_cache.REDIS_AVAILABLE', True) + self.redis_patcher.start() + + # Mock Redis client + self.redis_client_patcher = patch('lookup.integrations.redis_cache.redis.Redis') + self.mock_redis_class = self.redis_client_patcher.start() + + # Create mock Redis instance + self.mock_redis = MagicMock() + self.mock_redis_class.return_value = self.mock_redis + self.mock_redis.ping.return_value = True + + # Mock settings + self.settings_patcher = patch.multiple( + settings, + REDIS_HOST='localhost', + REDIS_PORT=6379, + REDIS_CACHE_DB=1, + REDIS_PASSWORD='test-pass' + ) + self.settings_patcher.start() + + # Initialize cache + self.cache = RedisLLMCache(ttl_hours=24) + + def tearDown(self): + """Clean up patches.""" + self.redis_patcher.stop() + self.redis_client_patcher.stop() + self.settings_patcher.stop() + + def test_initialization_with_redis(self): + """Test cache initialization with Redis available.""" + self.assertIsNotNone(self.cache.redis_client) + self.mock_redis.ping.assert_called_once() + self.assertEqual(self.cache.ttl_seconds, 24 * 3600) + self.assertEqual(self.cache.key_prefix, "lookup:llm:") + + def test_initialization_without_redis(self): + """Test cache initialization when Redis unavailable.""" + # Mock Redis connection failure + self.mock_redis.ping.side_effect = Exception("Connection failed") + + # Reinitialize cache + cache = RedisLLMCache(fallback_to_memory=True) + + # Should fall back to memory cache + self.assertIsNone(cache.redis_client) + self.assertIsNotNone(cache.memory_cache) + + def test_generate_cache_key(self): + """Test cache key generation.""" + prompt = "Test prompt" + reference = "Reference data" + + key = self.cache.generate_cache_key(prompt, reference) + + # Should be prefixed and hashed + self.assertTrue(key.startswith("lookup:llm:")) + self.assertEqual(len(key), len("lookup:llm:") + 64) # SHA256 hex length + + def test_set_and_get(self): + """Test setting and getting cache values.""" + # Test set + key = "lookup:llm:test-key" + value = '{"result": "test"}' + + result = self.cache.set(key, value) + self.assertTrue(result) + + # Verify Redis setex called + self.mock_redis.setex.assert_called_once_with( + name=key, + time=24 * 3600, + value=value + ) + + # Test get + self.mock_redis.get.return_value = value + retrieved = self.cache.get(key) + + self.assertEqual(retrieved, value) + self.mock_redis.get.assert_called_once_with(key) + + def test_get_cache_miss(self): + """Test cache miss.""" + self.mock_redis.get.return_value = None + + result = self.cache.get("lookup:llm:nonexistent") + + self.assertIsNone(result) + + def test_delete(self): + """Test deleting cache entries.""" + key = "lookup:llm:test-key" + self.mock_redis.delete.return_value = 1 + + result = self.cache.delete(key) + + self.assertTrue(result) + self.mock_redis.delete.assert_called_once_with(key) + + def test_delete_nonexistent(self): + """Test deleting non-existent entry.""" + self.mock_redis.delete.return_value = 0 + + result = self.cache.delete("lookup:llm:nonexistent") + + # Redis returns 0 for non-existent keys + self.assertFalse(result) + + def test_clear_pattern(self): + """Test clearing entries by pattern.""" + # Mock SCAN operation + self.mock_redis.scan.side_effect = [ + (100, ["lookup:llm:project1:key1", "lookup:llm:project1:key2"]), + (0, ["lookup:llm:project1:key3"]) + ] + self.mock_redis.delete.return_value = 3 + + count = self.cache.clear_pattern("lookup:llm:project1:*") + + self.assertEqual(count, 3) + self.mock_redis.delete.assert_called() + + def test_fallback_to_memory_cache(self): + """Test fallback to memory cache when Redis fails.""" + # Make Redis operations fail + from redis.exceptions import RedisError + self.mock_redis.get.side_effect = RedisError("Connection lost") + self.mock_redis.setex.side_effect = RedisError("Connection lost") + + # Cache should have memory fallback + self.assertIsNotNone(self.cache.memory_cache) + + # Test set with fallback + key = "lookup:llm:test" + value = "test-value" + + result = self.cache.set(key, value) + self.assertTrue(result) # Should succeed with memory cache + + # Test get with fallback + retrieved = self.cache.get(key) + # Memory cache uses key without prefix + self.assertEqual(retrieved, value) + + def test_get_stats(self): + """Test getting cache statistics.""" + # Mock Redis info + self.mock_redis.info.side_effect = [ + { # stats info + 'total_connections_received': 100, + 'keyspace_hits': 80, + 'keyspace_misses': 20 + }, + { # keyspace info + 'db1': {'keys': 50, 'expires': 45} + } + ] + + stats = self.cache.get_stats() + + # Verify stats structure + self.assertEqual(stats['backend'], 'redis') + self.assertEqual(stats['ttl_hours'], 24) + self.assertIn('redis', stats) + self.assertEqual(stats['redis']['keyspace_hits'], 80) + self.assertEqual(stats['redis']['hit_rate'], 0.8) + + def test_cleanup_expired(self): + """Test cleanup of expired entries.""" + # Redis handles expiration automatically + count = self.cache.cleanup_expired() + + # Should return 0 for Redis (automatic expiry) + self.assertEqual(count, 0) + + def test_warmup(self): + """Test cache warmup.""" + project_id = "test-project" + preload_data = { + "prompt1": "response1", + "prompt2": "response2" + } + + count = self.cache.warmup(project_id, preload_data) + + self.assertEqual(count, 2) + self.assertEqual(self.mock_redis.setex.call_count, 2) + + def test_custom_ttl(self): + """Test setting cache with custom TTL.""" + key = "lookup:llm:test" + value = "test-value" + custom_ttl = 3600 # 1 hour + + self.cache.set(key, value, ttl=custom_ttl) + + # Verify custom TTL was used + self.mock_redis.setex.assert_called_with( + name=key, + time=custom_ttl, + value=value + ) + + def test_hit_rate_calculation(self): + """Test cache hit rate calculation.""" + # Test with hits and misses + rate = self.cache._calculate_hit_rate(80, 20) + self.assertEqual(rate, 0.8) + + # Test with no data + rate = self.cache._calculate_hit_rate(0, 0) + self.assertEqual(rate, 0.0) + + def test_pattern_matching(self): + """Test pattern matching for memory cache.""" + # Test exact match + self.assertTrue(self.cache._match_pattern("key1", "key1")) + + # Test wildcard match + self.assertTrue(self.cache._match_pattern("project:key1", "project:*")) + self.assertTrue(self.cache._match_pattern("project:subkey:value", "project:*:value")) + + # Test non-match + self.assertFalse(self.cache._match_pattern("other:key", "project:*")) + + +class RedisUnavailableTest(TestCase): + """Test cases when Redis is not available.""" + + def setUp(self): + """Set up test without Redis.""" + # Mock Redis as unavailable + self.redis_patcher = patch('lookup.integrations.redis_cache.REDIS_AVAILABLE', False) + self.redis_patcher.start() + + def tearDown(self): + """Clean up patches.""" + self.redis_patcher.stop() + + def test_initialization_without_redis_package(self): + """Test when redis package is not installed.""" + cache = RedisLLMCache() + + # Should initialize without Redis + self.assertIsNone(cache.redis_client) + self.assertIsNotNone(cache.memory_cache) + + # Should still be functional with memory cache + key = cache.generate_cache_key("test", "data") + cache.set(key, "value") + self.assertEqual(cache.get(key), "value") diff --git a/backend/lookup/tests/test_integrations/test_storage_integration.py b/backend/lookup/tests/test_integrations/test_storage_integration.py new file mode 100644 index 0000000000..5939a2fc7e --- /dev/null +++ b/backend/lookup/tests/test_integrations/test_storage_integration.py @@ -0,0 +1,247 @@ +""" +Tests for object storage integration. +""" + +import uuid +from unittest.mock import MagicMock, patch + +from django.test import TestCase + +from ...integrations.storage_client import RemoteStorageClient + + +class RemoteStorageClientTest(TestCase): + """Test cases for RemoteStorageClient.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock the FileStorage instance + self.mock_fs_patcher = patch('lookup.integrations.storage_client.EnvHelper.get_storage') + self.mock_get_storage = self.mock_fs_patcher.start() + + # Create mock file storage + self.mock_fs = MagicMock() + self.mock_get_storage.return_value = self.mock_fs + + # Initialize client + self.client = RemoteStorageClient(base_path="test/lookup") + self.project_id = uuid.uuid4() + + def tearDown(self): + """Clean up patches.""" + self.mock_fs_patcher.stop() + + def test_upload_success(self): + """Test successful file upload.""" + # Setup mock + self.mock_fs.write.return_value = None # Successful write + + # Test upload + content = b"test content" + path = "test/file.txt" + + result = self.client.upload(path, content) + + # Verify + self.assertTrue(result) + self.mock_fs.mkdir.assert_called_once() + self.mock_fs.write.assert_called_once_with( + path=path, + mode="wb", + data=content + ) + + def test_upload_failure(self): + """Test file upload failure.""" + # Setup mock to raise exception + self.mock_fs.write.side_effect = Exception("Storage error") + + # Test upload + result = self.client.upload("test/file.txt", b"content") + + # Verify + self.assertFalse(result) + + def test_download_success(self): + """Test successful file download.""" + # Setup mock + expected_content = b"test content" + self.mock_fs.exists.return_value = True + self.mock_fs.read.return_value = expected_content + + # Test download + content = self.client.download("test/file.txt") + + # Verify + self.assertEqual(content, expected_content) + self.mock_fs.read.assert_called_once_with( + path="test/file.txt", + mode="rb" + ) + + def test_download_file_not_found(self): + """Test download when file doesn't exist.""" + # Setup mock + self.mock_fs.exists.return_value = False + + # Test download + content = self.client.download("nonexistent.txt") + + # Verify + self.assertIsNone(content) + self.mock_fs.read.assert_not_called() + + def test_delete_success(self): + """Test successful file deletion.""" + # Setup mock + self.mock_fs.exists.return_value = True + self.mock_fs.delete.return_value = None + + # Test delete + result = self.client.delete("test/file.txt") + + # Verify + self.assertTrue(result) + self.mock_fs.delete.assert_called_once_with("test/file.txt") + + def test_delete_file_not_found(self): + """Test delete when file doesn't exist.""" + # Setup mock + self.mock_fs.exists.return_value = False + + # Test delete + result = self.client.delete("nonexistent.txt") + + # Verify + self.assertFalse(result) + self.mock_fs.delete.assert_not_called() + + def test_exists_check(self): + """Test file existence check.""" + # Test existing file + self.mock_fs.exists.return_value = True + self.assertTrue(self.client.exists("test/file.txt")) + + # Test non-existing file + self.mock_fs.exists.return_value = False + self.assertFalse(self.client.exists("nonexistent.txt")) + + def test_list_files(self): + """Test listing files with prefix.""" + # Setup mock + self.mock_fs.listdir.return_value = ["file1.txt", "file2.txt", ".hidden"] + + # Test list + files = self.client.list_files("test/prefix") + + # Verify - hidden files should be excluded + self.assertEqual(len(files), 2) + self.assertIn("test/prefix/file1.txt", files) + self.assertIn("test/prefix/file2.txt", files) + self.assertNotIn("test/prefix/.hidden", files) + + def test_text_content_operations(self): + """Test text content save and retrieve.""" + # Test save + text = "Hello, World!" + self.mock_fs.write.return_value = None + + result = self.client.save_text_content("test.txt", text) + self.assertTrue(result) + self.mock_fs.write.assert_called_with( + path="test.txt", + mode="wb", + data=text.encode('utf-8') + ) + + # Test get + self.mock_fs.exists.return_value = True + self.mock_fs.read.return_value = text.encode('utf-8') + + retrieved = self.client.get_text_content("test.txt") + self.assertEqual(retrieved, text) + + def test_upload_reference_data(self): + """Test uploading reference data with metadata.""" + # Setup + content = b"reference data" + filename = "vendors.csv" + metadata = {"source": "manual", "version": 1} + + # Mock JSON encoding for metadata + import json + expected_meta = json.dumps(metadata, indent=2) + + # Test upload + path = self.client.upload_reference_data( + self.project_id, + filename, + content, + metadata + ) + + # Verify + expected_path = f"test/lookup/{self.project_id}/{filename}" + self.assertEqual(path, expected_path) + + # Check main file upload + call_args = [call for call in self.mock_fs.write.call_args_list + if call[1]['path'] == expected_path] + self.assertEqual(len(call_args), 1) + + # Check metadata upload + meta_path = f"{expected_path}.meta.json" + meta_calls = [call for call in self.mock_fs.write.call_args_list + if call[1]['path'] == meta_path] + self.assertEqual(len(meta_calls), 1) + + def test_get_reference_data(self): + """Test retrieving reference data.""" + # Setup + expected_data = "reference content" + self.mock_fs.exists.return_value = True + self.mock_fs.read.return_value = expected_data.encode('utf-8') + + # Test + data = self.client.get_reference_data(self.project_id, "data.txt") + + # Verify + self.assertEqual(data, expected_data) + expected_path = f"test/lookup/{self.project_id}/data.txt" + self.mock_fs.read.assert_called_with(path=expected_path, mode="rb") + + def test_list_project_files(self): + """Test listing all files for a project.""" + # Setup + self.mock_fs.listdir.return_value = ["file1.csv", "file2.json"] + + # Test + files = self.client.list_project_files(self.project_id) + + # Verify + expected_prefix = f"test/lookup/{self.project_id}" + self.assertEqual(len(files), 2) + self.assertIn(f"{expected_prefix}/file1.csv", files) + self.assertIn(f"{expected_prefix}/file2.json", files) + + def test_delete_project_data(self): + """Test deleting all project data.""" + # Setup + project_files = ["file1.csv", "file2.json"] + self.mock_fs.listdir.return_value = project_files + self.mock_fs.exists.return_value = True + + # Test + result = self.client.delete_project_data(self.project_id) + + # Verify + self.assertTrue(result) + + # Check files were deleted + expected_prefix = f"test/lookup/{self.project_id}" + for filename in project_files: + expected_path = f"{expected_prefix}/{filename}" + self.mock_fs.delete.assert_any_call(expected_path) + + # Check directory was removed + self.mock_fs.rmdir.assert_called_once_with(expected_prefix) diff --git a/backend/lookup/tests/test_migrations.py b/backend/lookup/tests/test_migrations.py new file mode 100644 index 0000000000..496891a816 --- /dev/null +++ b/backend/lookup/tests/test_migrations.py @@ -0,0 +1,286 @@ +"""Tests for Look-Up system database migrations.""" + +import pytest +from django.db import connection +from django.db.migrations.executor import MigrationExecutor +from django.test import TransactionTestCase + + +class TestLookupMigrations(TransactionTestCase): + """Test suite for Look-Up database migrations.""" + + @property + def app(self): + return "lookup" + + @property + def migrate_from(self): + return None # Initial migration + + @property + def migrate_to(self): + return [(self.app, "0001_initial")] + + def setUp(self): + """Set up test environment.""" + assert ( + self.migrate_from and self.migrate_to + ), "TestCase '{}' must define migrate_from and migrate_to properties".format( + type(self).__name__ + ) + self.migrate_from = [(self.app, self.migrate_from)] + self.migrate_to = [(self.app, self.migrate_to[0][1])] + executor = MigrationExecutor(connection) + old_apps = executor.loader.project_state(self.migrate_from).apps + + # Reverse to the original migration + executor.migrate(self.migrate_from) + + self.apps_before = old_apps + + def test_migration_can_be_applied(self): + """Test that the migration can be applied successfully.""" + executor = MigrationExecutor(connection) + + # Apply migrations + executor.migrate(self.migrate_to) + + # Check tables exist + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT table_name FROM information_schema.tables + WHERE table_name IN ( + 'lookup_projects', + 'lookup_data_sources', + 'lookup_prompt_templates', + 'prompt_studio_lookup_links' + ) + """ + ) + tables = [row[0] for row in cursor.fetchall()] + + assert len(tables) == 4, f"Expected 4 tables, found {len(tables)}: {tables}" + assert "lookup_projects" in tables + assert "lookup_data_sources" in tables + assert "lookup_prompt_templates" in tables + assert "prompt_studio_lookup_links" in tables + + def test_migration_can_be_reversed(self): + """Test that the migration can be reversed successfully.""" + executor = MigrationExecutor(connection) + + # Apply migration + executor.migrate(self.migrate_to) + + # Reverse migration + executor.migrate(self.migrate_from) + + # Check tables don't exist + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT table_name FROM information_schema.tables + WHERE table_name IN ( + 'lookup_projects', + 'lookup_data_sources', + 'lookup_prompt_templates', + 'prompt_studio_lookup_links' + ) + """ + ) + tables = [row[0] for row in cursor.fetchall()] + + assert ( + len(tables) == 0 + ), f"Tables should not exist after reversal, found: {tables}" + + def test_constraints_are_enforced(self): + """Test that all database constraints are properly enforced.""" + from django.contrib.auth import get_user_model + from account.models import Organization + from lookup.models import LookupProject, LookupDataSource + + User = get_user_model() + + # Create test user and organization + org = Organization.objects.create(name="Test Org") + user = User.objects.create_user( + username="testuser", email="test@example.com", password="password" + ) + + # Test CHECK constraint on lookup_type + with pytest.raises(Exception): + LookupProject.objects.create( + name="Invalid Type Project", + lookup_type="invalid_type", + llm_provider="openai", + llm_model="gpt-4", + organization=org, + created_by=user, + ) + + # Test CHECK constraint on extraction_status + project = LookupProject.objects.create( + name="Valid Project", + lookup_type="static_data", + llm_provider="openai", + llm_model="gpt-4", + organization=org, + created_by=user, + ) + + with pytest.raises(Exception): + LookupDataSource.objects.create( + project=project, + file_name="test.pdf", + file_path="/path/to/file.pdf", + file_size=1024, + file_type="pdf", + extraction_status="invalid_status", + uploaded_by=user, + ) + + def test_version_trigger_functionality(self): + """Test that the version management trigger works correctly.""" + from django.contrib.auth import get_user_model + from account.models import Organization + from lookup.models import LookupProject, LookupDataSource + + User = get_user_model() + + # Create test data + org = Organization.objects.create(name="Test Org") + user = User.objects.create_user( + username="testuser", email="test@example.com", password="password" + ) + + project = LookupProject.objects.create( + name="Version Test Project", + lookup_type="static_data", + llm_provider="openai", + llm_model="gpt-4", + organization=org, + created_by=user, + ) + + # Create first data source + ds1 = LookupDataSource.objects.create( + project=project, + file_name="v1.pdf", + file_path="/path/v1.pdf", + file_size=1024, + file_type="pdf", + uploaded_by=user, + ) + + # Verify first version + ds1.refresh_from_db() + assert ds1.version_number == 1 + assert ds1.is_latest is True + + # Create second data source + ds2 = LookupDataSource.objects.create( + project=project, + file_name="v2.pdf", + file_path="/path/v2.pdf", + file_size=2048, + file_type="pdf", + uploaded_by=user, + ) + + # Verify second version + ds2.refresh_from_db() + assert ds2.version_number == 2 + assert ds2.is_latest is True + + # Verify first version is no longer latest + ds1.refresh_from_db() + assert ds1.is_latest is False + + def test_indexes_are_created(self): + """Test that all required indexes are created.""" + with connection.cursor() as cursor: + # Check for index existence on lookup_projects + cursor.execute( + """ + SELECT indexname FROM pg_indexes + WHERE tablename = 'lookup_projects' + AND indexname IN ( + 'idx_lookup_proj_org', + 'idx_lookup_proj_created_by', + 'idx_lookup_proj_updated' + ) + """ + ) + project_indexes = [row[0] for row in cursor.fetchall()] + + assert len(project_indexes) == 3 + + with connection.cursor() as cursor: + # Check for index existence on lookup_data_sources + cursor.execute( + """ + SELECT indexname FROM pg_indexes + WHERE tablename = 'lookup_data_sources' + AND indexname IN ( + 'idx_lookup_ds_project', + 'idx_lookup_ds_latest', + 'idx_lookup_ds_created', + 'idx_lookup_ds_status' + ) + """ + ) + ds_indexes = [row[0] for row in cursor.fetchall()] + + assert len(ds_indexes) == 4 + + def test_foreign_key_constraints(self): + """Test that foreign key relationships work correctly.""" + from django.contrib.auth import get_user_model + from account.models import Organization + from lookup.models import ( + LookupProject, + LookupDataSource, + LookupPromptTemplate, + ) + + User = get_user_model() + + # Create test data + org = Organization.objects.create(name="Test Org") + user = User.objects.create_user( + username="testuser", email="test@example.com", password="password" + ) + + project = LookupProject.objects.create( + name="FK Test Project", + lookup_type="static_data", + llm_provider="openai", + llm_model="gpt-4", + organization=org, + created_by=user, + ) + + # Test cascade delete + data_source = LookupDataSource.objects.create( + project=project, + file_name="test.pdf", + file_path="/path/test.pdf", + file_size=1024, + file_type="pdf", + uploaded_by=user, + ) + + template = LookupPromptTemplate.objects.create( + project=project, + template_text="Test: {{input_data}}", + ) + + # Delete project should cascade + project_id = project.id + project.delete() + + # Verify cascaded deletes + assert not LookupDataSource.objects.filter(id=data_source.id).exists() + assert not LookupPromptTemplate.objects.filter(id=template.id).exists() diff --git a/backend/lookup/tests/test_services/test_audit_logger.py b/backend/lookup/tests/test_services/test_audit_logger.py new file mode 100644 index 0000000000..f7f138b193 --- /dev/null +++ b/backend/lookup/tests/test_services/test_audit_logger.py @@ -0,0 +1,455 @@ +""" +Tests for Audit Logger implementation. + +This module tests the AuditLogger class including logging executions, +convenience methods, and statistics retrieval. +""" + +import uuid +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + +from lookup.services.audit_logger import AuditLogger + + +class TestAuditLogger: + """Test cases for AuditLogger class.""" + + @pytest.fixture + def audit_logger(self): + """Create an AuditLogger instance.""" + return AuditLogger() + + @pytest.fixture + def mock_project(self): + """Create a mock LookupProject.""" + project = MagicMock() + project.id = uuid.uuid4() + project.name = "Test Look-Up" + return project + + @pytest.fixture + def execution_params(self, mock_project): + """Create standard execution parameters.""" + return { + 'execution_id': str(uuid.uuid4()), + 'lookup_project_id': mock_project.id, + 'prompt_studio_project_id': uuid.uuid4(), + 'input_data': {'vendor': 'Slack Technologies'}, + 'reference_data_version': 2, + 'llm_provider': 'openai', + 'llm_model': 'gpt-4', + 'llm_prompt': 'Match vendor Slack Technologies...', + 'llm_response': '{"canonical_vendor": "Slack", "confidence": 0.92}', + 'enriched_output': {'canonical_vendor': 'Slack'}, + 'status': 'success', + 'confidence_score': 0.92, + 'execution_time_ms': 1234, + 'llm_call_time_ms': 890, + 'llm_response_cached': False, + 'error_message': None + } + + # ========== Basic Logging Tests ========== + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_successful_logging( + self, mock_project_model, mock_audit_model, + audit_logger, execution_params, mock_project + ): + """Test successful execution logging.""" + # Setup mocks + mock_project_model.objects.get.return_value = mock_project + mock_audit_instance = MagicMock() + mock_audit_instance.id = uuid.uuid4() + mock_audit_model.objects.create.return_value = mock_audit_instance + + # Log execution + result = audit_logger.log_execution(**execution_params) + + # Verify project was fetched + mock_project_model.objects.get.assert_called_once_with( + id=execution_params['lookup_project_id'] + ) + + # Verify audit was created with correct params + mock_audit_model.objects.create.assert_called_once() + create_call = mock_audit_model.objects.create.call_args + kwargs = create_call.kwargs + + assert kwargs['lookup_project'] == mock_project + assert kwargs['execution_id'] == execution_params['execution_id'] + assert kwargs['status'] == 'success' + assert kwargs['llm_provider'] == 'openai' + assert kwargs['llm_model'] == 'gpt-4' + assert kwargs['confidence_score'] == Decimal('0.92') + + # Verify return value + assert result == mock_audit_instance + + @patch('lookup.services.audit_logger.LookupProject') + def test_project_not_found( + self, mock_project_model, audit_logger, execution_params + ): + """Test handling when Look-Up project doesn't exist.""" + # Make project lookup fail + mock_project_model.objects.get.side_effect = mock_project_model.DoesNotExist() + + # Log execution + result = audit_logger.log_execution(**execution_params) + + # Should return None and not raise exception + assert result is None + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_database_error_handling( + self, mock_project_model, mock_audit_model, + audit_logger, execution_params, mock_project + ): + """Test handling of database errors during logging.""" + # Setup mocks + mock_project_model.objects.get.return_value = mock_project + mock_audit_model.objects.create.side_effect = Exception("Database error") + + # Log execution - should not raise exception + result = audit_logger.log_execution(**execution_params) + + # Should return None + assert result is None + + # ========== Convenience Method Tests ========== + + @patch('lookup.services.audit_logger.AuditLogger.log_execution') + def test_log_success(self, mock_log_execution, audit_logger): + """Test log_success convenience method.""" + execution_id = str(uuid.uuid4()) + project_id = uuid.uuid4() + + audit_logger.log_success( + execution_id=execution_id, + project_id=project_id, + input_data={'test': 'data'}, + confidence_score=0.85 + ) + + mock_log_execution.assert_called_once_with( + execution_id=execution_id, + lookup_project_id=project_id, + status='success', + input_data={'test': 'data'}, + confidence_score=0.85 + ) + + @patch('lookup.services.audit_logger.AuditLogger.log_execution') + def test_log_failure(self, mock_log_execution, audit_logger): + """Test log_failure convenience method.""" + execution_id = str(uuid.uuid4()) + project_id = uuid.uuid4() + error_msg = "LLM timeout" + + audit_logger.log_failure( + execution_id=execution_id, + project_id=project_id, + error=error_msg, + input_data={'test': 'data'} + ) + + mock_log_execution.assert_called_once_with( + execution_id=execution_id, + lookup_project_id=project_id, + status='failed', + error_message=error_msg, + input_data={'test': 'data'} + ) + + @patch('lookup.services.audit_logger.AuditLogger.log_execution') + def test_log_partial(self, mock_log_execution, audit_logger): + """Test log_partial convenience method.""" + execution_id = str(uuid.uuid4()) + project_id = uuid.uuid4() + + audit_logger.log_partial( + execution_id=execution_id, + project_id=project_id, + confidence_score=0.35, + error_message='Low confidence' + ) + + mock_log_execution.assert_called_once_with( + execution_id=execution_id, + lookup_project_id=project_id, + status='partial', + confidence_score=0.35, + error_message='Low confidence' + ) + + # ========== Data Validation Tests ========== + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_confidence_score_conversion( + self, mock_project_model, mock_audit_model, + audit_logger, execution_params, mock_project + ): + """Test that confidence score is properly converted to Decimal.""" + mock_project_model.objects.get.return_value = mock_project + mock_audit_model.objects.create.return_value = MagicMock() + + # Test with float confidence + execution_params['confidence_score'] = 0.456789 + audit_logger.log_execution(**execution_params) + + create_call = mock_audit_model.objects.create.call_args + assert create_call.kwargs['confidence_score'] == Decimal('0.456789') + + # Test with None confidence + execution_params['confidence_score'] = None + audit_logger.log_execution(**execution_params) + + create_call = mock_audit_model.objects.create.call_args + assert create_call.kwargs['confidence_score'] is None + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_optional_fields( + self, mock_project_model, mock_audit_model, + audit_logger, mock_project + ): + """Test logging with minimal required fields.""" + mock_project_model.objects.get.return_value = mock_project + mock_audit_model.objects.create.return_value = MagicMock() + + # Minimal parameters + minimal_params = { + 'execution_id': str(uuid.uuid4()), + 'lookup_project_id': mock_project.id, + 'prompt_studio_project_id': None, + 'input_data': {}, + 'reference_data_version': 1, + 'llm_provider': 'openai', + 'llm_model': 'gpt-4', + 'llm_prompt': 'test prompt', + 'llm_response': None, + 'enriched_output': None, + 'status': 'failed' + } + + result = audit_logger.log_execution(**minimal_params) + + # Should still create audit record + mock_audit_model.objects.create.assert_called_once() + + # ========== History Retrieval Tests ========== + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_execution_history(self, mock_audit_model, audit_logger): + """Test retrieving execution history.""" + execution_id = str(uuid.uuid4()) + + # Create mock audit records + mock_audits = [] + for i in range(3): + audit = MagicMock() + audit.lookup_project.name = f"Look-Up {i+1}" + audit.status = 'success' if i < 2 else 'failed' + mock_audits.append(audit) + + # Setup mock query + mock_queryset = MagicMock() + mock_queryset.select_related.return_value = mock_queryset + mock_queryset.order_by.return_value = mock_queryset + mock_queryset.__getitem__.return_value = mock_audits + mock_audit_model.objects.filter.return_value = mock_queryset + + # Get history + result = audit_logger.get_execution_history(execution_id, limit=10) + + # Verify query + mock_audit_model.objects.filter.assert_called_once_with( + execution_id=execution_id + ) + mock_queryset.select_related.assert_called_once_with('lookup_project') + mock_queryset.order_by.assert_called_once_with('executed_at') + + # Check result + assert len(result) == 3 + assert result[0].status == 'success' + assert result[2].status == 'failed' + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_execution_history_error_handling( + self, mock_audit_model, audit_logger + ): + """Test error handling in get_execution_history.""" + mock_audit_model.objects.filter.side_effect = Exception("Database error") + + result = audit_logger.get_execution_history('test-id') + + # Should return empty list on error + assert result == [] + + # ========== Statistics Tests ========== + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_project_stats(self, mock_audit_model, audit_logger): + """Test getting project statistics.""" + project_id = uuid.uuid4() + + # Create mock audit records + mock_audits = [] + + # 3 successful executions + for i in range(3): + audit = MagicMock() + audit.status = 'success' + audit.execution_time_ms = 1000 + i * 100 + audit.llm_response_cached = (i == 0) # First one cached + audit.confidence_score = Decimal(f'0.{80 + i}') + mock_audits.append(audit) + + # 1 failed execution + audit = MagicMock() + audit.status = 'failed' + audit.execution_time_ms = 500 + audit.llm_response_cached = False + audit.confidence_score = None + mock_audits.append(audit) + + # 1 partial execution + audit = MagicMock() + audit.status = 'partial' + audit.execution_time_ms = 800 + audit.llm_response_cached = False + audit.confidence_score = Decimal('0.40') + mock_audits.append(audit) + + # Setup mock query + mock_queryset = MagicMock() + mock_queryset.order_by.return_value = mock_queryset + mock_queryset.__getitem__.return_value = mock_audits + mock_audit_model.objects.filter.return_value = mock_queryset + + # Get stats + stats = audit_logger.get_project_stats(project_id, limit=100) + + # Verify stats + assert stats['total_executions'] == 5 + assert stats['successful'] == 3 + assert stats['failed'] == 1 + assert stats['partial'] == 1 + assert stats['success_rate'] == 0.6 # 3/5 + assert stats['cache_hit_rate'] == 0.2 # 1/5 + assert stats['avg_execution_time_ms'] == 880 # (1000+1100+1200+500+800)/5 + # avg_confidence = (0.80 + 0.81 + 0.82 + 0.40) / 4 = 0.7075 + assert abs(stats['avg_confidence'] - 0.7075) < 0.001 + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_project_stats_empty(self, mock_audit_model, audit_logger): + """Test getting stats for project with no executions.""" + mock_queryset = MagicMock() + mock_queryset.order_by.return_value = mock_queryset + mock_queryset.__getitem__.return_value = [] + mock_audit_model.objects.filter.return_value = mock_queryset + + stats = audit_logger.get_project_stats(uuid.uuid4()) + + assert stats['total_executions'] == 0 + assert stats['success_rate'] == 0.0 + assert stats['avg_execution_time_ms'] == 0 + assert stats['cache_hit_rate'] == 0.0 + assert stats['avg_confidence'] == 0.0 + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_project_stats_error_handling( + self, mock_audit_model, audit_logger + ): + """Test error handling in get_project_stats.""" + mock_audit_model.objects.filter.side_effect = Exception("Database error") + + stats = audit_logger.get_project_stats(uuid.uuid4()) + + # Should return zero stats on error + assert stats['total_executions'] == 0 + assert stats['success_rate'] == 0.0 + + # ========== Integration Tests ========== + + @patch('lookup.services.audit_logger.logger') + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_logging_messages( + self, mock_project_model, mock_audit_model, mock_logger, + audit_logger, execution_params, mock_project + ): + """Test that appropriate log messages are generated.""" + mock_project_model.objects.get.return_value = mock_project + mock_audit_instance = MagicMock() + mock_audit_instance.id = uuid.uuid4() + mock_audit_model.objects.create.return_value = mock_audit_instance + + audit_logger.log_execution(**execution_params) + + # Should log debug message on success + mock_logger.debug.assert_called() + debug_message = mock_logger.debug.call_args[0][0] + assert 'Logged execution audit' in debug_message + assert mock_project.name in debug_message + + @patch('lookup.services.audit_logger.logger') + @patch('lookup.services.audit_logger.LookupProject') + def test_error_logging( + self, mock_project_model, mock_logger, + audit_logger, execution_params + ): + """Test that errors are properly logged.""" + mock_project_model.objects.get.side_effect = Exception("Database connection lost") + + audit_logger.log_execution(**execution_params) + + # Should log exception + mock_logger.exception.assert_called() + error_message = mock_logger.exception.call_args[0][0] + assert 'Failed to log execution audit' in error_message + + def test_real_world_scenario(self, audit_logger): + """Test realistic usage scenario with mock objects.""" + # This would normally require Django test database + # For now, just verify the interface works correctly + + execution_id = str(uuid.uuid4()) + project_id = uuid.uuid4() + + # Log various execution types + with patch('lookup.services.audit_logger.AuditLogger.log_execution') as mock_log: + mock_log.return_value = MagicMock() + + # Success + audit_logger.log_success( + execution_id=execution_id, + project_id=project_id, + input_data={'vendor': 'Slack'}, + enriched_output={'canonical': 'Slack'}, + confidence_score=0.95 + ) + + # Failure + audit_logger.log_failure( + execution_id=execution_id, + project_id=project_id, + error='Timeout', + input_data={'vendor': 'Unknown'} + ) + + # Partial + audit_logger.log_partial( + execution_id=execution_id, + project_id=project_id, + confidence_score=0.30 + ) + + # Verify all three logs were attempted + assert mock_log.call_count == 3 diff --git a/backend/lookup/tests/test_services/test_enrichment_merger.py b/backend/lookup/tests/test_services/test_enrichment_merger.py new file mode 100644 index 0000000000..bed77e2bf7 --- /dev/null +++ b/backend/lookup/tests/test_services/test_enrichment_merger.py @@ -0,0 +1,547 @@ +""" +Tests for Enrichment Merger implementation. + +This module tests the EnrichmentMerger class including merging logic, +conflict resolution, and metadata tracking. +""" + +import uuid + +import pytest + +from lookup.services.enrichment_merger import EnrichmentMerger + + +class TestEnrichmentMerger: + """Test cases for EnrichmentMerger class.""" + + @pytest.fixture + def merger(self): + """Create an EnrichmentMerger instance.""" + return EnrichmentMerger() + + @pytest.fixture + def sample_enrichments(self): + """Create sample enrichments for testing.""" + return [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Vendor Matcher', + 'data': { + 'canonical_vendor': 'Slack', + 'vendor_category': 'SaaS' + }, + 'confidence': 0.95, + 'execution_time_ms': 1234, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Product Classifier', + 'data': { + 'product_type': 'Software', + 'license_model': 'Subscription' + }, + 'confidence': 0.88, + 'execution_time_ms': 567, + 'cached': True + } + ] + + # ========== No Conflicts Tests ========== + + def test_merge_no_conflicts(self, merger, sample_enrichments): + """Test merging enrichments with no overlapping fields.""" + result = merger.merge(sample_enrichments) + + # Check merged data has all fields + assert 'canonical_vendor' in result['data'] + assert 'vendor_category' in result['data'] + assert 'product_type' in result['data'] + assert 'license_model' in result['data'] + + # Check values are correct + assert result['data']['canonical_vendor'] == 'Slack' + assert result['data']['vendor_category'] == 'SaaS' + assert result['data']['product_type'] == 'Software' + assert result['data']['license_model'] == 'Subscription' + + # Check no conflicts were resolved + assert result['conflicts_resolved'] == 0 + + # Check enrichment details + assert len(result['enrichment_details']) == 2 + assert result['enrichment_details'][0]['lookup_project_name'] == 'Vendor Matcher' + assert result['enrichment_details'][0]['fields_added'] == ['canonical_vendor', 'vendor_category'] + assert result['enrichment_details'][1]['lookup_project_name'] == 'Product Classifier' + assert result['enrichment_details'][1]['fields_added'] == ['product_type', 'license_model'] + + def test_merge_empty_enrichments(self, merger): + """Test merging empty list of enrichments.""" + result = merger.merge([]) + + assert result['data'] == {} + assert result['conflicts_resolved'] == 0 + assert result['enrichment_details'] == [] + + def test_merge_single_enrichment(self, merger): + """Test merging with only one enrichment.""" + enrichment = { + 'project_id': uuid.uuid4(), + 'project_name': 'Solo Lookup', + 'data': {'field1': 'value1', 'field2': 'value2'}, + 'confidence': 0.9, + 'execution_time_ms': 100, + 'cached': False + } + + result = merger.merge([enrichment]) + + assert result['data'] == {'field1': 'value1', 'field2': 'value2'} + assert result['conflicts_resolved'] == 0 + assert len(result['enrichment_details']) == 1 + assert result['enrichment_details'][0]['fields_added'] == ['field1', 'field2'] + + # ========== Confidence-Based Resolution Tests ========== + + def test_higher_confidence_wins(self, merger): + """Test that higher confidence value wins in conflict.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Low Confidence', + 'data': {'category': 'Communication'}, + 'confidence': 0.80, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'High Confidence', + 'data': {'category': 'Collaboration'}, + 'confidence': 0.95, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # Higher confidence should win + assert result['data']['category'] == 'Collaboration' + assert result['conflicts_resolved'] == 1 + + # Check which lookup contributed the field + details = result['enrichment_details'] + assert details[0]['fields_added'] == [] # Lost the conflict + assert details[1]['fields_added'] == ['category'] # Won the conflict + + def test_equal_confidence_first_wins(self, merger): + """Test that first-complete wins when confidence is equal.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'First', + 'data': {'status': 'active'}, + 'confidence': 0.90, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Second', + 'data': {'status': 'inactive'}, + 'confidence': 0.90, # Same confidence + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # First should win (first-complete-wins) + assert result['data']['status'] == 'active' + assert result['conflicts_resolved'] == 0 # No resolution needed, kept existing + + def test_confidence_beats_no_confidence(self, merger): + """Test that enrichment with confidence beats one without.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'No Confidence', + 'data': {'vendor': 'Microsoft'}, + 'confidence': None, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Has Confidence', + 'data': {'vendor': 'Slack'}, + 'confidence': 0.75, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # Confidence should win + assert result['data']['vendor'] == 'Slack' + assert result['conflicts_resolved'] == 1 + + def test_no_confidence_first_wins(self, merger): + """Test first-complete wins when neither has confidence.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'First No Conf', + 'data': {'region': 'US'}, + 'confidence': None, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Second No Conf', + 'data': {'region': 'EU'}, + 'confidence': None, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # First should win + assert result['data']['region'] == 'US' + assert result['conflicts_resolved'] == 0 + + # ========== Multiple Conflicts Tests ========== + + def test_multiple_conflicts_same_enrichments(self, merger): + """Test resolving multiple field conflicts between same enrichments.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Enrichment A', + 'data': { + 'field1': 'A1', + 'field2': 'A2', + 'field3': 'A3' + }, + 'confidence': 0.70, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Enrichment B', + 'data': { + 'field1': 'B1', # Conflict + 'field2': 'B2', # Conflict + 'field4': 'B4' # No conflict + }, + 'confidence': 0.85, + 'execution_time_ms': 200, + 'cached': True + } + ] + + result = merger.merge(enrichments) + + # Higher confidence (B) should win conflicts + assert result['data']['field1'] == 'B1' + assert result['data']['field2'] == 'B2' + assert result['data']['field3'] == 'A3' # Only from A + assert result['data']['field4'] == 'B4' # Only from B + + assert result['conflicts_resolved'] == 2 + + def test_three_way_conflicts(self, merger): + """Test resolving conflicts among three enrichments.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'First', + 'data': {'category': 'Cat1', 'type': 'Type1'}, + 'confidence': 0.60, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Second', + 'data': {'category': 'Cat2', 'vendor': 'Vendor2'}, + 'confidence': 0.80, + 'execution_time_ms': 200, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Third', + 'data': {'category': 'Cat3', 'type': 'Type3'}, + 'confidence': 0.75, + 'execution_time_ms': 300, + 'cached': True + } + ] + + result = merger.merge(enrichments) + + # Second should win category (0.80 confidence) + assert result['data']['category'] == 'Cat2' + # Third should win type (0.75 > 0.60) + assert result['data']['type'] == 'Type3' + # Vendor only from Second + assert result['data']['vendor'] == 'Vendor2' + + assert result['conflicts_resolved'] == 2 # category and type conflicts + + # ========== Metadata Tracking Tests ========== + + def test_enrichment_details_tracking(self, merger): + """Test that enrichment details are correctly tracked.""" + project_id1 = uuid.uuid4() + project_id2 = uuid.uuid4() + + enrichments = [ + { + 'project_id': project_id1, + 'project_name': 'Lookup 1', + 'data': {'field1': 'value1'}, + 'confidence': 0.9, + 'execution_time_ms': 1500, + 'cached': True + }, + { + 'project_id': project_id2, + 'project_name': 'Lookup 2', + 'data': {'field2': 'value2'}, + 'confidence': 0.85, + 'execution_time_ms': 800, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + details = result['enrichment_details'] + assert len(details) == 2 + + # Check first enrichment details + assert details[0]['lookup_project_id'] == str(project_id1) + assert details[0]['lookup_project_name'] == 'Lookup 1' + assert details[0]['confidence'] == 0.9 + assert details[0]['cached'] is True + assert details[0]['execution_time_ms'] == 1500 + assert details[0]['fields_added'] == ['field1'] + + # Check second enrichment details + assert details[1]['lookup_project_id'] == str(project_id2) + assert details[1]['lookup_project_name'] == 'Lookup 2' + assert details[1]['confidence'] == 0.85 + assert details[1]['cached'] is False + assert details[1]['execution_time_ms'] == 800 + assert details[1]['fields_added'] == ['field2'] + + def test_fields_added_with_conflicts(self, merger): + """Test fields_added tracking when conflicts are resolved.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Winner', + 'data': {'shared': 'win_value', 'unique1': 'value1'}, + 'confidence': 0.95, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Loser', + 'data': {'shared': 'lose_value', 'unique2': 'value2'}, + 'confidence': 0.50, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + details = result['enrichment_details'] + # Winner should have both its fields + assert 'shared' in details[0]['fields_added'] + assert 'unique1' in details[0]['fields_added'] + + # Loser should only have its unique field + assert 'shared' not in details[1]['fields_added'] + assert 'unique2' in details[1]['fields_added'] + + # ========== Edge Cases Tests ========== + + def test_missing_optional_fields(self, merger): + """Test handling of enrichments with missing optional fields.""" + enrichments = [ + { + 'project_id': None, # Missing ID + 'project_name': 'No ID Lookup', + 'data': {'field1': 'value1'}, + # Missing confidence + # Missing execution_time_ms + # Missing cached + }, + { + # Minimal valid enrichment + 'data': {'field2': 'value2'} + } + ] + + result = merger.merge(enrichments) + + assert result['data']['field1'] == 'value1' + assert result['data']['field2'] == 'value2' + assert result['conflicts_resolved'] == 0 + + # Check defaults are handled + details = result['enrichment_details'] + assert details[0]['lookup_project_id'] is None + assert details[0]['confidence'] is None + assert details[0]['execution_time_ms'] == 0 # Default + assert details[0]['cached'] is False # Default + + assert details[1]['lookup_project_name'] == 'Unknown' # Default + + def test_empty_data_fields(self, merger): + """Test handling enrichments with empty data.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Empty Data', + 'data': {}, # Empty + 'confidence': 0.9, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Has Data', + 'data': {'field': 'value'}, + 'confidence': 0.8, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + assert result['data'] == {'field': 'value'} + assert result['conflicts_resolved'] == 0 + assert result['enrichment_details'][0]['fields_added'] == [] + assert result['enrichment_details'][1]['fields_added'] == ['field'] + + def test_complex_value_types(self, merger): + """Test merging with complex value types (lists, dicts).""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Complex Types', + 'data': { + 'tags': ['tag1', 'tag2'], + 'metadata': {'key': 'value'}, + 'count': 42 + }, + 'confidence': 0.9, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'More Complex', + 'data': { + 'tags': ['tag3', 'tag4'], # Conflict - different list + 'settings': {'option': True} + }, + 'confidence': 0.95, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # Higher confidence wins for tags + assert result['data']['tags'] == ['tag3', 'tag4'] + assert result['data']['metadata'] == {'key': 'value'} + assert result['data']['count'] == 42 + assert result['data']['settings'] == {'option': True} + assert result['conflicts_resolved'] == 1 + + # ========== Integration Tests ========== + + def test_real_world_scenario(self, merger): + """Test a realistic scenario with multiple lookups.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Vendor Standardization', + 'data': { + 'canonical_vendor': 'Slack Technologies', + 'vendor_id': 'SLACK-001', + 'vendor_category': 'Communication' + }, + 'confidence': 0.92, + 'execution_time_ms': 1200, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Product Mapping', + 'data': { + 'product_name': 'Slack Workspace', + 'product_sku': 'SLK-WS-ENT', + 'vendor_category': 'Collaboration', # Conflict + 'license_type': 'Per User' + }, + 'confidence': 0.88, + 'execution_time_ms': 850, + 'cached': True + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Cost Center Assignment', + 'data': { + 'cost_center': 'CC-IT-001', + 'department': 'Information Technology', + 'budget_category': 'Software' + }, + 'confidence': None, # No confidence score + 'execution_time_ms': 450, + 'cached': True + } + ] + + result = merger.merge(enrichments) + + # Check all fields are present + expected_fields = [ + 'canonical_vendor', 'vendor_id', 'vendor_category', + 'product_name', 'product_sku', 'license_type', + 'cost_center', 'department', 'budget_category' + ] + for field in expected_fields: + assert field in result['data'] + + # vendor_category conflict: 0.92 > 0.88, first wins + assert result['data']['vendor_category'] == 'Communication' + + # Should have 1 conflict resolved + assert result['conflicts_resolved'] == 1 + + # Check enrichment details + assert len(result['enrichment_details']) == 3 + assert result['enrichment_details'][0]['fields_added'] == [ + 'canonical_vendor', 'vendor_id', 'vendor_category' + ] + # Product mapping lost vendor_category conflict + assert 'vendor_category' not in result['enrichment_details'][1]['fields_added'] + assert 'product_name' in result['enrichment_details'][1]['fields_added'] diff --git a/backend/lookup/tests/test_services/test_llm_cache.py b/backend/lookup/tests/test_services/test_llm_cache.py new file mode 100644 index 0000000000..422a72dad4 --- /dev/null +++ b/backend/lookup/tests/test_services/test_llm_cache.py @@ -0,0 +1,322 @@ +""" +Tests for LLM Response Cache implementation. + +This module tests the LLMResponseCache class including basic operations, +TTL expiration, cache key generation, and cache management functionality. +""" + +import time +from unittest.mock import patch + +import pytest + +from lookup.services.llm_cache import LLMResponseCache + + +class TestLLMResponseCache: + """Test cases for LLMResponseCache class.""" + + @pytest.fixture + def cache(self): + """Create a fresh cache instance for each test.""" + return LLMResponseCache(ttl_hours=1) + + @pytest.fixture + def short_ttl_cache(self): + """Create cache with very short TTL for expiration testing.""" + # 0.001 hours = 3.6 seconds + return LLMResponseCache(ttl_hours=0.001) + + # ========== Basic Operations Tests ========== + + def test_set_stores_value(self, cache): + """Test that set() correctly stores a value.""" + cache.set("test_key", "test_response") + assert "test_key" in cache.cache + stored_value, _ = cache.cache["test_key"] + assert stored_value == "test_response" + + def test_get_retrieves_stored_value(self, cache): + """Test that get() retrieves a stored value.""" + cache.set("test_key", "test_response") + result = cache.get("test_key") + assert result == "test_response" + + def test_get_returns_none_for_missing_key(self, cache): + """Test that get() returns None for non-existent key.""" + result = cache.get("nonexistent_key") + assert result is None + + def test_set_overwrites_existing_value(self, cache): + """Test that set() overwrites existing values.""" + cache.set("test_key", "first_response") + cache.set("test_key", "second_response") + result = cache.get("test_key") + assert result == "second_response" + + # ========== TTL Expiration Tests ========== + + def test_get_returns_value_before_expiration(self, cache): + """Test that get() returns value before TTL expiration.""" + cache.set("test_key", "test_response") + # Should still be valid after immediate retrieval + result = cache.get("test_key") + assert result == "test_response" + + def test_get_returns_none_after_expiration(self, short_ttl_cache): + """Test that get() returns None after TTL expiration.""" + short_ttl_cache.set("test_key", "test_response") + # Wait for expiration (TTL is 3.6 seconds) + time.sleep(4) + result = short_ttl_cache.get("test_key") + assert result is None + + def test_expired_entry_removed_on_access(self, short_ttl_cache): + """Test that expired entries are lazily removed on access.""" + short_ttl_cache.set("test_key", "test_response") + assert "test_key" in short_ttl_cache.cache + + # Wait for expiration + time.sleep(4) + result = short_ttl_cache.get("test_key") + + assert result is None + assert "test_key" not in short_ttl_cache.cache + + @patch('time.time') + def test_ttl_calculation(self, mock_time, cache): + """Test correct TTL calculation using mocked time.""" + # Set initial time + mock_time.return_value = 1000.0 + cache.set("test_key", "test_response") + + # Verify expiry is set correctly (1 hour = 3600 seconds) + _, expiry = cache.cache["test_key"] + assert expiry == 4600.0 # 1000 + 3600 + + # Move time forward but not past expiry + mock_time.return_value = 4599.0 + assert cache.get("test_key") == "test_response" + + # Move time past expiry + mock_time.return_value = 4601.0 + assert cache.get("test_key") is None + + # ========== Cache Key Generation Tests ========== + + def test_cache_key_generation_deterministic(self, cache): + """Test that same inputs generate same cache key.""" + prompt = "Match vendor {{input_data.vendor}}" + ref_data = "Slack\nGoogle\nMicrosoft" + + key1 = cache.generate_cache_key(prompt, ref_data) + key2 = cache.generate_cache_key(prompt, ref_data) + + assert key1 == key2 + + def test_different_prompt_different_key(self, cache): + """Test that different prompts generate different keys.""" + ref_data = "Slack\nGoogle\nMicrosoft" + + key1 = cache.generate_cache_key("Prompt 1", ref_data) + key2 = cache.generate_cache_key("Prompt 2", ref_data) + + assert key1 != key2 + + def test_different_ref_data_different_key(self, cache): + """Test that different reference data generates different keys.""" + prompt = "Match vendor {{input_data.vendor}}" + + key1 = cache.generate_cache_key(prompt, "Slack\nGoogle") + key2 = cache.generate_cache_key(prompt, "Slack\nMicrosoft") + + assert key1 != key2 + + def test_cache_key_is_valid_sha256(self, cache): + """Test that cache key is valid SHA256 hex (64 characters).""" + key = cache.generate_cache_key("test prompt", "test ref data") + + assert len(key) == 64 + assert all(c in '0123456789abcdef' for c in key) + + def test_cache_key_handles_unicode(self, cache): + """Test cache key generation with Unicode characters.""" + prompt = "Match vendor: Müller GmbH" + ref_data = "Zürich AG\n北京公司\n東京株式会社" + + key = cache.generate_cache_key(prompt, ref_data) + assert len(key) == 64 + + # ========== Cache Management Tests ========== + + def test_invalidate_removes_specific_key(self, cache): + """Test that invalidate() removes specific key.""" + cache.set("key1", "response1") + cache.set("key2", "response2") + + result = cache.invalidate("key1") + + assert result is True + assert cache.get("key1") is None + assert cache.get("key2") == "response2" + + def test_invalidate_returns_false_for_missing_key(self, cache): + """Test that invalidate() returns False for non-existent key.""" + result = cache.invalidate("nonexistent_key") + assert result is False + + def test_invalidate_all_clears_cache(self, cache): + """Test that invalidate_all() clears entire cache.""" + cache.set("key1", "response1") + cache.set("key2", "response2") + cache.set("key3", "response3") + + count = cache.invalidate_all() + + assert count == 3 + assert len(cache.cache) == 0 + assert cache.get("key1") is None + assert cache.get("key2") is None + assert cache.get("key3") is None + + def test_invalidate_all_returns_zero_for_empty_cache(self, cache): + """Test that invalidate_all() returns 0 for empty cache.""" + count = cache.invalidate_all() + assert count == 0 + + # ========== Statistics Tests ========== + + def test_get_stats_with_valid_entries(self, cache): + """Test get_stats() with all valid entries.""" + cache.set("key1", "response1") + cache.set("key2", "response2") + + stats = cache.get_stats() + + assert stats['total'] == 2 + assert stats['valid'] == 2 + assert stats['expired'] == 0 + + @patch('time.time') + def test_get_stats_with_mixed_entries(self, mock_time, cache): + """Test get_stats() with mix of valid and expired entries.""" + # Set initial time + mock_time.return_value = 1000.0 + + cache.set("key1", "response1") + cache.set("key2", "response2") + + # Manually expire one entry + cache.cache["key1"] = ("response1", 999.0) # Already expired + + stats = cache.get_stats() + + assert stats['total'] == 2 + assert stats['valid'] == 1 + assert stats['expired'] == 1 + + def test_get_stats_empty_cache(self, cache): + """Test get_stats() with empty cache.""" + stats = cache.get_stats() + + assert stats['total'] == 0 + assert stats['valid'] == 0 + assert stats['expired'] == 0 + + # ========== Cleanup Tests ========== + + @patch('time.time') + def test_cleanup_expired_removes_expired_entries(self, mock_time, cache): + """Test cleanup_expired() removes only expired entries.""" + # Set initial time + mock_time.return_value = 1000.0 + + cache.set("key1", "response1") + cache.set("key2", "response2") + cache.set("key3", "response3") + + # Manually expire two entries + cache.cache["key1"] = ("response1", 999.0) # Expired + cache.cache["key2"] = ("response2", 999.0) # Expired + # key3 remains valid (expiry at 4600.0) + + removed_count = cache.cleanup_expired() + + assert removed_count == 2 + assert len(cache.cache) == 1 + assert "key3" in cache.cache + assert "key1" not in cache.cache + assert "key2" not in cache.cache + + def test_cleanup_expired_empty_cache(self, cache): + """Test cleanup_expired() with empty cache.""" + removed_count = cache.cleanup_expired() + assert removed_count == 0 + + def test_cleanup_expired_no_expired_entries(self, cache): + """Test cleanup_expired() when no entries are expired.""" + cache.set("key1", "response1") + cache.set("key2", "response2") + + removed_count = cache.cleanup_expired() + + assert removed_count == 0 + assert len(cache.cache) == 2 + + # ========== Integration Tests ========== + + def test_end_to_end_caching_workflow(self, cache): + """Test complete caching workflow.""" + # Generate cache key + prompt = "Match vendor Slack India" + ref_data = "Slack\nGoogle\nMicrosoft" + cache_key = cache.generate_cache_key(prompt, ref_data) + + # Initially no cached value + assert cache.get(cache_key) is None + + # Store response + llm_response = '{"canonical_vendor": "Slack", "confidence": 0.92}' + cache.set(cache_key, llm_response) + + # Retrieve cached value + cached = cache.get(cache_key) + assert cached == llm_response + + # Check stats + stats = cache.get_stats() + assert stats['valid'] == 1 + + # Invalidate + removed = cache.invalidate(cache_key) + assert removed is True + assert cache.get(cache_key) is None + + def test_concurrent_operations(self, cache): + """Test cache handles multiple operations correctly.""" + # Add multiple entries + for i in range(10): + key = f"key_{i}" + response = f"response_{i}" + cache.set(key, response) + + # Verify all entries exist + for i in range(10): + key = f"key_{i}" + assert cache.get(key) == f"response_{i}" + + # Invalidate some entries + for i in range(0, 10, 2): # Even indices + cache.invalidate(f"key_{i}") + + # Verify correct entries remain + for i in range(10): + key = f"key_{i}" + if i % 2 == 0: + assert cache.get(key) is None + else: + assert cache.get(key) == f"response_{i}" + + # Clear all + count = cache.invalidate_all() + assert count == 5 # Only odd indices remained diff --git a/backend/lookup/tests/test_services/test_lookup_executor.py b/backend/lookup/tests/test_services/test_lookup_executor.py new file mode 100644 index 0000000000..a8b663caba --- /dev/null +++ b/backend/lookup/tests/test_services/test_lookup_executor.py @@ -0,0 +1,418 @@ +""" +Tests for Look-Up Executor implementation. + +This module tests the LookUpExecutor class including execution flow, +caching, error handling, and response parsing. +""" + +import json +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from lookup.exceptions import ( + ExtractionNotCompleteError, + ParseError, +) +from lookup.services.lookup_executor import LookUpExecutor + + +class TestLookUpExecutor: + """Test cases for LookUpExecutor class.""" + + @pytest.fixture + def mock_variable_resolver(self): + """Create a mock VariableResolver class.""" + mock_class = Mock() + mock_instance = Mock() + mock_instance.resolve.return_value = "Resolved prompt text" + mock_class.return_value = mock_instance + return mock_class + + @pytest.fixture + def mock_cache(self): + """Create a mock LLMResponseCache.""" + cache = MagicMock() + cache.generate_cache_key.return_value = "cache_key_123" + cache.get.return_value = None # Default to cache miss + return cache + + @pytest.fixture + def mock_ref_loader(self): + """Create a mock ReferenceDataLoader.""" + loader = MagicMock() + loader.load_latest_for_project.return_value = { + 'version': 1, + 'content': "Reference data content", + 'files': [], + 'total_size': 1000 + } + return loader + + @pytest.fixture + def mock_llm_client(self): + """Create a mock LLM client.""" + client = MagicMock() + client.generate.return_value = '{"canonical_vendor": "Slack", "confidence": 0.92}' + return client + + @pytest.fixture + def mock_project(self): + """Create a mock LookupProject.""" + project = MagicMock() + project.id = uuid.uuid4() + project.name = "Test Look-Up" + project.llm_config = {'temperature': 0.7} + + # Create mock template + template = MagicMock() + template.template_text = "Match {{input_data.vendor}} with {{reference_data}}" + project.template = template + + return project + + @pytest.fixture + def executor(self, mock_variable_resolver, mock_cache, mock_ref_loader, mock_llm_client): + """Create a LookUpExecutor instance with mocked dependencies.""" + return LookUpExecutor( + variable_resolver=mock_variable_resolver, + cache_manager=mock_cache, + reference_loader=mock_ref_loader, + llm_client=mock_llm_client + ) + + @pytest.fixture + def sample_input_data(self): + """Create sample input data.""" + return { + 'vendor': 'Slack India Pvt Ltd', + 'invoice_amount': 5000 + } + + # ========== Successful Execution Tests ========== + + def test_successful_execution(self, executor, mock_project, sample_input_data, + mock_variable_resolver, mock_cache, mock_llm_client): + """Test complete successful execution flow.""" + result = executor.execute(mock_project, sample_input_data) + + # Check success + assert result['status'] == 'success' + assert result['project_id'] == mock_project.id + assert result['project_name'] == 'Test Look-Up' + + # Check enrichment data + assert result['data'] == {'canonical_vendor': 'Slack'} + assert result['confidence'] == 0.92 + assert result['cached'] is False + assert result['execution_time_ms'] > 0 + + # Verify variable resolver was called + mock_variable_resolver.assert_called_once_with( + sample_input_data, "Reference data content" + ) + mock_variable_resolver.return_value.resolve.assert_called_once_with( + "Match {{input_data.vendor}} with {{reference_data}}" + ) + + # Verify cache was checked and set + mock_cache.get.assert_called_once_with("cache_key_123") + mock_cache.set.assert_called_once_with( + "cache_key_123", + '{"canonical_vendor": "Slack", "confidence": 0.92}' + ) + + # Verify LLM was called + mock_llm_client.generate.assert_called_once_with( + "Resolved prompt text", + {'temperature': 0.7} + ) + + def test_confidence_extraction(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test confidence score extraction from response.""" + mock_llm_client.generate.return_value = '{"field": "value", "confidence": 0.85}' + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['confidence'] == 0.85 + assert result['data'] == {'field': 'value'} + + def test_no_confidence_in_response(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling response without confidence score.""" + mock_llm_client.generate.return_value = '{"field": "value"}' + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['confidence'] is None + assert result['data'] == {'field': 'value'} + + # ========== Cache Tests ========== + + def test_cache_hit(self, executor, mock_project, sample_input_data, + mock_cache, mock_llm_client): + """Test execution with cache hit.""" + # Set up cache hit + mock_cache.get.return_value = '{"cached_field": "cached_value", "confidence": 0.88}' + + result = executor.execute(mock_project, sample_input_data) + + # Check result + assert result['status'] == 'success' + assert result['data'] == {'cached_field': 'cached_value'} + assert result['confidence'] == 0.88 + assert result['cached'] is True + assert result['execution_time_ms'] == 0 + + # Verify LLM was NOT called + mock_llm_client.generate.assert_not_called() + + # Verify cache.set was NOT called + mock_cache.set.assert_not_called() + + def test_cache_miss(self, executor, mock_project, sample_input_data, + mock_cache, mock_llm_client): + """Test execution with cache miss.""" + # Cache miss (default setup) + mock_cache.get.return_value = None + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['cached'] is False + + # Verify LLM was called + mock_llm_client.generate.assert_called_once() + + # Verify result was cached + mock_cache.set.assert_called_once() + + # ========== Error Handling Tests ========== + + def test_reference_data_not_ready(self, executor, mock_project, sample_input_data, + mock_ref_loader): + """Test handling of incomplete extraction.""" + mock_ref_loader.load_latest_for_project.side_effect = ExtractionNotCompleteError( + ['file1.csv', 'file2.txt'] + ) + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Reference data extraction not complete' in result['error'] + assert 'file1.csv' in result['error'] + assert 'file2.txt' in result['error'] + + def test_missing_template(self, executor, mock_project, sample_input_data): + """Test handling of missing template.""" + mock_project.template = None + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Missing prompt template' in result['error'] + + def test_llm_timeout(self, executor, mock_project, sample_input_data, mock_llm_client): + """Test handling of LLM timeout.""" + mock_llm_client.generate.side_effect = TimeoutError("Request timed out after 30s") + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'LLM request timed out' in result['error'] + assert '30s' in result['error'] + + def test_llm_generic_error(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling of generic LLM errors.""" + mock_llm_client.generate.side_effect = Exception("API key invalid") + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'LLM request failed' in result['error'] + assert 'API key invalid' in result['error'] + + def test_parse_error_invalid_json(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling of invalid JSON response.""" + mock_llm_client.generate.return_value = "Not valid JSON" + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Failed to parse LLM response' in result['error'] + + def test_parse_error_not_object(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling of non-object JSON response.""" + mock_llm_client.generate.return_value = '["array", "not", "object"]' + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Failed to parse LLM response' in result['error'] + + def test_reference_loader_error(self, executor, mock_project, sample_input_data, + mock_ref_loader): + """Test handling of reference loader errors.""" + mock_ref_loader.load_latest_for_project.side_effect = Exception("Storage unavailable") + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Failed to load reference data' in result['error'] + assert 'Storage unavailable' in result['error'] + + # ========== Variable Resolution Tests ========== + + def test_variable_resolution(self, executor, mock_project, sample_input_data, + mock_variable_resolver): + """Test that variables are correctly resolved.""" + result = executor.execute(mock_project, sample_input_data) + + # Verify resolver was instantiated with correct data + mock_variable_resolver.assert_called_once_with( + sample_input_data, + "Reference data content" + ) + + # Verify resolve was called with template + mock_variable_resolver.return_value.resolve.assert_called_once_with( + "Match {{input_data.vendor}} with {{reference_data}}" + ) + + # ========== Response Parsing Tests ========== + + def test_parse_llm_response_with_confidence(self, executor): + """Test parsing response with confidence score.""" + response = '{"field1": "value1", "field2": "value2", "confidence": 0.95}' + + data, confidence = executor._parse_llm_response(response) + + assert data == {'field1': 'value1', 'field2': 'value2'} + assert confidence == 0.95 + + def test_parse_llm_response_without_confidence(self, executor): + """Test parsing response without confidence score.""" + response = '{"field1": "value1", "field2": "value2"}' + + data, confidence = executor._parse_llm_response(response) + + assert data == {'field1': 'value1', 'field2': 'value2'} + assert confidence is None + + def test_parse_llm_response_invalid_confidence(self, executor): + """Test handling of invalid confidence values.""" + # Confidence outside range (should be clamped) + response = '{"field": "value", "confidence": 1.5}' + + with patch('lookup.services.lookup_executor.logger') as mock_logger: + data, confidence = executor._parse_llm_response(response) + + assert data == {'field': 'value'} + assert confidence == 1.0 # Clamped to max + mock_logger.warning.assert_called() + + def test_parse_llm_response_non_numeric_confidence(self, executor): + """Test handling of non-numeric confidence.""" + response = '{"field": "value", "confidence": "high"}' + + with patch('lookup.services.lookup_executor.logger') as mock_logger: + data, confidence = executor._parse_llm_response(response) + + assert data == {'field': 'value'} + assert confidence is None + mock_logger.warning.assert_called() + + def test_parse_llm_response_invalid_json(self, executor): + """Test parsing invalid JSON raises ParseError.""" + response = "This is not JSON" + + with pytest.raises(ParseError) as exc_info: + executor._parse_llm_response(response) + + assert "Invalid JSON response" in str(exc_info.value) + + # ========== Integration Tests ========== + + def test_end_to_end_execution(self, executor, mock_project, sample_input_data): + """Test complete execution flow with realistic data.""" + result = executor.execute(mock_project, sample_input_data) + + # Basic assertions + assert result['status'] == 'success' + assert 'data' in result + assert 'confidence' in result + assert 'cached' in result + assert 'execution_time_ms' in result + + def test_execution_with_empty_input(self, executor, mock_project): + """Test execution with empty input data.""" + result = executor.execute(mock_project, {}) + + # Should still execute successfully + assert result['status'] == 'success' + + def test_execution_time_tracking(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test that execution time is properly tracked.""" + # Add a small delay to LLM call + def delayed_generate(*args, **kwargs): + import time + time.sleep(0.01) # 10ms delay + return '{"result": "data"}' + + mock_llm_client.generate.side_effect = delayed_generate + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['execution_time_ms'] >= 10 # At least 10ms + + def test_failed_execution_time_tracking(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test that execution time is tracked even on failure.""" + mock_llm_client.generate.side_effect = Exception("Error") + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert result['execution_time_ms'] >= 0 + + @patch('lookup.services.lookup_executor.logger') + def test_unexpected_error_logging(self, mock_logger, executor, mock_project, + sample_input_data): + """Test that unexpected errors are logged.""" + # Create an error that will trigger the catch-all + mock_project.template.template_text = None # Will cause AttributeError + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Unexpected error' in result['error'] + mock_logger.exception.assert_called() + + def test_complex_llm_response(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling of complex nested LLM response.""" + complex_response = json.dumps({ + "vendor": { + "canonical_name": "Slack Technologies", + "id": "SLACK-001" + }, + "categories": ["Communication", "SaaS"], + "confidence": 0.88 + }) + mock_llm_client.generate.return_value = complex_response + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['data']['vendor']['canonical_name'] == 'Slack Technologies' + assert result['data']['categories'] == ['Communication', 'SaaS'] + assert result['confidence'] == 0.88 diff --git a/backend/lookup/tests/test_services/test_lookup_orchestrator.py b/backend/lookup/tests/test_services/test_lookup_orchestrator.py new file mode 100644 index 0000000000..f1031ef4bc --- /dev/null +++ b/backend/lookup/tests/test_services/test_lookup_orchestrator.py @@ -0,0 +1,521 @@ +""" +Tests for Look-Up Orchestrator implementation. + +This module tests the LookUpOrchestrator class including parallel execution, +timeout handling, result merging, and error recovery. +""" + +import time +import uuid +from concurrent.futures import TimeoutError as FutureTimeoutError +from unittest.mock import MagicMock, patch + +import pytest + +from lookup.services.lookup_orchestrator import LookUpOrchestrator + + +class TestLookUpOrchestrator: + """Test cases for LookUpOrchestrator class.""" + + @pytest.fixture + def mock_executor(self): + """Create a mock LookUpExecutor.""" + executor = MagicMock() + # Default to successful execution + executor.execute.return_value = { + 'status': 'success', + 'project_id': uuid.uuid4(), + 'project_name': 'Test Look-Up', + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + return executor + + @pytest.fixture + def mock_merger(self): + """Create a mock EnrichmentMerger.""" + merger = MagicMock() + merger.merge.return_value = { + 'data': {'merged_field': 'merged_value'}, + 'conflicts_resolved': 0, + 'enrichment_details': [] + } + return merger + + @pytest.fixture + def config(self): + """Create test configuration.""" + return { + 'max_concurrent_executions': 5, + 'queue_timeout_seconds': 10, + 'execution_timeout_seconds': 2 + } + + @pytest.fixture + def orchestrator(self, mock_executor, mock_merger, config): + """Create a LookUpOrchestrator with mocked dependencies.""" + return LookUpOrchestrator( + executor=mock_executor, + merger=mock_merger, + config=config + ) + + @pytest.fixture + def sample_input_data(self): + """Create sample input data.""" + return { + 'vendor': 'Slack Technologies', + 'amount': 5000 + } + + @pytest.fixture + def mock_projects(self): + """Create mock Look-Up projects.""" + projects = [] + for i in range(3): + project = MagicMock() + project.id = uuid.uuid4() + project.name = f"Look-Up {i+1}" + projects.append(project) + return projects + + # ========== Basic Execution Tests ========== + + def test_successful_parallel_execution( + self, orchestrator, sample_input_data, mock_projects, + mock_executor, mock_merger + ): + """Test successful parallel execution of multiple Look-Ups.""" + # Setup executor to return different data for each project + def execute_side_effect(project, input_data): + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {f'field_{project.name}': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + mock_executor.execute.side_effect = execute_side_effect + + # Execute + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Verify execution + assert mock_executor.execute.call_count == 3 + assert mock_merger.merge.call_count == 1 + + # Check metadata + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 3 + assert metadata['lookups_successful'] == 3 + assert metadata['lookups_failed'] == 0 + assert 'execution_id' in metadata + assert 'executed_at' in metadata + assert metadata['total_execution_time_ms'] > 0 + + def test_empty_projects_list(self, orchestrator, sample_input_data): + """Test execution with empty projects list.""" + result = orchestrator.execute_lookups(sample_input_data, []) + + assert result['lookup_enrichment'] == {} + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 0 + assert metadata['lookups_successful'] == 0 + assert metadata['lookups_failed'] == 0 + assert metadata['enrichments'] == [] + + def test_single_project_execution( + self, orchestrator, sample_input_data, mock_executor + ): + """Test execution with single Look-Up project.""" + project = MagicMock() + project.id = uuid.uuid4() + project.name = "Single Look-Up" + + result = orchestrator.execute_lookups(sample_input_data, [project]) + + assert mock_executor.execute.call_count == 1 + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 1 + assert metadata['lookups_successful'] == 1 + + # ========== Failure Handling Tests ========== + + def test_partial_failures( + self, orchestrator, sample_input_data, mock_projects, + mock_executor, mock_merger + ): + """Test handling of partial failures.""" + # Setup: First succeeds, second fails, third succeeds + def execute_side_effect(project, input_data): + if project.name == "Look-Up 2": + return { + 'status': 'failed', + 'project_id': project.id, + 'project_name': project.name, + 'error': 'Test error', + 'execution_time_ms': 50, + 'cached': False + } + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + mock_executor.execute.side_effect = execute_side_effect + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Check results + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 3 + assert metadata['lookups_successful'] == 2 + assert metadata['lookups_failed'] == 1 + + # Verify merger was called with only successful enrichments + merge_call_args = mock_merger.merge.call_args[0][0] + assert len(merge_call_args) == 2 # Only successful ones + + def test_all_failures( + self, orchestrator, sample_input_data, mock_projects, + mock_executor, mock_merger + ): + """Test handling when all Look-Ups fail.""" + mock_executor.execute.return_value = { + 'status': 'failed', + 'project_id': uuid.uuid4(), + 'project_name': 'Failed Look-Up', + 'error': 'Test error', + 'execution_time_ms': 50, + 'cached': False + } + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Check results + assert result['lookup_enrichment'] == {} # Empty merged data + metadata = result['_lookup_metadata'] + assert metadata['lookups_successful'] == 0 + assert metadata['lookups_failed'] == 3 + assert metadata['conflicts_resolved'] == 0 + + # ========== Timeout Tests ========== + + @patch('lookup.services.lookup_orchestrator.ThreadPoolExecutor') + def test_individual_execution_timeout( + self, mock_executor_class, orchestrator, sample_input_data, + mock_projects + ): + """Test handling of individual execution timeouts.""" + # Setup mock executor + mock_thread_executor = MagicMock() + mock_executor_class.return_value.__enter__.return_value = mock_thread_executor + + # Create futures that will timeout + future1 = MagicMock() + future1.result.side_effect = FutureTimeoutError() + + future2 = MagicMock() + future2.result.return_value = { + 'status': 'success', + 'project_id': mock_projects[1].id, + 'project_name': mock_projects[1].name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + # Setup as_completed to return futures + with patch('lookup.services.lookup_orchestrator.as_completed') as mock_as_completed: + mock_as_completed.return_value = [future1, future2] + + # Setup submit to return futures + mock_thread_executor.submit.side_effect = [future1, future2, MagicMock()] + + result = orchestrator.execute_lookups(sample_input_data, mock_projects[:2]) + + metadata = result['_lookup_metadata'] + # One timeout (failed), one success + assert metadata['lookups_failed'] >= 1 + assert metadata['lookups_successful'] >= 1 + + @patch('lookup.services.lookup_orchestrator.as_completed') + def test_queue_timeout( + self, mock_as_completed, orchestrator, sample_input_data, + mock_projects + ): + """Test handling of overall queue timeout.""" + # Make as_completed raise TimeoutError + mock_as_completed.side_effect = FutureTimeoutError() + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 3 + assert metadata['lookups_successful'] == 0 + assert metadata['lookups_failed'] == 3 # All marked as failed due to queue timeout + + # ========== Concurrency Tests ========== + + def test_max_concurrent_limit( + self, mock_executor, mock_merger, sample_input_data + ): + """Test that max concurrent executions limit is respected.""" + # Create orchestrator with low concurrency limit + config = {'max_concurrent_executions': 2} + orchestrator = LookUpOrchestrator(mock_executor, mock_merger, config) + + # Create many projects + projects = [] + for i in range(10): + project = MagicMock() + project.id = uuid.uuid4() + project.name = f"Look-Up {i+1}" + projects.append(project) + + # Add small delay to executor to simulate work + def slow_execute(project, input_data): + time.sleep(0.01) # Small delay + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 10 + } + + mock_executor.execute.side_effect = slow_execute + + # Execute + result = orchestrator.execute_lookups(sample_input_data, projects) + + # Should complete successfully despite concurrency limit + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 10 + assert metadata['lookups_successful'] == 10 + + # ========== Result Merging Tests ========== + + def test_successful_merge( + self, orchestrator, sample_input_data, mock_projects, + mock_executor, mock_merger + ): + """Test that successful enrichments are properly merged.""" + # Setup merger to return specific merged data + mock_merger.merge.return_value = { + 'data': { + 'vendor': 'Slack', + 'category': 'SaaS', + 'type': 'Communication' + }, + 'conflicts_resolved': 2, + 'enrichment_details': [ + {'lookup_project_name': 'Look-Up 1', 'fields_added': ['vendor']}, + {'lookup_project_name': 'Look-Up 2', 'fields_added': ['category']}, + {'lookup_project_name': 'Look-Up 3', 'fields_added': ['type']} + ] + } + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Check merged enrichment + assert result['lookup_enrichment'] == { + 'vendor': 'Slack', + 'category': 'SaaS', + 'type': 'Communication' + } + + # Check metadata + metadata = result['_lookup_metadata'] + assert metadata['conflicts_resolved'] == 2 + + # ========== Error Recovery Tests ========== + + def test_executor_exception_handling( + self, orchestrator, sample_input_data, mock_projects, + mock_executor + ): + """Test handling of unexpected exceptions from executor.""" + # Make executor raise exception for one project + def execute_with_exception(project, input_data): + if project.name == "Look-Up 2": + raise ValueError("Unexpected error in executor") + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + mock_executor.execute.side_effect = execute_with_exception + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Should handle exception gracefully + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 3 + assert metadata['lookups_successful'] == 2 + assert metadata['lookups_failed'] == 1 + + # Check that error is captured + failed_enrichment = next( + e for e in metadata['enrichments'] + if e['status'] == 'failed' and 'Unexpected error' in e['error'] + ) + assert failed_enrichment is not None + + # ========== Metadata Tests ========== + + def test_execution_metadata( + self, orchestrator, sample_input_data, mock_projects + ): + """Test that execution metadata is properly populated.""" + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + metadata = result['_lookup_metadata'] + + # Check all required metadata fields + assert 'execution_id' in metadata + assert isinstance(metadata['execution_id'], str) + assert len(metadata['execution_id']) == 36 # UUID format + + assert 'executed_at' in metadata + assert 'T' in metadata['executed_at'] # ISO8601 format + + assert 'total_execution_time_ms' in metadata + assert metadata['total_execution_time_ms'] >= 0 + + assert 'enrichments' in metadata + assert len(metadata['enrichments']) >= metadata['lookups_successful'] + + def test_enrichments_list_includes_all_results( + self, orchestrator, sample_input_data, mock_projects, + mock_executor + ): + """Test that enrichments list includes both successful and failed results.""" + # Setup mixed results + def execute_side_effect(project, input_data): + if project.name == "Look-Up 2": + return { + 'status': 'failed', + 'project_id': project.id, + 'project_name': project.name, + 'error': 'Test failure', + 'execution_time_ms': 50, + 'cached': False + } + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + mock_executor.execute.side_effect = execute_side_effect + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + metadata = result['_lookup_metadata'] + enrichments = metadata['enrichments'] + + # Should have all 3 enrichments (2 success + 1 failed) + assert len(enrichments) == 3 + + # Check statuses + statuses = [e['status'] for e in enrichments] + assert statuses.count('success') == 2 + assert statuses.count('failed') == 1 + + # ========== Configuration Tests ========== + + def test_default_configuration(self, mock_executor, mock_merger): + """Test orchestrator with default configuration.""" + orchestrator = LookUpOrchestrator(mock_executor, mock_merger) + + assert orchestrator.max_concurrent == 10 + assert orchestrator.queue_timeout == 120 + assert orchestrator.execution_timeout == 30 + + def test_custom_configuration(self, mock_executor, mock_merger): + """Test orchestrator with custom configuration.""" + config = { + 'max_concurrent_executions': 20, + 'queue_timeout_seconds': 300, + 'execution_timeout_seconds': 60 + } + + orchestrator = LookUpOrchestrator(mock_executor, mock_merger, config) + + assert orchestrator.max_concurrent == 20 + assert orchestrator.queue_timeout == 300 + assert orchestrator.execution_timeout == 60 + + # ========== Integration Tests ========== + + @patch('lookup.services.lookup_orchestrator.logger') + def test_logging( + self, mock_logger, orchestrator, sample_input_data, + mock_projects + ): + """Test that appropriate logging is performed.""" + orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Should log start and completion + assert mock_logger.info.call_count >= 2 + start_log = mock_logger.info.call_args_list[0][0][0] + assert 'Starting orchestrated execution' in start_log + + completion_log = mock_logger.info.call_args_list[-1][0][0] + assert 'completed' in completion_log + + def test_execution_id_propagation( + self, orchestrator, sample_input_data, mock_projects, + mock_executor + ): + """Test that execution ID is propagated to individual executions.""" + # Capture execution results + captured_results = [] + + def capture_execute(project, input_data): + result = { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + captured_results.append(result) + return result + + mock_executor.execute.side_effect = capture_execute + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + execution_id = result['_lookup_metadata']['execution_id'] + + # Check that execution_id is added to enrichments + for enrichment in result['_lookup_metadata']['enrichments']: + if 'execution_id' in enrichment: + assert enrichment['execution_id'] == execution_id diff --git a/backend/lookup/tests/test_services/test_reference_data_loader.py b/backend/lookup/tests/test_services/test_reference_data_loader.py new file mode 100644 index 0000000000..6e3e84490d --- /dev/null +++ b/backend/lookup/tests/test_services/test_reference_data_loader.py @@ -0,0 +1,565 @@ +""" +Tests for Reference Data Loader implementation. + +This module tests the ReferenceDataLoader class including loading latest/specific +versions, concatenation, and extraction validation. +""" + +import uuid +from datetime import datetime, timezone +from unittest.mock import MagicMock, Mock + +import pytest +from django.contrib.auth import get_user_model + +from lookup.exceptions import ExtractionNotCompleteError +from lookup.models import LookupDataSource, LookupProject +from lookup.services.reference_data_loader import ReferenceDataLoader + +User = get_user_model() + + +@pytest.mark.django_db +class TestReferenceDataLoader: + """Test cases for ReferenceDataLoader class.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock storage client.""" + storage = MagicMock() + storage.get = Mock(return_value="Default content") + return storage + + @pytest.fixture + def loader(self, mock_storage): + """Create a ReferenceDataLoader instance with mock storage.""" + return ReferenceDataLoader(mock_storage) + + @pytest.fixture + def test_user(self): + """Create a test user.""" + return User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + + @pytest.fixture + def test_project(self, test_user): + """Create a test Look-Up project.""" + return LookupProject.objects.create( + name="Test Project", + description="Test project for loader", + lookup_type='static_data', + llm_provider='openai', + llm_model='gpt-4', + llm_config={'temperature': 0.7}, + organization_id=uuid.uuid4(), + created_by=test_user + ) + + @pytest.fixture + def data_sources_completed(self, test_project, test_user): + """Create 3 completed data sources for testing.""" + sources = [] + for i in range(3): + source = LookupDataSource.objects.create( + project=test_project, + file_name=f"file{i+1}.csv", + file_path=f"uploads/file{i+1}.csv", + file_size=1000 * (i + 1), + file_type="text/csv", + extracted_content_path=f"extracted/file{i+1}.txt", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + sources.append(source) + return sources + + # ========== Load Latest Tests ========== + + def test_load_latest_success(self, loader, mock_storage, test_project, + data_sources_completed): + """Test successful loading of latest reference data.""" + # Setup mock storage responses + def storage_get(path): + if 'file1' in path: + return "Content from file 1" + elif 'file2' in path: + return "Content from file 2" + elif 'file3' in path: + return "Content from file 3" + return "Unknown file" + + mock_storage.get.side_effect = storage_get + + # Load latest data + result = loader.load_latest_for_project(test_project.id) + + # Verify result structure + assert 'version' in result + assert 'content' in result + assert 'files' in result + assert 'total_size' in result + + # Check version + assert result['version'] == 1 + + # Check concatenated content + expected_content = ( + "=== File: file1.csv ===\n\n" + "Content from file 1\n\n" + "=== File: file2.csv ===\n\n" + "Content from file 2\n\n" + "=== File: file3.csv ===\n\n" + "Content from file 3" + ) + assert result['content'] == expected_content + + # Check files metadata + assert len(result['files']) == 3 + assert result['files'][0]['name'] == 'file1.csv' + assert result['files'][1]['name'] == 'file2.csv' + assert result['files'][2]['name'] == 'file3.csv' + + # Check total size + assert result['total_size'] == 6000 # 1000 + 2000 + 3000 + + # Verify storage was called for each file + assert mock_storage.get.call_count == 3 + + def test_load_latest_incomplete_extraction(self, loader, test_project, + test_user): + """Test loading fails when extraction is incomplete.""" + # Create sources with mixed status + LookupDataSource.objects.create( + project=test_project, + file_name="completed.csv", + file_path="uploads/completed.csv", + file_size=1000, + file_type="text/csv", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + + LookupDataSource.objects.create( + project=test_project, + file_name="pending.csv", + file_path="uploads/pending.csv", + file_size=2000, + file_type="text/csv", + extraction_status='pending', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + + LookupDataSource.objects.create( + project=test_project, + file_name="failed.csv", + file_path="uploads/failed.csv", + file_size=3000, + file_type="text/csv", + extraction_status='failed', + extraction_error='Parse error', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + + # Attempt to load should raise exception + with pytest.raises(ExtractionNotCompleteError) as exc_info: + loader.load_latest_for_project(test_project.id) + + # Check error message includes failed files + assert 'pending.csv' in str(exc_info.value) + assert 'failed.csv' in str(exc_info.value) + assert exc_info.value.failed_files == ['pending.csv', 'failed.csv'] + + def test_load_latest_no_data_sources(self, loader): + """Test loading when no data sources exist.""" + non_existent_id = uuid.uuid4() + + with pytest.raises(LookupDataSource.DoesNotExist) as exc_info: + loader.load_latest_for_project(non_existent_id) + + assert str(non_existent_id) in str(exc_info.value) + + def test_load_latest_order_by_upload(self, loader, mock_storage, test_project, + test_user): + """Test that files are concatenated in upload order.""" + # Create sources with specific creation times + source1 = LookupDataSource.objects.create( + project=test_project, + file_name="third.csv", # Named third but uploaded first + file_path="uploads/third.csv", + file_size=1000, + file_type="text/csv", + extracted_content_path="extracted/third.txt", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + source1.created_at = datetime(2024, 1, 1, 10, 0, tzinfo=timezone.utc) + source1.save() + + source2 = LookupDataSource.objects.create( + project=test_project, + file_name="first.csv", # Named first but uploaded second + file_path="uploads/first.csv", + file_size=2000, + file_type="text/csv", + extracted_content_path="extracted/first.txt", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + source2.created_at = datetime(2024, 1, 1, 11, 0, tzinfo=timezone.utc) + source2.save() + + def storage_get(path): + if 'third' in path: + return "Third content" + elif 'first' in path: + return "First content" + return "Unknown" + + mock_storage.get.side_effect = storage_get + + result = loader.load_latest_for_project(test_project.id) + + # Verify order is by upload time, not name + assert "=== File: third.csv ===" in result['content'] + assert "=== File: first.csv ===" in result['content'] + assert result['content'].index('third.csv') < result['content'].index('first.csv') + + # ========== Load Specific Version Tests ========== + + def test_load_specific_version(self, loader, mock_storage, test_project, + test_user): + """Test loading a specific version of reference data.""" + # Create multiple versions + # Version 1 + LookupDataSource.objects.create( + project=test_project, + file_name="v1_file.csv", + file_path="uploads/v1_file.csv", + file_size=1000, + file_type="text/csv", + extracted_content_path="extracted/v1_file.txt", + extraction_status='completed', + version_number=1, + is_latest=False, + uploaded_by=test_user + ) + + # Version 2 + LookupDataSource.objects.create( + project=test_project, + file_name="v2_file.csv", + file_path="uploads/v2_file.csv", + file_size=2000, + file_type="text/csv", + extracted_content_path="extracted/v2_file.txt", + extraction_status='completed', + version_number=2, + is_latest=False, + uploaded_by=test_user + ) + + # Version 3 (latest) + LookupDataSource.objects.create( + project=test_project, + file_name="v3_file.csv", + file_path="uploads/v3_file.csv", + file_size=3000, + file_type="text/csv", + extracted_content_path="extracted/v3_file.txt", + extraction_status='completed', + version_number=3, + is_latest=True, + uploaded_by=test_user + ) + + def storage_get(path): + if 'v1' in path: + return "Version 1 content" + elif 'v2' in path: + return "Version 2 content" + elif 'v3' in path: + return "Version 3 content" + return "Unknown" + + mock_storage.get.side_effect = storage_get + + # Load version 2 + result = loader.load_specific_version(test_project.id, 2) + + assert result['version'] == 2 + assert "Version 2 content" in result['content'] + assert "v2_file.csv" in result['content'] + assert result['files'][0]['name'] == 'v2_file.csv' + + # Load version 1 + result = loader.load_specific_version(test_project.id, 1) + + assert result['version'] == 1 + assert "Version 1 content" in result['content'] + assert "v1_file.csv" in result['content'] + + def test_load_specific_version_not_found(self, loader, test_project): + """Test loading non-existent version.""" + with pytest.raises(LookupDataSource.DoesNotExist) as exc_info: + loader.load_specific_version(test_project.id, 999) + + assert "Version 999 not found" in str(exc_info.value) + + # ========== Concatenation Tests ========== + + def test_concatenate_sources(self, loader, mock_storage, data_sources_completed): + """Test concatenation of multiple sources.""" + def storage_get(path): + if 'file1' in path: + return "Alpha content" + elif 'file2' in path: + return "Beta content" + elif 'file3' in path: + return "Gamma content" + return "Unknown" + + mock_storage.get.side_effect = storage_get + + result = loader.concatenate_sources(data_sources_completed) + + # Check all files are included with headers + assert "=== File: file1.csv ===" in result + assert "=== File: file2.csv ===" in result + assert "=== File: file3.csv ===" in result + + # Check all content is included + assert "Alpha content" in result + assert "Beta content" in result + assert "Gamma content" in result + + # Check double newline separators + assert "\n\n" in result + + def test_concatenate_with_missing_path(self, loader, test_project, test_user): + """Test concatenation when extracted_content_path is missing.""" + source = LookupDataSource.objects.create( + project=test_project, + file_name="no_path.csv", + file_path="uploads/no_path.csv", + file_size=1000, + file_type="text/csv", + extracted_content_path=None, # No extracted path + extraction_status='completed', # But marked as completed + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + + result = loader.concatenate_sources([source]) + + assert "=== File: no_path.csv ===" in result + assert "[No extracted content path]" in result + + def test_concatenate_with_storage_error(self, loader, mock_storage, + data_sources_completed): + """Test concatenation handles storage errors gracefully.""" + mock_storage.get.side_effect = Exception("Storage unavailable") + + result = loader.concatenate_sources(data_sources_completed) + + # Should include error messages instead of content + assert "[Error loading file: Storage unavailable]" in result + # But should still have file headers + assert "=== File: file1.csv ===" in result + + # ========== Validation Tests ========== + + def test_validate_extraction_complete_all_success(self, loader, + data_sources_completed): + """Test validation when all extractions are complete.""" + is_complete, failed_files = loader.validate_extraction_complete( + data_sources_completed + ) + + assert is_complete is True + assert failed_files == [] + + def test_validate_extraction_with_failures(self, loader, test_project, test_user): + """Test validation identifies incomplete extractions.""" + sources = [] + + # Completed + sources.append(LookupDataSource.objects.create( + project=test_project, + file_name="good.csv", + file_path="uploads/good.csv", + file_size=1000, + file_type="text/csv", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + )) + + # Pending + sources.append(LookupDataSource.objects.create( + project=test_project, + file_name="pending.csv", + file_path="uploads/pending.csv", + file_size=2000, + file_type="text/csv", + extraction_status='pending', + version_number=1, + is_latest=True, + uploaded_by=test_user + )) + + # Processing + sources.append(LookupDataSource.objects.create( + project=test_project, + file_name="processing.csv", + file_path="uploads/processing.csv", + file_size=3000, + file_type="text/csv", + extraction_status='processing', + version_number=1, + is_latest=True, + uploaded_by=test_user + )) + + # Failed + sources.append(LookupDataSource.objects.create( + project=test_project, + file_name="failed.csv", + file_path="uploads/failed.csv", + file_size=4000, + file_type="text/csv", + extraction_status='failed', + version_number=1, + is_latest=True, + uploaded_by=test_user + )) + + is_complete, failed_files = loader.validate_extraction_complete(sources) + + assert is_complete is False + assert len(failed_files) == 3 + assert 'pending.csv' in failed_files + assert 'processing.csv' in failed_files + assert 'failed.csv' in failed_files + assert 'good.csv' not in failed_files + + # ========== Integration Tests ========== + + def test_end_to_end_multi_file_loading(self, loader, mock_storage, test_project, + test_user): + """Test complete workflow with multiple files and versions.""" + # Create version 1 with 2 files + v1_file1 = LookupDataSource.objects.create( + project=test_project, + file_name="vendors_v1.csv", + file_path="uploads/vendors_v1.csv", + file_size=5000, + file_type="text/csv", + extracted_content_path="extracted/vendors_v1.txt", + extraction_status='completed', + version_number=1, + is_latest=False, + uploaded_by=test_user + ) + + v1_file2 = LookupDataSource.objects.create( + project=test_project, + file_name="products_v1.csv", + file_path="uploads/products_v1.csv", + file_size=3000, + file_type="text/csv", + extracted_content_path="extracted/products_v1.txt", + extraction_status='completed', + version_number=1, + is_latest=False, + uploaded_by=test_user + ) + + # Create version 2 with 3 files (latest) + v2_file1 = LookupDataSource.objects.create( + project=test_project, + file_name="vendors_v2.csv", + file_path="uploads/vendors_v2.csv", + file_size=6000, + file_type="text/csv", + extracted_content_path="extracted/vendors_v2.txt", + extraction_status='completed', + version_number=2, + is_latest=True, + uploaded_by=test_user + ) + + v2_file2 = LookupDataSource.objects.create( + project=test_project, + file_name="products_v2.csv", + file_path="uploads/products_v2.csv", + file_size=4000, + file_type="text/csv", + extracted_content_path="extracted/products_v2.txt", + extraction_status='completed', + version_number=2, + is_latest=True, + uploaded_by=test_user + ) + + v2_file3 = LookupDataSource.objects.create( + project=test_project, + file_name="categories.csv", + file_path="uploads/categories.csv", + file_size=2000, + file_type="text/csv", + extracted_content_path="extracted/categories.txt", + extraction_status='completed', + version_number=2, + is_latest=True, + uploaded_by=test_user + ) + + def storage_get(path): + content_map = { + 'vendors_v1': "Slack\nMicrosoft", + 'products_v1': "Slack Workspace\nTeams", + 'vendors_v2': "Slack\nMicrosoft\nGoogle", + 'products_v2': "Slack Workspace\nTeams\nGoogle Workspace", + 'categories': "Communication\nProductivity" + } + for key, value in content_map.items(): + if key in path: + return value + return "Unknown content" + + mock_storage.get.side_effect = storage_get + + # Load latest (v2) + latest_result = loader.load_latest_for_project(test_project.id) + + assert latest_result['version'] == 2 + assert len(latest_result['files']) == 3 + assert latest_result['total_size'] == 12000 # 6000 + 4000 + 2000 + assert "Google" in latest_result['content'] + assert "categories.csv" in latest_result['content'] + + # Load specific version (v1) + v1_result = loader.load_specific_version(test_project.id, 1) + + assert v1_result['version'] == 1 + assert len(v1_result['files']) == 2 + assert v1_result['total_size'] == 8000 # 5000 + 3000 + assert "Google" not in v1_result['content'] # v1 doesn't have Google + assert "categories" not in v1_result['content'] # v1 doesn't have categories diff --git a/backend/lookup/tests/test_variable_resolver.py b/backend/lookup/tests/test_variable_resolver.py new file mode 100644 index 0000000000..826f961b10 --- /dev/null +++ b/backend/lookup/tests/test_variable_resolver.py @@ -0,0 +1,289 @@ +"""Tests for the VariableResolver class.""" + +import json +import pytest +from lookup.services.variable_resolver import VariableResolver + + +class TestVariableResolver: + """Test suite for VariableResolver.""" + + @pytest.fixture + def sample_input_data(self): + """Sample input data for testing.""" + return { + "vendor_name": "Slack India Pvt Ltd", + "contract_value": 50000, + "contract_date": "2024-01-15", + "line_items": [ + {"product": "Slack Pro", "quantity": 100, "price": 500} + ], + "metadata": { + "region": "APAC", + "currency": "INR" + }, + "none_field": None + } + + @pytest.fixture + def sample_reference_data(self): + """Sample reference data for testing.""" + return """Canonical Vendors: +- Slack (variations: Slack Inc, Slack India, Slack Singapore) +- Microsoft (variations: Microsoft Corp, MSFT) +- Google (variations: Google LLC, Google India)""" + + @pytest.fixture + def resolver(self, sample_input_data, sample_reference_data): + """Create a VariableResolver instance for testing.""" + return VariableResolver(sample_input_data, sample_reference_data) + + # ========================================================================== + # Basic Resolution Tests + # ========================================================================== + + def test_simple_variable_replacement(self, resolver): + """Test simple variable replacement.""" + template = "Vendor: {{input_data.vendor_name}}" + result = resolver.resolve(template) + assert result == "Vendor: Slack India Pvt Ltd" + + def test_reference_data_replacement(self, resolver): + """Test reference data replacement.""" + template = "Database: {{reference_data}}" + result = resolver.resolve(template) + assert "Canonical Vendors" in result + assert "Slack" in result + + def test_multiple_variables(self, resolver): + """Test multiple variable replacements in one template.""" + template = "Match {{input_data.vendor_name}} from {{reference_data}}" + result = resolver.resolve(template) + assert "Slack India Pvt Ltd" in result + assert "Canonical Vendors" in result + + # ========================================================================== + # Dot Notation Tests + # ========================================================================== + + def test_one_level_dot_notation(self, resolver): + """Test one level dot notation.""" + template = "Value: {{input_data.contract_value}}" + result = resolver.resolve(template) + assert result == "Value: 50000" + + def test_two_level_dot_notation(self, resolver): + """Test two level dot notation.""" + template = "Region: {{input_data.metadata.region}}" + result = resolver.resolve(template) + assert result == "Region: APAC" + + def test_array_indexing(self, resolver): + """Test array indexing with dot notation.""" + template = "Product: {{input_data.line_items.0.product}}" + result = resolver.resolve(template) + assert result == "Product: Slack Pro" + + def test_deep_nesting(self, resolver): + """Test deep nested path resolution.""" + template = "Price: {{input_data.line_items.0.price}}" + result = resolver.resolve(template) + assert result == "Price: 500" + + # ========================================================================== + # Complex Object Serialization Tests + # ========================================================================== + + def test_dict_serialization(self, resolver): + """Test that dicts are serialized to JSON.""" + template = "Metadata: {{input_data.metadata}}" + result = resolver.resolve(template) + # Parse the JSON portion to verify it's valid + json_str = result.replace("Metadata: ", "") + parsed = json.loads(json_str) + assert parsed["region"] == "APAC" + assert parsed["currency"] == "INR" + + def test_list_serialization(self, resolver): + """Test that lists are serialized to JSON.""" + template = "Items: {{input_data.line_items}}" + result = resolver.resolve(template) + # Parse the JSON portion to verify it's valid + json_str = result.replace("Items: ", "") + parsed = json.loads(json_str) + assert len(parsed) == 1 + assert parsed[0]["product"] == "Slack Pro" + + def test_full_input_serialization(self, resolver): + """Test serializing the entire input_data object.""" + template = "Full data: {{input_data}}" + result = resolver.resolve(template) + # Should contain JSON representation + assert '"vendor_name"' in result + assert '"contract_value"' in result + + # ========================================================================== + # Missing Value Tests + # ========================================================================== + + def test_missing_root_variable(self, resolver): + """Test missing root variable returns empty string.""" + template = "Missing: {{missing_data}}" + result = resolver.resolve(template) + assert result == "Missing: " + + def test_missing_nested_field(self, resolver): + """Test missing nested field returns empty string.""" + template = "Missing: {{input_data.missing.field}}" + result = resolver.resolve(template) + assert result == "Missing: " + + def test_partial_path_missing(self, resolver): + """Test partially missing path returns empty string.""" + template = "Missing: {{input_data.metadata.missing}}" + result = resolver.resolve(template) + assert result == "Missing: " + + def test_none_value(self, resolver): + """Test that None values are converted to empty string.""" + template = "None field: {{input_data.none_field}}" + result = resolver.resolve(template) + assert result == "None field: " + + def test_out_of_bounds_array_index(self, resolver): + """Test out of bounds array index returns empty string.""" + template = "Missing: {{input_data.line_items.999.product}}" + result = resolver.resolve(template) + assert result == "Missing: " + + # ========================================================================== + # Edge Cases Tests + # ========================================================================== + + def test_empty_template(self, resolver): + """Test empty template returns empty string.""" + assert resolver.resolve("") == "" + + def test_no_variables(self, resolver): + """Test template without variables returns unchanged.""" + template = "Plain text without variables" + assert resolver.resolve(template) == template + + def test_malformed_variable(self, resolver): + """Test malformed variable syntax is left unchanged.""" + template = "Malformed: {{unclosed" + result = resolver.resolve(template) + assert result == "Malformed: {{unclosed" + + def test_whitespace_in_variable(self, resolver): + """Test variables with whitespace are handled correctly.""" + template = "Vendor: {{ input_data.vendor_name }}" + result = resolver.resolve(template) + assert result == "Vendor: Slack India Pvt Ltd" + + def test_empty_braces(self, resolver): + """Test empty braces are handled.""" + template = "Empty: {{}}" + result = resolver.resolve(template) + assert result == "Empty: " + + # ========================================================================== + # Variable Detection Tests + # ========================================================================== + + def test_detect_single_variable(self, resolver): + """Test detecting a single variable.""" + template = "{{input_data.vendor_name}}" + variables = resolver.detect_variables(template) + assert variables == ["input_data.vendor_name"] + + def test_detect_multiple_variables(self, resolver): + """Test detecting multiple variables.""" + template = "{{input_data.vendor_name}} and {{reference_data}}" + variables = resolver.detect_variables(template) + assert set(variables) == {"input_data.vendor_name", "reference_data"} + + def test_detect_duplicate_variables(self, resolver): + """Test that duplicate variables are deduplicated.""" + template = "{{var}} and {{var}} and {{var}}" + variables = resolver.detect_variables(template) + assert variables == ["var"] + + def test_detect_no_variables(self, resolver): + """Test detecting no variables in plain text.""" + template = "Plain text without variables" + variables = resolver.detect_variables(template) + assert variables == [] + + def test_detect_variables_with_whitespace(self, resolver): + """Test detecting variables with whitespace.""" + template = "{{ input_data.vendor }} and {{reference_data}}" + variables = resolver.detect_variables(template) + assert set(variables) == {"input_data.vendor", "reference_data"} + + # ========================================================================== + # Variable Validation Tests + # ========================================================================== + + def test_validate_existing_variables(self, resolver): + """Test validation of existing variables.""" + template = "{{input_data.vendor_name}} and {{reference_data}}" + validation = resolver.validate_variables(template) + assert validation["input_data.vendor_name"] is True + assert validation["reference_data"] is True + + def test_validate_missing_variables(self, resolver): + """Test validation of missing variables.""" + template = "{{input_data.missing}} and {{nonexistent}}" + validation = resolver.validate_variables(template) + assert validation["input_data.missing"] is False + assert validation["nonexistent"] is False + + def test_get_missing_variables(self, resolver): + """Test getting list of missing variables.""" + template = "{{input_data.vendor_name}} {{missing}} {{input_data.nope}}" + missing = resolver.get_missing_variables(template) + assert set(missing) == {"missing", "input_data.nope"} + + # ========================================================================== + # Integration Tests + # ========================================================================== + + def test_complete_vendor_matching_template(self, resolver): + """Test complete vendor matching template from spec.""" + template = """Match vendor "{{input_data.vendor_name}}" from: +{{reference_data}} + +Contract value: {{input_data.contract_value}} +Region: {{input_data.metadata.region}}""" + + result = resolver.resolve(template) + + # Verify all variables were replaced correctly + assert "Slack India Pvt Ltd" in result + assert "Canonical Vendors" in result + assert "50000" in result + assert "APAC" in result + assert "{{" not in result # No unresolved variables + + def test_complex_nested_template(self, resolver): + """Test complex template with various variable types.""" + template = """Input Summary: +- Vendor: {{input_data.vendor_name}} +- Date: {{input_data.contract_date}} +- First Item: {{input_data.line_items.0.product}} +- Metadata: {{input_data.metadata}} +- Missing: {{input_data.nonexistent}} + +Reference Database: +{{reference_data}}""" + + result = resolver.resolve(template) + + # Verify each part + assert "Slack India Pvt Ltd" in result + assert "2024-01-15" in result + assert "Slack Pro" in result + assert '"region": "APAC"' in result # JSON serialized + assert "Missing: \n" in result # Empty for missing field + assert "Canonical Vendors" in result diff --git a/backend/lookup/urls.py b/backend/lookup/urls.py new file mode 100644 index 0000000000..bc30dfa21d --- /dev/null +++ b/backend/lookup/urls.py @@ -0,0 +1,34 @@ +"""URL configuration for Look-Up API endpoints.""" + +from django.urls import include, path +from rest_framework.routers import DefaultRouter + +from .views import ( + LookupDataSourceViewSet, + LookupDebugView, + LookupExecutionAuditViewSet, + LookupProfileManagerViewSet, + LookupProjectViewSet, + LookupPromptTemplateViewSet, + PromptStudioLookupLinkViewSet, +) + +# Create router for viewsets +router = DefaultRouter() +router.register(r"lookup-projects", LookupProjectViewSet, basename="lookupproject") +router.register( + r"lookup-templates", LookupPromptTemplateViewSet, basename="lookuptemplate" +) +router.register(r"lookup-profiles", LookupProfileManagerViewSet, basename="lookupprofile") +router.register(r"data-sources", LookupDataSourceViewSet, basename="lookupdatasource") +router.register(r"lookup-links", PromptStudioLookupLinkViewSet, basename="lookuplink") +router.register( + r"execution-audits", LookupExecutionAuditViewSet, basename="executionaudit" +) +router.register(r"lookup-debug", LookupDebugView, basename="lookupdebug") + +app_name = "lookup" + +urlpatterns = [ + path("", include(router.urls)), +] diff --git a/backend/lookup/views.py b/backend/lookup/views.py new file mode 100644 index 0000000000..de6311f4f2 --- /dev/null +++ b/backend/lookup/views.py @@ -0,0 +1,924 @@ +"""Django REST Framework views for Look-Up API. + +This module provides RESTful API endpoints for managing Look-Up projects, +templates, reference data, and executing Look-Ups. +""" + +import logging +import uuid + +from account_v2.custom_exceptions import DuplicateData +from django.db import IntegrityError, transaction +from permissions.permission import IsOwner +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from utils.filtering import FilterHelper +from utils.pagination import CustomPagination + +from .constants import LookupProfileManagerErrors, LookupProfileManagerKeys +from .exceptions import ExtractionNotCompleteError +from .models import ( + LookupDataSource, + LookupExecutionAudit, + LookupProfileManager, + LookupProject, + LookupPromptTemplate, + PromptStudioLookupLink, +) +from .serializers import ( + BulkLinkSerializer, + LookupDataSourceSerializer, + LookupExecutionAuditSerializer, + LookupExecutionRequestSerializer, + LookupExecutionResponseSerializer, + LookupProfileManagerSerializer, + LookupProjectSerializer, + LookupPromptTemplateSerializer, + PromptStudioLookupLinkSerializer, + ReferenceDataUploadSerializer, + TemplateValidationSerializer, +) +from .services import ( + AuditLogger, + EnrichmentMerger, + LLMResponseCache, + LookUpExecutor, + LookUpOrchestrator, + ReferenceDataLoader, + VariableResolver, +) + +logger = logging.getLogger(__name__) + + +class LookupProjectViewSet(viewsets.ModelViewSet): + """ViewSet for managing Look-Up projects. + + Provides CRUD operations and additional actions for + executing Look-Ups and managing reference data. + """ + + queryset = LookupProject.objects.all() + serializer_class = LookupProjectSerializer + permission_classes = [IsAuthenticated] + pagination_class = CustomPagination + + def get_queryset(self): + """Filter projects by organization and active status.""" + # Note: Organization filtering is handled automatically by + # DefaultOrganizationMixin's save() method and queryset filtering + # should be handled by a custom manager if needed (like Prompt Studio) + queryset = super().get_queryset() + + # Filter by active status if requested + is_active = self.request.query_params.get("is_active") + if is_active is not None: + queryset = queryset.filter(is_active=is_active.lower() == "true") + + return queryset.select_related("template") + + def perform_create(self, serializer): + """Set created_by from request.""" + # Note: organization is set automatically by DefaultOrganizationMixin's save() method + serializer.save(created_by=self.request.user) + + @action(detail=True, methods=["post"]) + def execute(self, request, pk=None): + """Execute a Look-Up project with provided input data. + + POST /api/v1/lookup-projects/{id}/execute/ + """ + project = self.get_object() + + # Validate request + request_serializer = LookupExecutionRequestSerializer(data=request.data) + request_serializer.is_valid(raise_exception=True) + + input_data = request_serializer.validated_data["input_data"] + use_cache = request_serializer.validated_data["use_cache"] + timeout = request_serializer.validated_data["timeout_seconds"] + + try: + # Get the LLM adapter from Lookup profile + from .integrations.file_storage_client import FileStorageClient + from .integrations.unstract_llm_client import UnstractLLMClient + + # Get profile for this project + profile = LookupProfileManager.objects.filter(lookup_project=project).first() + + if not profile or not profile.llm: + return Response( + {"error": "No LLM profile configured for this Look-Up project"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Create real LLM client using the profile's adapter + llm_client = UnstractLLMClient(profile.llm) + storage_client = FileStorageClient() + cache = LLMResponseCache() if use_cache else None + ref_loader = ReferenceDataLoader(storage_client) + merger = EnrichmentMerger() + + executor = LookUpExecutor( + variable_resolver=VariableResolver, + cache_manager=cache, + reference_loader=ref_loader, + llm_client=llm_client, + ) + + orchestrator = LookUpOrchestrator( + executor=executor, + merger=merger, + config={"execution_timeout_seconds": timeout}, + ) + + # Execute Look-Up + result = orchestrator.execute_lookups( + input_data=input_data, lookup_projects=[project] + ) + + # Serialize response + response_serializer = LookupExecutionResponseSerializer(data=result) + response_serializer.is_valid(raise_exception=True) + + return Response(response_serializer.validated_data, status=status.HTTP_200_OK) + + except ExtractionNotCompleteError as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception: + logger.exception(f"Error executing Look-Up project {project.id}") + return Response( + {"error": "Internal error during execution"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + @action(detail=True, methods=["post"], parser_classes=[MultiPartParser]) + def upload_reference_data(self, request, pk=None): + """Upload reference data for a Look-Up project. + + POST /api/v1/lookup-projects/{id}/upload_reference_data/ + """ + project = self.get_object() + + upload_serializer = ReferenceDataUploadSerializer(data=request.data) + upload_serializer.is_valid(raise_exception=True) + + file = upload_serializer.validated_data["file"] + extract_text = upload_serializer.validated_data["extract_text"] + + try: + from django.conf import settings + from utils.file_storage.constants import FileStorageKeys + + from unstract.sdk1.file_storage.constants import StorageType + from unstract.sdk1.file_storage.env_helper import EnvHelper + + # Get file storage instance + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + # Determine file type from extension + file_ext = file.name.split(".")[-1].lower() + file_type = ( + file_ext + if file_ext in ["pdf", "xlsx", "csv", "docx", "txt", "json"] + else "txt" + ) + + # Upload file to storage following Prompt Studio's path structure + # Pattern: {base_path}/{org_id}/{project_id}/{filename} + # Keep Lookup file storage independent of PS project linkage + from utils.user_context import UserContext + + org_id = UserContext.get_organization_identifier() + base_path = settings.PROMPT_STUDIO_FILE_PATH + + # Store files under Lookup project ID, not PS tool ID + # This ensures files remain accessible regardless of PS linkage changes + file_path = f"{base_path}/{org_id}/{project.id}/{file.name}" + + # Create parent directories if they don't exist + fs_instance.mkdir(f"{base_path}/{org_id}/{project.id}", create_parents=True) + fs_instance.mkdir( + f"{base_path}/{org_id}/{project.id}/extract", create_parents=True + ) + + # Upload the file + fs_instance.write(path=file_path, mode="wb", data=file.read()) + + logger.info(f"Uploaded file to storage: {file_path}") + + # Create a data source record + data_source = LookupDataSource.objects.create( + project=project, + file_name=file.name, + file_path=file_path, + file_size=file.size, + file_type=file_type, + extraction_status="pending" if extract_text else "completed", + uploaded_by=request.user, + ) + + serializer = LookupDataSourceSerializer(data_source) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + except Exception: + logger.exception(f"Error uploading reference data for project {project.id}") + return Response( + {"error": "Failed to upload reference data"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + @action(detail=True, methods=["get"]) + def data_sources(self, request, pk=None): + """List all data sources for a Look-Up project. + + GET /api/v1/lookup-projects/{id}/data_sources/ + """ + project = self.get_object() + data_sources = project.data_sources.all().order_by("-version_number") + + # Filter by is_latest if requested + is_latest = request.query_params.get("is_latest") + if is_latest is not None: + data_sources = data_sources.filter(is_latest=is_latest.lower() == "true") + + serializer = LookupDataSourceSerializer(data_sources, many=True) + return Response(serializer.data) + + @action(detail=True, methods=["post"]) + def index_all(self, request, pk=None): + """Index all reference data using the project's default profile. + + POST /api/v1/lookup-projects/{id}/index_all/ + + Triggers indexing of all completed data sources using the + configured default profile's adapters and settings. + + This calls external extraction and indexing services via PromptTool SDK. + """ + project = self.get_object() + + try: + from utils.user_context import UserContext + + from .exceptions import DefaultProfileError + from .services import IndexingService + + # Get organization and user context + org_id = UserContext.get_organization_identifier() + user_id = str(request.user.user_id) if request.user else None + + logger.info( + f"Starting indexing for project {project.id} " + f"(org: {org_id}, user: {user_id})" + ) + + # Index all using default profile + results = IndexingService.index_with_default_profile( + project_id=str(project.id), org_id=org_id, user_id=user_id + ) + + logger.info( + f"Indexing completed for project {project.id}: " + f"{results['success']} successful, {results['failed']} failed" + ) + + return Response( + {"message": "Indexing completed", "results": results}, + status=status.HTTP_200_OK, + ) + + except DefaultProfileError as e: + logger.error(f"Default profile error for project {project.id}: {e}") + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception as e: + logger.exception(f"Error indexing reference data for project {project.id}") + return Response( + {"error": f"Failed to index reference data: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class LookupPromptTemplateViewSet(viewsets.ModelViewSet): + """ViewSet for managing Look-Up prompt templates. + + Provides CRUD operations and template validation. + """ + + queryset = LookupPromptTemplate.objects.all() + serializer_class = LookupPromptTemplateSerializer + permission_classes = [IsAuthenticated] + pagination_class = CustomPagination + + def get_queryset(self): + """Filter templates by active status if requested.""" + queryset = super().get_queryset() + is_active = self.request.query_params.get("is_active") + + if is_active is not None: + queryset = queryset.filter(is_active=is_active.lower() == "true") + + return queryset + + def perform_create(self, serializer): + """Set created_by from request and update project's template reference.""" + template = serializer.save(created_by=self.request.user) + # Update the project's template field to point to this template + if template.project: + template.project.template = template + template.project.save(update_fields=["template"]) + + def perform_update(self, serializer): + """Update template and ensure project reference is maintained.""" + template = serializer.save() + # Ensure the project's template field points to this template + if template.project and template.project.template != template: + template.project.template = template + template.project.save(update_fields=["template"]) + + @action(detail=False, methods=["post"]) + def validate(self, request): + """Validate a template with optional sample data. + + POST /api/v1/lookup-templates/validate/ + """ + validator = TemplateValidationSerializer(data=request.data) + validator.is_valid(raise_exception=True) + + template_text = validator.validated_data["template_text"] + sample_data = validator.validated_data.get("sample_data", {}) + sample_reference = validator.validated_data.get("sample_reference", "") + + try: + # Test variable resolution + resolver = VariableResolver(sample_data, sample_reference) + resolved = resolver.resolve(template_text) + + return Response( + { + "valid": True, + "resolved_template": resolved[:1000], # First 1000 chars + "variables_found": list(resolver.get_all_variables(template_text)), + } + ) + + except Exception as e: + return Response( + {"valid": False, "error": str(e)}, status=status.HTTP_400_BAD_REQUEST + ) + + +class PromptStudioLookupLinkViewSet(viewsets.ModelViewSet): + """ViewSet for managing links between Prompt Studio projects and Look-Ups.""" + + queryset = PromptStudioLookupLink.objects.all() + serializer_class = PromptStudioLookupLinkSerializer + permission_classes = [IsAuthenticated] + pagination_class = CustomPagination + + def get_queryset(self): + """Filter links by PS project if requested.""" + queryset = super().get_queryset() + + ps_project_id = self.request.query_params.get("prompt_studio_project_id") + if ps_project_id: + queryset = queryset.filter(prompt_studio_project_id=ps_project_id) + + lookup_project_id = self.request.query_params.get("lookup_project_id") + if lookup_project_id: + queryset = queryset.filter(lookup_project_id=lookup_project_id) + + return queryset.select_related("lookup_project") + + @action(detail=False, methods=["post"]) + def bulk_link(self, request): + """Create or remove multiple links at once. + + POST /api/v1/lookup-links/bulk_link/ + """ + serializer = BulkLinkSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + + ps_project_id = serializer.validated_data["prompt_studio_project_id"] + lookup_project_ids = serializer.validated_data["lookup_project_ids"] + unlink = serializer.validated_data["unlink"] + + results = [] + + with transaction.atomic(): + for lookup_id in lookup_project_ids: + if unlink: + # Remove link + deleted_count, _ = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=ps_project_id, + lookup_project_id=lookup_id, + ).delete() + results.append( + { + "lookup_project_id": str(lookup_id), + "unlinked": deleted_count > 0, + } + ) + else: + # Create link + link, created = PromptStudioLookupLink.objects.get_or_create( + prompt_studio_project_id=ps_project_id, + lookup_project_id=lookup_id, + ) + results.append( + { + "lookup_project_id": str(lookup_id), + "linked": created, + "link_id": str(link.id) if created else None, + } + ) + + return Response({"results": results, "total_processed": len(results)}) + + +class LookupExecutionAuditViewSet(viewsets.ReadOnlyModelViewSet): + """ViewSet for viewing execution audit records. + + Read-only access to execution history and statistics. + """ + + queryset = LookupExecutionAudit.objects.all() + serializer_class = LookupExecutionAuditSerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + """Filter audit records by various parameters.""" + queryset = super().get_queryset() + + # Filter by Look-Up project + lookup_project_id = self.request.query_params.get("lookup_project_id") + if lookup_project_id: + queryset = queryset.filter(lookup_project_id=lookup_project_id) + + # Filter by PS project + ps_project_id = self.request.query_params.get("prompt_studio_project_id") + if ps_project_id: + queryset = queryset.filter(prompt_studio_project_id=ps_project_id) + + # Filter by execution ID + execution_id = self.request.query_params.get("execution_id") + if execution_id: + queryset = queryset.filter(execution_id=execution_id) + + # Filter by status + status_filter = self.request.query_params.get("status") + if status_filter: + queryset = queryset.filter(status=status_filter) + + return queryset.select_related("lookup_project").order_by("-executed_at") + + @action(detail=False, methods=["get"]) + def statistics(self, request): + """Get execution statistics for a project. + + GET /api/v1/execution-audits/statistics/?lookup_project_id={id} + """ + lookup_project_id = request.query_params.get("lookup_project_id") + if not lookup_project_id: + return Response( + {"error": "lookup_project_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + audit_logger = AuditLogger() + stats = audit_logger.get_project_stats( + project_id=uuid.UUID(lookup_project_id), limit=1000 + ) + return Response(stats) + + except ValueError: + return Response( + {"error": "Invalid UUID format"}, status=status.HTTP_400_BAD_REQUEST + ) + + +class LookupDebugView(viewsets.ViewSet): + """Debug endpoints for testing Look-Up execution with Prompt Studio.""" + + permission_classes = [IsAuthenticated] + + @action(detail=False, methods=["post"]) + def enrich_ps_output(self, request): + """Enrich Prompt Studio extracted output with linked Look-Ups. + + Uses real LLM clients configured in each Look-Up project's profile. + + POST /api/v1/lookup-debug/enrich_ps_output/ + + Request body: + { + "prompt_studio_project_id": "uuid", + "extracted_data": {"vendor_name": "Amzn Web Services Inc", ...} + } + + Response: + { + "original_data": {...}, + "enriched_data": {...}, + "lookup_enrichment": {...}, + "_lookup_metadata": {...} + } + """ + ps_project_id = request.data.get("prompt_studio_project_id") + extracted_data = request.data.get("extracted_data", {}) + + if not ps_project_id: + return Response( + {"error": "prompt_studio_project_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if not extracted_data: + return Response( + {"error": "extracted_data is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + # Get linked Look-Ups + links = ( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=ps_project_id + ) + .select_related("lookup_project") + .order_by("execution_order") + ) + + if not links: + return Response( + { + "original_data": extracted_data, + "enriched_data": extracted_data, + "lookup_enrichment": {}, + "_lookup_metadata": { + "lookups_executed": 0, + "message": "No Look-Ups linked to this Prompt Studio project", + }, + } + ) + + # Get Look-Up projects + lookup_projects = [link.lookup_project for link in links] + + # Initialize services with real clients + from .integrations.file_storage_client import FileStorageClient + from .integrations.unstract_llm_client import UnstractLLMClient + + storage_client = FileStorageClient() + cache = LLMResponseCache() + ref_loader = ReferenceDataLoader(storage_client) + merger = EnrichmentMerger() + + # Execute each Look-Up with its own LLM profile + all_enrichment = {} + all_metadata = {"lookups_executed": 0, "lookup_details": []} + + for project in lookup_projects: + # Get profile for this project + profile = LookupProfileManager.objects.filter( + lookup_project=project + ).first() + + if not profile or not profile.llm: + all_metadata["lookup_details"].append( + { + "project_id": str(project.id), + "project_name": project.name, + "status": "skipped", + "reason": "No LLM profile configured", + } + ) + continue + + try: + # Create LLM client for this project's profile + llm_client = UnstractLLMClient(profile.llm) + + executor = LookUpExecutor( + variable_resolver=VariableResolver, + cache_manager=cache, + reference_loader=ref_loader, + llm_client=llm_client, + ) + + orchestrator = LookUpOrchestrator(executor=executor, merger=merger) + + # Execute Look-Up with extracted data as input + result = orchestrator.execute_lookups( + input_data=extracted_data, lookup_projects=[project] + ) + + # Merge results + if result.get("lookup_enrichment"): + all_enrichment.update(result["lookup_enrichment"]) + + all_metadata["lookups_executed"] += 1 + all_metadata["lookup_details"].append( + { + "project_id": str(project.id), + "project_name": project.name, + "status": "success", + "enrichment_keys": list( + result.get("lookup_enrichment", {}).keys() + ), + } + ) + + except Exception as e: + logger.exception(f"Error executing Look-Up {project.id}") + all_metadata["lookup_details"].append( + { + "project_id": str(project.id), + "project_name": project.name, + "status": "error", + "error": str(e), + } + ) + + # Merge enrichment into extracted data + enriched_data = {**extracted_data, **all_enrichment} + + return Response( + { + "original_data": extracted_data, + "enriched_data": enriched_data, + "lookup_enrichment": all_enrichment, + "_lookup_metadata": all_metadata, + } + ) + + except Exception as e: + logger.exception(f"Error enriching PS output for project {ps_project_id}") + return Response( + {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + + @action(detail=False, methods=["post"]) + def test_with_ps_project(self, request): + """Test Look-Up execution with a Prompt Studio project context. + + POST /api/v1/lookup-debug/test_with_ps_project/ + """ + ps_project_id = request.data.get("prompt_studio_project_id") + input_data = request.data.get("input_data", {}) + + if not ps_project_id: + return Response( + {"error": "prompt_studio_project_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + # Get linked Look-Ups + links = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=ps_project_id + ).select_related("lookup_project") + + if not links: + return Response( + { + "message": "No Look-Ups linked to this Prompt Studio project", + "lookup_enrichment": {}, + "_lookup_metadata": {"lookups_executed": 0}, + } + ) + + # Get Look-Up projects + lookup_projects = [link.lookup_project for link in links] + + # Initialize services + from .services.mock_clients import MockLLMClient, MockStorageClient + + llm_client = MockLLMClient() + storage_client = MockStorageClient() + cache = LLMResponseCache() + ref_loader = ReferenceDataLoader(storage_client) + merger = EnrichmentMerger() + + executor = LookUpExecutor( + variable_resolver=VariableResolver, + cache_manager=cache, + reference_loader=ref_loader, + llm_client=llm_client, + ) + + orchestrator = LookUpOrchestrator(executor=executor, merger=merger) + + # Execute Look-Ups + result = orchestrator.execute_lookups( + input_data=input_data, lookup_projects=lookup_projects + ) + + return Response(result) + + except Exception as e: + logger.exception(f"Error in debug execution for PS project {ps_project_id}") + return Response( + {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + + +class LookupDataSourceViewSet(viewsets.ModelViewSet): + """ViewSet for managing LookupDataSource instances. + + Provides CRUD operations for Look-Up data sources (reference data files). + Supports listing, retrieving, and deleting data sources. + """ + + queryset = LookupDataSource.objects.all() + serializer_class = LookupDataSourceSerializer + permission_classes = [IsAuthenticated] + pagination_class = CustomPagination + + def get_queryset(self): + """Filter data sources by project if specified.""" + queryset = super().get_queryset() + + # Filter by project if provided + project_id = self.request.query_params.get("project") + if project_id: + queryset = queryset.filter(project_id=project_id) + + # Filter by is_latest if requested + is_latest = self.request.query_params.get("is_latest") + if is_latest is not None: + queryset = queryset.filter(is_latest=is_latest.lower() == "true") + + return queryset.select_related("project", "uploaded_by").order_by( + "-version_number" + ) + + def destroy(self, request, *args, **kwargs): + """Delete a data source and its associated files from storage. + + DELETE /api/v1/unstract/{org_id}/data-sources/{id}/ + """ + instance = self.get_object() + + try: + from utils.file_storage.constants import FileStorageKeys + + from unstract.sdk1.file_storage.constants import StorageType + from unstract.sdk1.file_storage.env_helper import EnvHelper + + # Get file storage instance + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + # Delete the file from storage if it exists + if instance.file_path: + try: + if fs_instance.exists(instance.file_path): + fs_instance.rm(instance.file_path) + logger.info(f"Deleted file from storage: {instance.file_path}") + except Exception as e: + logger.warning(f"Failed to delete file from storage: {e}") + + # Delete extracted content if it exists + if instance.extracted_content_path: + try: + if fs_instance.exists(instance.extracted_content_path): + fs_instance.rm(instance.extracted_content_path) + logger.info( + f"Deleted extracted content: {instance.extracted_content_path}" + ) + except Exception as e: + logger.warning(f"Failed to delete extracted content: {e}") + + # Delete associated index manager if exists + try: + from .models import LookupIndexManager + + LookupIndexManager.objects.filter(data_source=instance).delete() + logger.info(f"Deleted index manager for data source: {instance.id}") + except Exception as e: + logger.warning(f"Failed to delete index manager: {e}") + + # Delete the database record + instance.delete() + + logger.info(f"Successfully deleted data source: {instance.id}") + return Response(status=status.HTTP_204_NO_CONTENT) + + except Exception as e: + logger.exception(f"Error deleting data source {instance.id}") + return Response( + {"error": f"Failed to delete data source: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class LookupProfileManagerViewSet(viewsets.ModelViewSet): + """ViewSet for managing LookupProfileManager instances. + + Provides CRUD operations for Look-Up project profiles. + Each profile defines the set of adapters to use for a Look-Up project. + + Follows the same pattern as Prompt Studio's ProfileManagerView. + """ + + versioning_class = URLPathVersioning + permission_classes = [IsOwner] + serializer_class = LookupProfileManagerSerializer + pagination_class = CustomPagination + + def get_queryset(self): + """Filter queryset by created_by if specified in query params. + Otherwise return all profiles the user has access to. + """ + filter_args = FilterHelper.build_filter_args( + self.request, + LookupProfileManagerKeys.CREATED_BY, + ) + if filter_args: + queryset = LookupProfileManager.objects.filter(**filter_args) + else: + queryset = LookupProfileManager.objects.all() + + # Filter by lookup_project if provided in query params + lookup_project_id = self.request.query_params.get("lookup_project") + if lookup_project_id: + queryset = queryset.filter(lookup_project_id=lookup_project_id) + + return queryset.order_by("-created_at") + + def create(self, request, *args, **kwargs): + """Create a new profile. + + Handles IntegrityError for duplicate profile names within the same project. + """ + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData(LookupProfileManagerErrors.PROFILE_NAME_EXISTS) + + return Response(serializer.data, status=status.HTTP_201_CREATED) + + @action(detail=False, methods=["get"], url_path="default") + def get_default(self, request): + """Get the default profile for a lookup project. + + Query params: + - lookup_project: UUID of the lookup project (required) + + Returns: + Profile data or 404 if no default profile exists + """ + lookup_project_id = request.query_params.get("lookup_project") + + if not lookup_project_id: + return Response( + {"error": "lookup_project query parameter is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + project = LookupProject.objects.get(id=lookup_project_id) + profile = LookupProfileManager.get_default_profile(project) + serializer = self.get_serializer(profile) + return Response(serializer.data) + except LookupProject.DoesNotExist: + return Response( + {"error": "Lookup project not found"}, status=status.HTTP_404_NOT_FOUND + ) + except Exception as e: + return Response({"error": str(e)}, status=status.HTTP_404_NOT_FOUND) + + @action(detail=True, methods=["post"], url_path="set-default") + def set_default(self, request, pk=None): + """Set a profile as the default for its project. + + Unsets any existing default profile for the same project. + """ + profile = self.get_object() + + with transaction.atomic(): + # Unset existing default for this project + LookupProfileManager.objects.filter( + lookup_project=profile.lookup_project, is_default=True + ).update(is_default=False) + + # Set this profile as default + profile.is_default = True + profile.save() + + serializer = self.get_serializer(profile) + return Response(serializer.data) diff --git a/frontend/src/assets/lookups.svg b/frontend/src/assets/lookups.svg new file mode 100644 index 0000000000..db41980604 --- /dev/null +++ b/frontend/src/assets/lookups.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/frontend/src/components/custom-tools/combined-output/CombinedOutput.css b/frontend/src/components/custom-tools/combined-output/CombinedOutput.css index 13302b7e92..24c43a1104 100644 --- a/frontend/src/components/custom-tools/combined-output/CombinedOutput.css +++ b/frontend/src/components/custom-tools/combined-output/CombinedOutput.css @@ -1,39 +1,54 @@ /* Styles for CombinedOutput */ .combined-op-layout { - display: flex; - flex-direction: column; - height: 100%; - overflow-y: hidden; + display: flex; + flex-direction: column; + height: 100%; + overflow-y: hidden; } .combined-op-header { - display: flex; - margin-top: 7px; + display: flex; + margin-top: 7px; } .combined-op-segment { - margin-left: auto; + margin-left: auto; } .combined-op-body { - flex: 1; - overflow-y: auto; + flex: 1; + overflow-y: auto; } .combined-op-divider { - margin-bottom: 10px; + margin-bottom: 10px; } .code-snippet { - border: 1px solid #ECEFF3; + border: 1px solid #eceff3; } .code-snippet > .language-javascript { - margin: 0px !important; - height: 100%; + margin: 0px !important; + height: 100%; } .combined-op-layout .gap { - margin-bottom: 12px; + margin-bottom: 12px; +} + +/* Enrichment info bar styling */ +.enrichment-info-bar { + padding: 8px 12px; + background-color: #f6ffed; + border: 1px solid #b7eb8f; + border-radius: 4px; + margin-bottom: 10px; +} + +.combined-op-segment { + display: flex; + align-items: center; + gap: 8px; } diff --git a/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx b/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx index 36c043d0d0..e42f61715a 100644 --- a/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx +++ b/frontend/src/components/custom-tools/combined-output/CombinedOutput.jsx @@ -2,7 +2,7 @@ import "prismjs/components/prism-json"; import "prismjs/plugins/line-numbers/prism-line-numbers.css"; import "prismjs/plugins/line-numbers/prism-line-numbers.js"; import "prismjs/themes/prism.css"; -import { useEffect, useState, useCallback } from "react"; +import { useEffect, useState, useCallback, useRef } from "react"; import { useParams } from "react-router-dom"; import PropTypes from "prop-types"; @@ -64,6 +64,12 @@ function CombinedOutput({ docId, setFilledFields, selectedPrompts }) { const [selectedProfile, setSelectedProfile] = useState(defaultLlmProfile); const [filteredCombinedOutput, setFilteredCombinedOutput] = useState({}); + // Lookup enrichment state + const [isEnriching, setIsEnriching] = useState(false); + const [enrichmentResult, setEnrichmentResult] = useState(null); + const [hasLinkedLookups, setHasLinkedLookups] = useState(false); + const enrichmentCheckedRef = useRef(false); + const { id } = useParams(); const { sessionDetails } = useSessionStore(); const { setAlertDetails } = useAlertStore(); @@ -97,6 +103,90 @@ function CombinedOutput({ docId, setFilledFields, selectedPrompts }) { setSelectedProfile(singlePassExtractMode ? defaultLlmProfile : null); }, [singlePassExtractMode]); + // Check if there are linked Look-Ups for this project + useEffect(() => { + if (isSimplePromptStudio || isPublicSource) return; + + // Use id from URL params as fallback (same as tool_id) + const toolId = details?.tool_id || id; + if (!toolId || !sessionDetails?.orgId) return; + + // Skip if already checked for this tool + if (enrichmentCheckedRef.current === toolId) return; + + const checkLinkedLookups = async () => { + try { + const url = `/api/v1/unstract/${sessionDetails?.orgId}/lookup-links/?prompt_studio_project_id=${toolId}`; + const res = await axiosPrivate.get(url); + const links = res?.data?.results || res?.data || []; + setHasLinkedLookups(links.length > 0); + enrichmentCheckedRef.current = toolId; + } catch (err) { + // Silently fail - lookups may not be available + console.debug("Could not check for linked Look-Ups:", err); + setHasLinkedLookups(false); + } + }; + + checkLinkedLookups(); + }, [details?.tool_id, id, sessionDetails?.orgId]); + + // Reset enrichment when document changes + useEffect(() => { + setEnrichmentResult(null); + }, [docId]); + + // Handler for enriching output with Look-Ups + const handleEnrichWithLookups = useCallback(async () => { + if (isEnriching || Object.keys(filteredCombinedOutput).length === 0) return; + + setIsEnriching(true); + try { + const toolId = details?.tool_id || id; + const url = `/api/v1/unstract/${sessionDetails?.orgId}/lookup-debug/enrich_ps_output/`; + + // Get fresh CSRF token from cookie + const csrfToken = + sessionDetails?.csrfToken || + document.cookie + .split("; ") + .find((row) => row.startsWith("csrftoken=")) + ?.split("=")[1]; + + const res = await axiosPrivate.post( + url, + { + prompt_studio_project_id: toolId, + extracted_data: filteredCombinedOutput, + }, + { + headers: { + "X-CSRFToken": csrfToken, + "Content-Type": "application/json", + }, + } + ); + + setEnrichmentResult(res.data); + setAlertDetails({ + type: "success", + content: `Successfully enriched with ${ + res.data._lookup_metadata?.lookups_executed || 0 + } Look-Up(s)`, + }); + } catch (err) { + setAlertDetails(handleException(err, "Failed to enrich with Look-Ups")); + } finally { + setIsEnriching(false); + } + }, [ + filteredCombinedOutput, + details?.tool_id, + id, + sessionDetails?.orgId, + isEnriching, + ]); + useEffect(() => { if (!docId || isSinglePassExtractLoading) return; @@ -229,6 +319,10 @@ function CombinedOutput({ docId, setFilledFields, selectedPrompts }) { adapterData={adapterData} isSinglePass={singlePassExtractMode} isLoading={isOutputLoading} + onEnrichWithLookups={handleEnrichWithLookups} + isEnriching={isEnriching} + enrichmentResult={enrichmentResult} + hasLinkedLookups={hasLinkedLookups} /> ); } diff --git a/frontend/src/components/custom-tools/combined-output/JsonView.jsx b/frontend/src/components/custom-tools/combined-output/JsonView.jsx index aee6c297a5..f803bab378 100644 --- a/frontend/src/components/custom-tools/combined-output/JsonView.jsx +++ b/frontend/src/components/custom-tools/combined-output/JsonView.jsx @@ -1,8 +1,9 @@ import PropTypes from "prop-types"; import Prism from "prismjs"; -import { useEffect } from "react"; +import { useEffect, useState } from "react"; import TabPane from "antd/es/tabs/TabPane"; -import { Tabs } from "antd"; +import { Tabs, Button, Tooltip, Badge } from "antd"; +import { ThunderboltOutlined, CheckCircleOutlined } from "@ant-design/icons"; import { JsonViewBody } from "./JsonViewBody"; @@ -15,10 +16,32 @@ function JsonView({ llmProfiles, isSinglePass, isLoading, + onEnrichWithLookups, + isEnriching, + enrichmentResult, + hasLinkedLookups, }) { + const [showEnriched, setShowEnriched] = useState(false); + useEffect(() => { Prism.highlightAll(); - }, [combinedOutput]); + }, [combinedOutput, enrichmentResult, showEnriched]); + + // Determine what output to display + const displayOutput = + showEnriched && enrichmentResult?.enriched_data + ? enrichmentResult.enriched_data + : combinedOutput; + + const handleEnrichClick = () => { + if (enrichmentResult) { + // Toggle between original and enriched view + setShowEnriched(!showEnriched); + } else if (onEnrichWithLookups) { + // Trigger enrichment + onEnrichWithLookups(); + } + }; return (
@@ -34,15 +57,58 @@ function JsonView({ /> ))} -
+
+ {hasLinkedLookups && onEnrichWithLookups && ( + + + + + + )} +
@@ -58,6 +124,10 @@ JsonView.propTypes = { activeKey: PropTypes.string, isSinglePass: PropTypes.bool, isLoading: PropTypes.bool.isRequired, + onEnrichWithLookups: PropTypes.func, + isEnriching: PropTypes.bool, + enrichmentResult: PropTypes.object, + hasLinkedLookups: PropTypes.bool, }; export { JsonView }; diff --git a/frontend/src/components/custom-tools/combined-output/JsonViewBody.jsx b/frontend/src/components/custom-tools/combined-output/JsonViewBody.jsx index 97ddcd212e..2947a79e18 100644 --- a/frontend/src/components/custom-tools/combined-output/JsonViewBody.jsx +++ b/frontend/src/components/custom-tools/combined-output/JsonViewBody.jsx @@ -1,4 +1,6 @@ import PropTypes from "prop-types"; +import { Tag, Space } from "antd"; +import { CheckCircleOutlined, ThunderboltOutlined } from "@ant-design/icons"; import { ProfileInfoBar } from "../profile-info-bar/ProfileInfoBar"; import { SpinnerLoader } from "../../widgets/spinner-loader/SpinnerLoader"; @@ -9,6 +11,8 @@ function JsonViewBody({ llmProfiles, combinedOutput, isLoading, + isEnriched, + enrichmentMetadata, }) { if (isLoading) { return ; @@ -19,6 +23,38 @@ function JsonViewBody({ {activeKey !== "0" && ( )} + {isEnriched && enrichmentMetadata && ( +
+ + } color="success"> + Enriched with Look-Ups + + + {enrichmentMetadata.lookups_executed} Look-Up + {enrichmentMetadata.lookups_executed !== 1 ? "s" : ""} executed + + {enrichmentMetadata.lookup_details?.map((detail, idx) => ( + + ) : undefined + } + color={ + detail.status === "success" + ? "green" + : detail.status === "error" + ? "red" + : "default" + } + > + {detail.project_name} + + ))} + +
+ )}
{combinedOutput && (
@@ -38,6 +74,8 @@ JsonViewBody.propTypes = {
   llmProfiles: PropTypes.string,
   combinedOutput: PropTypes.object,
   isLoading: PropTypes.bool.isRequired,
+  isEnriched: PropTypes.bool,
+  enrichmentMetadata: PropTypes.object,
 };
 
 export { JsonViewBody };
diff --git a/frontend/src/components/lookups/create-project-modal/CreateProjectModal.jsx b/frontend/src/components/lookups/create-project-modal/CreateProjectModal.jsx
new file mode 100644
index 0000000000..1fb765b035
--- /dev/null
+++ b/frontend/src/components/lookups/create-project-modal/CreateProjectModal.jsx
@@ -0,0 +1,92 @@
+import { Form, Input, Modal, Select } from "antd";
+import PropTypes from "prop-types";
+
+const { TextArea } = Input;
+
+const REFERENCE_DATA_TYPES = [
+  { value: "vendor_catalog", label: "Vendor Catalog" },
+  { value: "product_catalog", label: "Product Catalog" },
+  { value: "customer_database", label: "Customer Database" },
+  { value: "pricing_data", label: "Pricing Data" },
+  { value: "compliance_rules", label: "Compliance Rules" },
+  { value: "custom", label: "Custom" },
+];
+
+export function CreateProjectModal({ open, onCancel, onCreate }) {
+  const [form] = Form.useForm();
+
+  const handleSubmit = async () => {
+    try {
+      const values = await form.validateFields();
+      await onCreate(values);
+      form.resetFields();
+    } catch (error) {
+      console.error("Validation failed:", error);
+    }
+  };
+
+  return (
+    
+      
+ + + + + +