diff --git a/backend/backend/settings/base.py b/backend/backend/settings/base.py index 94cb55b3ae..eb708936fb 100644 --- a/backend/backend/settings/base.py +++ b/backend/backend/settings/base.py @@ -342,6 +342,7 @@ def filter(self, record): "prompt_studio.prompt_studio_index_manager_v2", "tags", "configuration", + "lookup", ) TENANT_APPS = [] @@ -599,3 +600,11 @@ def filter(self, record): raise ValueError(ERROR_MESSAGE) ENABLE_HIGHLIGHT_API_DEPLOYMENT = os.environ.get("ENABLE_HIGHLIGHT_API_DEPLOYMENT", False) + +# Lookup Integration Settings +# Enable/disable automatic Lookup enrichment after Prompt Studio extraction +LOOKUP_AUTO_ENRICH_ENABLED = CommonUtils.str_to_bool( + os.environ.get("LOOKUP_AUTO_ENRICH_ENABLED", "True") +) +# Maximum time (in seconds) to wait for Lookup enrichment before returning +LOOKUP_ENRICHMENT_TIMEOUT = int(os.environ.get("LOOKUP_ENRICHMENT_TIMEOUT", "30")) diff --git a/backend/backend/urls.py b/backend/backend/urls.py index 1dd4a6a692..6b66dd8991 100644 --- a/backend/backend/urls.py +++ b/backend/backend/urls.py @@ -35,6 +35,7 @@ path("platform/", include("platform_settings.urls")), path("api/", include("api.urls")), path("usage/", include("usage.urls")), + path("lookup/", include("lookup.urls")), path( UrlPathConstants.PROMPT_STUDIO, include("prompt_studio.prompt_profile_manager.urls"), diff --git a/backend/backend/urls_v2.py b/backend/backend/urls_v2.py index 988183abfb..8b25436674 100644 --- a/backend/backend/urls_v2.py +++ b/backend/backend/urls_v2.py @@ -35,6 +35,7 @@ path("usage/", include("usage_v2.urls")), path("notifications/", include("notification_v2.urls")), path("logs/", include("logs_helper.urls")), + path("lookup/", include("lookup.urls")), path( UrlPathConstants.PROMPT_STUDIO, include("prompt_studio.prompt_profile_manager_v2.urls"), diff --git a/backend/lookup/__init__.py b/backend/lookup/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/lookup/apps.py b/backend/lookup/apps.py new file mode 100644 index 0000000000..ddaf2e194a --- /dev/null +++ b/backend/lookup/apps.py @@ -0,0 +1,11 @@ +"""Lookup app configuration.""" + +from django.apps import AppConfig + + +class LookupConfig(AppConfig): + """Configuration for the Lookup application.""" + + default_auto_field = "django.db.models.BigAutoField" + name = "lookup" + verbose_name = "Look-Up System" diff --git a/backend/lookup/constants.py b/backend/lookup/constants.py new file mode 100644 index 0000000000..f41cc62bfc --- /dev/null +++ b/backend/lookup/constants.py @@ -0,0 +1,28 @@ +"""Constants for the Look-Up system.""" + + +class LookupProfileManagerKeys: + """Keys used in LookupProfileManager serialization.""" + + CREATED_BY = "created_by" + MODIFIED_BY = "modified_by" + LOOKUP_PROJECT = "lookup_project" + PROFILE_NAME = "profile_name" + LLM = "llm" + VECTOR_STORE = "vector_store" + EMBEDDING_MODEL = "embedding_model" + X2TEXT = "x2text" + CHUNK_SIZE = "chunk_size" + CHUNK_OVERLAP = "chunk_overlap" + SIMILARITY_TOP_K = "similarity_top_k" + IS_DEFAULT = "is_default" + REINDEX = "reindex" + + +class LookupProfileManagerErrors: + """Error messages for LookupProfileManager operations.""" + + SERIALIZATION_FAILED = "Data serialization failed." + PROFILE_NAME_EXISTS = "A profile with this name already exists for this project." + DUPLICATE_API = "It appears that a duplicate call may have been made." + NO_DEFAULT_PROFILE = "No default profile found for this project." diff --git a/backend/lookup/exceptions.py b/backend/lookup/exceptions.py new file mode 100644 index 0000000000..3a8999eda2 --- /dev/null +++ b/backend/lookup/exceptions.py @@ -0,0 +1,97 @@ +"""Custom exceptions for the Look-Up system. + +This module defines custom exceptions specific to the Look-Up functionality. +""" + + +class LookupError(Exception): + """Base exception for Look-Up system errors.""" + + pass + + +class ExtractionNotCompleteError(LookupError): + """Raised when attempting to use reference data before extraction is complete. + + This exception is raised when trying to load reference data for a project + where one or more data sources have not completed extraction processing. + """ + + def __init__(self, failed_files=None): + """Initialize the exception. + + Args: + failed_files: List of file names that failed or are pending extraction + """ + self.failed_files = failed_files or [] + message = "Reference data extraction not complete" + if failed_files: + message += f" for files: {', '.join(failed_files)}" + super().__init__(message) + + +class TemplateNotFoundError(LookupError): + """Raised when a Look-Up project has no associated template. + + This exception is raised when attempting to execute a Look-Up + that doesn't have a prompt template configured. + """ + + pass + + +class ParseError(LookupError): + """Raised when LLM response cannot be parsed. + + This exception is raised when the LLM returns a response that + cannot be parsed as valid JSON or doesn't match expected format. + """ + + pass + + +class DefaultProfileError(LookupError): + """Raised when default profile is not found for a Look-Up project. + + This exception is raised when attempting to get the default profile + for a Look-Up project that doesn't have one configured. + """ + + pass + + +class ContextWindowExceededError(LookupError): + """Raised when prompt + reference data exceeds LLM context window. + + This exception is raised when the combined size of the prompt template, + reference data, and extracted data exceeds the configured LLM's context + window limit. + """ + + def __init__(self, token_count: int, context_limit: int, model: str): + """Initialize the exception. + + Args: + token_count: Number of tokens in the prompt + context_limit: Maximum tokens allowed by the model + model: Name of the LLM model + """ + self.token_count = token_count + self.context_limit = context_limit + self.model = model + message = ( + f"Context window exceeded: prompt requires {token_count:,} tokens " + f"but {model} has a limit of {context_limit:,} tokens. " + f"Reduce reference data size or use a model with larger context window." + ) + super().__init__(message) + + +class RetrievalError(LookupError): + """Raised when RAG retrieval fails. + + This exception is raised when the vector similarity search fails + to retrieve context from indexed reference data. + """ + + pass diff --git a/backend/lookup/integrations/__init__.py b/backend/lookup/integrations/__init__.py new file mode 100644 index 0000000000..53383fb10c --- /dev/null +++ b/backend/lookup/integrations/__init__.py @@ -0,0 +1,8 @@ +"""Integration modules for Look-Up functionality. + +This package contains integrations with external services: +- Object Storage (S3-compatible) +- LLM Providers (OpenAI, Anthropic, etc.) +- LLMWhisperer (Document extraction) +- Redis Cache +""" diff --git a/backend/lookup/integrations/file_storage_client.py b/backend/lookup/integrations/file_storage_client.py new file mode 100644 index 0000000000..28870c19e8 --- /dev/null +++ b/backend/lookup/integrations/file_storage_client.py @@ -0,0 +1,86 @@ +"""File Storage Client for Look-Up reference data. + +This module provides integration with Unstract's file storage +for loading extracted reference data content. +""" + +import logging + +from utils.file_storage.constants import FileStorageKeys + +from unstract.sdk1.file_storage.constants import StorageType +from unstract.sdk1.file_storage.env_helper import EnvHelper + +logger = logging.getLogger(__name__) + + +class FileStorageClient: + """Storage client implementation using Unstract's file storage. + + This client uses the actual platform file storage to read + extracted reference data content. + """ + + def __init__(self): + """Initialize the file storage client.""" + self.fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + def get(self, path: str) -> str: + """Retrieve file content from storage. + + Args: + path: Storage path to the file + + Returns: + File content as string + + Raises: + FileNotFoundError: If file doesn't exist + Exception: If reading fails + """ + try: + if not self.fs_instance.exists(path): + logger.error(f"File not found: {path}") + raise FileNotFoundError(f"File not found: {path}") + + # Use read() method with text mode + content = self.fs_instance.read(path, mode="r", encoding="utf-8") + logger.debug(f"Read {len(content)} chars from {path}") + return content + + except FileNotFoundError: + raise + except Exception as e: + logger.error(f"Failed to read file {path}: {e}") + raise Exception(f"Failed to read file: {str(e)}") + + def exists(self, path: str) -> bool: + """Check if path exists in storage. + + Args: + path: Storage path + + Returns: + True if exists + """ + return self.fs_instance.exists(path) + + def get_text_content(self, path: str) -> str | None: + """Get text content from storage (alias for get). + + Args: + path: Storage path + + Returns: + Text content or None if not found + """ + try: + return self.get(path) + except FileNotFoundError: + return None + except Exception as e: + logger.warning(f"Error reading {path}: {e}") + return None diff --git a/backend/lookup/integrations/llm_provider.py b/backend/lookup/integrations/llm_provider.py new file mode 100644 index 0000000000..ca16d9b235 --- /dev/null +++ b/backend/lookup/integrations/llm_provider.py @@ -0,0 +1,311 @@ +"""LLM Provider integration for Look-Up enrichment. + +This module provides integration with various LLM providers +(OpenAI, Anthropic, etc.) for generating enrichment data. +""" + +import json +import logging +import os +import time +from typing import Any + +from ..protocols import LLMClient + +logger = logging.getLogger(__name__) + + +class UnstractLLMClient(LLMClient): + """Implementation of LLMClient using Unstract's LLM abstraction. + + This client integrates with the Unstract platform's LLM providers + to generate enrichment data for Look-Ups. + """ + + def __init__(self, provider: str | None = None, model: str | None = None): + """Initialize the LLM client. + + Args: + provider: LLM provider name (e.g., 'openai', 'anthropic') + model: Model name (e.g., 'gpt-4', 'claude-2') + """ + self.default_provider = provider or os.getenv( + "LOOKUP_DEFAULT_LLM_PROVIDER", "openai" + ) + self.default_model = model or os.getenv("LOOKUP_DEFAULT_LLM_MODEL", "gpt-4") + + # Import Unstract LLM utilities + try: + from unstract.llmbox import LLMBox + from unstract.llmbox.llm import LLM + + self.LLMBox = LLMBox + self.LLM = LLM + self.llm_available = True + except ImportError: + logger.warning("Unstract LLMBox not available, using fallback implementation") + self.llm_available = False + + def _get_llm_instance(self, config: dict[str, Any]): + """Get an LLM instance based on configuration. + + Args: + config: LLM configuration + + Returns: + LLM instance + """ + if not self.llm_available: + raise RuntimeError("LLM integration not available") + + provider = config.get("provider", self.default_provider) + model = config.get("model", self.default_model) + + # Create LLM instance using Unstract's LLMBox + # This would integrate with the actual Unstract LLM abstraction + # For now, we'll use a simplified approach + + # Map provider to Unstract's LLM types + provider_map = { + "openai": "OpenAI", + "anthropic": "Anthropic", + "azure_openai": "AzureOpenAI", + "gemini": "Gemini", + "vertex_ai": "VertexAI", + } + + llm_provider = provider_map.get(provider, "OpenAI") + + # Create configuration for the provider + llm_config = { + "provider": llm_provider, + "model": model, + "temperature": config.get("temperature", 0.7), + "max_tokens": config.get("max_tokens", 1000), + } + + # Add API keys based on provider + if provider == "openai": + llm_config["api_key"] = os.getenv("OPENAI_API_KEY") + elif provider == "anthropic": + llm_config["api_key"] = os.getenv("ANTHROPIC_API_KEY") + elif provider == "azure_openai": + llm_config["api_key"] = os.getenv("AZURE_OPENAI_API_KEY") + llm_config["endpoint"] = os.getenv("AZURE_OPENAI_ENDPOINT") + + # For now, return a mock implementation + # In production, this would create actual LLM instance + return llm_config + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate LLM response for Look-Up enrichment. + + Args: + prompt: The prompt text with resolved variables + config: LLM configuration (provider, model, temperature, etc.) + timeout: Request timeout in seconds + + Returns: + JSON-formatted string with enrichment data + + Raises: + RuntimeError: If LLM call fails + """ + start_time = time.time() + + try: + if self.llm_available: + # Use actual Unstract LLM integration + llm_config = self._get_llm_instance(config) + + # In production, this would make actual LLM call + # For now, simulate the call + response = self._simulate_llm_call(prompt, llm_config) + else: + # Fallback implementation + response = self._fallback_generate(prompt, config) + + # Validate response is JSON + try: + json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from response + response = self._extract_json(response) + + elapsed_time = time.time() - start_time + logger.info(f"LLM generation completed in {elapsed_time:.2f}s") + + return response + + except Exception as e: + logger.error(f"LLM generation failed: {e}") + raise RuntimeError(f"LLM generation failed: {str(e)}") + + def _simulate_llm_call(self, prompt: str, config: dict[str, Any]) -> str: + """Simulate LLM call for development/testing. + + In production, this would be replaced with actual LLM API calls. + """ + # Simulate some processing + time.sleep(0.5) + + # Generate mock response based on prompt content + if "vendor" in prompt.lower(): + return json.dumps( + { + "canonical_vendor": "Sample Vendor Inc.", + "vendor_category": "SaaS", + "vendor_type": "Software", + "confidence": 0.95, + } + ) + elif "product" in prompt.lower(): + return json.dumps( + { + "product_name": "Sample Product", + "product_category": "Enterprise Software", + "product_type": "Cloud Service", + "confidence": 0.92, + } + ) + else: + return json.dumps({"enriched_data": "Sample enrichment", "confidence": 0.88}) + + def _fallback_generate(self, prompt: str, config: dict[str, Any]) -> str: + """Fallback generation when LLM integration is not available. + + This is primarily for testing and development. + """ + logger.warning("Using fallback LLM generation") + + # Simple pattern matching for testing + response_data = { + "status": "fallback", + "message": "LLM integration not available", + "confidence": 0.5, + } + + # Add some basic enrichment based on prompt + if "vendor" in prompt.lower(): + response_data["canonical_vendor"] = "Unknown Vendor" + response_data["vendor_category"] = "Unknown" + elif "product" in prompt.lower(): + response_data["product_name"] = "Unknown Product" + response_data["product_category"] = "Unknown" + + return json.dumps(response_data) + + def _extract_json(self, response: str) -> str: + """Extract JSON from LLM response if wrapped in text. + + Args: + response: Raw LLM response + + Returns: + Extracted JSON string + """ + # Try to find JSON in the response + import re + + # Look for JSON object pattern + json_pattern = r"\{[^{}]*\}" + matches = re.findall(json_pattern, response) + + if matches: + # Try to parse each match + for match in matches: + try: + json.loads(match) + return match + except json.JSONDecodeError: + continue + + # If no valid JSON found, create a basic response + return json.dumps( + { + "raw_response": response[:500], # Truncate if too long + "confidence": 0.3, + "warning": "Could not extract structured data", + } + ) + + def validate_response(self, response: str) -> bool: + """Validate that the LLM response is properly formatted. + + Args: + response: LLM response string + + Returns: + True if valid JSON with required fields + """ + try: + data = json.loads(response) + + # Check for confidence score + if "confidence" not in data: + logger.warning("Response missing confidence score") + return False + + # Check confidence is valid + confidence = data.get("confidence", 0) + if not (0 <= confidence <= 1): + logger.warning(f"Invalid confidence score: {confidence}") + return False + + return True + + except json.JSONDecodeError: + logger.error("Response is not valid JSON") + return False + + def get_token_count(self, text: str, model: str = None) -> int: + """Estimate token count for the given text. + + Args: + text: Input text + model: Model name for accurate counting + + Returns: + Estimated token count + """ + # Simple estimation: ~4 characters per token + # In production, use tiktoken or model-specific tokenizer + return len(text) // 4 + + +class OpenAILLMClient(UnstractLLMClient): + """OpenAI-specific LLM client implementation.""" + + def __init__(self): + """Initialize OpenAI client.""" + super().__init__(provider="openai", model="gpt-4") + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate using OpenAI API.""" + # Override config with OpenAI defaults + config = { + **config, + "provider": "openai", + "model": config.get("model", "gpt-4"), + "temperature": config.get("temperature", 0.7), + } + return super().generate(prompt, config, timeout) + + +class AnthropicLLMClient(UnstractLLMClient): + """Anthropic-specific LLM client implementation.""" + + def __init__(self): + """Initialize Anthropic client.""" + super().__init__(provider="anthropic", model="claude-2") + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate using Anthropic API.""" + # Override config with Anthropic defaults + config = { + **config, + "provider": "anthropic", + "model": config.get("model", "claude-2"), + "temperature": config.get("temperature", 0.7), + } + return super().generate(prompt, config, timeout) diff --git a/backend/lookup/integrations/llmwhisperer_client.py b/backend/lookup/integrations/llmwhisperer_client.py new file mode 100644 index 0000000000..5445196e97 --- /dev/null +++ b/backend/lookup/integrations/llmwhisperer_client.py @@ -0,0 +1,334 @@ +"""LLMWhisperer integration for document text extraction. + +This module provides integration with LLMWhisperer service +for extracting text from various document formats. +""" + +import logging +import time +from enum import Enum +from typing import Any + +import requests +from django.conf import settings + +logger = logging.getLogger(__name__) + + +class ExtractionStatus(Enum): + """Status of document extraction.""" + + PENDING = "pending" + PROCESSING = "processing" + COMPLETE = "complete" + FAILED = "failed" + NOT_REQUIRED = "not_required" + + +class LLMWhispererClient: + """Client for integrating with LLMWhisperer document extraction service. + + LLMWhisperer extracts text from PDFs, images, and other document formats + for use as reference data in Look-Ups. + """ + + def __init__(self): + """Initialize LLMWhisperer client with configuration.""" + self.base_url = getattr( + settings, "LLMWHISPERER_BASE_URL", "https://api.llmwhisperer.com" + ) + self.api_key = getattr(settings, "LLMWHISPERER_API_KEY", "") + + if not self.api_key: + logger.warning("LLMWhisperer API key not configured") + + self.session = requests.Session() + self.session.headers.update( + { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + ) + + def extract_text( + self, + file_content: bytes, + file_name: str, + extraction_config: dict[str, Any] | None = None, + ) -> tuple[str, str]: + """Extract text from a document using LLMWhisperer. + + Args: + file_content: File content as bytes + file_name: Original file name + extraction_config: Optional extraction configuration + + Returns: + Tuple of (extraction_id, status) + """ + if not self.api_key: + logger.error("LLMWhisperer API key not configured") + return "", ExtractionStatus.FAILED.value + + try: + # Prepare extraction request + config = extraction_config or self._get_default_config() + + # Create extraction job + url = f"{self.base_url}/v1/extract" + + # Prepare multipart form data + files = {"file": (file_name, file_content)} + + # Add configuration as form data + data = { + "processing_mode": config.get("processing_mode", "ocr"), + "output_format": config.get("output_format", "text"), + "page_separator": config.get("page_separator", "\n---\n"), + "force_text_processing": str( + config.get("force_text_processing", True) + ).lower(), + "line_splitter": config.get("line_splitter", "line"), + "horizontal_stretch": str(config.get("horizontal_stretch", 1.0)), + "vertical_stretch": str(config.get("vertical_stretch", 1.0)), + } + + # Remove JSON content type for multipart + headers = {"Authorization": f"Bearer {self.api_key}"} + + response = requests.post( + url, files=files, data=data, headers=headers, timeout=30 + ) + + if response.status_code == 200: + result = response.json() + extraction_id = result.get("extraction_id", "") + logger.info(f"Started extraction job: {extraction_id}") + return extraction_id, ExtractionStatus.PROCESSING.value + else: + logger.error( + f"Extraction failed with status {response.status_code}: {response.text}" + ) + return "", ExtractionStatus.FAILED.value + + except requests.exceptions.RequestException as e: + logger.error(f"LLMWhisperer extraction request failed: {e}") + return "", ExtractionStatus.FAILED.value + except Exception as e: + logger.error(f"Unexpected error during extraction: {e}") + return "", ExtractionStatus.FAILED.value + + def check_extraction_status(self, extraction_id: str) -> tuple[str, str | None]: + """Check the status of an extraction job. + + Args: + extraction_id: Extraction job ID + + Returns: + Tuple of (status, extracted_text) + """ + if not extraction_id: + return ExtractionStatus.FAILED.value, None + + try: + url = f"{self.base_url}/v1/status/{extraction_id}" + + response = self.session.get(url, timeout=10) + + if response.status_code == 200: + result = response.json() + status = result.get("status", "unknown") + + if status == "completed": + # Get extracted text + text = self._get_extracted_text(extraction_id) + return ExtractionStatus.COMPLETE.value, text + elif status == "processing": + return ExtractionStatus.PROCESSING.value, None + elif status == "failed": + error_msg = result.get("error", "Unknown error") + logger.error(f"Extraction {extraction_id} failed: {error_msg}") + return ExtractionStatus.FAILED.value, None + else: + logger.warning(f"Unknown extraction status: {status}") + return ExtractionStatus.PENDING.value, None + else: + logger.error(f"Status check failed with code {response.status_code}") + return ExtractionStatus.FAILED.value, None + + except Exception as e: + logger.error(f"Error checking extraction status: {e}") + return ExtractionStatus.FAILED.value, None + + def _get_extracted_text(self, extraction_id: str) -> str | None: + """Retrieve extracted text for a completed job. + + Args: + extraction_id: Extraction job ID + + Returns: + Extracted text or None + """ + try: + url = f"{self.base_url}/v1/result/{extraction_id}" + + response = self.session.get(url, timeout=30) + + if response.status_code == 200: + # Response is the extracted text + return response.text + else: + logger.error(f"Failed to get extraction result: {response.status_code}") + return None + + except Exception as e: + logger.error(f"Error retrieving extracted text: {e}") + return None + + def wait_for_extraction( + self, extraction_id: str, max_wait_seconds: int = 300, poll_interval: int = 5 + ) -> tuple[str, str | None]: + """Wait for extraction to complete with polling. + + Args: + extraction_id: Extraction job ID + max_wait_seconds: Maximum time to wait + poll_interval: Seconds between status checks + + Returns: + Tuple of (final_status, extracted_text) + """ + start_time = time.time() + + while time.time() - start_time < max_wait_seconds: + status, text = self.check_extraction_status(extraction_id) + + if status == ExtractionStatus.COMPLETE.value: + logger.info(f"Extraction {extraction_id} completed successfully") + return status, text + elif status == ExtractionStatus.FAILED.value: + logger.error(f"Extraction {extraction_id} failed") + return status, None + elif status == ExtractionStatus.PROCESSING.value: + logger.debug(f"Extraction {extraction_id} still processing...") + time.sleep(poll_interval) + else: + logger.warning(f"Unexpected status for {extraction_id}: {status}") + time.sleep(poll_interval) + + logger.error( + f"Extraction {extraction_id} timed out after {max_wait_seconds} seconds" + ) + return ExtractionStatus.FAILED.value, None + + def extract_and_wait( + self, + file_content: bytes, + file_name: str, + extraction_config: dict[str, Any] | None = None, + max_wait_seconds: int = 300, + ) -> tuple[bool, str | None]: + """Extract text and wait for completion. + + Args: + file_content: File content + file_name: File name + extraction_config: Extraction configuration + max_wait_seconds: Maximum wait time + + Returns: + Tuple of (success, extracted_text) + """ + # Start extraction + extraction_id, status = self.extract_text( + file_content, file_name, extraction_config + ) + + if status == ExtractionStatus.FAILED.value: + return False, None + + # Wait for completion + final_status, text = self.wait_for_extraction(extraction_id, max_wait_seconds) + + return final_status == ExtractionStatus.COMPLETE.value, text + + def _get_default_config(self) -> dict[str, Any]: + """Get default extraction configuration. + + Returns: + Default configuration dictionary + """ + return { + "processing_mode": "ocr", # 'ocr' or 'text' + "output_format": "text", # 'text' or 'markdown' + "page_separator": "\n---\n", + "force_text_processing": True, + "line_splitter": "line", # 'line' or 'paragraph' + "horizontal_stretch": 1.0, + "vertical_stretch": 1.0, + "pages": "", # Empty for all pages + "timeout": 300, + "store_metadata": False, + } + + def is_extraction_needed(self, file_name: str) -> bool: + """Check if extraction is needed for the file type. + + Args: + file_name: File name with extension + + Returns: + True if extraction is needed + """ + # File types that need extraction + extractable_extensions = { + ".pdf", + ".png", + ".jpg", + ".jpeg", + ".tiff", + ".bmp", + ".docx", + ".doc", + ".pptx", + ".ppt", + ".xlsx", + ".xls", + } + + # Check file extension + import os + + _, ext = os.path.splitext(file_name.lower()) + return ext in extractable_extensions + + def get_extraction_config_for_file(self, file_name: str) -> dict[str, Any]: + """Get optimal extraction configuration based on file type. + + Args: + file_name: File name with extension + + Returns: + Extraction configuration + """ + import os + + _, ext = os.path.splitext(file_name.lower()) + + config = self._get_default_config() + + # Optimize based on file type + if ext in [".pdf"]: + config["processing_mode"] = "ocr" + config["force_text_processing"] = True + elif ext in [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]: + config["processing_mode"] = "ocr" + config["force_text_processing"] = False + elif ext in [".docx", ".doc"]: + config["processing_mode"] = "text" + config["output_format"] = "markdown" + elif ext in [".xlsx", ".xls"]: + config["processing_mode"] = "text" + config["line_splitter"] = "paragraph" + + return config diff --git a/backend/lookup/integrations/redis_cache.py b/backend/lookup/integrations/redis_cache.py new file mode 100644 index 0000000000..976164a6ac --- /dev/null +++ b/backend/lookup/integrations/redis_cache.py @@ -0,0 +1,384 @@ +"""Redis cache implementation for Look-Up LLM responses. + +This module provides a Redis-based cache for storing and retrieving +LLM responses to improve performance and reduce API costs. +""" + +import hashlib +import logging +import time +from typing import Any + +try: + import redis + from redis.exceptions import RedisError + + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + +from django.conf import settings + +logger = logging.getLogger(__name__) + + +class RedisLLMCache: + """Redis-based cache for LLM responses. + + Provides persistent caching with TTL support and + automatic failover to in-memory cache if Redis is unavailable. + """ + + def __init__( + self, + ttl_hours: int = 24, + key_prefix: str = "lookup:llm:", + fallback_to_memory: bool = True, + ): + """Initialize Redis cache. + + Args: + ttl_hours: Time-to-live for cache entries in hours + key_prefix: Prefix for all cache keys + fallback_to_memory: Use in-memory cache if Redis unavailable + """ + self.ttl_seconds = ttl_hours * 3600 + self.key_prefix = key_prefix + self.fallback_to_memory = fallback_to_memory + + # Initialize Redis connection + self.redis_client = self._init_redis() + + # Fallback in-memory cache + if self.fallback_to_memory: + from ..services.llm_cache import LLMResponseCache + + self.memory_cache = LLMResponseCache(ttl_hours) + else: + self.memory_cache = None + + def _init_redis(self) -> Any | None: + """Initialize Redis connection. + + Returns: + Redis client or None if unavailable + """ + if not REDIS_AVAILABLE: + logger.warning("Redis package not installed, using fallback cache") + return None + + try: + # Get Redis configuration from settings + redis_host = getattr(settings, "REDIS_HOST", "localhost") + redis_port = getattr(settings, "REDIS_PORT", 6379) + redis_db = getattr(settings, "REDIS_CACHE_DB", 1) + redis_password = getattr(settings, "REDIS_PASSWORD", None) + + # Create Redis client + client = redis.Redis( + host=redis_host, + port=redis_port, + db=redis_db, + password=redis_password, + decode_responses=True, + socket_connect_timeout=5, + socket_timeout=5, + ) + + # Test connection + client.ping() + logger.info(f"Connected to Redis at {redis_host}:{redis_port}/{redis_db}") + return client + + except Exception as e: + logger.warning(f"Failed to connect to Redis: {e}") + return None + + def generate_cache_key(self, prompt: str, reference_data: str) -> str: + """Generate a cache key from prompt and reference data. + + Args: + prompt: The resolved prompt + reference_data: The reference data content + + Returns: + SHA256 hash as cache key + """ + combined = f"{prompt}{reference_data}" + hash_obj = hashlib.sha256(combined.encode("utf-8")) + return f"{self.key_prefix}{hash_obj.hexdigest()}" + + def get(self, key: str) -> str | None: + """Get cached response. + + Args: + key: Cache key + + Returns: + Cached response or None if not found/expired + """ + # Try Redis first + if self.redis_client: + try: + value = self.redis_client.get(key) + if value: + logger.debug(f"Redis cache hit for key: {key[:20]}...") + self._update_stats("hits") + return value + else: + logger.debug(f"Redis cache miss for key: {key[:20]}...") + self._update_stats("misses") + except RedisError as e: + logger.error(f"Redis get error: {e}") + # Fall through to memory cache + + # Fallback to memory cache + if self.memory_cache: + # Remove prefix for memory cache + memory_key = key.replace(self.key_prefix, "") + value = self.memory_cache.get(memory_key) + if value: + logger.debug(f"Memory cache hit for key: {key[:20]}...") + return value + + return None + + def set(self, key: str, value: str, ttl: int | None = None) -> bool: + """Set cache value. + + Args: + key: Cache key + value: Response to cache + ttl: Optional TTL override in seconds + + Returns: + True if successful + """ + ttl = ttl or self.ttl_seconds + + # Try Redis first + if self.redis_client: + try: + self.redis_client.setex(name=key, time=ttl, value=value) + logger.debug(f"Cached to Redis with key: {key[:20]}... (TTL: {ttl}s)") + self._update_stats("sets") + return True + except RedisError as e: + logger.error(f"Redis set error: {e}") + # Fall through to memory cache + + # Fallback to memory cache + if self.memory_cache: + # Remove prefix for memory cache + memory_key = key.replace(self.key_prefix, "") + self.memory_cache.set(memory_key, value) + logger.debug(f"Cached to memory with key: {key[:20]}...") + return True + + return False + + def delete(self, key: str) -> bool: + """Delete a cache entry. + + Args: + key: Cache key + + Returns: + True if deleted + """ + deleted = False + + # Try Redis + if self.redis_client: + try: + result = self.redis_client.delete(key) + if result > 0: + deleted = True + logger.debug(f"Deleted from Redis: {key[:20]}...") + except RedisError as e: + logger.error(f"Redis delete error: {e}") + + # Also delete from memory cache + if self.memory_cache: + memory_key = key.replace(self.key_prefix, "") + if memory_key in self.memory_cache.cache: + del self.memory_cache.cache[memory_key] + deleted = True + logger.debug(f"Deleted from memory: {key[:20]}...") + + return deleted + + def clear_pattern(self, pattern: str) -> int: + """Clear all cache entries matching a pattern. + + Args: + pattern: Pattern to match (e.g., "lookup:llm:project_*") + + Returns: + Number of entries cleared + """ + count = 0 + + # Clear from Redis + if self.redis_client: + try: + # Use SCAN to find matching keys + cursor = 0 + while True: + cursor, keys = self.redis_client.scan( + cursor=cursor, match=pattern, count=100 + ) + if keys: + count += self.redis_client.delete(*keys) + if cursor == 0: + break + + logger.info(f"Cleared {count} entries from Redis matching: {pattern}") + except RedisError as e: + logger.error(f"Redis clear pattern error: {e}") + + # Clear from memory cache + if self.memory_cache: + # Remove prefix from pattern + memory_pattern = pattern.replace(self.key_prefix, "") + keys_to_delete = [ + k + for k in self.memory_cache.cache.keys() + if self._match_pattern(k, memory_pattern) + ] + for key in keys_to_delete: + del self.memory_cache.cache[key] + count += 1 + + logger.info(f"Cleared {len(keys_to_delete)} entries from memory") + + return count + + def _match_pattern(self, key: str, pattern: str) -> bool: + """Simple pattern matching for memory cache. + + Args: + key: Key to check + pattern: Pattern with * wildcards + + Returns: + True if matches + """ + import fnmatch + + return fnmatch.fnmatch(key, pattern) + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary of cache statistics + """ + stats = { + "backend": "redis" if self.redis_client else "memory", + "ttl_hours": self.ttl_seconds // 3600, + "key_prefix": self.key_prefix, + } + + # Get Redis stats + if self.redis_client: + try: + info = self.redis_client.info("stats") + stats["redis"] = { + "total_connections": info.get("total_connections_received", 0), + "keyspace_hits": info.get("keyspace_hits", 0), + "keyspace_misses": info.get("keyspace_misses", 0), + "hit_rate": self._calculate_hit_rate( + info.get("keyspace_hits", 0), info.get("keyspace_misses", 0) + ), + } + + # Get key count + dbinfo = self.redis_client.info("keyspace") + db_key = f"db{self.redis_client.connection_pool.connection_kwargs.get('db', 0)}" + if db_key in dbinfo: + stats["redis"]["total_keys"] = dbinfo[db_key].get("keys", 0) + + except RedisError as e: + logger.error(f"Failed to get Redis stats: {e}") + stats["redis"] = {"error": str(e)} + + # Get memory cache stats + if self.memory_cache: + stats["memory"] = { + "entries": len(self.memory_cache.cache), + "size_estimate": sum( + len(k) + len(v[0]) for k, v in self.memory_cache.cache.items() + ), + } + + return stats + + def _calculate_hit_rate(self, hits: int, misses: int) -> float: + """Calculate cache hit rate.""" + total = hits + misses + return (hits / total) if total > 0 else 0.0 + + def _update_stats(self, stat_type: str) -> None: + """Update internal statistics. + + Args: + stat_type: Type of stat to update ('hits', 'misses', 'sets') + """ + # This could be extended to track more detailed statistics + pass + + def cleanup_expired(self) -> int: + """Clean up expired entries. + + For Redis, this happens automatically with TTL. + For memory cache, we clean up lazily. + + Returns: + Number of entries cleaned up + """ + count = 0 + + # Redis handles expiration automatically + if self.redis_client: + logger.debug("Redis handles expiration automatically") + + # Clean memory cache + if self.memory_cache: + current_time = time.time() + keys_to_delete = [] + + for key, (value, expiry) in list(self.memory_cache.cache.items()): + if current_time >= expiry: + keys_to_delete.append(key) + + for key in keys_to_delete: + del self.memory_cache.cache[key] + count += 1 + + if count > 0: + logger.info(f"Cleaned up {count} expired entries from memory cache") + + return count + + def warmup(self, project_id: str, preload_data: dict[str, str]) -> int: + """Warm up cache with preloaded data. + + Args: + project_id: Project ID for namespacing + preload_data: Dictionary of prompt->response to preload + + Returns: + Number of entries preloaded + """ + count = 0 + + for prompt, response in preload_data.items(): + # Generate key with project namespace + key = f"{self.key_prefix}project:{project_id}:{hashlib.md5(prompt.encode()).hexdigest()}" + + if self.set(key, response): + count += 1 + + logger.info(f"Warmed up cache with {count} entries for project {project_id}") + return count diff --git a/backend/lookup/integrations/storage_client.py b/backend/lookup/integrations/storage_client.py new file mode 100644 index 0000000000..816cf72f9c --- /dev/null +++ b/backend/lookup/integrations/storage_client.py @@ -0,0 +1,281 @@ +"""Object Storage integration for Look-Up reference data. + +This module provides integration with PERMANENT_REMOTE_STORAGE +for storing and retrieving reference data files. +""" + +import logging +from pathlib import Path +from typing import Any +from uuid import UUID + +from utils.file_storage.constants import FileStorageKeys + +from unstract.sdk1.file_storage.constants import StorageType +from unstract.sdk1.file_storage.env_helper import EnvHelper + +from ..protocols import StorageClient + +logger = logging.getLogger(__name__) + + +class RemoteStorageClient(StorageClient): + """Implementation of StorageClient using PERMANENT_REMOTE_STORAGE. + + This client integrates with the Unstract file storage system + to store and retrieve Look-Up reference data. + """ + + def __init__(self, base_path: str = "lookup/reference_data"): + """Initialize the remote storage client. + + Args: + base_path: Base path for Look-Up data in storage + """ + self.base_path = base_path + self.fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + def _get_project_path(self, project_id: UUID) -> str: + """Get the storage path for a project. + + Args: + project_id: Project UUID + + Returns: + Storage path string + """ + return str(Path(self.base_path) / str(project_id)) + + def _get_file_path(self, project_id: UUID, filename: str) -> str: + """Get the full storage path for a file. + + Args: + project_id: Project UUID + filename: File name + + Returns: + Full storage path + """ + return str(Path(self._get_project_path(project_id)) / filename) + + def upload(self, path: str, content: bytes) -> bool: + """Upload content to remote storage. + + Args: + path: Storage path + content: File content as bytes + + Returns: + True if successful, False otherwise + """ + try: + # Ensure parent directory exists + parent_dir = str(Path(path).parent) + self.fs_instance.mkdir(parent_dir, create_parents=True) + + # Write content + self.fs_instance.write(path=path, mode="wb", data=content) + logger.info(f"Successfully uploaded file to {path}") + return True + + except Exception as e: + logger.error(f"Failed to upload file to {path}: {e}") + return False + + def download(self, path: str) -> bytes | None: + """Download content from remote storage. + + Args: + path: Storage path + + Returns: + File content as bytes or None if not found + """ + try: + if not self.exists(path): + logger.warning(f"File not found at path: {path}") + return None + + # Read content + content = self.fs_instance.read(path=path, mode="rb") + logger.info(f"Successfully downloaded file from {path}") + return content + + except Exception as e: + logger.error(f"Failed to download file from {path}: {e}") + return None + + def delete(self, path: str) -> bool: + """Delete content from remote storage. + + Args: + path: Storage path + + Returns: + True if deleted, False otherwise + """ + try: + if not self.exists(path): + logger.warning(f"File not found for deletion: {path}") + return False + + self.fs_instance.delete(path) + logger.info(f"Successfully deleted file at {path}") + return True + + except Exception as e: + logger.error(f"Failed to delete file at {path}: {e}") + return False + + def exists(self, path: str) -> bool: + """Check if path exists in storage. + + Args: + path: Storage path + + Returns: + True if exists, False otherwise + """ + try: + return self.fs_instance.exists(path) + except Exception as e: + logger.error(f"Failed to check existence of {path}: {e}") + return False + + def list_files(self, prefix: str) -> list[str]: + """List files with given prefix. + + Args: + prefix: Path prefix + + Returns: + List of matching paths + """ + try: + # List all files in the directory + files = self.fs_instance.listdir(prefix) + return [str(Path(prefix) / f) for f in files if not f.startswith(".")] + + except Exception as e: + logger.error(f"Failed to list files with prefix {prefix}: {e}") + return [] + + def get_text_content(self, path: str) -> str | None: + """Get text content from storage. + + Args: + path: Storage path + + Returns: + Text content or None if not found + """ + content = self.download(path) + if content: + try: + return content.decode("utf-8") + except UnicodeDecodeError: + logger.error(f"Failed to decode text content from {path}") + return None + + def save_text_content(self, path: str, text: str) -> bool: + """Save text content to storage. + + Args: + path: Storage path + text: Text content + + Returns: + True if successful + """ + return self.upload(path, text.encode("utf-8")) + + def upload_reference_data( + self, + project_id: UUID, + filename: str, + content: bytes, + metadata: dict[str, Any] | None = None, + ) -> str: + """Upload reference data for a Look-Up project. + + Args: + project_id: Project UUID + filename: Original filename + content: File content + metadata: Optional metadata + + Returns: + Storage path of uploaded file + """ + # Generate storage path + storage_path = self._get_file_path(project_id, filename) + + # Upload file + if self.upload(storage_path, content): + # Store metadata if provided + if metadata: + meta_path = f"{storage_path}.meta.json" + import json + + meta_content = json.dumps(metadata, indent=2) + self.save_text_content(meta_path, meta_content) + + return storage_path + else: + raise Exception(f"Failed to upload reference data to {storage_path}") + + def get_reference_data(self, project_id: UUID, filename: str) -> str | None: + """Get reference data text content. + + Args: + project_id: Project UUID + filename: File name + + Returns: + Text content or None + """ + storage_path = self._get_file_path(project_id, filename) + return self.get_text_content(storage_path) + + def list_project_files(self, project_id: UUID) -> list[str]: + """List all files for a project. + + Args: + project_id: Project UUID + + Returns: + List of file paths + """ + project_path = self._get_project_path(project_id) + return self.list_files(project_path) + + def delete_project_data(self, project_id: UUID) -> bool: + """Delete all data for a project. + + Args: + project_id: Project UUID + + Returns: + True if successful + """ + try: + project_path = self._get_project_path(project_id) + files = self.list_project_files(project_id) + + # Delete all files + for file_path in files: + self.delete(file_path) + + # Delete directory + if self.exists(project_path): + self.fs_instance.rmdir(project_path) + + logger.info(f"Deleted all data for project {project_id}") + return True + + except Exception as e: + logger.error(f"Failed to delete project data: {e}") + return False diff --git a/backend/lookup/integrations/unstract_llm_client.py b/backend/lookup/integrations/unstract_llm_client.py new file mode 100644 index 0000000000..f18e9c2844 --- /dev/null +++ b/backend/lookup/integrations/unstract_llm_client.py @@ -0,0 +1,195 @@ +"""Unstract LLM Client for Look-Up enrichment. + +This module provides integration with Unstract's LLM abstraction +using the SDK's LLM class for generating enrichment data. +""" + +import json +import logging +from typing import Any, Protocol + +from adapter_processor_v2.models import AdapterInstance +from litellm import get_max_tokens, token_counter + +from lookup.exceptions import ContextWindowExceededError +from unstract.sdk1.adapters.constants import Common +from unstract.sdk1.adapters.llm1 import adapters +from unstract.sdk1.llm import LLM + +logger = logging.getLogger(__name__) + + +class LLMClient(Protocol): + """Protocol for LLM client abstraction.""" + + def generate(self, prompt: str, config: dict[str, Any]) -> str: + """Generate LLM response for the prompt.""" + ... + + +class UnstractLLMClient(LLMClient): + """LLM client implementation using Unstract's SDK LLM class. + + This client uses the actual LLM adapters configured in the platform + to generate enrichment data for Look-Ups. + """ + + # Reserve tokens for LLM response output + RESERVED_OUTPUT_TOKENS = 2048 + # Default context window if we can't determine the model's limit + DEFAULT_CONTEXT_WINDOW = 4096 + + def __init__(self, adapter_instance: AdapterInstance): + """Initialize the LLM client with an adapter instance. + + Args: + adapter_instance: The AdapterInstance model object containing + the LLM configuration + """ + self.adapter_instance = adapter_instance + self.adapter_id = adapter_instance.adapter_id + self.adapter_metadata = adapter_instance.metadata # Decrypted metadata + + # Initialize model info for context validation + self._model_name = self._get_model_name() + self._context_limit = self._get_context_limit() + + def _get_model_name(self) -> str: + """Get the model name from adapter metadata. + + Returns: + The model name string used by litellm + """ + try: + adapter = adapters[self.adapter_id][Common.MODULE] + return adapter.validate_model(self.adapter_metadata) + except Exception as e: + logger.warning(f"Failed to get model name: {e}") + return "unknown" + + def _get_context_limit(self) -> int: + """Get context window limit for the configured LLM. + + Returns: + Maximum number of tokens the model can handle + """ + try: + return get_max_tokens(self._model_name) + except Exception as e: + logger.warning( + f"Failed to get context limit for {self._model_name}: {e}. " + f"Using default: {self.DEFAULT_CONTEXT_WINDOW}" + ) + return self.DEFAULT_CONTEXT_WINDOW + + def validate_context_size(self, prompt: str) -> None: + """Validate that the prompt fits within the LLM's context window. + + This method counts the tokens in the prompt and compares against + the model's context window limit, accounting for reserved output tokens. + + Args: + prompt: The complete prompt to send to the LLM + + Raises: + ContextWindowExceededError: If the prompt exceeds the context limit + """ + try: + # Count tokens using litellm's accurate counter + messages = [{"role": "user", "content": prompt}] + token_count = token_counter(model=self._model_name, messages=messages) + except Exception as e: + # Fallback to rough estimation if token counting fails + logger.warning(f"Token counting failed, using estimation: {e}") + token_count = len(prompt) // 4 # Rough estimate: ~4 chars per token + + # Account for reserved output tokens + available_tokens = self._context_limit - self.RESERVED_OUTPUT_TOKENS + + logger.debug( + f"Context validation: {token_count:,} tokens in prompt, " + f"{available_tokens:,} available (limit: {self._context_limit:,}, " + f"reserved: {self.RESERVED_OUTPUT_TOKENS})" + ) + + if token_count > available_tokens: + raise ContextWindowExceededError( + token_count=token_count, + context_limit=available_tokens, + model=self._model_name, + ) + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate LLM response for Look-Up enrichment. + + Args: + prompt: The prompt text with resolved variables and reference data + config: Additional LLM configuration (temperature, etc.) + timeout: Request timeout in seconds + + Returns: + JSON-formatted string with enrichment data + + Raises: + ContextWindowExceededError: If prompt exceeds context window + RuntimeError: If LLM call fails + """ + # Validate context size before calling LLM + self.validate_context_size(prompt) + + try: + # Create LLM instance using SDK + llm = LLM(adapter_id=self.adapter_id, adapter_metadata=self.adapter_metadata) + + # Call the LLM + logger.debug(f"Calling LLM with prompt length: {len(prompt)}") + response = llm.complete(prompt) + + # Extract the response text + response_text = response["response"].text + + logger.debug(f"LLM response: {response_text[:500]}...") + + # Validate it's valid JSON + try: + json.loads(response_text) + return response_text + except json.JSONDecodeError: + # Try to extract JSON from response + return self._extract_json(response_text) + + except Exception as e: + logger.error(f"LLM generation failed: {e}") + raise RuntimeError(f"LLM generation failed: {str(e)}") + + def _extract_json(self, response: str) -> str: + """Extract JSON from LLM response if wrapped in text. + + Args: + response: Raw LLM response + + Returns: + Extracted JSON string + """ + # Look for JSON object pattern (handles nested objects) + # Try to find content between first { and last } + start_idx = response.find("{") + end_idx = response.rfind("}") + + if start_idx != -1 and end_idx != -1 and end_idx > start_idx: + potential_json = response[start_idx : end_idx + 1] + try: + json.loads(potential_json) + return potential_json + except json.JSONDecodeError: + pass + + # If no valid JSON found, create a basic response + logger.warning(f"Could not extract JSON from response: {response[:200]}") + return json.dumps( + { + "raw_response": response[:500], + "confidence": 0.3, + "warning": "Could not extract structured data from LLM response", + } + ) diff --git a/backend/lookup/migrations/0001_initial.py b/backend/lookup/migrations/0001_initial.py new file mode 100644 index 0000000000..da583747de --- /dev/null +++ b/backend/lookup/migrations/0001_initial.py @@ -0,0 +1,886 @@ +# Generated by Django - Squashed migration for lookup app + +import uuid + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + """Squashed initial migration for the lookup app. + + This migration creates all lookup-related models: + - LookupProject: Main project configuration + - LookupPromptTemplate: Prompt templates with variable detection + - LookupDataSource: Reference data file management with versioning + - LookupProfileManager: Adapter profiles for indexing/querying + - LookupIndexManager: Index tracking for vector DB + - PromptStudioLookupLink: Links between PS projects and Lookups + - LookupExecutionAudit: Execution history and metrics + """ + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("account_v2", "0001_initial"), + ("adapter_processor_v2", "0001_initial"), + ] + + operations = [ + # 1. Create LookupProject model + migrations.CreateModel( + name="LookupProject", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "name", + models.CharField( + help_text="Name of the Look-Up project", max_length=255 + ), + ), + ( + "description", + models.TextField( + blank=True, + help_text="Description of the Look-Up project's purpose", + null=True, + ), + ), + ( + "lookup_type", + models.CharField( + choices=[("static_data", "Static Data")], + default="static_data", + help_text="Type of Look-Up (only static_data for POC)", + max_length=50, + ), + ), + ( + "reference_data_type", + models.CharField( + choices=[ + ("vendor_catalog", "Vendor Catalog"), + ("product_catalog", "Product Catalog"), + ("customer_database", "Customer Database"), + ("pricing_data", "Pricing Data"), + ("compliance_rules", "Compliance Rules"), + ("custom", "Custom"), + ], + help_text="Category of reference data being stored", + max_length=50, + ), + ), + ( + "is_active", + models.BooleanField( + default=True, help_text="Whether this project is active" + ), + ), + ( + "metadata", + models.JSONField( + blank=True, + default=dict, + help_text="Additional metadata for the project", + ), + ), + ( + "llm_provider", + models.CharField( + blank=True, + choices=[ + ("openai", "OpenAI"), + ("anthropic", "Anthropic"), + ("azure", "Azure OpenAI"), + ("custom", "Custom Provider"), + ], + help_text="LLM provider to use for matching", + max_length=50, + null=True, + ), + ), + ( + "llm_model", + models.CharField( + blank=True, + help_text="Specific model name (e.g., gpt-4-turbo, claude-3-opus)", + max_length=100, + null=True, + ), + ), + ( + "llm_config", + models.JSONField( + blank=True, + default=dict, + help_text="Additional LLM configuration (temperature, max_tokens, etc.)", + ), + ), + ( + "created_by", + models.ForeignKey( + help_text="User who created this project", + on_delete=django.db.models.deletion.RESTRICT, + related_name="created_lookup_projects", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "organization", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="lookup_projects", + to="account_v2.organization", + ), + ), + ], + options={ + "verbose_name": "Look-Up Project", + "verbose_name_plural": "Look-Up Projects", + "db_table": "lookup_projects", + "ordering": ["-created_at"], + }, + ), + migrations.AddIndex( + model_name="lookupproject", + index=models.Index(fields=["organization"], name="lookup_proj_organiz_idx"), + ), + migrations.AddIndex( + model_name="lookupproject", + index=models.Index(fields=["created_by"], name="lookup_proj_created_idx"), + ), + migrations.AddIndex( + model_name="lookupproject", + index=models.Index(fields=["modified_at"], name="lookup_proj_modifie_idx"), + ), + # 2. Create LookupPromptTemplate model + migrations.CreateModel( + name="LookupPromptTemplate", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "name", + models.CharField( + help_text="Template name for identification", max_length=255 + ), + ), + ( + "template_text", + models.TextField(help_text="Template with {{variable}} placeholders"), + ), + ( + "llm_config", + models.JSONField( + blank=True, + default=dict, + help_text="LLM configuration including adapter_id", + ), + ), + ( + "is_active", + models.BooleanField( + default=True, help_text="Whether this template is active" + ), + ), + ( + "variable_mappings", + models.JSONField( + blank=True, + default=dict, + help_text="Optional documentation of variable mappings", + ), + ), + ( + "created_by", + models.ForeignKey( + help_text="User who created this template", + on_delete=django.db.models.deletion.RESTRICT, + related_name="created_lookup_templates", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "project", + models.OneToOneField( + help_text="Parent Look-Up project", + on_delete=django.db.models.deletion.CASCADE, + related_name="prompt_template_link", + to="lookup.lookupproject", + ), + ), + ], + options={ + "verbose_name": "Look-Up Prompt Template", + "verbose_name_plural": "Look-Up Prompt Templates", + "db_table": "lookup_prompt_templates", + "ordering": ["-modified_at"], + }, + ), + # 3. Add template FK to LookupProject (after template is created) + migrations.AddField( + model_name="lookupproject", + name="template", + field=models.ForeignKey( + blank=True, + help_text="Prompt template for this project", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="projects", + to="lookup.lookupprompttemplate", + ), + ), + # 4. Create LookupDataSource model + migrations.CreateModel( + name="LookupDataSource", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "file_name", + models.CharField(help_text="Original filename", max_length=255), + ), + ("file_path", models.TextField(help_text="Path in object storage")), + ("file_size", models.BigIntegerField(help_text="File size in bytes")), + ( + "file_type", + models.CharField( + choices=[ + ("pdf", "PDF"), + ("xlsx", "Excel"), + ("csv", "CSV"), + ("docx", "Word"), + ("txt", "Text"), + ("json", "JSON"), + ], + help_text="Type of file", + max_length=50, + ), + ), + ( + "extracted_content_path", + models.TextField( + blank=True, + help_text="Path to extracted text in object storage", + null=True, + ), + ), + ( + "extraction_status", + models.CharField( + choices=[ + ("pending", "Pending"), + ("processing", "Processing"), + ("completed", "Completed"), + ("failed", "Failed"), + ], + default="pending", + help_text="Status of text extraction", + max_length=20, + ), + ), + ( + "extraction_error", + models.TextField( + blank=True, + help_text="Error details if extraction failed", + null=True, + ), + ), + ( + "version_number", + models.IntegerField( + default=1, + help_text="Version number of this data source (auto-incremented)", + ), + ), + ( + "is_latest", + models.BooleanField( + default=True, help_text="Whether this is the latest version" + ), + ), + ( + "project", + models.ForeignKey( + help_text="Parent Look-Up project", + on_delete=django.db.models.deletion.CASCADE, + related_name="data_sources", + to="lookup.lookupproject", + ), + ), + ( + "uploaded_by", + models.ForeignKey( + help_text="User who uploaded this file", + on_delete=django.db.models.deletion.RESTRICT, + related_name="uploaded_lookup_data", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Look-Up Data Source", + "verbose_name_plural": "Look-Up Data Sources", + "db_table": "lookup_data_sources", + "ordering": ["-version_number"], + "unique_together": {("project", "version_number")}, + }, + ), + migrations.AddIndex( + model_name="lookupdatasource", + index=models.Index(fields=["project"], name="lookup_data_project_idx"), + ), + migrations.AddIndex( + model_name="lookupdatasource", + index=models.Index( + fields=["project", "is_latest"], name="lookup_data_proj_latest_idx" + ), + ), + migrations.AddIndex( + model_name="lookupdatasource", + index=models.Index(fields=["created_at"], name="lookup_data_created_idx"), + ), + migrations.AddIndex( + model_name="lookupdatasource", + index=models.Index( + fields=["extraction_status"], name="lookup_data_extract_idx" + ), + ), + # 5. Create LookupProfileManager model + migrations.CreateModel( + name="LookupProfileManager", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "profile_id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("profile_name", models.TextField(db_comment="Name of the profile")), + ( + "chunk_size", + models.IntegerField( + db_comment="Size of text chunks for indexing", default=1000 + ), + ), + ( + "chunk_overlap", + models.IntegerField( + db_comment="Overlap between consecutive chunks", default=200 + ), + ), + ( + "similarity_top_k", + models.IntegerField( + db_comment="Number of top similar chunks to retrieve", default=5 + ), + ), + ( + "is_default", + models.BooleanField( + db_comment="Whether this is the default profile for the project", + default=False, + ), + ), + ( + "reindex", + models.BooleanField( + db_comment="Flag to trigger re-indexing of reference data", + default=False, + ), + ), + ( + "lookup_project", + models.ForeignKey( + db_comment="Look-Up project this profile belongs to", + on_delete=django.db.models.deletion.CASCADE, + related_name="profiles", + to="lookup.lookupproject", + ), + ), + ( + "vector_store", + models.ForeignKey( + db_comment="Vector database adapter for storing embeddings", + on_delete=django.db.models.deletion.PROTECT, + related_name="lookup_profiles_vector_store", + to="adapter_processor_v2.adapterinstance", + ), + ), + ( + "embedding_model", + models.ForeignKey( + db_comment="Embedding model adapter for generating vectors", + on_delete=django.db.models.deletion.PROTECT, + related_name="lookup_profiles_embedding_model", + to="adapter_processor_v2.adapterinstance", + ), + ), + ( + "llm", + models.ForeignKey( + db_comment="LLM adapter for query processing and response generation", + on_delete=django.db.models.deletion.PROTECT, + related_name="lookup_profiles_llm", + to="adapter_processor_v2.adapterinstance", + ), + ), + ( + "x2text", + models.ForeignKey( + db_comment="X2Text adapter for extracting text from various file formats", + on_delete=django.db.models.deletion.PROTECT, + related_name="lookup_profiles_x2text", + to="adapter_processor_v2.adapterinstance", + ), + ), + ( + "created_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="lookup_profile_managers_created", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "modified_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="lookup_profile_managers_modified", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Lookup Profile Manager", + "verbose_name_plural": "Lookup Profile Managers", + "db_table": "lookup_profile_manager", + }, + ), + migrations.AddConstraint( + model_name="lookupprofilemanager", + constraint=models.UniqueConstraint( + fields=("lookup_project", "profile_name"), + name="unique_lookup_project_profile_name_index", + ), + ), + # 6. Create LookupIndexManager model + migrations.CreateModel( + name="LookupIndexManager", + fields=[ + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + db_comment="Timestamp when the record was created", + ), + ), + ( + "modified_at", + models.DateTimeField( + auto_now=True, + db_comment="Timestamp when the record was last modified", + ), + ), + ( + "index_manager_id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "raw_index_id", + models.CharField( + blank=True, + db_comment="Raw index ID for vector DB", + editable=False, + max_length=255, + null=True, + ), + ), + ( + "index_ids_history", + models.JSONField( + blank=False, + db_comment="List of all index IDs created for this data source", + default=list, + null=False, + ), + ), + ( + "extraction_status", + models.JSONField( + blank=False, + db_comment='Extraction status per X2Text config: {x2text_config_hash: {"extracted": bool, "enable_highlight": bool, "error": str}}', + default=dict, + null=False, + ), + ), + ( + "status", + models.JSONField( + blank=False, + db_comment="Extraction and indexing status", + default=dict, + null=False, + ), + ), + ( + "data_source", + models.ForeignKey( + db_comment="Reference data source being indexed", + editable=False, + on_delete=django.db.models.deletion.CASCADE, + related_name="index_managers", + to="lookup.lookupdatasource", + ), + ), + ( + "profile_manager", + models.ForeignKey( + blank=True, + db_comment="Profile used for indexing this data source", + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="index_managers", + to="lookup.lookupprofilemanager", + ), + ), + ( + "created_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="lookup_index_managers_created", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "modified_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="lookup_index_managers_modified", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "Lookup Index Manager", + "verbose_name_plural": "Lookup Index Managers", + "db_table": "lookup_index_manager", + }, + ), + migrations.AddConstraint( + model_name="lookupindexmanager", + constraint=models.UniqueConstraint( + fields=("data_source", "profile_manager"), + name="unique_data_source_profile_manager_index", + ), + ), + # 7. Create PromptStudioLookupLink model + migrations.CreateModel( + name="PromptStudioLookupLink", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "prompt_studio_project_id", + models.UUIDField(help_text="UUID of the Prompt Studio project"), + ), + ( + "execution_order", + models.PositiveIntegerField( + default=0, + help_text="Order in which this Look-Up executes (lower numbers first)", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ( + "lookup_project", + models.ForeignKey( + help_text="Linked Look-Up project", + on_delete=django.db.models.deletion.CASCADE, + related_name="ps_links", + to="lookup.lookupproject", + ), + ), + ], + options={ + "verbose_name": "Prompt Studio Look-Up Link", + "verbose_name_plural": "Prompt Studio Look-Up Links", + "db_table": "prompt_studio_lookup_links", + "ordering": ["execution_order", "created_at"], + "unique_together": {("prompt_studio_project_id", "lookup_project")}, + }, + ), + migrations.AddIndex( + model_name="promptstudiolookuplink", + index=models.Index( + fields=["prompt_studio_project_id"], name="ps_lookup_link_ps_proj_idx" + ), + ), + migrations.AddIndex( + model_name="promptstudiolookuplink", + index=models.Index( + fields=["lookup_project"], name="ps_lookup_link_lookup_idx" + ), + ), + migrations.AddIndex( + model_name="promptstudiolookuplink", + index=models.Index( + fields=["prompt_studio_project_id", "execution_order"], + name="ps_lookup_link_order_idx", + ), + ), + # 8. Create LookupExecutionAudit model + migrations.CreateModel( + name="LookupExecutionAudit", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "prompt_studio_project_id", + models.UUIDField( + blank=True, + help_text="Associated Prompt Studio project if applicable", + null=True, + ), + ), + ( + "execution_id", + models.UUIDField( + help_text="Groups all Look-Ups in a single execution batch" + ), + ), + ( + "input_data", + models.JSONField( + help_text="Input data from Prompt Studio extraction" + ), + ), + ( + "reference_data_version", + models.IntegerField(help_text="Version of reference data used"), + ), + ( + "enriched_output", + models.JSONField( + blank=True, help_text="Enrichment data produced", null=True + ), + ), + ( + "llm_provider", + models.CharField(help_text="LLM provider used", max_length=50), + ), + ( + "llm_model", + models.CharField(help_text="LLM model used", max_length=100), + ), + ("llm_prompt", models.TextField(help_text="Full prompt sent to LLM")), + ( + "llm_response", + models.TextField(blank=True, help_text="Raw LLM response", null=True), + ), + ( + "llm_response_cached", + models.BooleanField( + default=False, help_text="Whether response was from cache" + ), + ), + ( + "execution_time_ms", + models.IntegerField( + blank=True, + help_text="Total execution time in milliseconds", + null=True, + ), + ), + ( + "llm_call_time_ms", + models.IntegerField( + blank=True, + help_text="LLM API call time in milliseconds", + null=True, + ), + ), + ( + "status", + models.CharField( + choices=[ + ("success", "Success"), + ("partial", "Partial Success"), + ("failed", "Failed"), + ], + help_text="Execution status", + max_length=20, + ), + ), + ( + "error_message", + models.TextField( + blank=True, help_text="Error details if failed", null=True + ), + ), + ( + "confidence_score", + models.DecimalField( + blank=True, + decimal_places=2, + help_text="Confidence score from LLM (0.00 to 1.00)", + max_digits=3, + null=True, + ), + ), + ( + "executed_at", + models.DateTimeField( + auto_now_add=True, help_text="When the execution occurred" + ), + ), + ( + "lookup_project", + models.ForeignKey( + help_text="Look-Up project that was executed", + on_delete=django.db.models.deletion.CASCADE, + related_name="execution_audits", + to="lookup.lookupproject", + ), + ), + ], + options={ + "verbose_name": "Look-Up Execution Audit", + "verbose_name_plural": "Look-Up Execution Audits", + "db_table": "lookup_execution_audit", + "ordering": ["-executed_at"], + }, + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index(fields=["lookup_project"], name="lookup_audit_proj_idx"), + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index(fields=["execution_id"], name="lookup_audit_exec_idx"), + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index(fields=["executed_at"], name="lookup_audit_time_idx"), + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index(fields=["status"], name="lookup_audit_status_idx"), + ), + ] diff --git a/backend/lookup/migrations/0002_remove_reference_data_type.py b/backend/lookup/migrations/0002_remove_reference_data_type.py new file mode 100644 index 0000000000..44823bda04 --- /dev/null +++ b/backend/lookup/migrations/0002_remove_reference_data_type.py @@ -0,0 +1,22 @@ +# Generated manually - Remove reference_data_type field from LookupProject + +from django.db import migrations + + +class Migration(migrations.Migration): + """Remove reference_data_type field from LookupProject model. + + This field is no longer needed as the type categorization + has been removed from the Lookup projects feature. + """ + + dependencies = [ + ("lookup", "0001_initial"), + ] + + operations = [ + migrations.RemoveField( + model_name="lookupproject", + name="reference_data_type", + ), + ] diff --git a/backend/lookup/migrations/0003_add_file_execution_id.py b/backend/lookup/migrations/0003_add_file_execution_id.py new file mode 100644 index 0000000000..363b109fbd --- /dev/null +++ b/backend/lookup/migrations/0003_add_file_execution_id.py @@ -0,0 +1,28 @@ +# Generated manually for adding file_execution_id to LookupExecutionAudit + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("lookup", "0002_remove_reference_data_type"), + ] + + operations = [ + migrations.AddField( + model_name="lookupexecutionaudit", + name="file_execution_id", + field=models.UUIDField( + blank=True, + help_text="Workflow file execution ID for tracking in API/ETL pipelines", + null=True, + ), + ), + migrations.AddIndex( + model_name="lookupexecutionaudit", + index=models.Index( + fields=["file_execution_id"], + name="lookup_exec_file_ex_idx", + ), + ), + ] diff --git a/backend/lookup/migrations/0004_fix_file_paths_for_minio.py b/backend/lookup/migrations/0004_fix_file_paths_for_minio.py new file mode 100644 index 0000000000..db1c4ee0ca --- /dev/null +++ b/backend/lookup/migrations/0004_fix_file_paths_for_minio.py @@ -0,0 +1,70 @@ +"""Migration to fix file paths for MinIO storage. + +This migration converts local filesystem paths to MinIO-compatible paths +for existing LookupDataSource records. + +Old format: /app/prompt-studio-data/{org_id}/{project_id}/{filename} +New format: unstract/prompt-studio-data/{org_id}/{project_id}/{filename} +""" + +from django.db import migrations + + +def fix_file_paths(apps, schema_editor): + """Convert local paths to MinIO paths.""" + LookupDataSource = apps.get_model("lookup", "LookupDataSource") + + # Define path mappings + old_prefix = "/app/prompt-studio-data/" + new_prefix = "unstract/prompt-studio-data/" + + # Update file_path + updated_count = 0 + for data_source in LookupDataSource.objects.filter(file_path__startswith=old_prefix): + data_source.file_path = data_source.file_path.replace(old_prefix, new_prefix, 1) + + # Also fix extracted_content_path if it exists + if ( + data_source.extracted_content_path + and data_source.extracted_content_path.startswith(old_prefix) + ): + data_source.extracted_content_path = ( + data_source.extracted_content_path.replace(old_prefix, new_prefix, 1) + ) + + data_source.save(update_fields=["file_path", "extracted_content_path"]) + updated_count += 1 + + if updated_count > 0: + print(f" Updated {updated_count} LookupDataSource records with corrected paths") + + +def reverse_file_paths(apps, schema_editor): + """Revert MinIO paths back to local paths.""" + LookupDataSource = apps.get_model("lookup", "LookupDataSource") + + old_prefix = "unstract/prompt-studio-data/" + new_prefix = "/app/prompt-studio-data/" + + for data_source in LookupDataSource.objects.filter(file_path__startswith=old_prefix): + data_source.file_path = data_source.file_path.replace(old_prefix, new_prefix, 1) + + if ( + data_source.extracted_content_path + and data_source.extracted_content_path.startswith(old_prefix) + ): + data_source.extracted_content_path = ( + data_source.extracted_content_path.replace(old_prefix, new_prefix, 1) + ) + + data_source.save(update_fields=["file_path", "extracted_content_path"]) + + +class Migration(migrations.Migration): + dependencies = [ + ("lookup", "0003_add_file_execution_id"), + ] + + operations = [ + migrations.RunPython(fix_file_paths, reverse_file_paths), + ] diff --git a/backend/lookup/migrations/0005_add_reindex_required_field.py b/backend/lookup/migrations/0005_add_reindex_required_field.py new file mode 100644 index 0000000000..5749dd91a2 --- /dev/null +++ b/backend/lookup/migrations/0005_add_reindex_required_field.py @@ -0,0 +1,27 @@ +# Generated migration for adding reindex_required field to LookupIndexManager + +from django.db import migrations, models + + +class Migration(migrations.Migration): + """Add reindex_required field to LookupIndexManager. + + This field tracks whether indexes are stale and need re-indexing + when profile settings change (chunk_size, embedding_model, etc.). + """ + + dependencies = [ + ("lookup", "0004_fix_file_paths_for_minio"), + ] + + operations = [ + migrations.AddField( + model_name="lookupindexmanager", + name="reindex_required", + field=models.BooleanField( + default=False, + db_comment="Flag indicating indexes are stale and need re-indexing", + help_text="Set to True when profile settings change and re-indexing is needed", + ), + ), + ] diff --git a/backend/lookup/migrations/__init__.py b/backend/lookup/migrations/__init__.py new file mode 100644 index 0000000000..b98f22cc20 --- /dev/null +++ b/backend/lookup/migrations/__init__.py @@ -0,0 +1 @@ +# Lookup migrations package diff --git a/backend/lookup/models/__init__.py b/backend/lookup/models/__init__.py new file mode 100644 index 0000000000..4c196dc030 --- /dev/null +++ b/backend/lookup/models/__init__.py @@ -0,0 +1,24 @@ +"""Look-Up system models.""" + +from .lookup_data_source import LookupDataSource, LookupDataSourceManager +from .lookup_execution_audit import LookupExecutionAudit +from .lookup_index_manager import LookupIndexManager +from .lookup_profile_manager import LookupProfileManager +from .lookup_project import LookupProject +from .lookup_prompt_template import LookupPromptTemplate +from .prompt_studio_lookup_link import ( + PromptStudioLookupLink, + PromptStudioLookupLinkManager, +) + +__all__ = [ + "LookupProject", + "LookupDataSource", + "LookupDataSourceManager", + "LookupPromptTemplate", + "LookupProfileManager", + "LookupIndexManager", + "PromptStudioLookupLink", + "PromptStudioLookupLinkManager", + "LookupExecutionAudit", +] diff --git a/backend/lookup/models/lookup_data_source.py b/backend/lookup/models/lookup_data_source.py new file mode 100644 index 0000000000..afa0406e66 --- /dev/null +++ b/backend/lookup/models/lookup_data_source.py @@ -0,0 +1,220 @@ +"""LookupDataSource model for managing reference data versions.""" + +import uuid + +from django.contrib.auth import get_user_model +from django.db import models +from django.db.models.signals import post_delete, pre_save +from django.dispatch import receiver +from utils.models.base_model import BaseModel + +User = get_user_model() + + +class LookupDataSourceManager(models.Manager): + """Custom manager for LookupDataSource.""" + + def get_latest_for_project(self, project_id: uuid.UUID): + """Get the latest data sources for a project. + + Args: + project_id: UUID of the lookup project + + Returns: + QuerySet of latest data sources + """ + return self.filter(project_id=project_id, is_latest=True) + + def get_ready_for_project(self, project_id: uuid.UUID): + """Get all completed latest data sources for a project. + + Args: + project_id: UUID of the lookup project + + Returns: + QuerySet of completed latest data sources + """ + return self.get_latest_for_project(project_id).filter( + extraction_status="completed" + ) + + +class LookupDataSource(BaseModel): + """Represents a reference data source with version management. + + Each upload creates a new version, with automatic version numbering + and latest flag management. + """ + + EXTRACTION_STATUS_CHOICES = [ + ("pending", "Pending"), + ("processing", "Processing"), + ("completed", "Completed"), + ("failed", "Failed"), + ] + + FILE_TYPE_CHOICES = [ + ("pdf", "PDF"), + ("xlsx", "Excel"), + ("csv", "CSV"), + ("docx", "Word"), + ("txt", "Text"), + ("json", "JSON"), + ] + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + project = models.ForeignKey( + "lookup.LookupProject", + on_delete=models.CASCADE, + related_name="data_sources", + help_text="Parent Look-Up project", + ) + + # File Information + file_name = models.CharField(max_length=255, help_text="Original filename") + file_path = models.TextField(help_text="Path in object storage") + file_size = models.BigIntegerField(help_text="File size in bytes") + file_type = models.CharField( + max_length=50, choices=FILE_TYPE_CHOICES, help_text="Type of file" + ) + + # Extracted Content + extracted_content_path = models.TextField( + blank=True, null=True, help_text="Path to extracted text in object storage" + ) + extraction_status = models.CharField( + max_length=20, + choices=EXTRACTION_STATUS_CHOICES, + default="pending", + help_text="Status of text extraction", + ) + extraction_error = models.TextField( + blank=True, null=True, help_text="Error details if extraction failed" + ) + + # Version Management + version_number = models.IntegerField( + default=1, help_text="Version number of this data source (auto-incremented)" + ) + is_latest = models.BooleanField( + default=True, help_text="Whether this is the latest version" + ) + + # Upload Information + uploaded_by = models.ForeignKey( + User, + on_delete=models.RESTRICT, + related_name="uploaded_lookup_data", + help_text="User who uploaded this file", + ) + + objects = LookupDataSourceManager() + + class Meta: + """Model metadata.""" + + db_table = "lookup_data_sources" + ordering = ["-version_number"] + unique_together = [["project", "version_number"]] + verbose_name = "Look-Up Data Source" + verbose_name_plural = "Look-Up Data Sources" + indexes = [ + models.Index(fields=["project"]), + models.Index(fields=["project", "is_latest"]), + models.Index(fields=["created_at"]), + models.Index(fields=["extraction_status"]), + ] + + def __str__(self) -> str: + """String representation.""" + return f"{self.project.name} - v{self.version_number} - {self.file_name}" + + def get_file_size_display(self) -> str: + """Get human-readable file size. + + Returns: + Formatted file size string (e.g., "51.2 KB") + """ + size = self.file_size + for unit in ["B", "KB", "MB", "GB"]: + if size < 1024.0: + return f"{size:.1f} {unit}" + size /= 1024.0 + return f"{size:.1f} TB" + + @property + def is_extraction_complete(self) -> bool: + """Check if extraction is successfully completed.""" + return self.extraction_status == "completed" + + def get_extracted_content(self) -> str | None: + """Load extracted content from object storage. + + Returns: + Extracted text content or None if not available. + + Note: + This is a placeholder - actual implementation will + load from object storage using extracted_content_path. + """ + if not self.extracted_content_path or not self.is_extraction_complete: + return None + + # TODO: Implement actual object storage retrieval + # For now, return a placeholder + return f"[Content from {self.extracted_content_path}]" + + +@receiver(pre_save, sender=LookupDataSource) +def auto_increment_version_and_update_latest(sender, instance, **kwargs): + """Signal to auto-increment version number and manage is_latest flag. + + This signal: + 1. Auto-increments version_number if not set + 2. Marks all previous versions as not latest + """ + # Check if this is a new instance by querying the database + # (instance.pk is always truthy for UUIDField with default=uuid.uuid4) + is_new = not LookupDataSource.objects.filter(pk=instance.pk).exists() + + if is_new: + # Get the highest version number for this project + max_version = LookupDataSource.objects.filter(project=instance.project).aggregate( + max_version=models.Max("version_number") + )["max_version"] + + # Always auto-increment version for new instances + instance.version_number = (max_version or 0) + 1 + + # Mark all previous versions as not latest + LookupDataSource.objects.filter(project=instance.project, is_latest=True).update( + is_latest=False + ) + + # Ensure new version is marked as latest + instance.is_latest = True + + +@receiver(post_delete, sender=LookupDataSource) +def promote_previous_version_to_latest(sender, instance, **kwargs): + """Signal to promote the previous version to latest when the current latest is deleted. + + This signal: + 1. Checks if the deleted instance was the latest version + 2. If so, promotes the next most recent version to be the latest + """ + # Only act if the deleted instance was the latest + if not instance.is_latest: + return + + # Find the next most recent data source for this project + # (highest version_number among remaining records) + next_latest = ( + LookupDataSource.objects.filter(project_id=instance.project_id) + .order_by("-version_number") + .first() + ) + + if next_latest: + next_latest.is_latest = True + next_latest.save(update_fields=["is_latest"]) diff --git a/backend/lookup/models/lookup_execution_audit.py b/backend/lookup/models/lookup_execution_audit.py new file mode 100644 index 0000000000..585766dcd3 --- /dev/null +++ b/backend/lookup/models/lookup_execution_audit.py @@ -0,0 +1,140 @@ +"""LookupExecutionAudit model for tracking Look-Up execution history.""" + +import uuid +from decimal import Decimal + +from django.db import models + + +class LookupExecutionAudit(models.Model): + """Audit log for Look-Up executions. + + Tracks all execution attempts with detailed metadata for debugging + and performance monitoring. + """ + + STATUS_CHOICES = [ + ("success", "Success"), + ("partial", "Partial Success"), + ("failed", "Failed"), + ] + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + + # Execution Context + lookup_project = models.ForeignKey( + "lookup.LookupProject", + on_delete=models.CASCADE, + related_name="execution_audits", + help_text="Look-Up project that was executed", + ) + prompt_studio_project_id = models.UUIDField( + null=True, blank=True, help_text="Associated Prompt Studio project if applicable" + ) + execution_id = models.UUIDField( + help_text="Groups all Look-Ups in a single execution batch" + ) + file_execution_id = models.UUIDField( + null=True, + blank=True, + help_text="Workflow file execution ID for tracking in API/ETL pipelines", + ) + + # Input/Output + input_data = models.JSONField(help_text="Input data from Prompt Studio extraction") + reference_data_version = models.IntegerField( + help_text="Version of reference data used" + ) + enriched_output = models.JSONField( + null=True, blank=True, help_text="Enrichment data produced" + ) + + # LLM Details + llm_provider = models.CharField(max_length=50, help_text="LLM provider used") + llm_model = models.CharField(max_length=100, help_text="LLM model used") + llm_prompt = models.TextField(help_text="Full prompt sent to LLM") + llm_response = models.TextField(null=True, blank=True, help_text="Raw LLM response") + llm_response_cached = models.BooleanField( + default=False, help_text="Whether response was from cache" + ) + + # Performance Metrics + execution_time_ms = models.IntegerField( + null=True, blank=True, help_text="Total execution time in milliseconds" + ) + llm_call_time_ms = models.IntegerField( + null=True, blank=True, help_text="LLM API call time in milliseconds" + ) + + # Status & Errors + status = models.CharField( + max_length=20, choices=STATUS_CHOICES, help_text="Execution status" + ) + error_message = models.TextField( + null=True, blank=True, help_text="Error details if failed" + ) + confidence_score = models.DecimalField( + max_digits=3, + decimal_places=2, + null=True, + blank=True, + validators=[], + help_text="Confidence score from LLM (0.00 to 1.00)", + ) + + # Timestamps + executed_at = models.DateTimeField( + auto_now_add=True, help_text="When the execution occurred" + ) + + class Meta: + """Model metadata.""" + + db_table = "lookup_execution_audit" + ordering = ["-executed_at"] + verbose_name = "Look-Up Execution Audit" + verbose_name_plural = "Look-Up Execution Audits" + indexes = [ + models.Index(fields=["lookup_project"]), + models.Index(fields=["execution_id"]), + models.Index(fields=["file_execution_id"]), + models.Index(fields=["executed_at"]), + models.Index(fields=["status"]), + ] + + def __str__(self) -> str: + """String representation.""" + return f"{self.lookup_project.name} - {self.status} - {self.executed_at}" + + @property + def was_successful(self) -> bool: + """Check if execution was successful.""" + return self.status in ["success", "partial"] + + @property + def execution_duration_seconds(self) -> float: + """Get execution duration in seconds.""" + if self.execution_time_ms: + return self.execution_time_ms / 1000.0 + return 0.0 + + def clean(self): + """Validate the audit record.""" + super().clean() + from django.core.exceptions import ValidationError + + # Validate confidence score range + if self.confidence_score is not None: + if not (Decimal("0.00") <= self.confidence_score <= Decimal("1.00")): + raise ValidationError( + f"Confidence score must be between 0.00 and 1.00, " + f"got {self.confidence_score}" + ) + + # Ensure error_message is provided for failed status + if self.status == "failed" and not self.error_message: + raise ValidationError("Error message is required for failed executions") + + # Ensure enriched_output is provided for success status + if self.status == "success" and not self.enriched_output: + raise ValidationError("Enriched output is required for successful executions") diff --git a/backend/lookup/models/lookup_index_manager.py b/backend/lookup/models/lookup_index_manager.py new file mode 100644 index 0000000000..8cc501e222 --- /dev/null +++ b/backend/lookup/models/lookup_index_manager.py @@ -0,0 +1,183 @@ +"""LookupIndexManager model for tracking indexed reference data.""" + +import logging +import uuid + +from account_v2.models import User +from django.db import models +from django.db.models.signals import pre_delete +from django.dispatch import receiver +from utils.models.base_model import BaseModel + +logger = logging.getLogger(__name__) + + +class LookupIndexManager(BaseModel): + """Model to store indexing details for Look-Up reference data. + + Tracks which data sources have been indexed with which profile, + stores vector DB index IDs, and manages extraction status. + + Follows the same pattern as Prompt Studio's IndexManager. + """ + + index_manager_id = models.UUIDField( + primary_key=True, default=uuid.uuid4, editable=False + ) + + # Reference to the data source being indexed + data_source = models.ForeignKey( + "LookupDataSource", + on_delete=models.CASCADE, + related_name="index_managers", + editable=False, + null=False, + blank=False, + db_comment="Reference data source being indexed", + ) + + # Reference to the profile used for indexing + profile_manager = models.ForeignKey( + "LookupProfileManager", + on_delete=models.SET_NULL, + related_name="index_managers", + editable=False, + null=True, + blank=True, + db_comment="Profile used for indexing this data source", + ) + + # Vector DB index ID for this data source (raw index) + raw_index_id = models.CharField( + max_length=255, + db_comment="Raw index ID for vector DB", + editable=False, + null=True, + blank=True, + ) + + # History of all index IDs (for cleanup on deletion) + index_ids_history = models.JSONField( + db_comment="List of all index IDs created for this data source", + default=list, + null=False, + blank=False, + ) + + # Extraction status per X2Text configuration + # Format: {x2text_config_hash: {"extracted": bool, "enable_highlight": bool, "error": str|null}} + extraction_status = models.JSONField( + db_comment='Extraction status per X2Text config: {x2text_config_hash: {"extracted": bool, "enable_highlight": bool, "error": str}}', + default=dict, + null=False, + blank=False, + ) + + # Overall extraction and indexing status (legacy field, kept for compatibility) + # Format: {"extracted": bool, "indexed": bool, "error": str|null} + status = models.JSONField( + db_comment="Extraction and indexing status", null=False, blank=False, default=dict + ) + + # Flag to indicate that indexes are stale and need re-indexing + # Set to True when profile settings change (chunk_size, embedding_model, etc.) + reindex_required = models.BooleanField( + default=False, + db_comment="Flag indicating indexes are stale and need re-indexing", + help_text="Set to True when profile settings change and re-indexing is needed", + ) + + # Audit fields + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="lookup_index_managers_created", + null=True, + blank=True, + editable=False, + ) + + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="lookup_index_managers_modified", + null=True, + blank=True, + editable=False, + ) + + class Meta: + verbose_name = "Lookup Index Manager" + verbose_name_plural = "Lookup Index Managers" + db_table = "lookup_index_manager" + constraints = [ + models.UniqueConstraint( + fields=["data_source", "profile_manager"], + name="unique_data_source_profile_manager_index", + ), + ] + + def __str__(self): + return f"Index for {self.data_source.file_name} with {self.profile_manager.profile_name if self.profile_manager else 'No Profile'}" + + +def delete_from_vector_db(index_ids_history, vector_db_instance_id): + """Delete index IDs from vector database. + + This function is kept for backward compatibility and signal handler usage. + For new code, prefer using VectorDBCleanupService directly. + + Args: + index_ids_history: List of index IDs to delete + vector_db_instance_id: UUID of the vector DB adapter instance + """ + from lookup.services.vector_db_cleanup_service import VectorDBCleanupService + + cleanup_service = VectorDBCleanupService() + result = cleanup_service.cleanup_index_ids( + index_ids=index_ids_history, + vector_db_instance_id=vector_db_instance_id, + ) + + if result["errors"]: + for error in result["errors"]: + logger.error(error) + + +# Signal to perform vector DB cleanup on deletion +@receiver(pre_delete, sender=LookupIndexManager) +def perform_vector_db_cleanup(sender, instance, **kwargs): + """Signal handler to clean up vector DB entries when index is deleted. + + This ensures that when a LookupIndexManager is deleted (e.g., when + a data source is deleted or re-indexed), the corresponding vectors + are removed from the vector database. + """ + logger.debug( + f"Performing vector DB cleanup for data source: " + f"{instance.data_source.file_name}" + ) + + try: + # Get the index_ids_history to clean up from the vector db + index_ids_history = instance.index_ids_history + + if not index_ids_history: + logger.debug("No index IDs to clean up") + return + + if instance.profile_manager and instance.profile_manager.vector_store: + vector_db_instance_id = str(instance.profile_manager.vector_store.id) + delete_from_vector_db(index_ids_history, vector_db_instance_id) + else: + logger.warning( + f"Cannot cleanup vector DB: missing profile or vector store " + f"for data source {instance.data_source.file_name}" + ) + + except Exception as e: + logger.warning( + f"Error during vector DB cleanup for data source " + f"{instance.data_source.file_name}: {e}", + exc_info=True, + ) diff --git a/backend/lookup/models/lookup_profile_manager.py b/backend/lookup/models/lookup_profile_manager.py new file mode 100644 index 0000000000..681a8897e9 --- /dev/null +++ b/backend/lookup/models/lookup_profile_manager.py @@ -0,0 +1,226 @@ +"""LookupProfileManager model for managing adapter profiles in Look-Up projects.""" + +import logging +import uuid + +from account_v2.models import User +from adapter_processor_v2.models import AdapterInstance +from django.db import models +from django.db.models.signals import pre_delete +from django.dispatch import receiver +from utils.models.base_model import BaseModel + +from lookup.exceptions import DefaultProfileError + +logger = logging.getLogger(__name__) + + +class LookupProfileManager(BaseModel): + """Model to store adapter configuration profiles for Look-Up projects. + + Each profile defines the set of adapters (X2Text, Embedding, VectorDB, LLM) + to use for text extraction, indexing, and lookup operations. + + Follows the same pattern as Prompt Studio's ProfileManager. + """ + + profile_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + + profile_name = models.TextField( + blank=False, null=False, db_comment="Name of the profile" + ) + + # Foreign key to LookupProject + lookup_project = models.ForeignKey( + "LookupProject", + on_delete=models.CASCADE, + related_name="profiles", + db_comment="Look-Up project this profile belongs to", + ) + + # Required Adapters - All must be configured + vector_store = models.ForeignKey( + AdapterInstance, + db_comment="Vector database adapter for storing embeddings", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="lookup_profiles_vector_store", + ) + + embedding_model = models.ForeignKey( + AdapterInstance, + db_comment="Embedding model adapter for generating vectors", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="lookup_profiles_embedding_model", + ) + + llm = models.ForeignKey( + AdapterInstance, + db_comment="LLM adapter for query processing and response generation", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="lookup_profiles_llm", + ) + + x2text = models.ForeignKey( + AdapterInstance, + db_comment="X2Text adapter for extracting text from various file formats", + blank=False, + null=False, + on_delete=models.PROTECT, + related_name="lookup_profiles_x2text", + ) + + # Configuration fields + chunk_size = models.IntegerField( + default=1000, + null=False, + blank=False, + db_comment="Size of text chunks for indexing", + ) + + chunk_overlap = models.IntegerField( + default=200, + null=False, + blank=False, + db_comment="Overlap between consecutive chunks", + ) + + similarity_top_k = models.IntegerField( + default=5, + null=False, + blank=False, + db_comment="Number of top similar chunks to retrieve", + ) + + # Flags + is_default = models.BooleanField( + default=False, db_comment="Whether this is the default profile for the project" + ) + + reindex = models.BooleanField( + default=False, db_comment="Flag to trigger re-indexing of reference data" + ) + + # Audit fields + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="lookup_profile_managers_created", + null=True, + blank=True, + editable=False, + ) + + modified_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="lookup_profile_managers_modified", + null=True, + blank=True, + editable=False, + ) + + class Meta: + verbose_name = "Lookup Profile Manager" + verbose_name_plural = "Lookup Profile Managers" + db_table = "lookup_profile_manager" + constraints = [ + models.UniqueConstraint( + fields=["lookup_project", "profile_name"], + name="unique_lookup_project_profile_name_index", + ), + ] + + def __str__(self): + return f"{self.profile_name} ({self.lookup_project.name})" + + @staticmethod + def get_default_profile(project): + """Get the default profile for a Look-Up project. + + Args: + project: LookupProject instance + + Returns: + LookupProfileManager: The default profile + + Raises: + DefaultProfileError: If no default profile exists + """ + try: + return LookupProfileManager.objects.get( + lookup_project=project, is_default=True + ) + except LookupProfileManager.DoesNotExist: + raise DefaultProfileError( + f"No default profile found for project {project.name}" + ) + + +@receiver(pre_delete, sender=LookupProfileManager) +def cleanup_profile_indexes(sender, instance, **kwargs): + """Clean up all vector DB indexes created with this profile before deletion. + + This signal handler ensures that when a LookupProfileManager is deleted, + all associated vector DB indexes are cleaned up to prevent stale data + accumulation. + + Args: + sender: The model class (LookupProfileManager) + instance: The profile instance being deleted + **kwargs: Additional arguments from the signal + """ + # Import here to avoid circular imports + from lookup.services.vector_db_cleanup_service import VectorDBCleanupService + + try: + # Get all index managers associated with this profile + index_managers = instance.index_managers.all() + + if not index_managers.exists(): + logger.debug( + f"No index managers found for profile {instance.profile_name}, " + "skipping cleanup" + ) + return + + cleanup_service = VectorDBCleanupService() + total_deleted = 0 + total_failed = 0 + errors = [] + + for index_manager in index_managers: + if index_manager.index_ids_history: + result = cleanup_service.cleanup_index_ids( + index_ids=index_manager.index_ids_history, + vector_db_instance_id=str(instance.vector_store.id), + ) + total_deleted += result.get("deleted", 0) + total_failed += result.get("failed", 0) + if result.get("errors"): + errors.extend(result["errors"]) + + if total_deleted > 0: + logger.info( + f"Profile deletion cleanup for '{instance.profile_name}': " + f"deleted {total_deleted} indexes from vector DB" + ) + + if total_failed > 0: + logger.warning( + f"Profile deletion cleanup for '{instance.profile_name}': " + f"failed to delete {total_failed} indexes. Errors: {errors}" + ) + + except Exception as e: + # Log error but don't block deletion - cleanup is best-effort + logger.error( + f"Error during profile deletion cleanup for '{instance.profile_name}': " + f"{str(e)}", + exc_info=True, + ) diff --git a/backend/lookup/models/lookup_project.py b/backend/lookup/models/lookup_project.py new file mode 100644 index 0000000000..7a9014184b --- /dev/null +++ b/backend/lookup/models/lookup_project.py @@ -0,0 +1,163 @@ +"""LookupProject model for Static Data-based Look-Ups.""" + +import uuid + +from django.contrib.auth import get_user_model +from django.db import models +from django.urls import reverse +from utils.models.base_model import BaseModel +from utils.models.organization_mixin import ( + DefaultOrganizationMixin, +) + +User = get_user_model() + + +class LookupProject(DefaultOrganizationMixin, BaseModel): + """Represents a Look-Up project for static data-based enrichment. + + This model stores the configuration for a Look-Up project including + LLM settings and organization association. + """ + + LOOKUP_TYPE_CHOICES = [ + ("static_data", "Static Data"), + ] + + LLM_PROVIDER_CHOICES = [ + ("openai", "OpenAI"), + ("anthropic", "Anthropic"), + ("azure", "Azure OpenAI"), + ("custom", "Custom Provider"), + ] + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + name = models.CharField(max_length=255, help_text="Name of the Look-Up project") + description = models.TextField( + blank=True, null=True, help_text="Description of the Look-Up project's purpose" + ) + lookup_type = models.CharField( + max_length=50, + choices=LOOKUP_TYPE_CHOICES, + default="static_data", + help_text="Type of Look-Up (only static_data for POC)", + ) + + # Template and status + template = models.ForeignKey( + "LookupPromptTemplate", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="projects", + help_text="Prompt template for this project", + ) + is_active = models.BooleanField( + default=True, help_text="Whether this project is active" + ) + metadata = models.JSONField( + default=dict, blank=True, help_text="Additional metadata for the project" + ) + + # LLM Configuration + llm_provider = models.CharField( + max_length=50, + choices=LLM_PROVIDER_CHOICES, + null=True, + blank=True, + help_text="LLM provider to use for matching", + ) + llm_model = models.CharField( + max_length=100, + null=True, + blank=True, + help_text="Specific model name (e.g., gpt-4-turbo, claude-3-opus)", + ) + llm_config = models.JSONField( + default=dict, + blank=True, + help_text="Additional LLM configuration (temperature, max_tokens, etc.)", + ) + + # Ownership + created_by = models.ForeignKey( + User, + on_delete=models.RESTRICT, + related_name="created_lookup_projects", + help_text="User who created this project", + ) + + # Note: created_at and modified_at are inherited from BaseModel + # Note: organization ForeignKey is inherited from DefaultOrganizationMixin + + class Meta: + """Model metadata.""" + + db_table = "lookup_projects" + ordering = ["-created_at"] + verbose_name = "Look-Up Project" + verbose_name_plural = "Look-Up Projects" + indexes = [ + models.Index(fields=["organization"]), + models.Index(fields=["created_by"]), + models.Index(fields=["modified_at"]), + ] + + def __str__(self) -> str: + """String representation of the project.""" + return self.name + + def get_absolute_url(self) -> str: + """Get the URL for this project's detail view.""" + return reverse("lookup:project-detail", kwargs={"pk": self.pk}) + + @property + def is_ready(self) -> bool: + """Check if the project has reference data ready for use. + + Returns: + True if all reference data is extracted and ready, False otherwise. + """ + if not hasattr(self, "data_sources"): + return False + + latest_sources = self.data_sources.filter(is_latest=True) + if not latest_sources.exists(): + return False + + return all(source.extraction_status == "completed" for source in latest_sources) + + def get_latest_reference_version(self) -> int | None: + """Get the latest version number of reference data. + + Returns: + Latest version number or None if no data sources exist. + """ + if not hasattr(self, "data_sources"): + return None + + latest = self.data_sources.filter(is_latest=True).first() + return latest.version_number if latest else None + + def clean(self): + """Validate model fields.""" + super().clean() + from django.core.exceptions import ValidationError + + # Validate LLM provider + valid_providers = [choice[0] for choice in self.LLM_PROVIDER_CHOICES] + if self.llm_provider not in valid_providers: + raise ValidationError( + f"Invalid LLM provider: {self.llm_provider}. " + f"Must be one of: {', '.join(valid_providers)}" + ) + + # Validate LLM config structure + if self.llm_config and not isinstance(self.llm_config, dict): + raise ValidationError("LLM config must be a dictionary") + + # Validate lookup_type (only static_data for POC) + if self.lookup_type != "static_data": + raise ValidationError( + "Only 'static_data' lookup type is supported in this POC" + ) diff --git a/backend/lookup/models/lookup_prompt_template.py b/backend/lookup/models/lookup_prompt_template.py new file mode 100644 index 0000000000..62bd7aae5d --- /dev/null +++ b/backend/lookup/models/lookup_prompt_template.py @@ -0,0 +1,178 @@ +"""LookupPromptTemplate model for managing prompt templates.""" + +import re +import uuid + +from account_v2.models import User +from django.core.exceptions import ValidationError +from django.db import models +from utils.models.base_model import BaseModel + + +class LookupPromptTemplate(BaseModel): + """Represents a prompt template with variable detection and validation. + + Each Look-Up project has one template that defines how to construct + the LLM prompt with {{variable}} placeholders. + """ + + VARIABLE_PATTERN = r"\{\{([^}]+)\}\}" + RESERVED_PREFIXES = ["_", "_lookup_"] + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + project = models.OneToOneField( + "lookup.LookupProject", + on_delete=models.CASCADE, + related_name="prompt_template_link", + help_text="Parent Look-Up project", + ) + + # Template Configuration + name = models.CharField(max_length=255, help_text="Template name for identification") + template_text = models.TextField(help_text="Template with {{variable}} placeholders") + llm_config = models.JSONField( + default=dict, blank=True, help_text="LLM configuration including adapter_id" + ) + is_active = models.BooleanField( + default=True, help_text="Whether this template is active" + ) + created_by = models.ForeignKey( + User, + on_delete=models.RESTRICT, + related_name="created_lookup_templates", + help_text="User who created this template", + ) + variable_mappings = models.JSONField( + default=dict, blank=True, help_text="Optional documentation of variable mappings" + ) + + class Meta: + """Model metadata.""" + + db_table = "lookup_prompt_templates" + ordering = ["-modified_at"] + verbose_name = "Look-Up Prompt Template" + verbose_name_plural = "Look-Up Prompt Templates" + + def __str__(self) -> str: + """String representation.""" + return f"Template for {self.project.name}" + + def detect_variables(self) -> list[str]: + """Detect all {{variable}} references in the template. + + Returns: + List of unique variable paths found in the template. + + Example: + Template: "Match {{input_data.vendor}} from {{reference_data}}" + Returns: ["input_data.vendor", "reference_data"] + """ + if not self.template_text: + return [] + + matches = re.findall(self.VARIABLE_PATTERN, self.template_text) + # Strip whitespace and deduplicate + unique_vars = list({m.strip() for m in matches}) + return sorted(unique_vars) + + def validate_syntax(self) -> bool: + """Validate template syntax for matching braces. + + Returns: + True if syntax is valid, False otherwise. + + Raises: + ValidationError: If syntax is invalid. + """ + # Check for unmatched opening braces + open_count = self.template_text.count("{{") + close_count = self.template_text.count("}}") + + if open_count != close_count: + raise ValidationError( + f"Mismatched braces in template: {open_count} opening, " + f"{close_count} closing" + ) + + # Check for nested braces (not allowed) + if re.search(r"\{\{[^}]*\{\{", self.template_text): + raise ValidationError("Nested variable placeholders are not allowed") + + return True + + def validate_reserved_keywords(self) -> bool: + """Check that template doesn't use reserved keywords. + + Returns: + True if no reserved keywords are used. + + Raises: + ValidationError: If reserved keywords are found. + """ + variables = self.detect_variables() + + for var in variables: + # Check if variable starts with reserved prefixes + for prefix in self.RESERVED_PREFIXES: + if var.startswith(prefix): + raise ValidationError( + f"Variable '{var}' uses reserved prefix '{prefix}'. " + f"Reserved prefixes: {', '.join(self.RESERVED_PREFIXES)}" + ) + + # Check if trying to write to reserved fields + if "=" in var or var.endswith("_metadata"): + raise ValidationError( + f"Variable '{var}' appears to be trying to set a value. " + f"Variables should only reference existing data." + ) + + return True + + def get_variable_info(self) -> dict: + """Get detailed information about detected variables. + + Returns: + Dictionary with variable paths and their types/documentation. + """ + variables = self.detect_variables() + info = {} + + for var in variables: + parts = var.split(".") + if len(parts) > 0: + root = parts[0] + path = ".".join(parts[1:]) if len(parts) > 1 else "" + + info[var] = { + "root": root, + "path": path, + "depth": len(parts), + "description": self.variable_mappings.get(var, "No description"), + } + + return info + + def clean(self): + """Validate the template on save.""" + super().clean() + + if not self.template_text: + raise ValidationError("Template text cannot be empty") + + try: + self.validate_syntax() + self.validate_reserved_keywords() + except ValidationError as e: + raise e + + # Warn if no variables detected (might be intentional) + if not self.detect_variables(): + # This is just a warning, not an error + pass + + @property + def variable_count(self) -> int: + """Get the count of unique variables in the template.""" + return len(self.detect_variables()) diff --git a/backend/lookup/models/prompt_studio_lookup_link.py b/backend/lookup/models/prompt_studio_lookup_link.py new file mode 100644 index 0000000000..0a03cd7f63 --- /dev/null +++ b/backend/lookup/models/prompt_studio_lookup_link.py @@ -0,0 +1,128 @@ +"""PromptStudioLookupLink model for linking PS projects with Look-Up projects.""" + +import uuid + +from django.db import models + + +class PromptStudioLookupLinkManager(models.Manager): + """Custom manager for PromptStudioLookupLink.""" + + def get_links_for_ps_project(self, ps_project_id: uuid.UUID): + """Get all Look-Up links for a Prompt Studio project, ordered by execution order. + + Args: + ps_project_id: UUID of the Prompt Studio project + + Returns: + QuerySet of links ordered by execution_order + """ + return self.filter(prompt_studio_project_id=ps_project_id).order_by( + "execution_order", "created_at" + ) + + def get_lookup_projects_for_ps(self, ps_project_id: uuid.UUID): + """Get all Look-Up projects linked to a Prompt Studio project. + + Args: + ps_project_id: UUID of the Prompt Studio project + + Returns: + QuerySet of LookupProject instances + """ + from .lookup_project import LookupProject + + link_ids = self.filter(prompt_studio_project_id=ps_project_id).values_list( + "lookup_project_id", flat=True + ) + + return LookupProject.objects.filter(id__in=link_ids) + + +class PromptStudioLookupLink(models.Model): + """Many-to-many relationship between Prompt Studio projects and Look-Up projects. + + Manages the linking and execution order of Look-Up projects within + a Prompt Studio extraction pipeline. + """ + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + + # Note: We're using UUIDs for now since PS project model isn't defined yet + # In production, this should be a ForeignKey to the actual PS project model + prompt_studio_project_id = models.UUIDField( + help_text="UUID of the Prompt Studio project" + ) + + lookup_project = models.ForeignKey( + "lookup.LookupProject", + on_delete=models.CASCADE, + related_name="ps_links", + help_text="Linked Look-Up project", + ) + + execution_order = models.PositiveIntegerField( + default=0, help_text="Order in which this Look-Up executes (lower numbers first)" + ) + + # Timestamps + created_at = models.DateTimeField(auto_now_add=True) + + objects = PromptStudioLookupLinkManager() + + class Meta: + """Model metadata.""" + + db_table = "prompt_studio_lookup_links" + ordering = ["execution_order", "created_at"] + unique_together = [["prompt_studio_project_id", "lookup_project"]] + verbose_name = "Prompt Studio Look-Up Link" + verbose_name_plural = "Prompt Studio Look-Up Links" + indexes = [ + models.Index(fields=["prompt_studio_project_id"]), + models.Index(fields=["lookup_project"]), + models.Index(fields=["prompt_studio_project_id", "execution_order"]), + ] + + def __str__(self) -> str: + """String representation.""" + return f"PS Project {self.prompt_studio_project_id} → {self.lookup_project.name}" + + def save(self, *args, **kwargs): + """Override save to auto-assign execution_order if not set. + + Auto-assigns the next available execution order number for the + Prompt Studio project if not explicitly provided. + """ + if self.execution_order == 0 and not self.pk: + # Get the maximum execution order for this PS project + max_order = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.prompt_studio_project_id + ).aggregate(max_order=models.Max("execution_order"))["max_order"] + + # Set to max + 1, or 1 if no existing links + self.execution_order = (max_order or 0) + 1 + + super().save(*args, **kwargs) + + def clean(self): + """Validate the link.""" + super().clean() + + # In production, we would validate that both projects + # belong to the same organization here + # For now, we'll just ensure the lookup project is ready + + if self.lookup_project and not self.lookup_project.is_ready: + # This is a warning, not an error - allow linking but warn + # In production, this might log a warning + pass + + @property + def is_enabled(self) -> bool: + """Check if this link is enabled and ready for execution. + + Returns: + True if the linked Look-Up project is ready for use. + """ + return self.lookup_project.is_ready if self.lookup_project else False diff --git a/backend/lookup/serializers.py b/backend/lookup/serializers.py new file mode 100644 index 0000000000..5c30e8ac0e --- /dev/null +++ b/backend/lookup/serializers.py @@ -0,0 +1,362 @@ +"""Django REST Framework serializers for Look-Up API. + +This module provides serializers for all Look-Up models +to support RESTful API operations. +""" + +import logging +from typing import Any + +from adapter_processor_v2.adapter_processor import AdapterProcessor +from rest_framework import serializers + +from backend.serializers import AuditSerializer + +from .constants import LookupProfileManagerKeys +from .models import ( + LookupDataSource, + LookupExecutionAudit, + LookupProfileManager, + LookupProject, + LookupPromptTemplate, + PromptStudioLookupLink, +) + +logger = logging.getLogger(__name__) + + +class LookupPromptTemplateSerializer(serializers.ModelSerializer): + """Serializer for LookupPromptTemplate model.""" + + class Meta: + model = LookupPromptTemplate + fields = [ + "id", + "project", + "name", + "template_text", + "llm_config", + "is_active", + "created_by", + "created_at", + "modified_at", + ] + read_only_fields = ["id", "created_by", "created_at", "modified_at"] + + def validate_template_text(self, value: str) -> str: + """Validate that template text contains required placeholders.""" + if "{{reference_data}}" not in value: + raise serializers.ValidationError( + "Template must contain {{reference_data}} placeholder" + ) + return value + + def validate_llm_config(self, value: dict[str, Any]) -> dict[str, Any]: + """Validate LLM configuration structure.""" + # Accept either new format (adapter_id) or legacy format (provider + model) + has_adapter_id = "adapter_id" in value + has_legacy = "provider" in value and "model" in value + + if not has_adapter_id and not has_legacy: + raise serializers.ValidationError( + "llm_config must contain either 'adapter_id' (recommended) " + "or both 'provider' and 'model' fields" + ) + return value + + +class LookupDataSourceSerializer(serializers.ModelSerializer): + """Serializer for LookupDataSource model.""" + + extraction_status_display = serializers.CharField( + source="get_extraction_status_display", read_only=True + ) + + class Meta: + model = LookupDataSource + fields = [ + "id", + "project", + "file_name", + "file_path", + "file_size", + "file_type", + "extracted_content_path", + "extraction_status", + "extraction_status_display", + "extraction_error", + "version_number", + "is_latest", + "uploaded_by", + "created_at", + "modified_at", + ] + read_only_fields = [ + "id", + "version_number", + "is_latest", + "uploaded_by", + "created_at", + "modified_at", + "extraction_status_display", + ] + + +class LookupProjectSerializer(serializers.ModelSerializer): + """Serializer for LookupProject model.""" + + template = LookupPromptTemplateSerializer(read_only=True) + template_id = serializers.PrimaryKeyRelatedField( + queryset=LookupPromptTemplate.objects.all(), + source="template", + write_only=True, + allow_null=True, + required=False, + ) + data_source_count = serializers.SerializerMethodField() + latest_version = serializers.SerializerMethodField() + + class Meta: + model = LookupProject + fields = [ + "id", + "name", + "description", + "template", + "template_id", + "is_active", + "data_source_count", + "latest_version", + "metadata", + "created_by", + "created_at", + "modified_at", + ] + read_only_fields = [ + "id", + "data_source_count", + "latest_version", + "created_by", + "created_at", + "modified_at", + ] + + def get_data_source_count(self, obj) -> int: + """Get count of data sources for this project.""" + return obj.data_sources.filter(is_latest=True).count() + + def get_latest_version(self, obj) -> int: + """Get latest version number of data sources.""" + latest = obj.data_sources.filter(is_latest=True).first() + return latest.version_number if latest else 0 + + +class PromptStudioLookupLinkSerializer(serializers.ModelSerializer): + """Serializer for linking Look-Ups to Prompt Studio projects.""" + + lookup_project_name = serializers.CharField( + source="lookup_project.name", read_only=True + ) + + class Meta: + model = PromptStudioLookupLink + fields = [ + "id", + "prompt_studio_project_id", + "lookup_project", + "lookup_project_name", + "created_at", + ] + read_only_fields = ["id", "lookup_project_name", "created_at"] + + def validate(self, attrs): + """Validate that the link doesn't already exist.""" + ps_project_id = attrs.get("prompt_studio_project_id") + lookup_project = attrs.get("lookup_project") + + # Check if this combination already exists + existing = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=ps_project_id, lookup_project=lookup_project + ).exists() + + if existing and not self.instance: # Only check for creation, not update + raise serializers.ValidationError( + "This Look-Up is already linked to the Prompt Studio project" + ) + + return attrs + + +class LookupExecutionAuditSerializer(serializers.ModelSerializer): + """Serializer for execution audit records.""" + + lookup_project_name = serializers.CharField( + source="lookup_project.name", read_only=True + ) + + class Meta: + model = LookupExecutionAudit + fields = [ + "id", + "lookup_project", + "lookup_project_name", + "prompt_studio_project_id", + "execution_id", + "file_execution_id", + "input_data", + "enriched_output", + "reference_data_version", + "llm_provider", + "llm_model", + "llm_prompt", + "llm_response", + "llm_response_cached", + "execution_time_ms", + "llm_call_time_ms", + "status", + "error_message", + "confidence_score", + "executed_at", + ] + read_only_fields = fields # All fields are read-only for audit records + + +class LookupExecutionAuditSummarySerializer(serializers.ModelSerializer): + """Lightweight serializer for lookup audit summaries in log views. + + Used by the Nav Bar Logs page to show lookup enrichment details + without the full prompt/response content. + """ + + lookup_project_name = serializers.CharField( + source="lookup_project.name", read_only=True + ) + + class Meta: + model = LookupExecutionAudit + fields = [ + "id", + "lookup_project", + "lookup_project_name", + "execution_id", + "file_execution_id", + "status", + "llm_response_cached", + "execution_time_ms", + "confidence_score", + "error_message", + "executed_at", + ] + read_only_fields = fields + + +class LookupExecutionRequestSerializer(serializers.Serializer): + """Serializer for Look-Up execution requests.""" + + input_data = serializers.JSONField(help_text="Input data for variable resolution") + lookup_project_ids = serializers.ListField( + child=serializers.UUIDField(), + required=False, + help_text="List of Look-Up project IDs to execute", + ) + use_cache = serializers.BooleanField( + default=True, help_text="Whether to use cached LLM responses" + ) + timeout_seconds = serializers.IntegerField( + default=30, + min_value=1, + max_value=300, + help_text="Timeout for execution in seconds", + ) + + +class LookupExecutionResponseSerializer(serializers.Serializer): + """Serializer for Look-Up execution responses.""" + + lookup_enrichment = serializers.JSONField( + help_text="Merged enrichment data from all Look-Ups" + ) + _lookup_metadata = serializers.JSONField( + help_text="Execution metadata including timing and status" + ) + + +class ReferenceDataUploadSerializer(serializers.Serializer): + """Serializer for reference data upload requests.""" + + file = serializers.FileField(help_text="Reference data file to upload") + extract_text = serializers.BooleanField( + default=True, help_text="Whether to extract text from the file" + ) + metadata = serializers.JSONField( + required=False, default=dict, help_text="Additional metadata for the data source" + ) + + +class BulkLinkSerializer(serializers.Serializer): + """Serializer for bulk linking operations.""" + + prompt_studio_project_id = serializers.UUIDField(help_text="Prompt Studio project ID") + lookup_project_ids = serializers.ListField( + child=serializers.UUIDField(), help_text="List of Look-Up project IDs to link" + ) + unlink = serializers.BooleanField( + default=False, help_text="If true, unlink instead of link" + ) + + +class TemplateValidationSerializer(serializers.Serializer): + """Serializer for template validation requests.""" + + template_text = serializers.CharField(help_text="Template text to validate") + sample_data = serializers.JSONField( + required=False, help_text="Sample input data for testing variable resolution" + ) + sample_reference = serializers.CharField( + required=False, help_text="Sample reference data for testing" + ) + + +class LookupProfileManagerSerializer(AuditSerializer): + """Serializer for LookupProfileManager model. + + Follows the same pattern as Prompt Studio's ProfileManagerSerializer. + Expands adapter UUIDs to full adapter details in the response. + """ + + class Meta: + model = LookupProfileManager + fields = "__all__" + + def to_representation(self, instance): + """Expand adapter UUIDs to full adapter details. + + This converts the FK references to AdapterInstance objects + into full adapter details including adapter_name, adapter_type, etc. + """ + rep: dict[str, str] = super().to_representation(instance) + + # Expand each adapter FK to full details + llm = rep.get(LookupProfileManagerKeys.LLM) + embedding = rep.get(LookupProfileManagerKeys.EMBEDDING_MODEL) + vector_db = rep.get(LookupProfileManagerKeys.VECTOR_STORE) + x2text = rep.get(LookupProfileManagerKeys.X2TEXT) + + if llm: + rep[LookupProfileManagerKeys.LLM] = ( + AdapterProcessor.get_adapter_instance_by_id(llm) + ) + if embedding: + rep[LookupProfileManagerKeys.EMBEDDING_MODEL] = ( + AdapterProcessor.get_adapter_instance_by_id(embedding) + ) + if vector_db: + rep[LookupProfileManagerKeys.VECTOR_STORE] = ( + AdapterProcessor.get_adapter_instance_by_id(vector_db) + ) + if x2text: + rep[LookupProfileManagerKeys.X2TEXT] = ( + AdapterProcessor.get_adapter_instance_by_id(x2text) + ) + + return rep diff --git a/backend/lookup/services/__init__.py b/backend/lookup/services/__init__.py new file mode 100644 index 0000000000..1e7332ee1c --- /dev/null +++ b/backend/lookup/services/__init__.py @@ -0,0 +1,29 @@ +"""Service layer implementations for the Look-Up system. + +This package contains the core business logic and service classes +for the Static Data-based Look-Ups feature. +""" + +from .audit_logger import AuditLogger +from .document_indexing_service import LookupDocumentIndexingService +from .enrichment_merger import EnrichmentMerger +from .indexing_service import IndexingService +from .llm_cache import LLMResponseCache +from .lookup_executor import LookUpExecutor +from .lookup_index_helper import LookupIndexHelper +from .lookup_orchestrator import LookUpOrchestrator +from .reference_data_loader import ReferenceDataLoader +from .variable_resolver import VariableResolver + +__all__ = [ + "AuditLogger", + "LookupDocumentIndexingService", + "EnrichmentMerger", + "IndexingService", + "LookupIndexHelper", + "LLMResponseCache", + "LookUpExecutor", + "LookUpOrchestrator", + "ReferenceDataLoader", + "VariableResolver", +] diff --git a/backend/lookup/services/audit_logger.py b/backend/lookup/services/audit_logger.py new file mode 100644 index 0000000000..d8abbe9e76 --- /dev/null +++ b/backend/lookup/services/audit_logger.py @@ -0,0 +1,313 @@ +"""Audit Logger implementation for tracking Look-Up executions. + +This module provides functionality to log all Look-Up execution details +to the database for debugging, monitoring, and compliance purposes. +""" + +import logging +from decimal import Decimal +from typing import Any +from uuid import UUID + +from lookup.models import LookupExecutionAudit, LookupProject + +logger = logging.getLogger(__name__) + + +class AuditLogger: + """Logs Look-Up execution details to lookup_execution_audit table. + + This class provides methods to record all aspects of Look-Up executions + including inputs, outputs, performance metrics, and errors for audit + trail and debugging purposes. + """ + + def log_execution( + self, + execution_id: str, + lookup_project_id: UUID, + prompt_studio_project_id: UUID | None, + input_data: dict[str, Any], + reference_data_version: int, + llm_provider: str, + llm_model: str, + llm_prompt: str, + llm_response: str | None, + enriched_output: dict[str, Any] | None, + status: str, # 'success', 'partial', 'failed' + confidence_score: float | None = None, + execution_time_ms: int | None = None, + llm_call_time_ms: int | None = None, + llm_response_cached: bool = False, + error_message: str | None = None, + file_execution_id: UUID | None = None, + ) -> LookupExecutionAudit | None: + """Log execution to database. + + Records comprehensive details about a Look-Up execution including + the input data, LLM interaction, output, and performance metrics. + + Args: + execution_id: UUID of the orchestration execution + lookup_project_id: UUID of the Look-Up project + prompt_studio_project_id: Optional UUID of PS project + input_data: Input data used for variable resolution + reference_data_version: Version of reference data used + llm_provider: LLM provider name (e.g., 'openai') + llm_model: LLM model name (e.g., 'gpt-4') + llm_prompt: Final resolved prompt sent to LLM + llm_response: Raw response from LLM + enriched_output: Parsed enrichment data + status: Execution status ('success', 'partial', 'failed') + confidence_score: Optional confidence score (0.0-1.0) + execution_time_ms: Total execution time in milliseconds + llm_call_time_ms: Time spent calling LLM in milliseconds + llm_response_cached: Whether response was from cache + error_message: Error message if execution failed + file_execution_id: Optional workflow file execution ID for tracking + + Returns: + Created LookupExecutionAudit instance or None if logging fails + + Example: + >>> logger = AuditLogger() + >>> audit = logger.log_execution( + ... execution_id='abc-123', + ... lookup_project_id=project_id, + ... status='success', + ... ... + ... ) + """ + try: + # Get the Look-Up project + try: + lookup_project = LookupProject.objects.get(id=lookup_project_id) + except LookupProject.DoesNotExist: + logger.error(f"Look-Up project {lookup_project_id} not found for audit") + return None + + # Convert confidence score to Decimal if provided + if confidence_score is not None: + confidence_score = Decimal(str(confidence_score)) + + # Create audit record + audit = LookupExecutionAudit.objects.create( + lookup_project=lookup_project, + prompt_studio_project_id=prompt_studio_project_id, + execution_id=execution_id, + file_execution_id=file_execution_id, + input_data=input_data, + reference_data_version=reference_data_version, + enriched_output=enriched_output, + llm_provider=llm_provider, + llm_model=llm_model, + llm_prompt=llm_prompt, + llm_response=llm_response, + llm_response_cached=llm_response_cached, + execution_time_ms=execution_time_ms, + llm_call_time_ms=llm_call_time_ms, + status=status, + error_message=error_message, + confidence_score=confidence_score, + ) + + logger.debug( + f"Logged execution audit {audit.id} for Look-Up {lookup_project.name} " + f"(execution {execution_id})" + ) + + return audit + + except Exception as e: + # Log error but don't fail the execution + logger.exception( + f"Failed to log execution audit for {lookup_project_id}: {str(e)}" + ) + return None + + def log_success( + self, execution_id: str, project_id: UUID, **kwargs + ) -> LookupExecutionAudit | None: + """Convenience method for logging successful execution. + + Args: + execution_id: UUID of the orchestration execution + project_id: UUID of the Look-Up project + **kwargs: Additional parameters to pass to log_execution + + Returns: + Created audit record or None if logging fails + + Example: + >>> audit = logger.log_success( + ... execution_id="abc-123", + ... project_id=project_id, + ... input_data={"vendor": "Slack"}, + ... enriched_output={"canonical_vendor": "Slack"}, + ... confidence_score=0.92, + ... ) + """ + return self.log_execution( + execution_id=execution_id, + lookup_project_id=project_id, + status="success", + **kwargs, + ) + + def log_failure( + self, execution_id: str, project_id: UUID, error: str, **kwargs + ) -> LookupExecutionAudit | None: + """Convenience method for logging failed execution. + + Args: + execution_id: UUID of the orchestration execution + project_id: UUID of the Look-Up project + error: Error message describing the failure + **kwargs: Additional parameters to pass to log_execution + + Returns: + Created audit record or None if logging fails + + Example: + >>> audit = logger.log_failure( + ... execution_id="abc-123", + ... project_id=project_id, + ... error="LLM timeout after 30 seconds", + ... input_data={"vendor": "Slack"}, + ... ) + """ + return self.log_execution( + execution_id=execution_id, + lookup_project_id=project_id, + status="failed", + error_message=error, + **kwargs, + ) + + def log_partial( + self, execution_id: str, project_id: UUID, **kwargs + ) -> LookupExecutionAudit | None: + """Convenience method for logging partial success. + + Used when some enrichment was achieved but with issues + (e.g., low confidence, incomplete data). + + Args: + execution_id: UUID of the orchestration execution + project_id: UUID of the Look-Up project + **kwargs: Additional parameters to pass to log_execution + + Returns: + Created audit record or None if logging fails + + Example: + >>> audit = logger.log_partial( + ... execution_id="abc-123", + ... project_id=project_id, + ... input_data={"vendor": "Unknown Corp"}, + ... enriched_output={"canonical_vendor": "Unknown"}, + ... confidence_score=0.35, + ... error_message="Low confidence match", + ... ) + """ + return self.log_execution( + execution_id=execution_id, + lookup_project_id=project_id, + status="partial", + **kwargs, + ) + + def get_execution_history(self, execution_id: str, limit: int = 100) -> list: + """Retrieve audit records for a specific execution. + + Args: + execution_id: UUID of the orchestration execution + limit: Maximum number of records to return + + Returns: + List of LookupExecutionAudit instances + + Example: + >>> history = logger.get_execution_history("abc-123") + >>> for audit in history: + ... print(f"{audit.lookup_project.name}: {audit.status}") + """ + try: + return list( + LookupExecutionAudit.objects.filter(execution_id=execution_id) + .select_related("lookup_project") + .order_by("executed_at")[:limit] + ) + except Exception as e: + logger.error(f"Failed to retrieve execution history: {str(e)}") + return [] + + def get_project_stats(self, project_id: UUID, limit: int = 1000) -> dict[str, Any]: + """Get execution statistics for a Look-Up project. + + Args: + project_id: UUID of the Look-Up project + limit: Maximum number of records to analyze + + Returns: + Dictionary with statistics including success rate, + average execution time, cache hit rate, etc. + + Example: + >>> stats = logger.get_project_stats(project_id) + >>> print(f"Success rate: {stats['success_rate']:.1%}") + """ + try: + audits = LookupExecutionAudit.objects.filter( + lookup_project_id=project_id + ).order_by("-executed_at")[:limit] + + total = len(audits) + if total == 0: + return { + "total_executions": 0, + "success_rate": 0.0, + "avg_execution_time_ms": 0, + "cache_hit_rate": 0.0, + "avg_confidence": 0.0, + } + + successful = sum(1 for a in audits if a.status == "success") + cached = sum(1 for a in audits if a.llm_response_cached) + + exec_times = [ + a.execution_time_ms for a in audits if a.execution_time_ms is not None + ] + avg_exec_time = sum(exec_times) / len(exec_times) if exec_times else 0 + + confidence_scores = [ + float(a.confidence_score) + for a in audits + if a.confidence_score is not None + ] + avg_confidence = ( + sum(confidence_scores) / len(confidence_scores) + if confidence_scores + else 0.0 + ) + + return { + "total_executions": total, + "success_rate": successful / total if total > 0 else 0.0, + "avg_execution_time_ms": int(avg_exec_time), + "cache_hit_rate": cached / total if total > 0 else 0.0, + "avg_confidence": avg_confidence, + "successful": successful, + "failed": sum(1 for a in audits if a.status == "failed"), + "partial": sum(1 for a in audits if a.status == "partial"), + } + + except Exception as e: + logger.error(f"Failed to get project stats: {str(e)}") + return { + "total_executions": 0, + "success_rate": 0.0, + "avg_execution_time_ms": 0, + "cache_hit_rate": 0.0, + "avg_confidence": 0.0, + } diff --git a/backend/lookup/services/document_indexing_service.py b/backend/lookup/services/document_indexing_service.py new file mode 100644 index 0000000000..5da68e8542 --- /dev/null +++ b/backend/lookup/services/document_indexing_service.py @@ -0,0 +1,141 @@ +"""Document Indexing Service for Lookup projects. + +This service manages indexing status tracking using cache to prevent +duplicate indexing operations and track in-progress indexing. + +Based on Prompt Studio's DocumentIndexingService pattern. +""" + +import logging + +from django.core.cache import cache + +logger = logging.getLogger(__name__) + + +class LookupDocumentIndexingService: + """Cache-based service to track document indexing status. + + Prevents duplicate indexing and tracks in-progress operations. + """ + + # Cache key format: lookup_indexing:{org_id}:{user_id}:{doc_id_key} + CACHE_KEY_PREFIX = "lookup_indexing" + CACHE_TIMEOUT = 3600 # 1 hour + + # Status values + STATUS_INDEXING = "INDEXING" # Currently being indexed + + @staticmethod + def _get_cache_key(org_id: str, user_id: str, doc_id_key: str) -> str: + """Generate cache key for document indexing status. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key (hash of indexing parameters) + + Returns: + Cache key string + """ + return f"{LookupDocumentIndexingService.CACHE_KEY_PREFIX}:{org_id}:{user_id}:{doc_id_key}" + + @staticmethod + def is_document_indexing(org_id: str, user_id: str, doc_id_key: str) -> bool: + """Check if document is currently being indexed. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + + Returns: + True if document is being indexed, False otherwise + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + status = cache.get(cache_key) + + if status == LookupDocumentIndexingService.STATUS_INDEXING: + logger.debug(f"Document {doc_id_key} is currently being indexed") + return True + + return False + + @staticmethod + def set_document_indexing(org_id: str, user_id: str, doc_id_key: str) -> None: + """Mark document as being indexed. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + cache.set( + cache_key, + LookupDocumentIndexingService.STATUS_INDEXING, + LookupDocumentIndexingService.CACHE_TIMEOUT, + ) + logger.debug(f"Marked document {doc_id_key} as being indexed") + + @staticmethod + def mark_document_indexed( + org_id: str, user_id: str, doc_id_key: str, doc_id: str + ) -> None: + """Mark document as indexed with final doc_id. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + doc_id: Final document ID from indexing service + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + # Store the final doc_id instead of status + cache.set(cache_key, doc_id, LookupDocumentIndexingService.CACHE_TIMEOUT) + logger.debug(f"Marked document {doc_id_key} as indexed with ID {doc_id}") + + @staticmethod + def get_indexed_document_id(org_id: str, user_id: str, doc_id_key: str) -> str | None: + """Get indexed document ID if already indexed. + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + + Returns: + Document ID if indexed, None otherwise + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + cached_value = cache.get(cache_key) + + # Return doc_id only if it's not the "INDEXING" status + if cached_value and cached_value != LookupDocumentIndexingService.STATUS_INDEXING: + logger.debug(f"Document {doc_id_key} already indexed with ID {cached_value}") + return cached_value + + return None + + @staticmethod + def clear_indexing_status(org_id: str, user_id: str, doc_id_key: str) -> None: + """Clear indexing status from cache (e.g., on error). + + Args: + org_id: Organization ID + user_id: User ID + doc_id_key: Document ID key + """ + cache_key = LookupDocumentIndexingService._get_cache_key( + org_id, user_id, doc_id_key + ) + cache.delete(cache_key) + logger.debug(f"Cleared indexing status for document {doc_id_key}") diff --git a/backend/lookup/services/enrichment_merger.py b/backend/lookup/services/enrichment_merger.py new file mode 100644 index 0000000000..47da11b079 --- /dev/null +++ b/backend/lookup/services/enrichment_merger.py @@ -0,0 +1,174 @@ +"""Enrichment Merger implementation for combining multiple Look-Up results. + +This module provides functionality to merge enrichment data from multiple +Look-Up executions with confidence-based conflict resolution. +""" + +from typing import Any + + +class EnrichmentMerger: + """Merges enrichments from multiple Look-Ups with conflict resolution. + + When multiple Look-Ups run on the same input data, they may produce + overlapping enrichment fields. This class handles merging those results + and resolving conflicts based on confidence scores. + """ + + def merge(self, enrichments: list[dict[str, Any]]) -> dict[str, Any]: + """Merge enrichments with conflict resolution. + + Combines enrichment data from multiple Look-Up executions, + resolving conflicts based on confidence scores when the same + field appears in multiple enrichments. + + Args: + enrichments: List of dicts with structure: + { + 'project_id': UUID, + 'project_name': str, + 'data': Dict[str, Any], # Enrichment fields + 'confidence': Optional[float], # 0.0-1.0 + 'execution_time_ms': int, + 'cached': bool + } + + Returns: + Dictionary containing: + - data: Merged enrichment data + - conflicts_resolved: Number of conflicts resolved + - enrichment_details: Per-lookup metadata + + Example: + >>> merger = EnrichmentMerger() + >>> enrichments = [ + ... { + ... "project_id": uuid1, + ... "project_name": "Vendor Matcher", + ... "data": {"vendor": "Slack", "category": "SaaS"}, + ... "confidence": 0.95, + ... "execution_time_ms": 1234, + ... "cached": False, + ... }, + ... { + ... "project_id": uuid2, + ... "project_name": "Product Classifier", + ... "data": {"category": "Communication", "type": "Software"}, + ... "confidence": 0.80, + ... "execution_time_ms": 567, + ... "cached": True, + ... }, + ... ] + >>> result = merger.merge(enrichments) + >>> print(result["data"]) + {'vendor': 'Slack', 'category': 'SaaS', 'type': 'Software'} + >>> print(result["conflicts_resolved"]) + 1 + """ + merged_data = {} + field_sources = {} # Track which lookup contributed each field + conflicts_resolved = 0 + enrichment_details = [] + + # Process each enrichment + for enrichment in enrichments: + project_id = enrichment.get("project_id") + project_name = enrichment.get("project_name", "Unknown") + data = enrichment.get("data", {}) + confidence = enrichment.get("confidence") + execution_time_ms = enrichment.get("execution_time_ms", 0) + cached = enrichment.get("cached", False) + + fields_added = [] + + # Process each field in the enrichment + for field_name, field_value in data.items(): + field_entry = { + "lookup_id": project_id, + "lookup_name": project_name, + "confidence": confidence, + "value": field_value, + } + + if field_name not in field_sources: + # No conflict, add the field + field_sources[field_name] = field_entry + merged_data[field_name] = field_value + fields_added.append(field_name) + else: + # Conflict! Resolve it + existing = field_sources[field_name] + winner = self._resolve_conflict(field_name, existing, field_entry) + + # Check if resolution changed the winner + if winner["lookup_id"] != existing["lookup_id"]: + # New winner, update the merged data + field_sources[field_name] = winner + merged_data[field_name] = winner["value"] + fields_added.append(field_name) + conflicts_resolved += 1 + elif winner["lookup_id"] == field_entry["lookup_id"]: + # Current enrichment won but was not originally there + conflicts_resolved += 1 + + # Track enrichment details + enrichment_details.append( + { + "lookup_project_id": str(project_id) if project_id else None, + "lookup_project_name": project_name, + "confidence": confidence, + "cached": cached, + "execution_time_ms": execution_time_ms, + "fields_added": fields_added, + } + ) + + return { + "data": merged_data, + "conflicts_resolved": conflicts_resolved, + "enrichment_details": enrichment_details, + } + + def _resolve_conflict(self, field_name: str, existing: dict, new: dict) -> dict: + """Resolve conflict for a single field. + + Uses confidence scores to determine which value to keep. + When confidence scores are equal or absent, uses first-complete-wins + strategy (keeps the existing value). + + Args: + field_name: Name of the field with conflict (for context) + existing: Dict with {lookup_id, lookup_name, confidence, value} + new: Dict with {lookup_id, lookup_name, confidence, value} + + Returns: + Winner dict with same structure + + Resolution rules: + 1. Both have confidence: higher confidence wins + 2. Equal confidence: first-complete wins (existing) + 3. One has confidence: confidence one wins + 4. Neither has confidence: first-complete wins (existing) + """ + existing_confidence = existing.get("confidence") + new_confidence = new.get("confidence") + + # Case 1: Both have confidence scores + if existing_confidence is not None and new_confidence is not None: + if new_confidence > existing_confidence: + return new + else: + # Equal or existing is higher: keep existing (first-complete-wins) + return existing + + # Case 2: Only new has confidence + elif existing_confidence is None and new_confidence is not None: + return new + + # Case 3: Only existing has confidence + elif existing_confidence is not None and new_confidence is None: + return existing + + # Case 4: Neither has confidence - first-complete-wins + else: + return existing diff --git a/backend/lookup/services/execution_context.py b/backend/lookup/services/execution_context.py new file mode 100644 index 0000000000..e5bcf8da1a --- /dev/null +++ b/backend/lookup/services/execution_context.py @@ -0,0 +1,104 @@ +"""Execution context for Look-up operations. + +This module provides a dataclass for managing execution context +across different execution environments (Prompt Studio vs Workflow). +""" + +from dataclasses import dataclass, field +from uuid import UUID + + +@dataclass +class LookupExecutionContext: + """Context for Look-up execution supporting both PS and Workflow contexts. + + This dataclass encapsulates all context information needed for Look-up + execution, including logging configuration for both real-time WebSocket + logs (Prompt Studio) and file-centric logs (ETL/Workflow/API). + + Attributes: + organization_id: The tenant organization ID for multi-tenancy + prompt_studio_project_id: The Prompt Studio project UUID + workflow_execution_id: Workflow execution UUID (for ETL/Workflow/API) + file_execution_id: File execution UUID (for file-centric logging) + session_id: WebSocket session ID (for real-time Prompt Studio logs) + doc_name: Current document name being processed + publish_logs: Whether to publish logs (default True) + execution_id: Unique execution ID for grouping related Look-ups + + Example: + # Prompt Studio context (real-time WebSocket logs) + >>> ps_context = LookupExecutionContext( + ... organization_id="org-123", + ... prompt_studio_project_id=UUID("..."), + ... session_id="ws-session-abc", + ... doc_name="invoice.pdf", + ... ) + >>> ps_context.is_prompt_studio_context + True + + # Workflow context (file-centric logs) + >>> wf_context = LookupExecutionContext( + ... organization_id="org-123", + ... prompt_studio_project_id=UUID("..."), + ... workflow_execution_id=UUID("..."), + ... file_execution_id=UUID("..."), + ... doc_name="invoice.pdf", + ... ) + >>> wf_context.is_workflow_context + True + """ + + # Required fields + organization_id: str + prompt_studio_project_id: UUID + + # Optional - for file-centric logging (ETL/Workflow/API) + workflow_execution_id: UUID | None = None + file_execution_id: UUID | None = None + + # Optional - for real-time logging (Prompt Studio) + session_id: str | None = None + doc_name: str | None = None + + # Logging control + publish_logs: bool = True + + # Execution tracking + execution_id: str | None = field(default=None) + + @property + def is_workflow_context(self) -> bool: + """Check if executing within a workflow context (ETL/Workflow/API). + + Returns: + True if file_execution_id is set, indicating workflow execution. + """ + return self.file_execution_id is not None + + @property + def is_prompt_studio_context(self) -> bool: + """Check if executing within Prompt Studio IDE (real-time logs). + + Returns: + True if session_id is set and NOT in workflow context. + """ + return self.session_id is not None and not self.is_workflow_context + + @property + def should_emit_websocket_logs(self) -> bool: + """Check if WebSocket logs should be emitted. + + Returns: + True if in Prompt Studio context and logging is enabled. + """ + return self.publish_logs and self.is_prompt_studio_context + + @property + def should_persist_execution_logs(self) -> bool: + """Check if execution logs should be persisted to database. + + Returns: + True if in workflow context and logging is enabled. + """ + return self.publish_logs and self.is_workflow_context diff --git a/backend/lookup/services/indexing_service.py b/backend/lookup/services/indexing_service.py new file mode 100644 index 0000000000..6dd427f5c5 --- /dev/null +++ b/backend/lookup/services/indexing_service.py @@ -0,0 +1,471 @@ +"""Service for indexing reference data using configured profiles. + +This service implements the actual indexing workflow by calling external +extraction and indexing services via the PromptTool SDK, following the +same pattern as Prompt Studio's indexing implementation. +""" + +import json +import logging +import os +from typing import Any + +from django.conf import settings +from prompt_studio.prompt_studio_core_v2.prompt_ide_base_tool import ( + PromptIdeBaseTool, +) +from utils.file_storage.constants import FileStorageKeys +from utils.user_context import UserContext + +from lookup.models import LookupDataSource, LookupProfileManager +from unstract.sdk1.constants import LogLevel +from unstract.sdk1.exceptions import SdkError +from unstract.sdk1.file_storage.constants import StorageType +from unstract.sdk1.file_storage.env_helper import EnvHelper +from unstract.sdk1.prompt import PromptTool +from unstract.sdk1.utils.indexing import IndexingUtils +from unstract.sdk1.utils.tool import ToolUtils + +from .document_indexing_service import LookupDocumentIndexingService +from .lookup_index_helper import LookupIndexHelper + +logger = logging.getLogger(__name__) + + +class IndexingService: + """Service to orchestrate indexing of reference data. + + Uses PromptTool SDK to call external extraction and indexing services, + similar to Prompt Studio's implementation but adapted for Lookup projects. + """ + + def __init__(self, profile: LookupProfileManager): + """Initialize indexing service with profile configuration. + + Args: + profile: LookupProfileManager instance with adapter configuration + """ + self.profile = profile + self.chunk_size = profile.chunk_size + self.chunk_overlap = profile.chunk_overlap + self.similarity_top_k = profile.similarity_top_k + + # Adapters from profile + self.llm = profile.llm + self.embedding_model = profile.embedding_model + self.vector_store = profile.vector_store + self.x2text = profile.x2text + + @staticmethod + def extract_text( + data_source: LookupDataSource, + profile: LookupProfileManager, + org_id: str, + run_id: str = None, + ) -> str: + """Extract text from data source using X2Text adapter via external service. + + Args: + data_source: LookupDataSource instance to extract + profile: LookupProfileManager with X2Text adapter configuration + org_id: Organization ID + run_id: Optional run ID for tracking + + Returns: + Extracted text content + + Raises: + SdkError: If extraction service fails + """ + # Generate X2Text config hash for tracking + metadata = profile.x2text.metadata or {} + x2text_config_hash = ToolUtils.hash_str(json.dumps(metadata, sort_keys=True)) + + # Check if already extracted + is_extracted = LookupIndexHelper.check_extraction_status( + data_source_id=str(data_source.id), + profile_manager=profile, + x2text_config_hash=x2text_config_hash, + enable_highlight=False, # Lookup doesn't need highlighting + ) + + # Get file storage instance + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + logger.info(f"File storage instance type: {type(fs_instance)}") + logger.info( + f"File storage config: {fs_instance.fs if hasattr(fs_instance, 'fs') else 'N/A'}" + ) + + # Construct file paths + file_path = data_source.file_path + logger.info(f"Data source file_path from DB: {file_path}") + logger.info( + f"Storage type: {StorageType.PERMANENT}, env: {FileStorageKeys.PERMANENT_REMOTE_STORAGE}" + ) + + directory, filename = os.path.split(file_path) + extract_file_path = os.path.join( + directory, "extract", os.path.splitext(filename)[0] + ".txt" + ) + + logger.info(f"Constructed paths - directory: {directory}, filename: {filename}") + logger.info(f"Extract file path: {extract_file_path}") + + if is_extracted: + try: + extracted_text = fs_instance.read(path=extract_file_path, mode="r") + logger.info(f"Extracted text found for {filename}, reading from file") + return extracted_text + except FileNotFoundError as e: + logger.warning( + f"File not found for extraction: {extract_file_path}. {e}. " + "Continuing with extraction..." + ) + + # Call extraction service via PromptTool SDK + usage_kwargs = {"run_id": run_id, "file_name": filename} + payload = { + "x2text_instance_id": str(profile.x2text.id), + "file_path": file_path, + "enable_highlight": False, + "usage_kwargs": usage_kwargs.copy(), + "run_id": run_id, + "execution_source": "ide", + "output_file_path": extract_file_path, + } + + logger.info( + f"Extraction payload: x2text_id={payload['x2text_instance_id']}, " + f"file_path={payload['file_path']}, " + f"execution_source={payload['execution_source']}, " + f"output_file_path={payload['output_file_path']}" + ) + + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + + try: + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + request_id=None, + ) + extracted_text = responder.extract(payload=payload) + + # Mark extraction success in IndexManager + success = LookupIndexHelper.mark_extraction_status( + data_source_id=str(data_source.id), + profile_manager=profile, + x2text_config_hash=x2text_config_hash, + enable_highlight=False, + ) + + if not success: + logger.warning( + f"Failed to mark extraction success for data source {data_source.id}. " + "Extraction completed but status not saved." + ) + + # Update the data source extraction_status field for UI display + data_source.extraction_status = "completed" + data_source.save(update_fields=["extraction_status"]) + + logger.info(f"Successfully extracted text from {filename}") + return extracted_text + + except SdkError as e: + msg = str(e) + if e.actual_err and hasattr(e.actual_err, "response"): + msg = e.actual_err.response.json().get("error", str(e)) + + # Mark extraction failure in IndexManager + LookupIndexHelper.mark_extraction_status( + data_source_id=str(data_source.id), + profile_manager=profile, + x2text_config_hash=x2text_config_hash, + enable_highlight=False, + extracted=False, + error_message=msg, + ) + + # Update the data source extraction_status field for UI display + data_source.extraction_status = "failed" + data_source.extraction_error = msg + data_source.save(update_fields=["extraction_status", "extraction_error"]) + + raise Exception(f"Failed to extract '{filename}': {msg}") from e + + @staticmethod + def index_data_source( + data_source: LookupDataSource, + profile: LookupProfileManager, + org_id: str, + user_id: str, + extracted_text: str, + run_id: str = None, + reindex: bool = True, + ) -> str: + """Index extracted text using profile's adapters via external indexing service. + + Args: + data_source: LookupDataSource instance + profile: LookupProfileManager with adapter configuration + org_id: Organization ID + user_id: User ID + extracted_text: Pre-extracted text content + run_id: Optional run ID for tracking + reindex: Whether to reindex if already indexed + + Returns: + Document ID from indexing service + + Raises: + SdkError: If indexing service fails + """ + # Skip indexing if chunk_size is 0 (full context mode) + if profile.chunk_size == 0: + # Generate doc_id for tracking + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + + doc_id = IndexingUtils.generate_index_key( + vector_db=str(profile.vector_store.id), + embedding=str(profile.embedding_model.id), + x2text=str(profile.x2text.id), + chunk_size=str(profile.chunk_size), + chunk_overlap=str(profile.chunk_overlap), + file_path=data_source.file_path, + file_hash=None, + fs=fs_instance, + tool=util, + ) + + # Update index manager without actual indexing + LookupIndexHelper.handle_index_manager( + data_source_id=str(data_source.id), + profile_manager=profile, + doc_id=doc_id, + ) + + logger.info("Skipping vector DB indexing since chunk size is 0") + return doc_id + + # Get adapter IDs + embedding_model = str(profile.embedding_model.id) + vector_db = str(profile.vector_store.id) + x2text_adapter = str(profile.x2text.id) + + # Construct file paths + directory, filename = os.path.split(data_source.file_path) + file_path = os.path.join( + directory, "extract", os.path.splitext(filename)[0] + ".txt" + ) + + # Generate index key + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + + doc_id_key = IndexingUtils.generate_index_key( + vector_db=vector_db, + embedding=embedding_model, + x2text=x2text_adapter, + chunk_size=str(profile.chunk_size), + chunk_overlap=str(profile.chunk_overlap), + file_path=data_source.file_path, + file_hash=None, + fs=fs_instance, + tool=util, + ) + + try: + usage_kwargs = {"run_id": run_id, "file_name": filename} + + # Check if already indexed (unless reindexing) + if not reindex: + indexed_doc_id = LookupDocumentIndexingService.get_indexed_document_id( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + if indexed_doc_id: + logger.info(f"Document {filename} already indexed: {indexed_doc_id}") + return indexed_doc_id + + # Check if currently being indexed + if LookupDocumentIndexingService.is_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ): + raise Exception(f"Document {filename} is currently being indexed") + + # Mark as being indexed + LookupDocumentIndexingService.set_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + + logger.info(f"Invoking indexing service for: {doc_id_key}") + + # Build payload for indexing service + payload = { + "tool_id": str(data_source.project.id), # Use project ID as tool ID + "embedding_instance_id": embedding_model, + "vector_db_instance_id": vector_db, + "x2text_instance_id": x2text_adapter, + "file_path": file_path, + "file_hash": None, + "chunk_overlap": profile.chunk_overlap, + "chunk_size": profile.chunk_size, + "reindex": reindex, + "enable_highlight": False, + "usage_kwargs": usage_kwargs.copy(), + "extracted_text": extracted_text, + "run_id": run_id, + "execution_source": "ide", + } + + # Call indexing service via PromptTool SDK + try: + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + request_id=None, + ) + doc_id = responder.index(payload=payload) + + # Update index manager with doc_id + LookupIndexHelper.handle_index_manager( + data_source_id=str(data_source.id), + profile_manager=profile, + doc_id=doc_id, + ) + + # Mark as indexed in cache + LookupDocumentIndexingService.mark_document_indexed( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key, doc_id=doc_id + ) + + logger.info(f"Successfully indexed {filename} with doc_id: {doc_id}") + return doc_id + + except SdkError as e: + msg = str(e) + if e.actual_err and hasattr(e.actual_err, "response"): + msg = e.actual_err.response.json().get("error", str(e)) + raise Exception(f"Failed to index '{filename}': {msg}") from e + + except Exception as e: + logger.error(f"Error indexing {filename}: {e}", exc_info=True) + # Clear indexing status on error + LookupDocumentIndexingService.clear_indexing_status( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + raise + + @classmethod + def index_with_default_profile( + cls, project_id: str, org_id: str = None, user_id: str = None + ) -> dict[str, Any]: + """Index all completed data sources using the project's default profile. + + Args: + project_id: UUID of the lookup project + org_id: Organization ID (if None, gets from UserContext) + user_id: User ID (if None, gets from UserContext) + + Returns: + Dict with indexing results summary + + Raises: + DefaultProfileError: If no default profile exists for project + ValueError: If project not found + """ + from lookup.models import LookupProject + + # Get context if not provided + if org_id is None: + org_id = UserContext.get_organization_identifier() + if user_id is None: + user_id = UserContext.get_user_id() + + try: + project = LookupProject.objects.get(id=project_id) + except LookupProject.DoesNotExist: + raise ValueError(f"Project {project_id} not found") + + # Get default profile + default_profile = LookupProfileManager.get_default_profile(project) + + # Get all data sources (extraction will be done as part of indexing) + data_sources = LookupDataSource.objects.filter(project_id=project_id).order_by( + "-version_number" + ) + + logger.info(f"Found {data_sources.count()} data sources for project {project_id}") + + # Log each data source status + for ds in data_sources: + logger.info(f" - {ds.file_name}: extraction_status={ds.extraction_status}") + + results = { + "total": data_sources.count(), + "success": 0, + "failed": 0, + "errors": [], + } + + for data_source in data_sources: + try: + logger.info( + f"Indexing data source {data_source.id}: {data_source.file_name}" + ) + + # Extract text + extracted_text = cls.extract_text( + data_source=data_source, + profile=default_profile, + org_id=org_id, + run_id=None, + ) + + # Index the extracted text + doc_id = cls.index_data_source( + data_source=data_source, + profile=default_profile, + org_id=org_id, + user_id=user_id, + extracted_text=extracted_text, + run_id=None, + reindex=True, + ) + + results["success"] += 1 + logger.info( + f"Successfully indexed {data_source.file_name} with doc_id: {doc_id}" + ) + + except Exception as e: + results["failed"] += 1 + error_msg = str(e) + results["errors"].append( + { + "data_source_id": str(data_source.id), + "file_name": data_source.file_name, + "error": error_msg, + } + ) + logger.error(f"Failed to index {data_source.file_name}: {error_msg}") + + logger.info( + f"Indexing complete for project {project_id}: " + f"{results['success']} successful, {results['failed']} failed" + ) + + return results diff --git a/backend/lookup/services/llm_cache.py b/backend/lookup/services/llm_cache.py new file mode 100644 index 0000000000..e7e46bf0bd --- /dev/null +++ b/backend/lookup/services/llm_cache.py @@ -0,0 +1,186 @@ +"""LLM Response Cache implementation for caching LLM API responses. + +This module provides an in-memory cache with TTL (Time-To-Live) support +for storing and retrieving LLM responses based on prompt and reference data. +""" + +import hashlib +import time + + +class LLMResponseCache: + """In-memory cache for LLM responses with TTL expiration. + + This cache is used to store LLM API responses to avoid redundant + API calls for identical prompts and reference data combinations. + Uses SHA256 hashing for cache key generation and supports TTL-based + expiration for automatic cleanup. + """ + + def __init__(self, ttl_hours: int = 24): + """Initialize the LLM response cache. + + Args: + ttl_hours: Time-to-live in hours (default 24). + Cached entries expire after this duration. + """ + self.cache: dict[str, tuple[str, float]] = {} + self.ttl_seconds = ttl_hours * 3600 + + def get(self, key: str) -> str | None: + """Get cached response if not expired. + + Performs lazy cleanup by removing expired entries when accessed. + + Args: + key: Cache key (generated by generate_cache_key) + + Returns: + Cached response string if valid and not expired, None otherwise + + Example: + >>> cache = LLMResponseCache(ttl_hours=1) + >>> cache.set("key123", "response") + >>> cache.get("key123") + 'response' + """ + if key not in self.cache: + return None + + response, expiry = self.cache[key] + current_time = time.time() + + if current_time >= expiry: + # Entry has expired, remove it (lazy cleanup) + del self.cache[key] + return None + + return response + + def set(self, key: str, response: str) -> None: + """Cache response with TTL. + + Args: + key: Cache key (generated by generate_cache_key) + response: LLM response to cache + + Example: + >>> cache = LLMResponseCache() + >>> key = cache.generate_cache_key("prompt", "ref_data") + >>> cache.set(key, "LLM response") + """ + expiry = time.time() + self.ttl_seconds + self.cache[key] = (response, expiry) + + def invalidate(self, key: str) -> bool: + """Remove specific key from cache. + + Args: + key: Cache key to invalidate + + Returns: + True if key was removed, False if key didn't exist + + Example: + >>> cache = LLMResponseCache() + >>> cache.set("key123", "response") + >>> cache.invalidate("key123") + True + """ + if key in self.cache: + del self.cache[key] + return True + return False + + def invalidate_all(self) -> int: + """Clear entire cache. + + Returns: + Count of invalidated entries + + Example: + >>> cache = LLMResponseCache() + >>> cache.set("key1", "response1") + >>> cache.set("key2", "response2") + >>> cache.invalidate_all() + 2 + """ + count = len(self.cache) + self.cache.clear() + return count + + def generate_cache_key(self, prompt: str, reference_data: str) -> str: + r"""Generate SHA256 hash from prompt + reference data. + + Creates a deterministic cache key based on the prompt and + reference data combination. Same inputs always produce the + same key. + + Args: + prompt: The resolved prompt text + reference_data: The reference data text + + Returns: + 64-character hexadecimal SHA256 hash + + Example: + >>> cache = LLMResponseCache() + >>> key = cache.generate_cache_key("Match vendor", "Slack\nGoogle") + >>> len(key) + 64 + """ + combined = f"{prompt}{reference_data}" + hash_obj = hashlib.sha256(combined.encode("utf-8")) + return hash_obj.hexdigest() + + def get_stats(self) -> dict[str, int]: + """Return cache statistics. + + Analyzes the current cache state and returns counts of + total, expired, and valid entries. + + Returns: + Dictionary with keys: 'total', 'expired', 'valid' + + Example: + >>> cache = LLMResponseCache() + >>> cache.set("key1", "response1") + >>> stats = cache.get_stats() + >>> stats["total"] + 1 + """ + current_time = time.time() + total = len(self.cache) + expired = sum( + 1 for _, (_, expiry) in self.cache.items() if current_time >= expiry + ) + valid = total - expired + + return {"total": total, "expired": expired, "valid": valid} + + def cleanup_expired(self) -> int: + """Remove expired entries. + + Performs a full scan of the cache and removes all entries + that have exceeded their TTL. Can be called periodically + for cache maintenance. + + Returns: + Count of removed entries + + Example: + >>> cache = LLMResponseCache(ttl_hours=0) # Immediate expiry + >>> cache.set("key1", "response1") + >>> time.sleep(0.1) + >>> cache.cleanup_expired() + 1 + """ + current_time = time.time() + keys_to_remove = [ + key for key, (_, expiry) in self.cache.items() if current_time >= expiry + ] + + for key in keys_to_remove: + del self.cache[key] + + return len(keys_to_remove) diff --git a/backend/lookup/services/log_emitter.py b/backend/lookup/services/log_emitter.py new file mode 100644 index 0000000000..07bed5e703 --- /dev/null +++ b/backend/lookup/services/log_emitter.py @@ -0,0 +1,427 @@ +"""Look-up Log Emitter for WebSocket and file-centric logging. + +This module provides functionality to emit Look-up execution logs +via WebSocket for real-time display in Prompt Studio and to persist +logs for file-centric logging in ETL/Workflow/API executions. +""" + +import logging +import time +from typing import Any +from uuid import UUID + +from unstract.core.pubsub_helper import LogPublisher +from unstract.workflow_execution.enums import LogLevel, LogStage + +logger = logging.getLogger(__name__) + + +class LookupLogEmitter: + """Emits Lookup enrichment logs via WebSocket and persists to execution logs. + + This class provides methods to emit logs at different stages of Look-up + execution, with support for both real-time WebSocket logs (Prompt Studio) + and file-centric execution logs (ETL/Workflow/API). + + The logs use a purple color scheme (#722ed1) in the frontend for + visual distinction from other log stages. + + Example: + >>> emitter = LookupLogEmitter( + ... session_id="ws-session-123", + ... execution_id="exec-456", + ... organization_id="org-789", + ... ) + >>> emitter.emit_enrichment_start("Vendor Lookup", ["vendor_name"]) + >>> emitter.emit_enrichment_success("Vendor Lookup", 3, False, 150) + """ + + LOG_STAGE = LogStage.LOOKUP.value + LOG_TYPE = "LOOKUP_ENRICHMENT" + + def __init__( + self, + session_id: str | None = None, + execution_id: str | None = None, + file_execution_id: str | UUID | None = None, + organization_id: str | None = None, + doc_name: str | None = None, + ): + """Initialize the log emitter. + + Args: + session_id: WebSocket session ID for real-time logs + execution_id: Workflow execution ID for file-centric logs + file_execution_id: File execution ID for file-centric logs + organization_id: Organization ID for multi-tenancy + doc_name: Current document name being processed + """ + self.session_id = session_id + self.execution_id = str(execution_id) if execution_id else None + self.file_execution_id = str(file_execution_id) if file_execution_id else None + self.organization_id = organization_id + self.doc_name = doc_name + + def _build_component( + self, + lookup_project_name: str, + **extra: Any, + ) -> dict[str, Any]: + """Build the component metadata for the log entry. + + Args: + lookup_project_name: Name of the Look-up project + **extra: Additional metadata to include + + Returns: + Dictionary with component metadata + """ + component = { + "type": self.LOG_TYPE, + "lookup_project": lookup_project_name, + } + if self.doc_name: + component["doc_name"] = self.doc_name + component.update(extra) + return component + + def emit_log( + self, + level: str, + message: str, + lookup_project_name: str = "", + state: str = "INFO", + **extra: Any, + ) -> None: + """Emit a lookup log event via WebSocket and/or ExecutionLog. + + For Prompt Studio (session_id set): Emits to WebSocket for real-time display + For Workflow/API (file_execution_id set): Persists to ExecutionLog for Nav bar + + Args: + level: Log level (INFO, ERROR, WARN, DEBUG) + message: Log message + lookup_project_name: Name of the Look-up project + state: Log state (STARTED, COMPLETED, FAILED, SKIPPED) + **extra: Additional metadata for the component + """ + log_details = LogPublisher.log_workflow( + stage=self.LOG_STAGE, + message=message, + level=level, + execution_id=self.execution_id, + file_execution_id=self.file_execution_id, + organization_id=self.organization_id, + ) + + # Add component metadata for frontend rendering + log_details["component"] = self._build_component( + lookup_project_name=lookup_project_name, + state=state, + **extra, + ) + + # Emit to WebSocket if session_id is available (Prompt Studio) + if self.session_id: + LogPublisher.publish(self.session_id, log_details) + logger.debug(f"Emitted lookup log to WebSocket: {message}") + + # Persist to ExecutionLog if in workflow context (Nav bar logs) + if self.file_execution_id and self.execution_id: + self._persist_to_execution_log(log_details, level, message) + + def _persist_to_execution_log( + self, + log_details: dict[str, Any], + level: str, + message: str, + ) -> None: + """Persist log entry to ExecutionLog via Redis queue for Nav bar display. + + Uses the same Redis queue mechanism as other workflow logs to ensure + proper ordering of logs when displayed in the Nav bar. + + Args: + log_details: The log details dictionary + level: Log level + message: Log message + """ + try: + import redis + from django.conf import settings + + from unstract.core.log_utils import store_execution_log + + # Build log data matching the expected format for the queue + log_data = { + "timestamp": time.time(), # Unix timestamp for queue processing + "type": "LOG", + "level": level, + "stage": self.LOG_STAGE, + "log": message, + "execution_id": self.execution_id, + "file_execution_id": self.file_execution_id, + "organization_id": self.organization_id, + **log_details, + } + + # Use the same Redis queue as other workflow logs + redis_client = redis.Redis( + host=settings.REDIS_HOST, + port=int(settings.REDIS_PORT), + username=settings.REDIS_USER, + password=settings.REDIS_PASSWORD, + ) + + from utils.constants import ExecutionLogConstants + + store_execution_log( + data=log_data, + redis_client=redis_client, + log_queue_name=ExecutionLogConstants.LOG_QUEUE_NAME, + is_enabled=ExecutionLogConstants.IS_ENABLED, + ) + logger.debug(f"Queued lookup log for ExecutionLog: {message}") + + except Exception as e: + logger.warning(f"Failed to queue lookup log for ExecutionLog: {e}") + + def emit_enrichment_start( + self, + lookup_project_name: str, + input_fields: list[str] | None = None, + ) -> None: + """Emit log when Look-up enrichment starts. + + Args: + lookup_project_name: Name of the Look-up project + input_fields: List of input field names being used + """ + input_fields = input_fields or [] + fields_str = ", ".join(input_fields[:3]) + if len(input_fields) > 3: + fields_str += f" (+{len(input_fields) - 3} more)" + + message = f"Starting enrichment with Look-Up '{lookup_project_name}'" + if fields_str: + message += f" for: {fields_str}" + + self.emit_log( + level=LogLevel.INFO.value, + state="STARTED", + message=message, + lookup_project_name=lookup_project_name, + input_fields=input_fields, + ) + + def emit_enrichment_success( + self, + lookup_project_name: str, + enriched_count: int, + cached: bool, + execution_time_ms: int, + confidence: float | None = None, + context_type: str = "full", + ) -> None: + """Emit log when Look-up enrichment succeeds. + + Args: + lookup_project_name: Name of the Look-up project + enriched_count: Number of fields enriched + cached: Whether the response was from cache + execution_time_ms: Execution time in milliseconds + confidence: Optional confidence score (0.0-1.0) + context_type: Type of context used - "rag" or "full" + """ + cache_msg = " (cached)" if cached else "" + # Display context type clearly: RAG-based or Full context + context_display = "RAG" if context_type == "rag" else "Full context" + message = ( + f"Look-Up '{lookup_project_name}' [{context_display}] enriched " + f"{enriched_count} field(s){cache_msg} in {execution_time_ms}ms" + ) + + if confidence is not None: + message += f" (confidence: {confidence:.0%})" + + self.emit_log( + level=LogLevel.INFO.value, + state="COMPLETED", + message=message, + lookup_project_name=lookup_project_name, + cached=cached, + execution_time_ms=execution_time_ms, + enriched_count=enriched_count, + confidence=confidence, + context_type=context_type, + ) + + def emit_enrichment_failure( + self, + lookup_project_name: str, + error_message: str, + ) -> None: + """Emit log when Look-up enrichment fails. + + Args: + lookup_project_name: Name of the Look-up project + error_message: Error message describing the failure + """ + message = f"Look-Up '{lookup_project_name}' failed: {error_message}" + + self.emit_log( + level=LogLevel.ERROR.value, + state="FAILED", + message=message, + lookup_project_name=lookup_project_name, + error_message=error_message, + ) + + def emit_context_overflow_error( + self, + lookup_project_name: str, + token_count: int, + context_limit: int, + model: str, + ) -> None: + """Emit log when context window is exceeded. + + This provides a clear, actionable error message when the prompt + (reference data + template + extracted data) exceeds the LLM's + context window limit. + + Args: + lookup_project_name: Name of the Look-up project + token_count: Number of tokens in the prompt + context_limit: Maximum tokens allowed by the model + model: Name of the LLM model + """ + message = ( + f"Look-Up '{lookup_project_name}' failed: Context window exceeded - " + f"prompt requires {token_count:,} tokens but {model} limit is " + f"{context_limit:,} tokens. Reduce reference data size or use a " + f"model with larger context window." + ) + + self.emit_log( + level=LogLevel.ERROR.value, + state="FAILED", + message=message, + lookup_project_name=lookup_project_name, + error_type="context_window_exceeded", + token_count=token_count, + context_limit=context_limit, + model=model, + suggestion="Reduce reference data or use larger context model", + ) + + def emit_enrichment_partial( + self, + lookup_project_name: str, + enriched_count: int, + execution_time_ms: int, + warning_message: str, + confidence: float | None = None, + context_type: str = "full", + ) -> None: + """Emit log when Look-up enrichment partially succeeds. + + Args: + lookup_project_name: Name of the Look-up project + enriched_count: Number of fields enriched + execution_time_ms: Execution time in milliseconds + warning_message: Warning message about partial success + confidence: Optional confidence score (0.0-1.0) + context_type: Type of context used - "rag" or "full" + """ + # Display context type clearly: RAG-based or Full context + context_display = "RAG" if context_type == "rag" else "Full context" + message = ( + f"Look-Up '{lookup_project_name}' [{context_display}] partial success: " + f"enriched {enriched_count} field(s) in {execution_time_ms}ms - " + f"{warning_message}" + ) + + self.emit_log( + level=LogLevel.WARN.value, + state="PARTIAL", + message=message, + lookup_project_name=lookup_project_name, + enriched_count=enriched_count, + execution_time_ms=execution_time_ms, + warning_message=warning_message, + confidence=confidence, + context_type=context_type, + ) + + def emit_no_linked_lookups(self) -> None: + """Emit debug log when no Look-Ups are linked.""" + self.emit_log( + level=LogLevel.INFO.value, + state="SKIPPED", + message="No linked Look-Up projects found", + lookup_project_name="", + ) + + def emit_orchestration_start( + self, + lookup_count: int, + lookup_names: list[str], + ) -> None: + """Emit log when Look-up orchestration starts. + + Args: + lookup_count: Number of Look-ups to execute + lookup_names: Names of the Look-up projects + """ + names_str = ", ".join(lookup_names[:3]) + if len(lookup_names) > 3: + names_str += f" (+{len(lookup_names) - 3} more)" + + message = f"Starting Look-Up enrichment: {lookup_count} project(s) [{names_str}]" + + self.emit_log( + level=LogLevel.INFO.value, + state="STARTED", + message=message, + lookup_project_name="", + lookup_count=lookup_count, + lookup_names=lookup_names, + ) + + def emit_orchestration_complete( + self, + total_lookups: int, + successful: int, + failed: int, + total_time_ms: int, + total_enriched_fields: int, + ) -> None: + """Emit log when Look-up orchestration completes. + + Args: + total_lookups: Total number of Look-ups executed + successful: Number of successful Look-ups + failed: Number of failed Look-ups + total_time_ms: Total execution time in milliseconds + total_enriched_fields: Total number of fields enriched + """ + status = "completed" if failed == 0 else "completed with errors" + message = ( + f"Look-Up enrichment {status}: {successful}/{total_lookups} succeeded, " + f"{total_enriched_fields} field(s) enriched in {total_time_ms}ms" + ) + + level = LogLevel.INFO.value if failed == 0 else LogLevel.WARN.value + + self.emit_log( + level=level, + state="COMPLETED" if failed == 0 else "PARTIAL", + message=message, + lookup_project_name="", + total_lookups=total_lookups, + successful=successful, + failed=failed, + total_time_ms=total_time_ms, + total_enriched_fields=total_enriched_fields, + ) diff --git a/backend/lookup/services/lookup_executor.py b/backend/lookup/services/lookup_executor.py new file mode 100644 index 0000000000..0fb6349f22 --- /dev/null +++ b/backend/lookup/services/lookup_executor.py @@ -0,0 +1,716 @@ +"""Look-Up Executor implementation for single Look-Up execution. + +This module provides functionality to execute a single Look-Up project +against input data, including variable resolution, LLM calling, and +response caching. +""" + +import json +import logging +import time +import uuid +from typing import Any, Protocol + +from lookup.exceptions import ( + ContextWindowExceededError, + ExtractionNotCompleteError, + ParseError, + TemplateNotFoundError, +) +from lookup.models import LookupProfileManager, LookupProject +from lookup.services.audit_logger import AuditLogger +from lookup.services.lookup_retrieval_service import LookupRetrievalService + +logger = logging.getLogger(__name__) + + +class LLMClient(Protocol): + """Protocol for LLM client abstraction.""" + + def generate(self, prompt: str, config: dict[str, Any]) -> str: + """Generate LLM response for the prompt.""" + ... + + +class LookUpExecutor: + """Executes a single Look-Up project against input data. + + This class handles the complete execution flow of a Look-Up: + loading reference data, resolving variables in the prompt template, + calling the LLM, caching responses, and parsing the results. + """ + + def __init__( + self, + variable_resolver, # Class, not instance + cache_manager, + reference_loader, + llm_client: LLMClient, + org_id: str | None = None, + ): + """Initialize the Look-Up executor. + + Args: + variable_resolver: VariableResolver class (not instance) + cache_manager: LLMResponseCache instance + reference_loader: ReferenceDataLoader instance + llm_client: LLM provider client implementing LLMClient protocol + org_id: Organization ID for multi-tenancy (used in RAG retrieval) + """ + self.variable_resolver_class = variable_resolver + self.cache = cache_manager + self.ref_loader = reference_loader + self.llm_client = llm_client + self.org_id = org_id + logger.info(f"LookUpExecutor initialized with org_id='{org_id}'") + + def execute( + self, + lookup_project: LookupProject, + input_data: dict[str, Any], + execution_id: str | None = None, + prompt_studio_project_id: str | None = None, + ) -> dict[str, Any]: + """Execute single Look-Up. + + Performs the complete Look-Up execution including variable resolution, + LLM calling with caching, and response parsing. + + Args: + lookup_project: The Look-Up project to execute + input_data: Input data containing variables to resolve + execution_id: Optional UUID to group related executions for audit + prompt_studio_project_id: Optional PS project ID for audit tracking + + Returns: + Dictionary containing: + - status: 'success' or 'failed' + - project_id: UUID of the project + - project_name: Name of the project + - data: Enrichment data (if success) + - confidence: Confidence score 0.0-1.0 (if available) + - cached: Whether response was from cache + - execution_time_ms: Time taken in milliseconds + - error: Error message (if failed) + + Example: + >>> executor = LookUpExecutor(...) + >>> result = executor.execute(project, {"vendor": "Slack"}) + >>> if result["status"] == "success": + ... print(result["data"]) + {'canonical_vendor': 'Slack Technologies', 'confidence': 0.92} + """ + start_time = time.time() + + # Initialize audit tracking variables + exec_id = execution_id or str(uuid.uuid4()) + resolved_prompt = None + llm_response = None + llm_time_ms = None + reference_data_version = 1 + llm_provider = "unknown" + llm_model = "unknown" + _cached = False # noqa F841 + context_type = "full" # Track whether using RAG or full context + + try: + # Step 1: Load reference data (RAG or full context based on chunk_size) + try: + # Get the default profile to check chunk_size + profile = LookupProfileManager.get_default_profile(lookup_project) + + # Log profile configuration for debugging + logger.info( + f"Profile config for {lookup_project.name}: " + f"chunk_size={profile.chunk_size}, " + f"chunk_overlap={profile.chunk_overlap}, " + f"similarity_top_k={profile.similarity_top_k}, " + f"vector_store={profile.vector_store.id}, " + f"embedding_model={profile.embedding_model.id}" + ) + + if profile.chunk_size > 0: + # RAG Mode: Retrieve relevant chunks from vector DB + context_type = "rag" + logger.info( + f"Using RAG mode (chunk_size={profile.chunk_size}) " + f"for project {lookup_project.name}" + ) + reference_data = self._retrieve_rag_context( + lookup_project=lookup_project, + profile=profile, + input_data=input_data, + ) + reference_data_version = 1 # RAG mode doesn't have versions + logger.info( + f"RAG retrieval returned {len(reference_data)} chars " + f"for project {lookup_project.name}" + ) + + # Warn if RAG returned empty - likely indexing issue + if not reference_data.strip(): + logger.warning( + f"RAG returned EMPTY context for project {lookup_project.name}. " + "Check: 1) Data sources are uploaded, 2) Extraction completed, " + "3) Indexing completed with chunk_size > 0" + ) + else: + # Full Context Mode: Load all reference data (existing behavior) + logger.info( + f"Using full context mode (chunk_size=0) " + f"for project {lookup_project.name}" + ) + reference_data_dict = self.ref_loader.load_latest_for_project( + lookup_project.id + ) + reference_data = reference_data_dict["content"] + reference_data_version = reference_data_dict.get("version", 1) + + except ExtractionNotCompleteError as e: + result = self._failed_result( + lookup_project, + f"Reference data extraction not complete: {str(e)}", + start_time, + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + llm_response, + None, + "failed", + None, + result["execution_time_ms"], + None, + False, + result["error"], + ) + return result + except Exception as e: + result = self._failed_result( + lookup_project, f"Failed to load reference data: {str(e)}", start_time + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + llm_response, + None, + "failed", + None, + result["execution_time_ms"], + None, + False, + result["error"], + ) + return result + + # Step 2: Load prompt template + try: + template = lookup_project.template + if not template: + raise TemplateNotFoundError("No template configured") + template_text = template.template_text + + # Extract LLM info from template config if available + if template.llm_config: + llm_provider = template.llm_config.get("provider", "unknown") + llm_model = template.llm_config.get("model", "unknown") + except (AttributeError, TemplateNotFoundError) as e: + result = self._failed_result( + lookup_project, f"Missing prompt template: {str(e)}", start_time + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + llm_response, + None, + "failed", + None, + result["execution_time_ms"], + None, + False, + result["error"], + ) + return result + + # Step 3: Resolve variables + logger.info(f"Input data received: {input_data}") + logger.info(f"Reference data length: {len(reference_data)} chars") + logger.info(f"Template text: {template_text[:200]}...") + resolver = self.variable_resolver_class(input_data, reference_data) + resolved_prompt = resolver.resolve(template_text) + logger.info(f"Resolved prompt: {resolved_prompt[:500]}...") + + # Step 4: Check cache (if caching is enabled) + cache_key = None + cached_response = None + if self.cache: + cache_key = self.cache.generate_cache_key(resolved_prompt, reference_data) + cached_response = self.cache.get(cache_key) + + if cached_response: + # Cache hit - parse and return + enrichment_data, confidence = self._parse_llm_response( + cached_response + ) + result = self._success_result( + lookup_project, + enrichment_data, + confidence, + cached=True, + execution_time_ms=0, # No execution time for cached response + context_type=context_type, + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + cached_response, + enrichment_data, + "success", + confidence, + 0, + None, + True, + None, + ) + return result + + # Step 5: Call LLM (cache miss or caching disabled) + try: + llm_start = time.time() + llm_response = self.llm_client.generate( + resolved_prompt, lookup_project.llm_config or {} + ) + llm_time_ms = int((time.time() - llm_start) * 1000) + + # Store in cache (if caching is enabled) + if self.cache and cache_key: + self.cache.set(cache_key, llm_response) + + except ContextWindowExceededError as e: + # Context window exceeded - provide clear actionable error + error_msg = ( + f"Context window exceeded: prompt requires {e.token_count:,} " + f"tokens but {e.model} limit is {e.context_limit:,} tokens. " + f"Reduce reference data size or use a model with larger " + f"context window." + ) + result = self._failed_result(lookup_project, error_msg, start_time) + # Add context window error details for specialized logging + result["error_type"] = "context_window_exceeded" + result["token_count"] = e.token_count + result["context_limit"] = e.context_limit + result["model"] = e.model + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + None, + None, + "failed", + None, + result["execution_time_ms"], + None, + False, + error_msg, + ) + return result + except TimeoutError as e: + result = self._failed_result( + lookup_project, f"LLM request timed out: {str(e)}", start_time + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + None, + None, + "failed", + None, + result["execution_time_ms"], + llm_time_ms, + False, + result["error"], + ) + return result + except Exception as e: + result = self._failed_result( + lookup_project, f"LLM request failed: {str(e)}", start_time + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + None, + None, + "failed", + None, + result["execution_time_ms"], + llm_time_ms, + False, + result["error"], + ) + return result + + # Step 6: Parse response + try: + enrichment_data, confidence = self._parse_llm_response(llm_response) + except ParseError as e: + result = self._failed_result( + lookup_project, f"Failed to parse LLM response: {str(e)}", start_time + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + llm_response, + None, + "failed", + None, + result["execution_time_ms"], + llm_time_ms, + False, + result["error"], + ) + return result + + # Step 7: Return result and log success + result = self._success_result( + lookup_project, + enrichment_data, + confidence, + cached=False, + execution_time_ms=llm_time_ms, + context_type=context_type, + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + llm_response, + enrichment_data, + "success", + confidence, + llm_time_ms, + llm_time_ms, + False, + None, + ) + return result + + except Exception as e: + # Catch-all for unexpected errors + logger.exception(f"Unexpected error executing Look-Up {lookup_project.id}") + result = self._failed_result( + lookup_project, f"Unexpected error: {str(e)}", start_time + ) + self._log_audit( + exec_id, + lookup_project, + prompt_studio_project_id, + input_data, + reference_data_version, + llm_provider, + llm_model, + resolved_prompt, + llm_response, + None, + "failed", + None, + result["execution_time_ms"], + llm_time_ms, + False, + result["error"], + ) + return result + + def _retrieve_rag_context( + self, + lookup_project: LookupProject, + profile: LookupProfileManager, + input_data: dict[str, Any], + ) -> str: + """Retrieve context using RAG when chunk_size > 0. + + Builds a semantic query from input_data and retrieves relevant + chunks from the vector DB using similarity search. + + Args: + lookup_project: The lookup project being executed + profile: Profile with RAG configuration (vector store, embeddings, etc.) + input_data: Input data to build the retrieval query from + + Returns: + Retrieved context as a string (concatenated chunks) + """ + query = self._build_retrieval_query(input_data) + + retrieval_service = LookupRetrievalService(profile, org_id=self.org_id) + context = retrieval_service.retrieve_context( + query=query, + project_id=str(lookup_project.id), + ) + + if not context: + logger.warning( + f"RAG retrieval returned empty for project {lookup_project.id}. " + "Ensure data sources are indexed with chunk_size > 0." + ) + + return context + + def _build_retrieval_query(self, input_data: dict[str, Any]) -> str: + """Build retrieval query from input data fields with semantic context. + + Includes field names with values to preserve semantic meaning for + vector similarity search. This allows the embedding model to understand + the context of each value (e.g., "vendor: Slack" vs just "Slack"). + + Args: + input_data: Dictionary of extracted input fields + + Returns: + Query string for vector retrieval with field context + + Example: + >>> executor._build_retrieval_query({"vendor": "Slack", "type": "SaaS"}) + 'vendor: Slack, type: SaaS' + """ + query_parts = [] + + for key, value in input_data.items(): + if value is not None: + # Include field name to preserve semantic context + if isinstance(value, dict): + # For nested dicts, include key-value pairs + nested_parts = [f"{k}: {v}" for k, v in value.items() if v] + if nested_parts: + query_parts.append(f"{key}: {', '.join(nested_parts)}") + elif isinstance(value, list): + # For lists, join values with the field name + list_values = [str(v) for v in value if v] + if list_values: + query_parts.append(f"{key}: {', '.join(list_values)}") + else: + # Simple key-value pair + query_parts.append(f"{key}: {value}") + + query = ", ".join(query_parts) + + if not query.strip(): + query = "find relevant reference data" + + logger.info(f"Built retrieval query: {query[:200]}...") + return query + + def _log_audit( + self, + execution_id: str, + lookup_project: LookupProject, + prompt_studio_project_id: str | None, + input_data: dict[str, Any], + reference_data_version: int, + llm_provider: str, + llm_model: str, + llm_prompt: str | None, + llm_response: str | None, + enriched_output: dict[str, Any] | None, + status: str, + confidence_score: float | None, + execution_time_ms: int | None, + llm_call_time_ms: int | None, + llm_response_cached: bool, + error_message: str | None, + ) -> None: + """Log execution to audit table. + + This method wraps AuditLogger to capture all execution details + for the Execution History tab. + """ + try: + audit_logger = AuditLogger() + audit_logger.log_execution( + execution_id=execution_id, + lookup_project_id=lookup_project.id, + prompt_studio_project_id=( + uuid.UUID(prompt_studio_project_id) + if prompt_studio_project_id + else None + ), + input_data=input_data, + reference_data_version=reference_data_version, + llm_provider=llm_provider, + llm_model=llm_model, + llm_prompt=llm_prompt or "", + llm_response=llm_response, + enriched_output=enriched_output, + status=status, + confidence_score=confidence_score, + execution_time_ms=execution_time_ms, + llm_call_time_ms=llm_call_time_ms, + llm_response_cached=llm_response_cached, + error_message=error_message, + ) + logger.debug( + f"Audit logged for Look-Up {lookup_project.name} " + f"(execution_id={execution_id}, status={status})" + ) + except Exception as e: + # Don't fail the execution if audit logging fails + logger.warning(f"Failed to log audit for Look-Up execution: {e}") + + def _parse_llm_response(self, response_text: str) -> tuple[dict, float | None]: + """Parse LLM response to extract enrichment data. + + Attempts to parse the LLM response as JSON and extract + enrichment fields and optional confidence score. + + Args: + response_text: Raw text response from the LLM + + Returns: + Tuple of (enrichment_data, confidence) + - enrichment_data: Dictionary of extracted fields + - confidence: Optional confidence score (0.0-1.0) + + Raises: + ParseError: If response cannot be parsed as valid JSON + + Example: + >>> response = '{"vendor": "Slack", "confidence": 0.92}' + >>> data, conf = executor._parse_llm_response(response) + >>> print(data) + {'vendor': 'Slack'} + >>> print(conf) + 0.92 + """ + try: + # Try direct JSON parse + parsed = json.loads(response_text) + + if not isinstance(parsed, dict): + raise ParseError(f"Expected JSON object, got {type(parsed).__name__}") + + # Extract confidence if present + confidence = None + if "confidence" in parsed: + confidence = parsed.pop("confidence") + + # Validate confidence is a number between 0 and 1 + if isinstance(confidence, (int, float)): + confidence = float(confidence) + if not 0.0 <= confidence <= 1.0: + logger.warning( + f"Confidence {confidence} outside valid range [0.0, 1.0]" + ) + confidence = max(0.0, min(1.0, confidence)) # Clamp to range + else: + logger.warning(f"Invalid confidence type: {type(confidence)}") + confidence = None + + # Remaining fields are the enrichment data + enrichment_data = parsed + + return enrichment_data, confidence + + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse LLM response as JSON: {e}") + raise ParseError(f"Invalid JSON response: {str(e)}") + except Exception as e: + logger.warning(f"Unexpected error parsing LLM response: {e}") + raise ParseError(f"Parse error: {str(e)}") + + def _success_result( + self, + project: LookupProject, + data: dict[str, Any], + confidence: float | None, + cached: bool, + execution_time_ms: int, + context_type: str = "full", + ) -> dict[str, Any]: + """Build success result dictionary. + + Args: + project: The LookupProject that was executed + data: Enrichment data from the LLM response + confidence: Confidence score (0.0-1.0) if available + cached: Whether the response was from cache + execution_time_ms: Execution time in milliseconds + context_type: Type of context used - "rag" or "full" + + Returns: + Dictionary with execution results including context_type + """ + return { + "status": "success", + "project_id": str(project.id), # Convert UUID to string for JSON + "project_name": project.name, + "data": data, + "confidence": confidence, + "cached": cached, + "execution_time_ms": execution_time_ms, + "context_type": context_type, + } + + def _failed_result( + self, project: LookupProject, error: str, start_time: float + ) -> dict[str, Any]: + """Build failed result dictionary.""" + execution_time_ms = int((time.time() - start_time) * 1000) + return { + "status": "failed", + "project_id": str(project.id), # Convert UUID to string for JSON + "project_name": project.name, + "error": error, + "execution_time_ms": execution_time_ms, + "cached": False, + } diff --git a/backend/lookup/services/lookup_index_helper.py b/backend/lookup/services/lookup_index_helper.py new file mode 100644 index 0000000000..81e5f2fa8e --- /dev/null +++ b/backend/lookup/services/lookup_index_helper.py @@ -0,0 +1,211 @@ +"""Helper service for Lookup IndexManager operations. + +Based on Prompt Studio's PromptStudioIndexHelper pattern. +""" + +import logging + +from django.db import transaction + +from lookup.models import LookupIndexManager +from lookup.services.vector_db_cleanup_service import VectorDBCleanupService + +logger = logging.getLogger(__name__) + + +class LookupIndexHelper: + """Helper class for LookupIndexManager CRUD operations.""" + + @staticmethod + @transaction.atomic + def handle_index_manager( + data_source_id: str, + profile_manager, + doc_id: str, + cleanup_old_indexes: bool = True, + ) -> LookupIndexManager: + """Create or update LookupIndexManager with doc_id. + + When re-indexing (updating existing index manager), this method + will clean up old vector DB nodes before adding the new doc_id + to prevent stale data accumulation. + + Args: + data_source_id: UUID of the LookupDataSource + profile_manager: LookupProfileManager instance + doc_id: Document ID returned from indexing service + cleanup_old_indexes: If True, deletes old indexes from vector DB + before adding new doc_id (default: True) + + Returns: + LookupIndexManager instance + """ + from lookup.models import LookupDataSource + + try: + data_source = LookupDataSource.objects.get(pk=data_source_id) + + # Get or create index manager for this data source + profile combination + index_manager, created = LookupIndexManager.objects.get_or_create( + data_source=data_source, + profile_manager=profile_manager, + defaults={ + "raw_index_id": doc_id, + "index_ids_history": [doc_id], + "status": {"indexed": True, "error": None}, + }, + ) + + if not created: + # Re-indexing: Clean up old indexes before adding new one + if cleanup_old_indexes and index_manager.index_ids_history: + cleanup_service = VectorDBCleanupService() + cleanup_result = cleanup_service.cleanup_before_reindex(index_manager) + if cleanup_result["deleted"] > 0: + logger.info( + f"Cleaned up {cleanup_result['deleted']} old index(es) " + f"for data source {data_source.file_name}" + ) + if cleanup_result["errors"]: + logger.warning( + f"Some cleanup errors occurred: {cleanup_result['errors']}" + ) + + # Update with new doc_id + index_manager.raw_index_id = doc_id + # Start fresh history with only the new doc_id + index_manager.index_ids_history = [doc_id] + # Update status and clear reindex_required flag + index_manager.status = {"indexed": True, "error": None} + if hasattr(index_manager, "reindex_required"): + index_manager.reindex_required = False + + index_manager.save() + logger.debug(f"Updated index manager for data source {data_source_id}") + else: + logger.debug(f"Created index manager for data source {data_source_id}") + + return index_manager + + except LookupDataSource.DoesNotExist: + logger.error(f"Data source {data_source_id} not found") + raise + except Exception as e: + logger.error(f"Error handling index manager: {e}", exc_info=True) + raise + + @staticmethod + def check_extraction_status( + data_source_id: str, + profile_manager, + x2text_config_hash: str, + enable_highlight: bool = False, + ) -> bool: + """Check if extraction already completed for this configuration. + + Args: + data_source_id: UUID of the LookupDataSource + profile_manager: LookupProfileManager instance + x2text_config_hash: Hash of X2Text adapter configuration + enable_highlight: Whether highlighting is enabled + + Returns: + True if extraction completed with same settings, False otherwise + """ + try: + index_manager = LookupIndexManager.objects.get( + data_source_id=data_source_id, profile_manager=profile_manager + ) + + extraction_status = index_manager.extraction_status.get( + x2text_config_hash, {} + ) + + if not extraction_status.get("extracted", False): + return False + + # Check if highlight setting matches + stored_highlight = extraction_status.get("enable_highlight", False) + if stored_highlight != enable_highlight: + logger.debug( + f"Highlight setting mismatch: stored={stored_highlight}, " + f"requested={enable_highlight}" + ) + return False + + logger.debug(f"Extraction already completed for {x2text_config_hash}") + return True + + except LookupIndexManager.DoesNotExist: + logger.debug(f"No index manager found for data source {data_source_id}") + return False + except Exception as e: + logger.error(f"Error checking extraction status: {e}", exc_info=True) + return False + + @staticmethod + @transaction.atomic + def mark_extraction_status( + data_source_id: str, + profile_manager, + x2text_config_hash: str, + enable_highlight: bool = False, + extracted: bool = True, + error_message: str = None, + ) -> bool: + """Mark extraction success or failure in IndexManager. + + Args: + data_source_id: UUID of the LookupDataSource + profile_manager: LookupProfileManager instance + x2text_config_hash: Hash of X2Text adapter configuration + enable_highlight: Whether highlighting is enabled + extracted: Whether extraction succeeded + error_message: Error message if extraction failed + + Returns: + True if status marked successfully, False otherwise + """ + from lookup.models import LookupDataSource + + try: + data_source = LookupDataSource.objects.get(pk=data_source_id) + + # Get or create index manager + index_manager, created = LookupIndexManager.objects.get_or_create( + data_source=data_source, + profile_manager=profile_manager, + defaults={"extraction_status": {}, "status": {}}, + ) + + # Update extraction status for this configuration + index_manager.extraction_status[x2text_config_hash] = { + "extracted": extracted, + "enable_highlight": enable_highlight, + "error": error_message, + } + + # Also update overall status + if extracted: + index_manager.status["extracted"] = True + index_manager.status["error"] = None + else: + index_manager.status["extracted"] = False + index_manager.status["error"] = error_message + + index_manager.save() + + status_text = "success" if extracted else "failure" + logger.debug( + f"Marked extraction {status_text} for data source {data_source_id}, " + f"config {x2text_config_hash}" + ) + + return True + + except LookupDataSource.DoesNotExist: + logger.error(f"Data source {data_source_id} not found") + return False + except Exception as e: + logger.error(f"Error marking extraction status: {e}", exc_info=True) + return False diff --git a/backend/lookup/services/lookup_integration_service.py b/backend/lookup/services/lookup_integration_service.py new file mode 100644 index 0000000000..bfb3219c1f --- /dev/null +++ b/backend/lookup/services/lookup_integration_service.py @@ -0,0 +1,455 @@ +"""Service for automatic Lookup integration with Prompt Studio. + +This module provides seamless enrichment of extraction results when Lookup +projects are linked to a Prompt Studio project. It executes automatically +after PS extraction completes. +""" + +import logging +import uuid +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError +from typing import Any + +from django.conf import settings + +logger = logging.getLogger(__name__) + +# Configuration defaults +LOOKUP_AUTO_ENRICH_ENABLED = getattr(settings, "LOOKUP_AUTO_ENRICH_ENABLED", True) +LOOKUP_ENRICHMENT_TIMEOUT = getattr(settings, "LOOKUP_ENRICHMENT_TIMEOUT", 30) + + +class LookupIntegrationService: + """Service for automatic Lookup integration with Prompt Studio. + + Provides seamless enrichment of extraction results when Lookup + projects are linked to a Prompt Studio project. + """ + + @staticmethod + def enrich_if_linked( + prompt_studio_project_id: str, + extracted_data: dict[str, Any], + run_id: str | None = None, + timeout: float | None = None, + session_id: str | None = None, + doc_name: str | None = None, + file_execution_id: str | None = None, + workflow_execution_id: str | None = None, + organization_id: str | None = None, + prompt_lookup_map: dict[str, str] | None = None, + ) -> dict[str, Any]: + """Execute Lookup enrichment if PS project has linked Lookups. + + Features: + - Zero overhead if no links exist + - Timeout protection to not block extraction + - Graceful degradation on errors + - Full audit logging + - WebSocket log emission for Prompt Studio UI + - ExecutionLog persistence for Nav bar display (workflow context) + - Prompt-level lookup support: specific lookups per field + + Args: + prompt_studio_project_id: UUID of Prompt Studio project + extracted_data: Dict of extracted field values + run_id: Optional execution run ID for tracking + timeout: Max seconds to wait (default from settings) + session_id: WebSocket session ID for real-time log emission + doc_name: Document name being processed + file_execution_id: File execution ID for Nav bar logs + workflow_execution_id: Workflow execution ID for Nav bar logs + organization_id: Organization ID for multi-tenancy + prompt_lookup_map: Optional mapping of field names (prompt_key) to + specific lookup_project_id. Fields with a specific lookup will + ONLY be enriched by that lookup. Fields without a specific + lookup will be SKIPPED (no enrichment applied). + + Returns: + Dict with 'lookup_enrichment' and '_lookup_metadata' keys, + or empty dict if no links or enrichment disabled. + """ + # Check if auto-enrichment is enabled + if not LOOKUP_AUTO_ENRICH_ENABLED: + logger.debug("Lookup auto-enrichment is disabled") + return {} + + if not extracted_data: + logger.debug("No extracted data provided for enrichment") + return {} + + timeout = timeout or LOOKUP_ENRICHMENT_TIMEOUT + prompt_lookup_map = prompt_lookup_map or {} + + # Skip enrichment if no prompts have lookup enabled (prompt_lookup_map empty) + # This ensures lookups only run when explicitly enabled at the prompt level + if not prompt_lookup_map: + logger.debug( + "No prompts have lookup enabled (prompt_lookup_map is empty), " + "skipping enrichment" + ) + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "status": "skipped", + "message": "No prompts have lookup enabled", + "lookups_executed": 0, + "lookups_successful": 0, + }, + } + + try: + return LookupIntegrationService._execute_enrichment( + prompt_studio_project_id=prompt_studio_project_id, + extracted_data=extracted_data, + run_id=run_id, + timeout=timeout, + session_id=session_id, + doc_name=doc_name, + file_execution_id=file_execution_id, + workflow_execution_id=workflow_execution_id, + organization_id=organization_id, + prompt_lookup_map=prompt_lookup_map, + ) + except FuturesTimeoutError: + logger.warning( + f"Lookup enrichment timed out for PS project " + f"{prompt_studio_project_id} after {timeout}s" + ) + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "status": "timeout", + "message": f"Enrichment timed out after {timeout}s", + "lookups_executed": 0, + "lookups_successful": 0, + }, + } + except Exception as e: + logger.error( + f"Lookup enrichment failed for PS project " + f"{prompt_studio_project_id}: {e}", + exc_info=True, + ) + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "status": "error", + "message": str(e), + "lookups_executed": 0, + "lookups_successful": 0, + }, + } + + @staticmethod + def _execute_enrichment( + prompt_studio_project_id: str, + extracted_data: dict[str, Any], + run_id: str | None, + timeout: float, + session_id: str | None = None, + doc_name: str | None = None, + file_execution_id: str | None = None, + workflow_execution_id: str | None = None, + organization_id: str | None = None, + prompt_lookup_map: dict[str, str] | None = None, + ) -> dict[str, Any]: + """Internal method to execute enrichment with timeout. + + Args: + prompt_lookup_map: Mapping of field names to specific lookup project IDs. + Fields in this map will only be enriched by their specific lookup. + """ + from lookup.models import PromptStudioLookupLink + from lookup.services.log_emitter import LookupLogEmitter + + # Initialize log emitter for WebSocket and/or ExecutionLog + # When file_execution_id is set (workflow context), logs persist to Nav bar + log_emitter = LookupLogEmitter( + session_id=session_id, + execution_id=workflow_execution_id or run_id, + file_execution_id=file_execution_id, + organization_id=organization_id, + doc_name=doc_name, + ) + + # Quick existence check - minimal overhead if no links + links = ( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=prompt_studio_project_id + ) + .select_related("lookup_project") + .order_by("execution_order") + ) + + if not links.exists(): + logger.debug( + f"No Lookup links found for PS project {prompt_studio_project_id}" + ) + log_emitter.emit_no_linked_lookups() + return {} + + # Get enabled lookup projects (those with ready status) + lookup_projects = [link.lookup_project for link in links if link.is_enabled] + + if not lookup_projects: + logger.debug( + f"No enabled Lookup projects for PS project {prompt_studio_project_id}" + ) + return {} + + logger.info( + f"Executing {len(lookup_projects)} Lookup(s) for PS project " + f"{prompt_studio_project_id}" + ) + + # Emit orchestration start log + lookup_names = [lp.name for lp in lookup_projects] + log_emitter.emit_orchestration_start( + lookup_count=len(lookup_projects), + lookup_names=lookup_names, + ) + + # Execute with timeout protection + import time + + start_time = time.time() + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + LookupIntegrationService._run_orchestrator, + lookup_projects=lookup_projects, + input_data=extracted_data, + execution_id=run_id or str(uuid.uuid4()), + prompt_studio_project_id=prompt_studio_project_id, + log_emitter=log_emitter, + organization_id=organization_id, + prompt_lookup_map=prompt_lookup_map or {}, + ) + result = future.result(timeout=timeout) + + # Emit orchestration complete log + total_time_ms = int((time.time() - start_time) * 1000) + metadata = result.get("_lookup_metadata", {}) + log_emitter.emit_orchestration_complete( + total_lookups=len(lookup_projects), + successful=metadata.get("lookups_successful", 0), + failed=metadata.get("lookups_executed", 0) + - metadata.get("lookups_successful", 0), + total_time_ms=total_time_ms, + total_enriched_fields=len(result.get("lookup_enrichment", {})), + ) + + return result + + @staticmethod + def _run_orchestrator( + lookup_projects: list, + input_data: dict[str, Any], + execution_id: str, + prompt_studio_project_id: str, + log_emitter: Any = None, + organization_id: str | None = None, + prompt_lookup_map: dict[str, str] | None = None, + ) -> dict[str, Any]: + """Execute the lookup orchestrator for all linked projects. + + Supports prompt-level lookups: if a field has a specific lookup assigned + via prompt_lookup_map, only that lookup will enrich it. Fields without + specific lookups will be SKIPPED (no enrichment applied). + + Args: + lookup_projects: List of LookupProject instances linked at project level + input_data: Dict of extracted field values + execution_id: Execution run ID for tracking + prompt_studio_project_id: UUID of Prompt Studio project + log_emitter: Optional log emitter for WebSocket logs + organization_id: Organization ID for multi-tenancy + prompt_lookup_map: Mapping of field names to specific lookup project IDs + """ + from lookup.integrations.file_storage_client import FileStorageClient + from lookup.integrations.unstract_llm_client import UnstractLLMClient + from lookup.models import LookupProfileManager + from lookup.services.enrichment_merger import EnrichmentMerger + from lookup.services.llm_cache import LLMResponseCache + from lookup.services.lookup_executor import LookUpExecutor + from lookup.services.lookup_orchestrator import LookUpOrchestrator + from lookup.services.reference_data_loader import ReferenceDataLoader + from lookup.services.variable_resolver import VariableResolver + + try: + # Get profile manager for LLM client + # Use the first lookup project's default profile + first_project = lookup_projects[0] + profile_manager = LookupProfileManager.objects.filter( + lookup_project=first_project, is_default=True + ).first() + + if not profile_manager: + logger.warning( + f"No default profile for Lookup project {first_project.id}" + ) + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "status": "error", + "message": "No LLM profile configured for Lookup project", + "lookups_executed": 0, + "lookups_successful": 0, + }, + } + + # Get LLM adapter instance from profile + llm_adapter_instance = profile_manager.llm + if not llm_adapter_instance: + logger.warning("No LLM adapter configured in profile") + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "status": "error", + "message": "No LLM adapter configured", + "lookups_executed": 0, + "lookups_successful": 0, + }, + } + + # Create LLM client using the existing UnstractLLMClient + llm_client = UnstractLLMClient(llm_adapter_instance) + + # Initialize services + cache = LLMResponseCache() + merger = EnrichmentMerger() + + # Create file storage client for reference data loading + storage_client = FileStorageClient() + ref_loader = ReferenceDataLoader(storage_client) + + # Create executor with wrapper for LLM client interface + executor = LookUpExecutor( + variable_resolver=VariableResolver, + cache_manager=cache, + reference_loader=ref_loader, + llm_client=LLMClientWrapper(llm_client), + org_id=organization_id, + ) + + # Create orchestrator with log emitter for WebSocket logs + orchestrator = LookUpOrchestrator( + executor=executor, merger=merger, log_emitter=log_emitter + ) + + # Handle prompt-level lookups if mapping is provided + prompt_lookup_map = prompt_lookup_map or {} + + # Separate fields by their lookup assignment + # Fields with specific lookups: only that lookup enriches them + # Fields without specific lookups: SKIP enrichment (no lookup enabled) + fields_with_specific_lookup: dict[str, dict[str, Any]] = {} + fields_skipped: list[str] = [] + + for field_name, field_value in input_data.items(): + if field_name in prompt_lookup_map: + lookup_id = prompt_lookup_map[field_name] + if lookup_id not in fields_with_specific_lookup: + fields_with_specific_lookup[lookup_id] = {} + fields_with_specific_lookup[lookup_id][field_name] = field_value + else: + # Field has no lookup assigned - skip enrichment entirely + fields_skipped.append(field_name) + + if fields_skipped: + logger.info( + f"Skipping enrichment for fields without lookup assigned: " + f"{fields_skipped}" + ) + + all_enrichment: dict[str, Any] = {} + all_enrichments: list[dict[str, Any]] = [] # Collect all enrichment results + total_executed = 0 + total_successful = 0 + + # Execute specific lookups for their assigned fields + for lookup_id, fields in fields_with_specific_lookup.items(): + specific_project = next( + (p for p in lookup_projects if str(p.id) == lookup_id), None + ) + if specific_project: + logger.info( + f"Executing prompt-level lookup {specific_project.name} " + f"for fields: {list(fields.keys())}" + ) + result = orchestrator.execute_lookups( + input_data=fields, + lookup_projects=[specific_project], + execution_id=execution_id, + prompt_studio_project_id=prompt_studio_project_id, + ) + all_enrichment.update(result.get("lookup_enrichment", {})) + metadata = result.get("_lookup_metadata", {}) + total_executed += metadata.get("lookups_executed", 0) + total_successful += metadata.get("lookups_successful", 0) + # Collect enrichment details for error checking + all_enrichments.extend(metadata.get("enrichments", [])) + else: + logger.warning( + f"Lookup project {lookup_id} not found in linked projects " + f"for fields: {list(fields.keys())}" + ) + + return { + "lookup_enrichment": all_enrichment, + "_lookup_metadata": { + "status": "success", + "lookups_executed": total_executed, + "lookups_successful": total_successful, + "prompt_level_lookups": len(fields_with_specific_lookup), + "fields_skipped": len(fields_skipped), + "enrichments": all_enrichments, # Include for error checking + }, + } + + except Exception as e: + logger.error(f"Error in lookup orchestrator: {e}", exc_info=True) + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "status": "error", + "message": str(e), + "lookups_executed": 0, + "lookups_successful": 0, + }, + } + + +class LLMClientWrapper: + """Wrapper to adapt UnstractLLMClient to LookUpExecutor interface. + + The LookUpExecutor expects an LLM client with a `generate(prompt, config)` + method that returns a string. + """ + + def __init__(self, unstract_client: Any) -> None: + """Initialize wrapper. + + Args: + unstract_client: UnstractLLMClient instance + """ + self.client = unstract_client + + def generate(self, prompt: str, config: dict[str, Any] | None = None) -> str: + """Execute LLM generation. + + Args: + prompt: The prompt to send to LLM + config: Optional LLM configuration + + Returns: + LLM response text + """ + try: + # Use the generate method from UnstractLLMClient + response = self.client.generate(prompt=prompt, config=config or {}) + return response + except Exception as e: + logger.error(f"LLM generation failed: {e}") + raise diff --git a/backend/lookup/services/lookup_orchestrator.py b/backend/lookup/services/lookup_orchestrator.py new file mode 100644 index 0000000000..9f20e59726 --- /dev/null +++ b/backend/lookup/services/lookup_orchestrator.py @@ -0,0 +1,451 @@ +"""Look-Up Orchestrator implementation for parallel execution. + +This module provides functionality to execute multiple Look-Up projects +in parallel and merge their results into a single enriched output. +""" + +import logging +import time +import uuid +from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed +from datetime import UTC, datetime +from typing import Any + +from lookup.models import LookupProject +from lookup.services.enrichment_merger import EnrichmentMerger +from lookup.services.lookup_executor import LookUpExecutor + +logger = logging.getLogger(__name__) + + +class LookUpOrchestrator: + """Orchestrates parallel execution of multiple Look-Up projects. + + This class manages the concurrent execution of multiple Look-Up projects, + handles timeouts, collects results, and merges them into a single + enriched output using the EnrichmentMerger. + """ + + def __init__( + self, + executor: LookUpExecutor, + merger: EnrichmentMerger, + config: dict[str, Any] = None, + log_emitter: Any = None, + ): + """Initialize the Look-Up orchestrator. + + Args: + executor: LookUpExecutor instance for single Look-Up execution + merger: EnrichmentMerger instance for combining results + config: Configuration dictionary with optional keys: + - max_concurrent_executions: Maximum parallel executions (default 10) + - queue_timeout_seconds: Overall queue timeout (default 120) + - execution_timeout_seconds: Per-execution timeout (default 30) + log_emitter: Optional LookupLogEmitter for WebSocket logging + """ + self.executor = executor + self.merger = merger + self.log_emitter = log_emitter + + config = config or {} + self.max_concurrent = config.get("max_concurrent_executions", 10) + self.queue_timeout = config.get("queue_timeout_seconds", 120) + self.execution_timeout = config.get("execution_timeout_seconds", 30) + + logger.info( + f"Orchestrator initialized with max_concurrent={self.max_concurrent}, " + f"queue_timeout={self.queue_timeout}s, " + f"execution_timeout={self.execution_timeout}s" + ) + + def execute_lookups( + self, + input_data: dict[str, Any], + lookup_projects: list[LookupProject], + execution_id: str | None = None, + prompt_studio_project_id: str | None = None, + ) -> dict[str, Any]: + """Execute all Look-Ups in parallel and merge results. + + Submits all Look-Up projects for parallel execution, collects + the results, and merges successful enrichments into a single + output. Handles timeouts and failures gracefully. + + Args: + input_data: Input data to enrich + lookup_projects: List of Look-Up projects to execute + execution_id: Optional UUID to group related executions for audit + prompt_studio_project_id: Optional PS project ID for audit tracking + + Returns: + Dictionary containing: + - lookup_enrichment: Merged enrichment data + - _lookup_metadata: Execution metadata including: + - execution_id: Unique ID for this execution + - executed_at: ISO8601 timestamp + - total_execution_time_ms: Total time in milliseconds + - lookups_executed: Number of Look-Ups attempted + - lookups_successful: Number of successful executions + - lookups_failed: Number of failed executions + - conflicts_resolved: Number of field conflicts resolved + - enrichments: List of individual enrichment results + + Example: + >>> orchestrator = LookUpOrchestrator(executor, merger) + >>> projects = [vendor_lookup, product_lookup] + >>> result = orchestrator.execute_lookups({"vendor": "Slack"}, projects) + >>> print(result["lookup_enrichment"]) + {'canonical_vendor': 'Slack', 'product_type': 'SaaS'} + >>> print(result["_lookup_metadata"]["lookups_successful"]) + 2 + """ + execution_id = execution_id or str(uuid.uuid4()) + start_time = time.time() + executed_at = datetime.now(UTC).isoformat() + + logger.info( + f"Starting orchestrated execution {execution_id} for " + f"{len(lookup_projects)} Look-Up projects" + ) + + if not lookup_projects: + # No Look-Ups to execute + return self._empty_result(execution_id, executed_at, start_time) + + successful_enrichments = [] + failed_lookups = [] + timeout_count = 0 + + # Build project order mapping for sorting results later + project_order = { + str(project.id): idx for idx, project in enumerate(lookup_projects) + } + + # Execute Look-Ups in parallel + with ThreadPoolExecutor(max_workers=self.max_concurrent) as thread_executor: + # Submit all tasks + futures = { + thread_executor.submit( + self._execute_single, + execution_id, + input_data, + lookup_project, + prompt_studio_project_id, + ): lookup_project + for lookup_project in lookup_projects + } + + logger.debug(f"Submitted {len(futures)} Look-Up tasks for parallel execution") + + # Collect results with timeout + try: + for future in as_completed(futures, timeout=self.queue_timeout): + lookup_project = futures[future] + try: + result = future.result(timeout=self.execution_timeout) + + if result["status"] == "success": + successful_enrichments.append(result) + logger.debug( + f"Look-Up {lookup_project.name} completed successfully" + ) + # Emit success log via WebSocket + self._emit_success_log(result, lookup_project) + else: + failed_lookups.append(result) + logger.warning( + f"Look-Up {lookup_project.name} failed: {result.get('error')}" + ) + # Emit failure log via WebSocket + self._emit_failure_log(result, lookup_project) + + except TimeoutError: + # Individual execution timeout + timeout_count += 1 + logger.error( + f"Look-Up {lookup_project.name} timed out after " + f"{self.execution_timeout}s" + ) + failed_lookups.append( + { + "status": "failed", + "project_id": str(lookup_project.id), + "project_name": lookup_project.name, + "error": f"Execution timeout ({self.execution_timeout}s)", + "execution_time_ms": self.execution_timeout * 1000, + "cached": False, + } + ) + + except Exception as e: + # Unexpected error in future.result() + logger.exception( + f"Unexpected error getting result for {lookup_project.name}" + ) + failed_lookups.append( + { + "status": "failed", + "project_id": str(lookup_project.id), + "project_name": lookup_project.name, + "error": f"Execution error: {str(e)}", + "execution_time_ms": 0, + "cached": False, + } + ) + + except TimeoutError: + # Overall queue timeout + logger.error( + f"Queue timeout after {self.queue_timeout}s, " + f"some Look-Ups may not have completed" + ) + # Cancel remaining futures + for future in futures: + if not future.done(): + future.cancel() + lookup_project = futures[future] + failed_lookups.append( + { + "status": "failed", + "project_id": str(lookup_project.id), + "project_name": lookup_project.name, + "error": f"Queue timeout ({self.queue_timeout}s)", + "execution_time_ms": 0, + "cached": False, + } + ) + + # Sort successful enrichments by original execution order before merging + # This ensures that when there's no confidence score, the lookup with + # lower execution_order (higher priority) wins in conflict resolution + if successful_enrichments: + successful_enrichments.sort( + key=lambda x: project_order.get(x.get("project_id"), 999) + ) + merge_result = self.merger.merge(successful_enrichments) + merged_data = merge_result["data"] + conflicts_resolved = merge_result["conflicts_resolved"] + # enrichment_details = merge_result["enrichment_details"] + else: + # No successful enrichments + merged_data = {} + conflicts_resolved = 0 + + # Calculate execution time + total_execution_time_ms = int((time.time() - start_time) * 1000) + + # Combine all enrichment results (successful and failed) + all_enrichments = successful_enrichments + failed_lookups + + logger.info( + f"Orchestration {execution_id} completed: " + f"{len(successful_enrichments)} successful, " + f"{len(failed_lookups)} failed, " + f"{timeout_count} timeouts, " + f"{conflicts_resolved} conflicts resolved, " + f"total time {total_execution_time_ms}ms" + ) + logger.info(f"Merged enrichment data: {merged_data}") + + return { + "lookup_enrichment": merged_data, + "_lookup_metadata": { + "execution_id": execution_id, + "executed_at": executed_at, + "total_execution_time_ms": total_execution_time_ms, + "lookups_executed": len(lookup_projects), + "lookups_successful": len(successful_enrichments), + "lookups_failed": len(failed_lookups), + "conflicts_resolved": conflicts_resolved, + "enrichments": all_enrichments, + }, + } + + def _execute_single( + self, + execution_id: str, + input_data: dict[str, Any], + lookup_project: LookupProject, + prompt_studio_project_id: str | None = None, + ) -> dict[str, Any]: + """Execute a single Look-Up project. + + Wrapper around the executor to add execution context and + handle any unexpected errors. + + Args: + execution_id: ID of the orchestration execution + input_data: Input data for enrichment + lookup_project: Look-Up project to execute + prompt_studio_project_id: Optional PS project ID for audit tracking + + Returns: + Enrichment result dictionary from the executor + """ + try: + logger.debug( + f"Executing Look-Up {lookup_project.name} for execution {execution_id}" + ) + + # Execute the Look-Up with audit context + result = self.executor.execute( + lookup_project=lookup_project, + input_data=input_data, + execution_id=execution_id, + prompt_studio_project_id=prompt_studio_project_id, + ) + + # Add execution context + result["execution_id"] = execution_id + + # Filter enrichment data to only include fields that actually changed + # This prevents a lookup from overwriting fields it didn't canonicalize + if result.get("status") == "success" and result.get("data"): + result["data"] = self._filter_changed_fields(input_data, result["data"]) + logger.debug( + f"Filtered enrichment for {lookup_project.name}: " + f"{list(result['data'].keys())}" + ) + + return result + + except Exception as e: + # Catch any unexpected errors from the executor + logger.exception(f"Unexpected error executing Look-Up {lookup_project.name}") + return { + "status": "failed", + "project_id": str(lookup_project.id), + "project_name": lookup_project.name, + "error": f"Unexpected error: {str(e)}", + "execution_time_ms": 0, + "cached": False, + "execution_id": execution_id, + } + + def _filter_changed_fields( + self, + input_data: dict[str, Any], + enrichment_data: dict[str, Any], + ) -> dict[str, Any]: + """Filter enrichment data to only include fields that changed. + + When an LLM returns the entire input with modifications, this method + identifies which fields actually changed and returns only those. + This prevents one lookup from overwriting fields that another lookup + is responsible for canonicalizing. + + Args: + input_data: Original input data before enrichment + enrichment_data: Data returned by the lookup + + Returns: + Dictionary containing only fields that differ from input_data, + plus any new fields not present in input_data + """ + changed_fields = {} + + for field_name, enriched_value in enrichment_data.items(): + # Include field if: + # 1. It's a new field not in input_data, OR + # 2. The value is different from the input value + if field_name not in input_data: + # New field added by the lookup + changed_fields[field_name] = enriched_value + elif input_data[field_name] != enriched_value: + # Field value was changed by the lookup + changed_fields[field_name] = enriched_value + # else: field unchanged, don't include it + + return changed_fields + + def _empty_result( + self, execution_id: str, executed_at: str, start_time: float + ) -> dict[str, Any]: + """Build result for empty Look-Up list. + + Args: + execution_id: Execution ID + executed_at: Execution timestamp + start_time: Start time for calculating duration + + Returns: + Empty result dictionary with metadata + """ + total_execution_time_ms = int((time.time() - start_time) * 1000) + + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "execution_id": execution_id, + "executed_at": executed_at, + "total_execution_time_ms": total_execution_time_ms, + "lookups_executed": 0, + "lookups_successful": 0, + "lookups_failed": 0, + "conflicts_resolved": 0, + "enrichments": [], + }, + } + + def _emit_success_log( + self, result: dict[str, Any], lookup_project: LookupProject + ) -> None: + """Emit success log via WebSocket if log_emitter is available. + + Args: + result: Execution result dictionary + lookup_project: The Look-Up project that was executed + """ + if not self.log_emitter: + return + + try: + enriched_count = len(result.get("data", {})) + cached = result.get("cached", False) + execution_time_ms = result.get("execution_time_ms", 0) + confidence = result.get("confidence") + context_type = result.get("context_type", "full") + + self.log_emitter.emit_enrichment_success( + lookup_project_name=lookup_project.name, + enriched_count=enriched_count, + cached=cached, + execution_time_ms=execution_time_ms, + confidence=confidence, + context_type=context_type, + ) + except Exception as e: + logger.warning(f"Failed to emit success log: {e}") + + def _emit_failure_log( + self, result: dict[str, Any], lookup_project: LookupProject + ) -> None: + """Emit failure log via WebSocket if log_emitter is available. + + Args: + result: Execution result dictionary + lookup_project: The Look-Up project that was executed + """ + if not self.log_emitter: + return + + try: + error_type = result.get("error_type") + + # Use specialized log for context window exceeded errors + if error_type == "context_window_exceeded": + self.log_emitter.emit_context_overflow_error( + lookup_project_name=lookup_project.name, + token_count=result.get("token_count", 0), + context_limit=result.get("context_limit", 0), + model=result.get("model", "unknown"), + ) + else: + error_message = result.get("error", "Unknown error") + self.log_emitter.emit_enrichment_failure( + lookup_project_name=lookup_project.name, + error_message=error_message, + ) + except Exception as e: + logger.warning(f"Failed to emit failure log: {e}") diff --git a/backend/lookup/services/lookup_retrieval_service.py b/backend/lookup/services/lookup_retrieval_service.py new file mode 100644 index 0000000000..efb0673f47 --- /dev/null +++ b/backend/lookup/services/lookup_retrieval_service.py @@ -0,0 +1,262 @@ +"""RAG retrieval service for Lookup projects using similarity search with ranking. + +This module provides functionality to retrieve relevant context chunks from +indexed reference data using vector similarity search. Chunks are ranked by +similarity score and annotated with source information for better context. +""" + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters +from prompt_studio.prompt_studio_core_v2.prompt_ide_base_tool import PromptIdeBaseTool +from utils.user_context import UserContext + +from lookup.models import LookupIndexManager + +if TYPE_CHECKING: + from lookup.models import LookupProfileManager + +from unstract.sdk1.constants import LogLevel +from unstract.sdk1.embedding import EmbeddingCompat +from unstract.sdk1.vector_db import VectorDB + +logger = logging.getLogger(__name__) + + +@dataclass +class RetrievedChunk: + """A retrieved chunk with metadata for ranking and attribution.""" + + content: str + score: float + source_file: str + doc_id: str + + +class LookupRetrievalService: + """Service to perform RAG retrieval for Lookup projects. + + Uses simple cosine similarity search against indexed reference data. + When chunk_size > 0, this service retrieves relevant chunks instead + of loading the entire reference data. + """ + + def __init__(self, profile: "LookupProfileManager", org_id: str | None = None): + """Initialize the retrieval service. + + Args: + profile: LookupProfileManager with adapter configuration + org_id: Organization ID (if None, gets from UserContext) + """ + self.profile = profile + self.org_id = org_id or UserContext.get_organization_identifier() + logger.info( + f"LookupRetrievalService initialized with org_id='{self.org_id}' " + f"(provided={org_id}, from_context={UserContext.get_organization_identifier()})" + ) + + def retrieve_context( + self, query: str, project_id: str, min_score: float = 0.3 + ) -> str: + r"""Retrieve relevant context chunks from indexed data sources. + + Queries the vector DB for chunks semantically similar to the query, + filtering by doc_id to ensure results come from the correct + indexed reference data. Chunks are ranked by similarity score and + annotated with source file information. + + Args: + query: The semantic query to search for (built from input_data) + project_id: UUID of the lookup project + min_score: Minimum similarity score threshold (default 0.3) + + Returns: + Concatenated retrieved chunks as context string, sorted by + relevance score (highest first) with source attribution. + Returns empty string if no indexed sources found. + + Example: + >>> service = LookupRetrievalService(profile) + >>> context = service.retrieve_context("vendor: Slack, type: SaaS", project_id) + >>> print(context[:100]) + '=== Source: vendors.csv (relevance: 0.89) ===\nSlack Technologies Inc...' + """ + # Get all indexed data sources for this project with extraction complete + index_managers = LookupIndexManager.objects.filter( + data_source__project_id=project_id, + data_source__is_latest=True, + data_source__extraction_status="completed", # Only fully extracted sources + profile_manager=self.profile, + raw_index_id__isnull=False, + ).select_related("data_source") + + if not index_managers.exists(): + logger.warning( + f"No indexed data sources for project {project_id}. " + "Ensure data sources are uploaded and extraction is complete." + ) + return "" + + # Build doc_id to source file mapping for attribution + doc_id_to_source: dict[str, str] = {} + for im in index_managers: + if im.raw_index_id: + doc_id_to_source[im.raw_index_id] = im.data_source.file_name + + doc_ids = list(doc_id_to_source.keys()) + + if not doc_ids: + logger.warning(f"No valid doc_ids found for project {project_id}") + return "" + + logger.info( + f"Retrieving context for project {project_id} " + f"from {len(doc_ids)} indexed sources: {list(doc_id_to_source.values())}" + ) + + # Retrieve from each doc_id and aggregate results with metadata + all_chunks: list[RetrievedChunk] = [] + for doc_id in doc_ids: + source_file = doc_id_to_source.get(doc_id, "unknown") + try: + chunks = self._retrieve_chunks(query, doc_id, source_file, min_score) + all_chunks.extend(chunks) + except Exception as e: + logger.error( + f"Failed to retrieve chunks for doc_id {doc_id} " + f"(source: {source_file}): {e}" + ) + # Continue with other doc_ids + + if not all_chunks: + logger.warning( + f"No chunks retrieved for project {project_id} with query: {query[:100]}. " + f"This may indicate poor semantic match or indexing issues." + ) + return "" + + # Sort by score (highest first) for better context quality + all_chunks.sort(key=lambda c: c.score, reverse=True) + + # Deduplicate by content while preserving score-based order + seen_content: set[str] = set() + unique_chunks: list[RetrievedChunk] = [] + for chunk in all_chunks: + if chunk.content not in seen_content: + seen_content.add(chunk.content) + unique_chunks.append(chunk) + + logger.info( + f"Retrieved {len(unique_chunks)} unique chunks " + f"(from {len(all_chunks)} total) for project {project_id}. " + f"Score range: {unique_chunks[-1].score:.3f} - {unique_chunks[0].score:.3f}" + ) + + # Format chunks with source attribution for LLM context + formatted_chunks = [] + for chunk in unique_chunks: + header = f"=== Source: {chunk.source_file} (relevance: {chunk.score:.2f}) ===" + formatted_chunks.append(f"{header}\n{chunk.content}") + + return "\n\n".join(formatted_chunks) + + def _retrieve_chunks( + self, query: str, doc_id: str, source_file: str, min_score: float = 0.3 + ) -> list[RetrievedChunk]: + """Retrieve chunks from vector DB with similarity scores. + + Uses the configured embedding model and vector store from the profile + to perform cosine similarity search with doc_id filtering. Returns + chunks with their similarity scores for ranking. + + Args: + query: The semantic query to search for + doc_id: Document ID to filter results (from LookupIndexManager) + source_file: Source file name for attribution + min_score: Minimum similarity score threshold (default 0.3) + + Returns: + List of RetrievedChunk objects with content, score, and source info + """ + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=self.org_id) + + # Initialize embedding model from profile + embedding = EmbeddingCompat( + adapter_instance_id=str(self.profile.embedding_model.id), + tool=util, + kwargs={}, + ) + + # Initialize vector DB from profile + vector_db = VectorDB( + tool=util, + adapter_instance_id=str(self.profile.vector_store.id), + embedding=embedding, + ) + + try: + # Get vector store index for retrieval + vector_store_index = vector_db.get_vector_store_index() + + # Create retriever with doc_id filter and top-k from profile + retriever = vector_store_index.as_retriever( + similarity_top_k=self.profile.similarity_top_k, + filters=MetadataFilters( + filters=[ExactMatchFilter(key="doc_id", value=doc_id)], + ), + ) + + # Execute retrieval + logger.info( + f"Executing vector retrieval for doc_id={doc_id}, " + f"query='{query[:100]}...', top_k={self.profile.similarity_top_k}" + ) + nodes = retriever.retrieve(query) + + logger.info(f"Vector DB returned {len(nodes)} nodes for doc_id {doc_id}") + + # Extract chunks with scores above threshold + chunks: list[RetrievedChunk] = [] + for node in nodes: + logger.debug( + f"Node {node.node_id}: score={node.score:.4f}, " + f"content_length={len(node.get_content())}" + ) + if node.score >= min_score: + chunks.append( + RetrievedChunk( + content=node.get_content(), + score=node.score, + source_file=source_file, + doc_id=doc_id, + ) + ) + else: + logger.debug( + f"Ignored node {node.node_id} with score {node.score:.3f} " + f"(below threshold {min_score})" + ) + + if len(nodes) > 0 and len(chunks) == 0: + logger.warning( + f"All {len(nodes)} nodes for doc_id {doc_id} were below " + f"min_score threshold ({min_score}). Highest score: " + f"{max(n.score for n in nodes):.4f}" + ) + elif len(nodes) == 0: + logger.warning( + f"Vector DB returned 0 nodes for doc_id {doc_id}. " + "This doc_id may not exist in the vector DB - check indexing." + ) + + logger.info( + f"Retrieved {len(chunks)} chunks for doc_id {doc_id} " + f"(source: {source_file})" + ) + return chunks + + finally: + # Always close vector DB connection + vector_db.close() diff --git a/backend/lookup/services/mock_clients.py b/backend/lookup/services/mock_clients.py new file mode 100644 index 0000000000..9d505f84ef --- /dev/null +++ b/backend/lookup/services/mock_clients.py @@ -0,0 +1,177 @@ +"""Mock implementations of LLM and Storage clients for testing. + +These are temporary implementations for testing the API layer +before integrating with real services in Phase 4. +""" + +import json +import random +import time +from typing import Any + + +class MockLLMClient: + """Mock LLM client for testing Look-Up execution. + + Generates synthetic responses for testing purposes. + """ + + def generate(self, prompt: str, config: dict[str, Any], timeout: int = 30) -> str: + """Generate a mock LLM response. + + Returns JSON-formatted enrichment data with random confidence. + """ + # Simulate processing time + time.sleep(random.uniform(0.1, 0.5)) + + # Extract vendor name from prompt if available + vendor = "Unknown" + if "vendor" in prompt.lower(): + # Try to extract vendor name from prompt + lines = prompt.split("\n") + for line in lines: + if "vendor" in line.lower() and ":" in line: + vendor = line.split(":")[-1].strip() + break + + # Generate mock enrichment data + confidence = random.uniform(0.6, 0.98) + enrichment_data = { + "canonical_vendor": self._canonicalize(vendor), + "vendor_category": random.choice( + ["SaaS", "Infrastructure", "Security", "Analytics"] + ), + "vendor_type": random.choice(["Software", "Service", "Platform"]), + "confidence": round(confidence, 2), + } + + return json.dumps(enrichment_data) + + def _canonicalize(self, vendor: str) -> str: + """Mock canonicalization of vendor names.""" + # Simple mock canonicalization + mappings = { + "Slack Technologies": "Slack", + "Microsoft Corp": "Microsoft", + "Amazon Web Services": "AWS", + "Google Cloud Platform": "GCP", + "International Business Machines": "IBM", + } + return mappings.get(vendor, vendor) + + +class MockStorageClient: + """Mock storage client for testing reference data operations. + + Stores data in memory for testing purposes. + """ + + def __init__(self): + """Initialize in-memory storage.""" + self.storage = {} + + def upload(self, path: str, content: bytes) -> bool: + """Upload content to mock storage. + + Args: + path: Storage path + content: File content + + Returns: + True if successful + """ + self.storage[path] = content + return True + + def download(self, path: str) -> bytes | None: + """Download content from mock storage. + + Args: + path: Storage path + + Returns: + File content or None if not found + """ + return self.storage.get(path) + + def delete(self, path: str) -> bool: + """Delete content from mock storage. + + Args: + path: Storage path + + Returns: + True if deleted, False if not found + """ + if path in self.storage: + del self.storage[path] + return True + return False + + def exists(self, path: str) -> bool: + """Check if path exists in storage. + + Args: + path: Storage path + + Returns: + True if exists + """ + return path in self.storage + + def list_files(self, prefix: str) -> list: + """List files with given prefix. + + Args: + prefix: Path prefix + + Returns: + List of matching paths + """ + return [path for path in self.storage.keys() if path.startswith(prefix)] + + def get_text_content(self, path: str) -> str | None: + """Get text content from storage. + + Args: + path: Storage path + + Returns: + Text content or None if not found + """ + content = self.download(path) + if content: + return content.decode("utf-8") + return None + + def get(self, path: str) -> str: + """Retrieve file content from storage. + + This method implements the StorageClient protocol expected + by ReferenceDataLoader. + + Args: + path: Storage path + + Returns: + Text content of the file + + Raises: + FileNotFoundError: If file not found in storage + """ + content = self.get_text_content(path) + if content is None: + raise FileNotFoundError(f"File not found: {path}") + return content + + def save_text_content(self, path: str, text: str) -> bool: + """Save text content to storage. + + Args: + path: Storage path + text: Text content + + Returns: + True if successful + """ + return self.upload(path, text.encode("utf-8")) diff --git a/backend/lookup/services/reference_data_loader.py b/backend/lookup/services/reference_data_loader.py new file mode 100644 index 0000000000..30799b4570 --- /dev/null +++ b/backend/lookup/services/reference_data_loader.py @@ -0,0 +1,267 @@ +"""Reference Data Loader implementation for loading and concatenating reference data. + +This module provides functionality to load reference data from object storage +and concatenate multiple sources into a single text for LLM processing. +""" + +from typing import Any, Protocol +from uuid import UUID + +from django.db.models import QuerySet + +from lookup.exceptions import ExtractionNotCompleteError +from lookup.models import LookupDataSource + + +class StorageClient(Protocol): + """Protocol for object storage client abstraction. + + Any storage client implementation must provide a get() method + that retrieves file content by path. + """ + + def get(self, path: str) -> str: + """Retrieve file content from storage.""" + ... + + +class ReferenceDataLoader: + """Loads and concatenates reference data from object storage. + + This class handles loading reference data files that have been + extracted from uploaded documents and stored in object storage. + It ensures all files are properly extracted before loading and + concatenates multiple sources in the order they were uploaded. + """ + + def __init__(self, storage_client: StorageClient): + """Initialize the reference data loader. + + Args: + storage_client: Object storage client (abstraction). + Must implement the StorageClient protocol. + """ + self.storage = storage_client + + def load_latest_for_project(self, project_id: UUID) -> dict[str, Any]: + r"""Load latest reference data for a project. + + Retrieves the most recent version of reference data for the specified + project. Ensures all data sources have completed extraction before + loading and concatenating their content. + + Args: + project_id: UUID of the Look-Up project + + Returns: + Dictionary containing: + - version: Version number of the reference data + - content: Concatenated text from all files + - files: List of metadata about source files + - total_size: Total size in bytes + + Raises: + ExtractionNotCompleteError: If any data source extraction is incomplete + LookupDataSource.DoesNotExist: If no data sources found + + Example: + >>> loader = ReferenceDataLoader(storage) + >>> data = loader.load_latest_for_project(project_id) + >>> print(data["version"]) + 3 + >>> print(data["content"][:50]) + '=== File: vendors.csv ===\n\nSlack\nMicrosoft...' + """ + # Get latest version data sources + data_sources = LookupDataSource.objects.filter( + project_id=project_id, is_latest=True + ).order_by("created_at") + + if not data_sources.exists(): + raise LookupDataSource.DoesNotExist( + f"No data sources found for project {project_id}" + ) + + # Validate all extractions are complete + is_complete, failed_files = self.validate_extraction_complete(data_sources) + if not is_complete: + raise ExtractionNotCompleteError(failed_files) + + # Get version number from first source (all should have same version) + version_number = data_sources.first().version_number + + # Concatenate content from all sources + content = self.concatenate_sources(data_sources) + + # Build file metadata + files = [] + total_size = 0 + for source in data_sources: + files.append( + { + "id": str(source.id), + "name": source.file_name, + "size": source.file_size, + "type": source.file_type, + "uploaded_at": source.created_at.isoformat(), + } + ) + total_size += source.file_size + + return { + "version": version_number, + "content": content, + "files": files, + "total_size": total_size, + } + + def load_specific_version(self, project_id: UUID, version: int) -> dict[str, Any]: + """Load specific version of reference data. + + Retrieves a specific version of reference data for the project, + regardless of whether it's the latest version. + + Args: + project_id: UUID of the Look-Up project + version: Version number to load + + Returns: + Dictionary with same structure as load_latest_for_project() + + Raises: + ExtractionNotCompleteError: If any data source extraction is incomplete + LookupDataSource.DoesNotExist: If version not found + + Example: + >>> loader = ReferenceDataLoader(storage) + >>> data = loader.load_specific_version(project_id, 2) + >>> print(data["version"]) + 2 + """ + # Get specific version data sources + data_sources = LookupDataSource.objects.filter( + project_id=project_id, version_number=version + ).order_by("created_at") + + if not data_sources.exists(): + raise LookupDataSource.DoesNotExist( + f"Version {version} not found for project {project_id}" + ) + + # Validate all extractions are complete + is_complete, failed_files = self.validate_extraction_complete(data_sources) + if not is_complete: + raise ExtractionNotCompleteError(failed_files) + + # Concatenate content from all sources + content = self.concatenate_sources(data_sources) + + # Build file metadata + files = [] + total_size = 0 + for source in data_sources: + files.append( + { + "id": str(source.id), + "name": source.file_name, + "size": source.file_size, + "type": source.file_type, + "uploaded_at": source.created_at.isoformat(), + } + ) + total_size += source.file_size + + return { + "version": version, + "content": content, + "files": files, + "total_size": total_size, + } + + def concatenate_sources(self, data_sources: QuerySet) -> str: + """Concatenate extracted content from multiple sources in upload order. + + Loads the extracted content for each data source from object storage + and concatenates them with file headers for clarity. + + Args: + data_sources: QuerySet of LookupDataSource objects, + should be ordered by created_at + + Returns: + Concatenated string with all file contents separated by headers + + Example: + >>> content = loader.concatenate_sources(sources) + >>> print(content) + === File: vendors.csv === + + Slack + Microsoft + Google + + === File: products.txt === + + Slack Workspace + Microsoft Teams + """ + contents = [] + + for source in data_sources: + # Add file header + header = f"=== File: {source.file_name} ===" + + # Load content from storage + # First try extracted_content_path, then fall back to original file_path + # for text-based files (CSV, TXT, JSON) + content_path = source.extracted_content_path + if not content_path: + # For text files, use the original file path + text_file_types = ["csv", "txt", "json"] + if source.file_type in text_file_types: + content_path = source.file_path + + if content_path: + try: + file_content = self.storage.get(content_path) + except Exception as e: + # If storage fails, include error in output + file_content = f"[Error loading file: {str(e)}]" + else: + file_content = "[No content path available]" + + # Combine header and content + contents.append(f"{header}\n\n{file_content}") + + # Join all contents with double newline separator + return "\n\n".join(contents) + + def validate_extraction_complete( + self, data_sources: QuerySet + ) -> tuple[bool, list[str]]: + """Check if all sources have completed extraction. + + Verifies that all data sources in the queryset have successfully + completed the extraction process. + + Args: + data_sources: QuerySet of LookupDataSource objects + + Returns: + Tuple of: + - all_complete: True if all extractions complete, False otherwise + - failed_files: List of filenames that are not complete + + Example: + >>> is_complete, failed = loader.validate_extraction_complete(sources) + >>> if not is_complete: + ... print(f"Waiting for: {', '.join(failed)}") + """ + failed_files = [] + + for source in data_sources: + if source.extraction_status != "completed": + failed_files.append(source.file_name) + + all_complete = len(failed_files) == 0 + return all_complete, failed_files diff --git a/backend/lookup/services/variable_resolver.py b/backend/lookup/services/variable_resolver.py new file mode 100644 index 0000000000..45d57fb6bb --- /dev/null +++ b/backend/lookup/services/variable_resolver.py @@ -0,0 +1,158 @@ +"""Variable resolver for template variable replacement.""" + +import json +import re +from typing import Any + + +class VariableResolver: + """Resolves {{variable}} placeholders in prompt templates with actual values. + + Supports dot notation for nested field access and handles complex + data types (dicts, lists) by converting them to JSON. + """ + + VARIABLE_PATTERN = r"\{\{([^}]*)\}\}" + + def __init__(self, input_data: dict[str, Any], reference_data: str): + """Initialize the variable resolver with context data. + + Args: + input_data: Extracted data from Prompt Studio + reference_data: Concatenated text from all reference files + + Note: + Variables can be accessed in templates as: + - {{reference_data}} - the full reference data string + - {{input_data.field_name}} - explicit input_data prefix + - {{field_name}} - shorthand for input_data fields (auto-resolved) + """ + self.context = {"input_data": input_data, "reference_data": reference_data} + # Also add input_data fields at top level for shorthand access + # This allows {{vendor_name}} instead of {{input_data.vendor_name}} + for key, value in input_data.items(): + if key not in self.context: # Don't override reference_data + self.context[key] = value + + def resolve(self, template: str) -> str: + r"""Replace all {{variable}} references in template with actual values. + + Args: + template: Prompt template with {{variable}} placeholders + + Returns: + Resolved prompt with variables replaced + + Example: + >>> resolver = VariableResolver( + ... {"vendor": "Slack Inc"}, "Slack\nMicrosoft\nGoogle" + ... ) + >>> template = "Match {{input_data.vendor}} against: {{reference_data}}" + >>> resolver.resolve(template) + 'Match Slack Inc against: Slack\nMicrosoft\nGoogle' + """ + + def replacer(match): + variable_path = match.group(1).strip() + return str(self._get_nested_value(variable_path)) + + return re.sub(self.VARIABLE_PATTERN, replacer, template) + + def detect_variables(self, template: str) -> list[str]: + """Extract all {{variable}} references from template. + + Args: + template: Template text to analyze + + Returns: + List of unique variable paths found in template + + Example: + >>> resolver = VariableResolver({}, "") + >>> template = "Match {{input_data.vendor}} from {{reference_data}} and {{input_data.vendor}}" + >>> resolver.detect_variables(template) + ['input_data.vendor', 'reference_data'] + """ + if not template: + return [] + + matches = re.findall(self.VARIABLE_PATTERN, template) + # Strip whitespace and deduplicate + unique_vars = list({m.strip() for m in matches}) + return sorted(unique_vars) + + def _get_nested_value(self, path: str) -> Any: + """Get value from context using dot notation path. + + Args: + path: Dot-separated path (e.g., "input_data.vendor.name") + + Returns: + Value at path, or empty string if not found + + Examples: + >>> resolver = VariableResolver({"vendor": {"name": "Slack"}}, "") + >>> resolver._get_nested_value("input_data.vendor.name") + 'Slack' + >>> resolver._get_nested_value("input_data.missing") + '' + """ + if not path: + return "" + + keys = path.split(".") + value = self.context + + for key in keys: + if isinstance(value, dict): + value = value.get(key, "") + elif isinstance(value, list): + # Try to convert key to integer for list indexing + try: + index = int(key) + value = value[index] if 0 <= index < len(value) else "" + except (ValueError, IndexError): + return "" + else: + return "" + + # If value is complex object, return JSON representation + if isinstance(value, (dict, list)): + return json.dumps(value, indent=2, ensure_ascii=False) + + # Handle None values + if value is None: + return "" + + return value + + def validate_variables(self, template: str) -> dict[str, bool]: + """Validate that all variables in template can be resolved. + + Args: + template: Template to validate + + Returns: + Dictionary mapping variable paths to their availability status + """ + variables = self.detect_variables(template) + validation = {} + + for var in variables: + value = self._get_nested_value(var) + # Consider empty string as not available (could be missing) + validation[var] = value != "" + + return validation + + def get_missing_variables(self, template: str) -> list[str]: + """Get list of variables that cannot be resolved. + + Args: + template: Template to check + + Returns: + List of variable paths that resolve to empty/missing values + """ + validation = self.validate_variables(template) + return [var for var, available in validation.items() if not available] diff --git a/backend/lookup/services/vector_db_cleanup_service.py b/backend/lookup/services/vector_db_cleanup_service.py new file mode 100644 index 0000000000..3319e61e60 --- /dev/null +++ b/backend/lookup/services/vector_db_cleanup_service.py @@ -0,0 +1,366 @@ +"""Vector DB Cleanup Service for Lookup feature. + +Provides centralized cleanup operations for removing obsolete +vector DB nodes when reference data is re-indexed, profiles +are changed, or data sources are deleted. +""" + +import logging +from typing import Any + +from utils.user_context import UserContext + +from unstract.sdk1.constants import LogLevel +from unstract.sdk1.vector_db import VectorDB + +logger = logging.getLogger(__name__) + + +class VectorDBCleanupService: + """Centralized service for vector DB cleanup operations. + + This service handles all vector DB node deletion scenarios: + - Cleanup on re-indexing (delete old nodes before adding new) + - Cleanup on profile deletion + - Cleanup on data source deletion + - Manual cleanup of stale indexes + - Cleanup when switching from RAG to full context mode + + Example: + >>> service = VectorDBCleanupService() + >>> result = service.cleanup_index_ids( + ... index_ids=["doc_id_1", "doc_id_2"], vector_db_instance_id="uuid-of-vector-db" + ... ) + >>> print(result) + {'success': True, 'deleted': 2, 'failed': 0, 'errors': []} + """ + + def __init__(self, org_id: str | None = None): + """Initialize the cleanup service. + + Args: + org_id: Organization ID for multi-tenancy. If not provided, + will be fetched from UserContext. + """ + self.org_id = org_id or UserContext.get_organization_identifier() + + def _get_vector_db_client(self, vector_db_instance_id: str) -> VectorDB: + """Get a VectorDB client for the given adapter instance. + + Args: + vector_db_instance_id: UUID of the vector DB adapter instance + + Returns: + VectorDB client instance + """ + from prompt_studio.prompt_studio_core_v2.prompt_ide_base_tool import ( + PromptIdeBaseTool, + ) + + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=self.org_id) + return VectorDB(tool=util, adapter_instance_id=vector_db_instance_id) + + def cleanup_index_ids( + self, + index_ids: list[str], + vector_db_instance_id: str, + ) -> dict[str, Any]: + """Delete specific index IDs from vector DB. + + Args: + index_ids: List of index IDs (doc_ids) to delete + vector_db_instance_id: UUID of the vector DB adapter instance + + Returns: + Dictionary with cleanup results: + - success: True if all deletions succeeded + - deleted: Number of successfully deleted indexes + - failed: Number of failed deletions + - errors: List of error messages for failed deletions + """ + if not index_ids: + logger.debug("No index IDs to clean up") + return {"success": True, "deleted": 0, "failed": 0, "errors": []} + + if not vector_db_instance_id: + logger.warning("Cannot cleanup: vector_db_instance_id not provided") + return { + "success": False, + "deleted": 0, + "failed": len(index_ids), + "errors": ["vector_db_instance_id not provided"], + } + + deleted = 0 + failed = 0 + errors = [] + + try: + vector_db = self._get_vector_db_client(vector_db_instance_id) + + for index_id in index_ids: + try: + logger.debug(f"Deleting from VectorDB - index id: {index_id}") + vector_db.delete(ref_doc_id=index_id) + deleted += 1 + except Exception as e: + error_msg = f"Error deleting index {index_id}: {e}" + logger.error(error_msg) + errors.append(error_msg) + failed += 1 + + logger.info( + f"Vector DB cleanup completed: {deleted} deleted, {failed} failed" + ) + + except Exception as e: + error_msg = f"Error initializing vector DB client: {e}" + logger.error(error_msg, exc_info=True) + return { + "success": False, + "deleted": deleted, + "failed": len(index_ids) - deleted, + "errors": [error_msg] + errors, + } + + return { + "success": failed == 0, + "deleted": deleted, + "failed": failed, + "errors": errors, + } + + def cleanup_for_data_source( + self, + data_source_id: str, + profile_id: str | None = None, + ) -> dict[str, Any]: + """Clean up all indexes for a data source. + + Args: + data_source_id: UUID of the LookupDataSource + profile_id: Optional profile ID to filter by. If not provided, + cleans up indexes for all profiles. + + Returns: + Dictionary with cleanup results + """ + from lookup.models import LookupIndexManager + + try: + queryset = LookupIndexManager.objects.filter(data_source_id=data_source_id) + if profile_id: + queryset = queryset.filter(profile_manager_id=profile_id) + + total_deleted = 0 + total_failed = 0 + all_errors = [] + + for index_manager in queryset: + if ( + index_manager.index_ids_history + and index_manager.profile_manager + and index_manager.profile_manager.vector_store + ): + result = self.cleanup_index_ids( + index_ids=index_manager.index_ids_history, + vector_db_instance_id=str( + index_manager.profile_manager.vector_store.id + ), + ) + total_deleted += result["deleted"] + total_failed += result["failed"] + all_errors.extend(result["errors"]) + + # Clear the history after successful cleanup + if result["success"]: + index_manager.index_ids_history = [] + index_manager.raw_index_id = None + index_manager.status = {"indexed": False, "cleaned": True} + index_manager.save() + + return { + "success": total_failed == 0, + "deleted": total_deleted, + "failed": total_failed, + "errors": all_errors, + } + + except Exception as e: + error_msg = f"Error cleaning up data source {data_source_id}: {e}" + logger.error(error_msg, exc_info=True) + return {"success": False, "deleted": 0, "failed": 0, "errors": [error_msg]} + + def cleanup_stale_indexes( + self, + index_manager, + keep_current: bool = True, + ) -> dict[str, Any]: + """Clean up old indexes, optionally keeping the current one. + + This is useful when re-indexing - delete old nodes but keep + the most recent one (which will be replaced). + + Args: + index_manager: LookupIndexManager instance + keep_current: If True, keeps the current raw_index_id + + Returns: + Dictionary with cleanup results + """ + if not index_manager.index_ids_history: + return {"success": True, "deleted": 0, "failed": 0, "errors": []} + + if ( + not index_manager.profile_manager + or not index_manager.profile_manager.vector_store + ): + logger.warning( + f"Cannot cleanup stale indexes: missing profile or vector store " + f"for index manager {index_manager.index_manager_id}" + ) + return { + "success": False, + "deleted": 0, + "failed": 0, + "errors": ["Missing profile or vector store"], + } + + # Determine which IDs to delete + ids_to_delete = list(index_manager.index_ids_history) + if keep_current and index_manager.raw_index_id: + ids_to_delete = [ + id for id in ids_to_delete if id != index_manager.raw_index_id + ] + + if not ids_to_delete: + return {"success": True, "deleted": 0, "failed": 0, "errors": []} + + result = self.cleanup_index_ids( + index_ids=ids_to_delete, + vector_db_instance_id=str(index_manager.profile_manager.vector_store.id), + ) + + # Update history to remove deleted IDs + if result["deleted"] > 0: + remaining_ids = [ + id for id in index_manager.index_ids_history if id not in ids_to_delete + ] + index_manager.index_ids_history = remaining_ids + index_manager.save() + + return result + + def cleanup_for_profile(self, profile_id: str) -> dict[str, Any]: + """Clean up all indexes created with a specific profile. + + Use this when a profile is being deleted or when switching + from RAG mode to full context mode. + + Args: + profile_id: UUID of the LookupProfileManager + + Returns: + Dictionary with cleanup results + """ + from lookup.models import LookupIndexManager, LookupProfileManager + + try: + profile = LookupProfileManager.objects.get(pk=profile_id) + if not profile.vector_store: + logger.warning(f"Profile {profile_id} has no vector store configured") + return { + "success": False, + "deleted": 0, + "failed": 0, + "errors": ["Profile has no vector store configured"], + } + + vector_db_instance_id = str(profile.vector_store.id) + index_managers = LookupIndexManager.objects.filter(profile_manager=profile) + + total_deleted = 0 + total_failed = 0 + all_errors = [] + + for index_manager in index_managers: + if index_manager.index_ids_history: + result = self.cleanup_index_ids( + index_ids=index_manager.index_ids_history, + vector_db_instance_id=vector_db_instance_id, + ) + total_deleted += result["deleted"] + total_failed += result["failed"] + all_errors.extend(result["errors"]) + + # Clear the history after cleanup + index_manager.index_ids_history = [] + index_manager.raw_index_id = None + index_manager.status = {"indexed": False, "cleaned": True} + index_manager.reindex_required = True + index_manager.save() + + logger.info( + f"Profile cleanup completed for {profile_id}: " + f"{total_deleted} deleted, {total_failed} failed" + ) + + return { + "success": total_failed == 0, + "deleted": total_deleted, + "failed": total_failed, + "errors": all_errors, + } + + except LookupProfileManager.DoesNotExist: + error_msg = f"Profile {profile_id} not found" + logger.error(error_msg) + return {"success": False, "deleted": 0, "failed": 0, "errors": [error_msg]} + except Exception as e: + error_msg = f"Error cleaning up profile {profile_id}: {e}" + logger.error(error_msg, exc_info=True) + return {"success": False, "deleted": 0, "failed": 0, "errors": [error_msg]} + + def cleanup_before_reindex( + self, + index_manager, + ) -> dict[str, Any]: + """Clean up all existing indexes before re-indexing. + + This should be called before adding a new doc_id during re-indexing + to ensure old stale data is removed from the vector DB. + + Args: + index_manager: LookupIndexManager instance + + Returns: + Dictionary with cleanup results + """ + if not index_manager.index_ids_history: + return {"success": True, "deleted": 0, "failed": 0, "errors": []} + + if ( + not index_manager.profile_manager + or not index_manager.profile_manager.vector_store + ): + logger.warning( + "Cannot cleanup before reindex: missing profile or vector store" + ) + return { + "success": False, + "deleted": 0, + "failed": 0, + "errors": ["Missing profile or vector store"], + } + + logger.info( + f"Cleaning up {len(index_manager.index_ids_history)} old index(es) " + f"before re-indexing data source {index_manager.data_source.file_name}" + ) + + result = self.cleanup_index_ids( + index_ids=index_manager.index_ids_history, + vector_db_instance_id=str(index_manager.profile_manager.vector_store.id), + ) + + return result diff --git a/backend/lookup/services/workflow_integration.py b/backend/lookup/services/workflow_integration.py new file mode 100644 index 0000000000..637f3e438b --- /dev/null +++ b/backend/lookup/services/workflow_integration.py @@ -0,0 +1,419 @@ +"""Workflow Integration Service for Look-up enrichment. + +This module provides integration between Look-up enrichment and +workflow file execution (ETL, Workflow, API deployments). +It handles logging to both WebSocket (real-time) and ExecutionLog +(file-centric) based on execution context. +""" + +import json +import logging +import uuid +from typing import Any +from uuid import UUID + +from lookup.models import LookupExecutionAudit, PromptStudioLookupLink + +logger = logging.getLogger(__name__) + + +class LookupWorkflowIntegration: + """Service for integrating Look-ups with workflow file execution. + + This service provides methods to execute Look-ups within workflow + contexts (ETL, Workflow, API) with proper logging and audit trail. + + Example: + >>> from workflow_manager.file_execution.models import WorkflowFileExecution + >>> file_exec = WorkflowFileExecution.objects.get(id=file_id) + >>> result = LookupWorkflowIntegration.execute_lookups_for_file( + ... prompt_studio_project_id=ps_project_id, + ... extraction_output={"vendor": "Acme Corp"}, + ... workflow_file_execution=file_exec, + ... organization_id="org-123", + ... ) + """ + + @classmethod + def execute_lookups_for_file( + cls, + prompt_studio_project_id: UUID, + extraction_output: dict[str, Any], + workflow_execution_id: UUID, + file_execution_id: UUID, + organization_id: str, + file_name: str | None = None, + session_id: str | None = None, + ) -> dict[str, Any]: + """Execute linked Look-ups for a file being processed in a workflow. + + This method is called from ETL, Workflow, and API execution pipelines + after Prompt Studio extraction completes for a file. + + Args: + prompt_studio_project_id: The PS project UUID + extraction_output: Output from Prompt Studio extraction + workflow_execution_id: The workflow execution UUID + file_execution_id: The file execution UUID + organization_id: Tenant organization ID + file_name: Optional file name for logging + session_id: Optional WebSocket session for real-time logs + + Returns: + Enriched output with Look-up data merged, or original output + if no Look-ups are linked or enrichment fails. + """ + # Check for linked lookups first + if not cls.has_linked_lookups(prompt_studio_project_id): + logger.debug(f"No linked Look-ups for PS project {prompt_studio_project_id}") + return extraction_output + + try: + # Execute lookups using the integration service + # LookupIntegrationService handles all logging via LookupLogEmitter + from lookup.services.lookup_integration_service import ( + LookupIntegrationService, + ) + + result = LookupIntegrationService.enrich_if_linked( + prompt_studio_project_id=str(prompt_studio_project_id), + extracted_data=extraction_output, + run_id=str(uuid.uuid4()), + session_id=session_id, + doc_name=file_name, + file_execution_id=str(file_execution_id), + workflow_execution_id=str(workflow_execution_id), + organization_id=organization_id, + ) + + # Get enrichment result + enrichment = result.get("lookup_enrichment", {}) + + # Replace enriched values in extraction output (not add at top level) + # The enrichment dict contains {field_name: enriched_value} pairs + # These should replace the original values in extraction_output + if enrichment: + merged_output = extraction_output.copy() + for field_name, enriched_value in enrichment.items(): + if field_name in merged_output: + logger.info( + f"[LOOKUP] Replacing '{field_name}' value: " + f"'{merged_output[field_name]}' -> '{enriched_value}'" + ) + merged_output[field_name] = enriched_value + else: + logger.warning( + f"[LOOKUP] Field '{field_name}' not found in extraction_output, " + f"skipping enrichment" + ) + return merged_output + + return extraction_output + + except Exception as e: + logger.error( + f"Look-up enrichment failed for file execution " + f"{file_execution_id}: {e}", + exc_info=True, + ) + # Return original output on failure + return extraction_output + + @classmethod + def process_workflow_enrichment( + cls, + workflow_id: str, + original_output: str, + file_execution_id: str, + ) -> tuple[str | dict[str, Any], bool]: + """Process Look-up enrichment for workflow output. + + This method is called from file_execution_tasks._try_lookup_enrichment + to enrich extraction output with Look-up data. + + Args: + workflow_id: The workflow UUID as string + original_output: The extraction output (JSON string or dict) + file_execution_id: The file execution UUID as string + + Returns: + Tuple of (enriched_output, was_enriched): + - enriched_output: The enriched data (same type as input) + - was_enriched: True if enrichment was applied + """ + from prompt_studio.prompt_studio_registry_v2.models import PromptStudioRegistry + from tool_instance_v2.models import ToolInstance + from workflow_manager.file_execution.models import WorkflowFileExecution + from workflow_manager.workflow_v2.models.workflow import Workflow + + try: + logger.info( + f"[LOOKUP] process_workflow_enrichment called for workflow " + f"{workflow_id}, file_execution {file_execution_id}" + ) + + # Parse output if string + if isinstance(original_output, str): + try: + output_data = json.loads(original_output) + logger.info( + f"[LOOKUP] Parsed output data keys: {list(output_data.keys())}" + ) + except json.JSONDecodeError: + logger.warning( + f"[LOOKUP] Could not parse output as JSON for workflow {workflow_id}" + ) + return original_output, False + else: + output_data = original_output + logger.info( + f"[LOOKUP] Output data keys: {list(output_data.keys()) if isinstance(output_data, dict) else type(output_data)}" + ) + + # Get workflow and its prompt studio registry + workflow = Workflow.objects.get(id=workflow_id) + logger.info(f"[LOOKUP] Found workflow: {workflow.id}") + + # Get prompt studio project ID from workflow's tool instance + # The tool_id in ToolInstance is the prompt_registry_id + tool_instance = ToolInstance.objects.filter(workflow_id=workflow_id).first() + + if not tool_instance: + logger.info(f"[LOOKUP] No tool instance found for workflow {workflow_id}") + return original_output, False + + logger.info( + f"[LOOKUP] Found tool instance: {tool_instance.id}, tool_id: {tool_instance.tool_id}" + ) + + # Get the PromptStudioRegistry to find the custom_tool (PS project) + try: + prompt_registry = PromptStudioRegistry.objects.get( + prompt_registry_id=tool_instance.tool_id + ) + logger.info( + f"[LOOKUP] Found prompt registry: {prompt_registry.prompt_registry_id}" + ) + if prompt_registry.custom_tool: + prompt_studio_project_id = str(prompt_registry.custom_tool.tool_id) + logger.info( + f"[LOOKUP] Found PS project ID: {prompt_studio_project_id}" + ) + else: + logger.info( + f"[LOOKUP] No custom tool linked to registry {tool_instance.tool_id}" + ) + return original_output, False + except PromptStudioRegistry.DoesNotExist: + logger.info( + f"[LOOKUP] No prompt registry found for tool {tool_instance.tool_id}" + ) + return original_output, False + + if not prompt_studio_project_id: + logger.info( + f"[LOOKUP] No Prompt Studio project found for workflow {workflow_id}" + ) + return original_output, False + + # Check for linked lookups + if not cls.has_linked_lookups(UUID(prompt_studio_project_id)): + logger.info( + f"[LOOKUP] No linked Look-ups for PS project {prompt_studio_project_id}" + ) + return original_output, False + + logger.info( + f"[LOOKUP] Found linked lookups for PS project {prompt_studio_project_id}" + ) + + # Get file execution for context + file_execution = WorkflowFileExecution.objects.get(id=file_execution_id) + workflow_execution_id = file_execution.workflow_execution_id + organization_id = str(workflow.organization_id) + + # Extract the actual output data for enrichment + # The workflow output structure is: {metadata: {...}, metrics: {...}, output: {...}} + # The extracted fields are inside the 'output' key + extracted_fields = output_data.get("output", {}) + if not extracted_fields or not isinstance(extracted_fields, dict): + logger.info( + "[LOOKUP] No 'output' key found or not a dict in output_data, " + "trying to use output_data directly" + ) + extracted_fields = output_data + + logger.info( + f"[LOOKUP] Extracted fields for enrichment: {list(extracted_fields.keys())}" + ) + + # Execute lookups with file execution context for Nav bar logging + # LookupIntegrationService handles all logging via LookupLogEmitter + from lookup.services.lookup_integration_service import ( + LookupIntegrationService, + ) + + result = LookupIntegrationService.enrich_if_linked( + prompt_studio_project_id=prompt_studio_project_id, + extracted_data=extracted_fields, + run_id=str(uuid.uuid4()), + file_execution_id=file_execution_id, + workflow_execution_id=str(workflow_execution_id), + organization_id=organization_id, + doc_name=file_execution.file_name, + ) + + # Get result metadata + _metadata = result.get("_lookup_metadata", {}) # noqa F841 + enrichment = result.get("lookup_enrichment", {}) + + # Replace enriched values in the output structure + # The enrichment dict contains {field_name: enriched_value} pairs + if enrichment: + merged_output = output_data.copy() + # Check if we need to update inside 'output' key or at top level + if "output" in merged_output and isinstance( + merged_output["output"], dict + ): + # Update inside the 'output' sub-object + merged_output["output"] = merged_output["output"].copy() + for field_name, enriched_value in enrichment.items(): + if field_name in merged_output["output"]: + logger.info( + f"[LOOKUP] Replacing output['{field_name}'] value: " + f"'{merged_output['output'][field_name]}' -> '{enriched_value}'" + ) + merged_output["output"][field_name] = enriched_value + else: + logger.warning( + f"[LOOKUP] Field '{field_name}' not found in output, " + f"skipping enrichment" + ) + else: + # Update at top level (fallback for flat structure) + for field_name, enriched_value in enrichment.items(): + if field_name in merged_output: + logger.info( + f"[LOOKUP] Replacing '{field_name}' value: " + f"'{merged_output[field_name]}' -> '{enriched_value}'" + ) + merged_output[field_name] = enriched_value + else: + logger.warning( + f"[LOOKUP] Field '{field_name}' not found in output_data, " + f"skipping enrichment" + ) + return merged_output, True + + return original_output, False + + except Exception as e: + logger.error( + f"Look-up enrichment failed for workflow {workflow_id}, " + f"file execution {file_execution_id}: {e}", + exc_info=True, + ) + return original_output, False + + @classmethod + def has_linked_lookups(cls, prompt_studio_project_id: UUID) -> bool: + """Check if PS project has linked Look-ups. + + Args: + prompt_studio_project_id: The PS project UUID + + Returns: + True if at least one Look-up is linked + """ + return PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=prompt_studio_project_id + ).exists() + + @classmethod + def _get_enabled_lookup_projects(cls, prompt_studio_project_id: UUID) -> list: + """Get enabled lookup projects for a PS project. + + Args: + prompt_studio_project_id: The PS project UUID + + Returns: + List of enabled LookupProject instances + """ + links = ( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=prompt_studio_project_id + ) + .select_related("lookup_project") + .order_by("execution_order") + ) + + return [link.lookup_project for link in links if link.is_enabled] + + @classmethod + def get_lookup_logs_for_file( + cls, + file_execution_id: UUID, + ) -> list[dict]: + """Get all Look-up related logs for a file execution. + + Args: + file_execution_id: The file execution UUID + + Returns: + List of log dictionaries with data and event_time + """ + from workflow_manager.workflow_v2.models import ExecutionLog + + return list( + ExecutionLog.objects.filter( + file_execution_id=file_execution_id, + data__stage="LOOKUP", + ) + .values("data", "event_time") + .order_by("event_time") + ) + + @classmethod + def get_lookup_audits_for_file( + cls, + file_execution_id: UUID, + ) -> list[LookupExecutionAudit]: + """Get all Look-up audit records for a file execution. + + Args: + file_execution_id: The file execution UUID + + Returns: + List of LookupExecutionAudit instances + """ + return list( + LookupExecutionAudit.objects.filter(file_execution_id=file_execution_id) + .select_related("lookup_project") + .order_by("executed_at") + ) + + @classmethod + def get_lookup_audits_for_workflow( + cls, + workflow_execution_id: UUID, + ) -> list[LookupExecutionAudit]: + """Get all Look-up audit records for a workflow execution. + + Args: + workflow_execution_id: The workflow execution UUID + + Returns: + List of LookupExecutionAudit instances + """ + # Get all file execution IDs for this workflow + from workflow_manager.file_execution.models import WorkflowFileExecution + + file_execution_ids = WorkflowFileExecution.objects.filter( + workflow_execution_id=workflow_execution_id + ).values_list("id", flat=True) + + return list( + LookupExecutionAudit.objects.filter(file_execution_id__in=file_execution_ids) + .select_related("lookup_project") + .order_by("executed_at") + ) diff --git a/backend/lookup/tests/__init__.py b/backend/lookup/tests/__init__.py new file mode 100644 index 0000000000..291608f561 --- /dev/null +++ b/backend/lookup/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Look-Up system.""" diff --git a/backend/lookup/tests/test_api/__init__.py b/backend/lookup/tests/test_api/__init__.py new file mode 100644 index 0000000000..58150d0e27 --- /dev/null +++ b/backend/lookup/tests/test_api/__init__.py @@ -0,0 +1,3 @@ +""" +API tests for Look-Up functionality. +""" diff --git a/backend/lookup/tests/test_api/test_execution_api.py b/backend/lookup/tests/test_api/test_execution_api.py new file mode 100644 index 0000000000..a64bd37e61 --- /dev/null +++ b/backend/lookup/tests/test_api/test_execution_api.py @@ -0,0 +1,306 @@ +""" +Tests for Look-Up execution and debug API endpoints. +""" + +import uuid +from decimal import Decimal +from unittest.mock import patch, MagicMock + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import ( + LookupProject, + LookupPromptTemplate, + PromptStudioLookupLink, + LookupExecutionAudit +) + +User = get_user_model() + + +class LookupExecutionAPITest(TestCase): + """Test cases for Look-Up execution API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="Vendor: {{vendor_name}}\n{{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + # Create projects + self.lookup1 = LookupProject.objects.create( + name="Vendor Lookup", + description="Vendor enrichment", + template=self.template, + created_by=self.user + ) + self.lookup2 = LookupProject.objects.create( + name="Product Lookup", + description="Product enrichment", + template=self.template, + created_by=self.user + ) + + # Create PS project and links + self.ps_project_id = uuid.uuid4() + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup2 + ) + + @patch('lookup.views.LookUpOrchestrator') + @patch('lookup.views.LookUpExecutor') + def test_debug_with_ps_project(self, mock_executor_class, mock_orchestrator_class): + """Test debug execution with PS project context.""" + # Mock the execution + mock_orchestrator = MagicMock() + mock_orchestrator_class.return_value = mock_orchestrator + mock_orchestrator.execute_lookups.return_value = { + 'lookup_enrichment': { + 'canonical_vendor': 'Test Vendor', + 'vendor_category': 'SaaS', + 'product_type': 'Software' + }, + '_lookup_metadata': { + 'lookups_executed': 2, + 'successful_lookups': 2, + 'failed_lookups': 0, + 'execution_time_ms': 250, + 'conflicts_resolved': 0 + } + } + + url = reverse('lookup:lookupdebug-test-with-ps-project') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'input_data': { + 'vendor_name': 'Test Vendor Inc', + 'product_id': 'PROD-123' + } + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('lookup_enrichment', response.data) + self.assertIn('canonical_vendor', response.data['lookup_enrichment']) + self.assertEqual(response.data['_lookup_metadata']['lookups_executed'], 2) + + # Verify orchestrator was called with correct projects + mock_orchestrator.execute_lookups.assert_called_once() + call_args = mock_orchestrator.execute_lookups.call_args + self.assertEqual(len(call_args.kwargs['lookup_projects']), 2) + + def test_debug_without_ps_project_id(self): + """Test debug endpoint requires PS project ID.""" + url = reverse('lookup:lookupdebug-test-with-ps-project') + data = {'input_data': {}} + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('prompt_studio_project_id is required', response.data['error']) + + def test_debug_with_no_linked_lookups(self): + """Test debug with PS project that has no linked Look-Ups.""" + unlinked_ps_id = uuid.uuid4() + + url = reverse('lookup:lookupdebug-test-with-ps-project') + data = { + 'prompt_studio_project_id': str(unlinked_ps_id), + 'input_data': {} + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['message'], 'No Look-Ups linked to this Prompt Studio project') + self.assertEqual(response.data['lookup_enrichment'], {}) + self.assertEqual(response.data['_lookup_metadata']['lookups_executed'], 0) + + @patch('lookup.views.LookUpOrchestrator') + def test_debug_with_execution_error(self, mock_orchestrator_class): + """Test debug endpoint handles execution errors gracefully.""" + mock_orchestrator = MagicMock() + mock_orchestrator_class.return_value = mock_orchestrator + mock_orchestrator.execute_lookups.side_effect = Exception("Test error") + + url = reverse('lookup:lookupdebug-test-with-ps-project') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'input_data': {} + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + self.assertIn('Test error', response.data['error']) + + +class LookupAuditAPITest(TestCase): + """Test cases for execution audit API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template and project + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="{{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + self.lookup = LookupProject.objects.create( + name="Test Lookup", + description="Test", + template=self.template, + created_by=self.user + ) + + # Create audit records + self.execution_id = str(uuid.uuid4()) + + self.audit1 = LookupExecutionAudit.objects.create( + lookup_project=self.lookup, + prompt_studio_project_id=uuid.uuid4(), + execution_id=self.execution_id, + input_data={'vendor': 'Test1'}, + enriched_output={'canonical_vendor': 'Test'}, + reference_data_version=1, + llm_provider='openai', + llm_model='gpt-4', + llm_prompt='Test prompt', + llm_response='{"canonical_vendor": "Test"}', + llm_response_cached=False, + execution_time_ms=150, + llm_call_time_ms=100, + status='success', + confidence_score=Decimal('0.95') + ) + + self.audit2 = LookupExecutionAudit.objects.create( + lookup_project=self.lookup, + prompt_studio_project_id=uuid.uuid4(), + execution_id=str(uuid.uuid4()), + input_data={'vendor': 'Test2'}, + status='failure', + error_message='LLM timeout' + ) + + def test_list_audit_records(self): + """Test listing all audit records.""" + url = reverse('lookup:executionaudit-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + def test_filter_audits_by_lookup_project(self): + """Test filtering audits by Look-Up project.""" + url = reverse('lookup:executionaudit-list') + response = self.client.get(url, {'lookup_project_id': str(self.lookup.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + def test_filter_audits_by_execution_id(self): + """Test filtering audits by execution ID.""" + url = reverse('lookup:executionaudit-list') + response = self.client.get(url, {'execution_id': self.execution_id}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['execution_id'], self.execution_id) + + def test_filter_audits_by_status(self): + """Test filtering audits by status.""" + url = reverse('lookup:executionaudit-list') + + # Get successful executions + response = self.client.get(url, {'status': 'success'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['status'], 'success') + + # Get failed executions + response = self.client.get(url, {'status': 'failure'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['status'], 'failure') + + def test_retrieve_audit_record(self): + """Test retrieving a specific audit record.""" + url = reverse('lookup:executionaudit-detail', args=[self.audit1.id]) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['execution_id'], self.execution_id) + self.assertEqual(response.data['status'], 'success') + + def test_audit_records_are_readonly(self): + """Test that audit records cannot be modified.""" + url = reverse('lookup:executionaudit-detail', args=[self.audit1.id]) + + # Try to update + response = self.client.patch(url, {'status': 'modified'}, format='json') + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + # Try to delete + response = self.client.delete(url) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + @patch('lookup.views.AuditLogger') + def test_get_statistics(self, mock_logger_class): + """Test getting execution statistics.""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + mock_logger.get_project_stats.return_value = { + 'total_executions': 100, + 'success_rate': 0.95, + 'avg_execution_time_ms': 150.5, + 'cache_hit_rate': 0.30, + 'avg_confidence_score': 0.92 + } + + url = reverse('lookup:executionaudit-statistics') + response = self.client.get(url, {'lookup_project_id': str(self.lookup.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['total_executions'], 100) + self.assertEqual(response.data['success_rate'], 0.95) + + def test_statistics_requires_project_id(self): + """Test that statistics endpoint requires project ID.""" + url = reverse('lookup:executionaudit-statistics') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('lookup_project_id is required', response.data['error']) diff --git a/backend/lookup/tests/test_api/test_linking_api.py b/backend/lookup/tests/test_api/test_linking_api.py new file mode 100644 index 0000000000..92aec4937d --- /dev/null +++ b/backend/lookup/tests/test_api/test_linking_api.py @@ -0,0 +1,258 @@ +""" +Tests for Prompt Studio Look-Up linking API endpoints. +""" + +import uuid +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import ( + LookupProject, + LookupPromptTemplate, + PromptStudioLookupLink +) + +User = get_user_model() + + +class PromptStudioLinkingAPITest(TestCase): + """Test cases for PS Look-Up linking API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="{{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + # Create Look-Up projects + self.lookup1 = LookupProject.objects.create( + name="Lookup 1", + description="First lookup", + template=self.template, + created_by=self.user + ) + self.lookup2 = LookupProject.objects.create( + name="Lookup 2", + description="Second lookup", + template=self.template, + created_by=self.user + ) + + # Create PS project ID + self.ps_project_id = uuid.uuid4() + + def test_create_link(self): + """Test creating a link between PS project and Look-Up.""" + url = reverse('lookup:lookuplink-list') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project': str(self.lookup1.id) + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['lookup_project_name'], 'Lookup 1') + self.assertTrue( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ).exists() + ) + + def test_create_duplicate_link(self): + """Test that duplicate links are rejected.""" + # Create first link + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + + # Try to create duplicate + url = reverse('lookup:lookuplink-list') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project': str(self.lookup1.id) + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('already linked', str(response.data)) + + def test_list_links(self): + """Test listing all links.""" + # Create links + link1 = PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + link2 = PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup2 + ) + + url = reverse('lookup:lookuplink-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + def test_filter_links_by_ps_project(self): + """Test filtering links by PS project ID.""" + # Create links for different PS projects + ps_project_id_2 = uuid.uuid4() + + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=ps_project_id_2, + lookup_project=self.lookup2 + ) + + url = reverse('lookup:lookuplink-list') + response = self.client.get(url, {'prompt_studio_project_id': str(self.ps_project_id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['lookup_project_name'], 'Lookup 1') + + def test_filter_links_by_lookup_project(self): + """Test filtering links by Look-Up project ID.""" + # Create links + ps_project_id_2 = uuid.uuid4() + + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=ps_project_id_2, + lookup_project=self.lookup1 + ) + + url = reverse('lookup:lookuplink-list') + response = self.client.get(url, {'lookup_project_id': str(self.lookup1.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + def test_delete_link(self): + """Test deleting a link.""" + link = PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + + url = reverse('lookup:lookuplink-detail', args=[link.id]) + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertFalse( + PromptStudioLookupLink.objects.filter(id=link.id).exists() + ) + + def test_bulk_link(self): + """Test bulk linking multiple Look-Ups to a PS project.""" + url = reverse('lookup:lookuplink-bulk-link') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project_ids': [str(self.lookup1.id), str(self.lookup2.id)], + 'unlink': False + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['total_processed'], 2) + self.assertTrue(response.data['results'][0]['linked']) + self.assertTrue(response.data['results'][1]['linked']) + + # Verify links were created + self.assertEqual( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.ps_project_id + ).count(), + 2 + ) + + def test_bulk_unlink(self): + """Test bulk unlinking Look-Ups from a PS project.""" + # Create links first + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup2 + ) + + url = reverse('lookup:lookuplink-bulk-link') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project_ids': [str(self.lookup1.id), str(self.lookup2.id)], + 'unlink': True + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['total_processed'], 2) + self.assertTrue(response.data['results'][0]['unlinked']) + self.assertTrue(response.data['results'][1]['unlinked']) + + # Verify links were removed + self.assertEqual( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.ps_project_id + ).count(), + 0 + ) + + def test_bulk_link_idempotent(self): + """Test that bulk link is idempotent.""" + # Create one link first + PromptStudioLookupLink.objects.create( + prompt_studio_project_id=self.ps_project_id, + lookup_project=self.lookup1 + ) + + url = reverse('lookup:lookuplink-bulk-link') + data = { + 'prompt_studio_project_id': str(self.ps_project_id), + 'lookup_project_ids': [str(self.lookup1.id), str(self.lookup2.id)], + 'unlink': False + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['total_processed'], 2) + self.assertFalse(response.data['results'][0]['linked']) # Already existed + self.assertTrue(response.data['results'][1]['linked']) # Newly created + + # Still only 2 links total + self.assertEqual( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=self.ps_project_id + ).count(), + 2 + ) diff --git a/backend/lookup/tests/test_api/test_profile_manager_api.py b/backend/lookup/tests/test_api/test_profile_manager_api.py new file mode 100644 index 0000000000..850c5157be --- /dev/null +++ b/backend/lookup/tests/test_api/test_profile_manager_api.py @@ -0,0 +1,395 @@ +""" +Tests for LookupProfileManager API endpoints. +""" + +import uuid +from unittest.mock import patch, MagicMock + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import LookupProject, LookupProfileManager + +User = get_user_model() + + +class LookupProfileManagerAPITest(TestCase): + """Test cases for LookupProfileManager API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create lookup project + self.project = LookupProject.objects.create( + name="Test Lookup Project", + description="Test Description", + created_by=self.user + ) + + # Mock adapter instances (UUIDs) + self.mock_llm_id = str(uuid.uuid4()) + self.mock_embedding_id = str(uuid.uuid4()) + self.mock_vector_db_id = str(uuid.uuid4()) + self.mock_x2text_id = str(uuid.uuid4()) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + def test_create_profile(self, mock_get_adapter): + """Test creating a new profile.""" + # Mock adapter instances + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + url = reverse('lookup:lookupprofile-list') + data = { + 'profile_name': 'Default Profile', + 'lookup_project': str(self.project.id), + 'llm': self.mock_llm_id, + 'embedding_model': self.mock_embedding_id, + 'vector_store': self.mock_vector_db_id, + 'x2text': self.mock_x2text_id, + 'chunk_size': 1000, + 'chunk_overlap': 200, + 'similarity_top_k': 5, + 'is_default': True + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['profile_name'], 'Default Profile') + self.assertTrue(response.data['is_default']) + self.assertEqual(LookupProfileManager.objects.count(), 1) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + def test_duplicate_profile_name(self, mock_get_adapter): + """Test that duplicate profile names for same project are rejected.""" + # Mock adapter instances + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + # Create first profile + LookupProfileManager.objects.create( + profile_name='Test Profile', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + # Try to create duplicate + url = reverse('lookup:lookupprofile-list') + data = { + 'profile_name': 'Test Profile', # Same name + 'lookup_project': str(self.project.id), # Same project + 'llm': self.mock_llm_id, + 'embedding_model': self.mock_embedding_id, + 'vector_store': self.mock_vector_db_id, + 'x2text': self.mock_x2text_id, + } + + response = self.client.post(url, data, format='json') + + # Should fail due to unique constraint + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_list_profiles(self, mock_get_by_id, mock_get_adapter): + """Test listing all profiles.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create test profiles + LookupProfileManager.objects.create( + profile_name='Profile 1', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + LookupProfileManager.objects.create( + profile_name='Profile 2', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + url = reverse('lookup:lookupprofile-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_filter_by_project(self, mock_get_by_id, mock_get_adapter): + """Test filtering profiles by project.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create another project + project2 = LookupProject.objects.create( + name="Project 2", + description="Description 2", + created_by=self.user + ) + + # Create profiles for different projects + LookupProfileManager.objects.create( + profile_name='Profile 1', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + LookupProfileManager.objects.create( + profile_name='Profile 2', + lookup_project=project2, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + # Filter by project 1 + url = reverse('lookup:lookupprofile-list') + response = self.client.get(url, {'lookup_project': str(self.project.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['profile_name'], 'Profile 1') + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_get_default_profile(self, mock_get_by_id, mock_get_adapter): + """Test getting the default profile for a project.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create profiles + profile1 = LookupProfileManager.objects.create( + profile_name='Non-Default', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + is_default=False, + created_by=self.user + ) + + profile2 = LookupProfileManager.objects.create( + profile_name='Default Profile', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + is_default=True, + created_by=self.user + ) + + # Get default profile + url = reverse('lookup:lookupprofile-default') + response = self.client.get(url, {'lookup_project': str(self.project.id)}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['profile_name'], 'Default Profile') + self.assertTrue(response.data['is_default']) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_set_default_profile(self, mock_get_by_id, mock_get_adapter): + """Test setting a profile as default.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create two profiles + profile1 = LookupProfileManager.objects.create( + profile_name='Profile 1', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + is_default=True, + created_by=self.user + ) + + profile2 = LookupProfileManager.objects.create( + profile_name='Profile 2', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + is_default=False, + created_by=self.user + ) + + # Set profile2 as default + url = reverse('lookup:lookupprofile-set-default', args=[profile2.profile_id]) + response = self.client.post(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue(response.data['is_default']) + + # Verify profile1 is no longer default + profile1.refresh_from_db() + self.assertFalse(profile1.is_default) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + @patch('adapter_processor_v2.adapter_processor.AdapterProcessor.get_adapter_instance_by_id') + def test_update_profile(self, mock_get_by_id, mock_get_adapter): + """Test updating a profile.""" + # Mock adapters + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + mock_adapter_detail = { + 'id': str(uuid.uuid4()), + 'adapter_name': 'Test Adapter', + 'adapter_type': 'LLM' + } + mock_get_by_id.return_value = mock_adapter_detail + + # Create profile + profile = LookupProfileManager.objects.create( + profile_name='Original Name', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + chunk_size=1000, + created_by=self.user + ) + + # Update profile + url = reverse('lookup:lookupprofile-detail', args=[profile.profile_id]) + data = { + 'chunk_size': 2000, + 'chunk_overlap': 300, + 'similarity_top_k': 10 + } + + response = self.client.patch(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['chunk_size'], 2000) + self.assertEqual(response.data['chunk_overlap'], 300) + self.assertEqual(response.data['similarity_top_k'], 10) + + @patch('adapter_processor_v2.models.AdapterInstance.objects.get') + def test_delete_profile(self, mock_get_adapter): + """Test deleting a profile.""" + # Mock adapter instances + mock_adapter = MagicMock() + mock_adapter.id = uuid.uuid4() + mock_get_adapter.return_value = mock_adapter + + # Create profile + profile = LookupProfileManager.objects.create( + profile_name='To Delete', + lookup_project=self.project, + llm_id=self.mock_llm_id, + embedding_model_id=self.mock_embedding_id, + vector_store_id=self.mock_vector_db_id, + x2text_id=self.mock_x2text_id, + created_by=self.user + ) + + url = reverse('lookup:lookupprofile-detail', args=[profile.profile_id]) + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(LookupProfileManager.objects.count(), 0) + + def test_get_default_profile_no_default_exists(self): + """Test getting default profile when none exists.""" + url = reverse('lookup:lookupprofile-default') + response = self.client.get(url, {'lookup_project': str(self.project.id)}) + + # Should return 404 when no default profile exists + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_required_adapters(self): + """Test that all 4 adapters are required.""" + url = reverse('lookup:lookupprofile-list') + + # Missing x2text adapter + data = { + 'profile_name': 'Incomplete Profile', + 'lookup_project': str(self.project.id), + 'llm': self.mock_llm_id, + 'embedding_model': self.mock_embedding_id, + 'vector_store': self.mock_vector_db_id, + # Missing x2text + } + + response = self.client.post(url, data, format='json') + + # Should fail validation + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('x2text', str(response.data)) diff --git a/backend/lookup/tests/test_api/test_project_api.py b/backend/lookup/tests/test_api/test_project_api.py new file mode 100644 index 0000000000..0213a9a83e --- /dev/null +++ b/backend/lookup/tests/test_api/test_project_api.py @@ -0,0 +1,211 @@ +""" +Tests for Look-Up Project API endpoints. +""" + +import json +from unittest.mock import patch, MagicMock + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import LookupProject, LookupPromptTemplate, LookupDataSource + +User = get_user_model() + + +class LookupProjectAPITest(TestCase): + """Test cases for Look-Up Project API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="Vendor: {{vendor_name}}\nReference: {{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + # Create project + self.project = LookupProject.objects.create( + name="Test Project", + description="Test Description", + template=self.template, + created_by=self.user + ) + + def test_list_projects(self): + """Test listing all projects.""" + url = reverse('lookup:lookupproject-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Test Project') + + def test_create_project(self): + """Test creating a new project.""" + url = reverse('lookup:lookupproject-list') + data = { + 'name': 'New Project', + 'description': 'New Description', + 'template_id': str(self.template.id), + 'is_active': True + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['name'], 'New Project') + self.assertEqual(LookupProject.objects.count(), 2) + + def test_retrieve_project(self): + """Test retrieving a specific project.""" + url = reverse('lookup:lookupproject-detail', args=[self.project.id]) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['name'], 'Test Project') + self.assertIn('template', response.data) + + def test_update_project(self): + """Test updating a project.""" + url = reverse('lookup:lookupproject-detail', args=[self.project.id]) + data = { + 'name': 'Updated Project', + 'description': 'Updated Description', + 'is_active': False + } + + response = self.client.patch(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['name'], 'Updated Project') + self.assertFalse(response.data['is_active']) + + def test_delete_project(self): + """Test deleting a project.""" + url = reverse('lookup:lookupproject-detail', args=[self.project.id]) + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(LookupProject.objects.count(), 0) + + @patch('lookup.views.LookUpOrchestrator') + def test_execute_project(self, mock_orchestrator_class): + """Test executing a Look-Up project.""" + # Mock the orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator_class.return_value = mock_orchestrator + mock_orchestrator.execute_lookups.return_value = { + 'lookup_enrichment': { + 'canonical_vendor': 'Test Vendor', + 'vendor_category': 'SaaS' + }, + '_lookup_metadata': { + 'lookups_executed': 1, + 'successful_lookups': 1, + 'execution_time_ms': 150 + } + } + + url = reverse('lookup:lookupproject-execute', args=[self.project.id]) + data = { + 'input_data': {'vendor_name': 'Test Vendor Inc'}, + 'use_cache': True, + 'timeout_seconds': 30 + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('lookup_enrichment', response.data) + self.assertIn('_lookup_metadata', response.data) + + def test_execute_project_without_auth(self): + """Test that unauthenticated requests are rejected.""" + self.client.force_authenticate(user=None) + url = reverse('lookup:lookupproject-execute', args=[self.project.id]) + data = {'input_data': {}} + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_filter_projects_by_active_status(self): + """Test filtering projects by active status.""" + # Create inactive project + LookupProject.objects.create( + name="Inactive Project", + description="Inactive", + is_active=False, + created_by=self.user + ) + + url = reverse('lookup:lookupproject-list') + response = self.client.get(url, {'is_active': 'true'}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Test Project') + + def test_upload_reference_data(self): + """Test uploading reference data.""" + url = reverse('lookup:lookupproject-upload-reference-data', args=[self.project.id]) + + # Create a mock file + from django.core.files.uploadedfile import SimpleUploadedFile + file_content = b"vendor1,category1\nvendor2,category2" + file = SimpleUploadedFile("vendors.csv", file_content, content_type="text/csv") + + data = { + 'file': file, + 'extract_text': True, + 'metadata': json.dumps({'source': 'manual_upload'}) + } + + response = self.client.post(url, data, format='multipart') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertIn('source_file_path', response.data) + self.assertEqual(response.data['extraction_status'], 'pending') + + def test_list_data_sources(self): + """Test listing data sources for a project.""" + # Create data sources + LookupDataSource.objects.create( + project=self.project, + source_file_path="test/file1.csv", + extraction_status='complete', + version=1, + is_latest=False + ) + LookupDataSource.objects.create( + project=self.project, + source_file_path="test/file2.csv", + extraction_status='complete', + version=2, + is_latest=True + ) + + url = reverse('lookup:lookupproject-data-sources', args=[self.project.id]) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 2) + + # Test filtering by is_latest + response = self.client.get(url, {'is_latest': 'true'}) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]['version'], 2) diff --git a/backend/lookup/tests/test_api/test_template_api.py b/backend/lookup/tests/test_api/test_template_api.py new file mode 100644 index 0000000000..0883049ffa --- /dev/null +++ b/backend/lookup/tests/test_api/test_template_api.py @@ -0,0 +1,175 @@ +""" +Tests for Look-Up Template API endpoints. +""" + +from unittest.mock import patch, MagicMock + +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from ...models import LookupPromptTemplate + +User = get_user_model() + + +class LookupTemplateAPITest(TestCase): + """Test cases for Look-Up Template API.""" + + def setUp(self): + """Set up test data.""" + self.client = APIClient() + self.user = User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + self.client.force_authenticate(user=self.user) + + # Create template + self.template = LookupPromptTemplate.objects.create( + name="Test Template", + template_text="Vendor: {{vendor_name}}\nReference: {{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + created_by=self.user + ) + + def test_list_templates(self): + """Test listing all templates.""" + url = reverse('lookup:lookuptemplate-list') + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Test Template') + + def test_create_template(self): + """Test creating a new template.""" + url = reverse('lookup:lookuptemplate-list') + data = { + 'name': 'New Template', + 'template_text': 'Product: {{product_name}}\n{{reference_data}}', + 'llm_config': {'provider': 'anthropic', 'model': 'claude-2'}, + 'is_active': True + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['name'], 'New Template') + self.assertEqual(LookupPromptTemplate.objects.count(), 2) + + def test_create_template_without_reference_placeholder(self): + """Test that template without {{reference_data}} is rejected.""" + url = reverse('lookup:lookuptemplate-list') + data = { + 'name': 'Invalid Template', + 'template_text': 'Product: {{product_name}}', # Missing {{reference_data}} + 'llm_config': {'provider': 'openai', 'model': 'gpt-4'} + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('template_text', response.data) + + def test_create_template_invalid_llm_config(self): + """Test that template with invalid LLM config is rejected.""" + url = reverse('lookup:lookuptemplate-list') + data = { + 'name': 'Invalid Config Template', + 'template_text': '{{reference_data}}', + 'llm_config': {'model': 'gpt-4'} # Missing provider + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn('llm_config', response.data) + + def test_update_template(self): + """Test updating a template.""" + url = reverse('lookup:lookuptemplate-detail', args=[self.template.id]) + data = { + 'name': 'Updated Template', + 'is_active': False + } + + response = self.client.patch(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['name'], 'Updated Template') + self.assertFalse(response.data['is_active']) + + def test_delete_template(self): + """Test deleting a template.""" + url = reverse('lookup:lookuptemplate-detail', args=[self.template.id]) + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(LookupPromptTemplate.objects.count(), 0) + + @patch('lookup.views.VariableResolver') + def test_validate_template(self, mock_resolver_class): + """Test template validation endpoint.""" + # Mock the variable resolver + mock_resolver = MagicMock() + mock_resolver_class.return_value = mock_resolver + mock_resolver.resolve.return_value = "Resolved template text" + mock_resolver.get_all_variables.return_value = {'vendor_name', 'product_id'} + + url = reverse('lookup:lookuptemplate-validate') + data = { + 'template_text': 'Vendor: {{vendor_name}}\nProduct: {{product_id}}\n{{reference_data}}', + 'sample_data': {'vendor_name': 'Test Vendor', 'product_id': '123'}, + 'sample_reference': 'Sample reference data' + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertTrue(response.data['valid']) + self.assertIn('resolved_template', response.data) + self.assertIn('variables_found', response.data) + self.assertEqual(set(response.data['variables_found']), {'vendor_name', 'product_id'}) + + def test_validate_template_with_error(self): + """Test template validation with error.""" + url = reverse('lookup:lookuptemplate-validate') + data = { + 'template_text': 'Invalid: {{unclosed_variable', # Invalid template + 'sample_data': {} + } + + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertFalse(response.data['valid']) + self.assertIn('error', response.data) + + def test_filter_templates_by_active_status(self): + """Test filtering templates by active status.""" + # Create inactive template + LookupPromptTemplate.objects.create( + name="Inactive Template", + template_text="{{reference_data}}", + llm_config={"provider": "openai", "model": "gpt-4"}, + is_active=False, + created_by=self.user + ) + + url = reverse('lookup:lookuptemplate-list') + + # Get only active templates + response = self.client.get(url, {'is_active': 'true'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Test Template') + + # Get only inactive templates + response = self.client.get(url, {'is_active': 'false'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['name'], 'Inactive Template') diff --git a/backend/lookup/tests/test_integrations/__init__.py b/backend/lookup/tests/test_integrations/__init__.py new file mode 100644 index 0000000000..c0565ab6fe --- /dev/null +++ b/backend/lookup/tests/test_integrations/__init__.py @@ -0,0 +1,3 @@ +""" +Integration tests for Look-Up external service integrations. +""" diff --git a/backend/lookup/tests/test_integrations/test_llm_integration.py b/backend/lookup/tests/test_integrations/test_llm_integration.py new file mode 100644 index 0000000000..620d9cd6a7 --- /dev/null +++ b/backend/lookup/tests/test_integrations/test_llm_integration.py @@ -0,0 +1,239 @@ +""" +Tests for LLM provider integration. +""" + +import json +from unittest.mock import patch + +from django.test import TestCase + +from ...integrations.llm_provider import ( + UnstractLLMClient, + OpenAILLMClient, + AnthropicLLMClient +) + + +class UnstractLLMClientTest(TestCase): + """Test cases for UnstractLLMClient.""" + + def setUp(self): + """Set up test fixtures.""" + # Patch environment variables + self.env_patcher = patch('lookup.integrations.llm_provider.os.getenv') + self.mock_getenv = self.env_patcher.start() + self.mock_getenv.side_effect = self._mock_getenv + + # Initialize client + self.client = UnstractLLMClient() + + def tearDown(self): + """Clean up patches.""" + self.env_patcher.stop() + + def _mock_getenv(self, key, default=None): + """Mock environment variables.""" + env_vars = { + 'LOOKUP_DEFAULT_LLM_PROVIDER': 'openai', + 'LOOKUP_DEFAULT_LLM_MODEL': 'gpt-4', + 'OPENAI_API_KEY': 'test-openai-key', + 'ANTHROPIC_API_KEY': 'test-anthropic-key', + 'AZURE_OPENAI_API_KEY': 'test-azure-key', + 'AZURE_OPENAI_ENDPOINT': 'https://test.azure.com' + } + return env_vars.get(key, default) + + def test_initialization(self): + """Test client initialization.""" + # Test default initialization + client = UnstractLLMClient() + self.assertEqual(client.default_provider, 'openai') + self.assertEqual(client.default_model, 'gpt-4') + + # Test custom initialization + client = UnstractLLMClient(provider='anthropic', model='claude-2') + self.assertEqual(client.default_provider, 'anthropic') + self.assertEqual(client.default_model, 'claude-2') + + def test_generate_with_valid_json(self): + """Test generation with valid JSON response.""" + prompt = "Extract vendor information" + config = { + 'provider': 'openai', + 'model': 'gpt-4', + 'temperature': 0.7 + } + + # Since we don't have actual LLM integration, test fallback + response = self.client.generate(prompt, config) + + # Verify response is valid JSON + data = json.loads(response) + self.assertIsInstance(data, dict) + + # Should have confidence score + if 'confidence' in data: + self.assertTrue(0 <= data['confidence'] <= 1) + + def test_generate_with_timeout(self): + """Test generation respects timeout.""" + import time + start = time.time() + + response = self.client.generate( + "Test prompt", + {'provider': 'openai'}, + timeout=1 + ) + + elapsed = time.time() - start + + # Should complete within reasonable time + self.assertLess(elapsed, 2) + self.assertIsNotNone(response) + + def test_extract_json_from_text(self): + """Test JSON extraction from mixed text.""" + # Test with embedded JSON + text = "Here is the result: {\"vendor\": \"Test Corp\", \"confidence\": 0.9} end of response" + result = self.client._extract_json(text) + + data = json.loads(result) + self.assertEqual(data['vendor'], 'Test Corp') + self.assertEqual(data['confidence'], 0.9) + + # Test with no valid JSON + text = "No JSON here" + result = self.client._extract_json(text) + + data = json.loads(result) + self.assertIn('raw_response', data) + self.assertIn('warning', data) + + def test_validate_response(self): + """Test response validation.""" + # Valid response + valid_response = json.dumps({ + 'vendor': 'Test', + 'confidence': 0.85 + }) + self.assertTrue(self.client.validate_response(valid_response)) + + # Missing confidence + no_confidence = json.dumps({'vendor': 'Test'}) + self.assertFalse(self.client.validate_response(no_confidence)) + + # Invalid confidence + bad_confidence = json.dumps({'confidence': 1.5}) + self.assertFalse(self.client.validate_response(bad_confidence)) + + # Invalid JSON + not_json = "not json" + self.assertFalse(self.client.validate_response(not_json)) + + def test_get_token_count(self): + """Test token counting estimation.""" + # Test basic estimation + text = "This is a test prompt with some content" + count = self.client.get_token_count(text) + + # Should be roughly len/4 + expected = len(text) // 4 + self.assertAlmostEqual(count, expected, delta=2) + + # Test empty text + self.assertEqual(self.client.get_token_count(""), 0) + + def test_fallback_generation(self): + """Test fallback generation when LLM unavailable.""" + # Force fallback mode + self.client.llm_available = False + + response = self.client.generate( + "vendor extraction prompt", + {'provider': 'openai'} + ) + + # Should return valid JSON + data = json.loads(response) + self.assertEqual(data['status'], 'fallback') + self.assertIn('canonical_vendor', data) + + def test_simulate_llm_call(self): + """Test simulated LLM call.""" + # Test vendor prompt + vendor_response = self.client._simulate_llm_call( + "Extract vendor information", + {'provider': 'openai'} + ) + vendor_data = json.loads(vendor_response) + self.assertIn('canonical_vendor', vendor_data) + self.assertIn('vendor_category', vendor_data) + + # Test product prompt + product_response = self.client._simulate_llm_call( + "Extract product details", + {'provider': 'openai'} + ) + product_data = json.loads(product_response) + self.assertIn('product_name', product_data) + self.assertIn('product_category', product_data) + + +class OpenAILLMClientTest(TestCase): + """Test cases for OpenAI-specific client.""" + + def test_openai_client_initialization(self): + """Test OpenAI client initialization.""" + client = OpenAILLMClient() + self.assertEqual(client.default_provider, 'openai') + self.assertEqual(client.default_model, 'gpt-4') + + @patch('lookup.integrations.llm_provider.UnstractLLMClient.generate') + def test_openai_generate(self, mock_generate): + """Test OpenAI-specific generation.""" + mock_generate.return_value = '{"result": "test"}' + + client = OpenAILLMClient() + response = client.generate( + "test prompt", + {'temperature': 0.5} + ) + + # Verify OpenAI config was used + mock_generate.assert_called_once() + call_args = mock_generate.call_args[0] + config = call_args[1] + + self.assertEqual(config['provider'], 'openai') + self.assertEqual(config['temperature'], 0.5) + + +class AnthropicLLMClientTest(TestCase): + """Test cases for Anthropic-specific client.""" + + def test_anthropic_client_initialization(self): + """Test Anthropic client initialization.""" + client = AnthropicLLMClient() + self.assertEqual(client.default_provider, 'anthropic') + self.assertEqual(client.default_model, 'claude-2') + + @patch('lookup.integrations.llm_provider.UnstractLLMClient.generate') + def test_anthropic_generate(self, mock_generate): + """Test Anthropic-specific generation.""" + mock_generate.return_value = '{"result": "test"}' + + client = AnthropicLLMClient() + response = client.generate( + "test prompt", + {'temperature': 0.8} + ) + + # Verify Anthropic config was used + mock_generate.assert_called_once() + call_args = mock_generate.call_args[0] + config = call_args[1] + + self.assertEqual(config['provider'], 'anthropic') + self.assertEqual(config['model'], 'claude-2') + self.assertEqual(config['temperature'], 0.8) diff --git a/backend/lookup/tests/test_integrations/test_llmwhisperer_integration.py b/backend/lookup/tests/test_integrations/test_llmwhisperer_integration.py new file mode 100644 index 0000000000..a03773159f --- /dev/null +++ b/backend/lookup/tests/test_integrations/test_llmwhisperer_integration.py @@ -0,0 +1,297 @@ +""" +Tests for LLMWhisperer document extraction integration. +""" + +from unittest.mock import patch, Mock + +from django.test import TestCase +from django.conf import settings + +from ...integrations.llmwhisperer_client import ( + LLMWhispererClient, + ExtractionStatus +) + + +class LLMWhispererClientTest(TestCase): + """Test cases for LLMWhispererClient.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock settings + self.settings_patcher = patch.multiple( + settings, + LLMWHISPERER_BASE_URL='https://test.llmwhisperer.com', + LLMWHISPERER_API_KEY='test-api-key' + ) + self.settings_patcher.start() + + # Mock requests + self.requests_patcher = patch('lookup.integrations.llmwhisperer_client.requests') + self.mock_requests = self.requests_patcher.start() + + # Initialize client + self.client = LLMWhispererClient() + + def tearDown(self): + """Clean up patches.""" + self.settings_patcher.stop() + self.requests_patcher.stop() + + def test_initialization(self): + """Test client initialization.""" + self.assertEqual(self.client.base_url, 'https://test.llmwhisperer.com') + self.assertEqual(self.client.api_key, 'test-api-key') + self.assertIsNotNone(self.client.session) + + def test_extract_text_success(self): + """Test successful text extraction.""" + # Mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'extraction_id': 'test-id-123', + 'status': 'processing' + } + self.mock_requests.post.return_value = mock_response + + # Test extraction + file_content = b"PDF content here" + file_name = "test.pdf" + + extraction_id, status = self.client.extract_text( + file_content, + file_name + ) + + # Verify + self.assertEqual(extraction_id, 'test-id-123') + self.assertEqual(status, ExtractionStatus.PROCESSING.value) + + # Check API call + self.mock_requests.post.assert_called_once() + call_args = self.mock_requests.post.call_args + + self.assertIn('/v1/extract', call_args[0][0]) + self.assertIn('files', call_args[1]) + self.assertIn('data', call_args[1]) + + def test_extract_text_failure(self): + """Test extraction failure.""" + # Mock failed response + mock_response = Mock() + mock_response.status_code = 400 + mock_response.text = "Bad request" + self.mock_requests.post.return_value = mock_response + + # Test extraction + extraction_id, status = self.client.extract_text( + b"content", + "test.pdf" + ) + + # Verify + self.assertEqual(extraction_id, "") + self.assertEqual(status, ExtractionStatus.FAILED.value) + + def test_check_extraction_status_complete(self): + """Test checking completed extraction status.""" + # Mock status response + mock_status_response = Mock() + mock_status_response.status_code = 200 + mock_status_response.json.return_value = { + 'status': 'completed', + 'extraction_id': 'test-id' + } + + # Mock result response + mock_result_response = Mock() + mock_result_response.status_code = 200 + mock_result_response.text = "Extracted text content" + + # Set up session mock + self.client.session.get = Mock(side_effect=[ + mock_status_response, + mock_result_response + ]) + + # Test status check + status, text = self.client.check_extraction_status('test-id') + + # Verify + self.assertEqual(status, ExtractionStatus.COMPLETE.value) + self.assertEqual(text, "Extracted text content") + + def test_check_extraction_status_processing(self): + """Test checking processing extraction status.""" + # Mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'status': 'processing' + } + self.client.session.get = Mock(return_value=mock_response) + + # Test status check + status, text = self.client.check_extraction_status('test-id') + + # Verify + self.assertEqual(status, ExtractionStatus.PROCESSING.value) + self.assertIsNone(text) + + def test_check_extraction_status_failed(self): + """Test checking failed extraction status.""" + # Mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'status': 'failed', + 'error': 'Extraction error' + } + self.client.session.get = Mock(return_value=mock_response) + + # Test status check + status, text = self.client.check_extraction_status('test-id') + + # Verify + self.assertEqual(status, ExtractionStatus.FAILED.value) + self.assertIsNone(text) + + @patch('time.sleep') + def test_wait_for_extraction_success(self, mock_sleep): + """Test waiting for extraction completion.""" + # Mock status checks + mock_responses = [ + (ExtractionStatus.PROCESSING.value, None), + (ExtractionStatus.PROCESSING.value, None), + (ExtractionStatus.COMPLETE.value, "Extracted text") + ] + + self.client.check_extraction_status = Mock( + side_effect=mock_responses + ) + + # Test wait + status, text = self.client.wait_for_extraction( + 'test-id', + max_wait_seconds=60, + poll_interval=5 + ) + + # Verify + self.assertEqual(status, ExtractionStatus.COMPLETE.value) + self.assertEqual(text, "Extracted text") + self.assertEqual(self.client.check_extraction_status.call_count, 3) + + @patch('time.sleep') + @patch('time.time') + def test_wait_for_extraction_timeout(self, mock_time, mock_sleep): + """Test extraction timeout.""" + # Mock time to simulate timeout + mock_time.side_effect = [0, 10, 20, 35, 40] # Exceeds 30 second limit + + self.client.check_extraction_status = Mock( + return_value=(ExtractionStatus.PROCESSING.value, None) + ) + + # Test wait with short timeout + status, text = self.client.wait_for_extraction( + 'test-id', + max_wait_seconds=30, + poll_interval=5 + ) + + # Verify + self.assertEqual(status, ExtractionStatus.FAILED.value) + self.assertIsNone(text) + + def test_extract_and_wait(self): + """Test combined extract and wait.""" + # Mock extraction + self.client.extract_text = Mock( + return_value=('test-id', ExtractionStatus.PROCESSING.value) + ) + + # Mock wait + self.client.wait_for_extraction = Mock( + return_value=(ExtractionStatus.COMPLETE.value, "Extracted text") + ) + + # Test + success, text = self.client.extract_and_wait( + b"content", + "test.pdf" + ) + + # Verify + self.assertTrue(success) + self.assertEqual(text, "Extracted text") + + def test_is_extraction_needed(self): + """Test checking if extraction is needed.""" + # Files that need extraction + extractable = [ + 'document.pdf', + 'image.png', + 'photo.jpg', + 'scan.tiff', + 'presentation.pptx', + 'spreadsheet.xlsx' + ] + + for filename in extractable: + self.assertTrue( + self.client.is_extraction_needed(filename), + f"{filename} should need extraction" + ) + + # Files that don't need extraction + non_extractable = [ + 'data.json', + 'script.py', + 'text.txt', + 'config.yml' + ] + + for filename in non_extractable: + self.assertFalse( + self.client.is_extraction_needed(filename), + f"{filename} should not need extraction" + ) + + def test_get_extraction_config_for_file(self): + """Test getting extraction config based on file type.""" + # Test PDF config + pdf_config = self.client.get_extraction_config_for_file('test.pdf') + self.assertEqual(pdf_config['processing_mode'], 'ocr') + self.assertTrue(pdf_config['force_text_processing']) + + # Test image config + img_config = self.client.get_extraction_config_for_file('test.jpg') + self.assertEqual(img_config['processing_mode'], 'ocr') + self.assertFalse(img_config['force_text_processing']) + + # Test Word config + doc_config = self.client.get_extraction_config_for_file('test.docx') + self.assertEqual(doc_config['processing_mode'], 'text') + self.assertEqual(doc_config['output_format'], 'markdown') + + # Test Excel config + xls_config = self.client.get_extraction_config_for_file('test.xlsx') + self.assertEqual(xls_config['processing_mode'], 'text') + self.assertEqual(xls_config['line_splitter'], 'paragraph') + + def test_default_config(self): + """Test default extraction configuration.""" + config = self.client._get_default_config() + + # Verify required fields + self.assertIn('processing_mode', config) + self.assertIn('output_format', config) + self.assertIn('page_separator', config) + self.assertIn('timeout', config) + + # Verify defaults + self.assertEqual(config['processing_mode'], 'ocr') + self.assertEqual(config['output_format'], 'text') + self.assertEqual(config['timeout'], 300) diff --git a/backend/lookup/tests/test_integrations/test_redis_cache_integration.py b/backend/lookup/tests/test_integrations/test_redis_cache_integration.py new file mode 100644 index 0000000000..de1d203a99 --- /dev/null +++ b/backend/lookup/tests/test_integrations/test_redis_cache_integration.py @@ -0,0 +1,272 @@ +""" +Tests for Redis cache integration. +""" + +from unittest.mock import MagicMock, patch + +from django.test import TestCase +from django.conf import settings + +from ...integrations.redis_cache import RedisLLMCache + + +class RedisLLMCacheTest(TestCase): + """Test cases for RedisLLMCache.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock Redis availability + self.redis_patcher = patch('lookup.integrations.redis_cache.REDIS_AVAILABLE', True) + self.redis_patcher.start() + + # Mock Redis client + self.redis_client_patcher = patch('lookup.integrations.redis_cache.redis.Redis') + self.mock_redis_class = self.redis_client_patcher.start() + + # Create mock Redis instance + self.mock_redis = MagicMock() + self.mock_redis_class.return_value = self.mock_redis + self.mock_redis.ping.return_value = True + + # Mock settings + self.settings_patcher = patch.multiple( + settings, + REDIS_HOST='localhost', + REDIS_PORT=6379, + REDIS_CACHE_DB=1, + REDIS_PASSWORD='test-pass' + ) + self.settings_patcher.start() + + # Initialize cache + self.cache = RedisLLMCache(ttl_hours=24) + + def tearDown(self): + """Clean up patches.""" + self.redis_patcher.stop() + self.redis_client_patcher.stop() + self.settings_patcher.stop() + + def test_initialization_with_redis(self): + """Test cache initialization with Redis available.""" + self.assertIsNotNone(self.cache.redis_client) + self.mock_redis.ping.assert_called_once() + self.assertEqual(self.cache.ttl_seconds, 24 * 3600) + self.assertEqual(self.cache.key_prefix, "lookup:llm:") + + def test_initialization_without_redis(self): + """Test cache initialization when Redis unavailable.""" + # Mock Redis connection failure + self.mock_redis.ping.side_effect = Exception("Connection failed") + + # Reinitialize cache + cache = RedisLLMCache(fallback_to_memory=True) + + # Should fall back to memory cache + self.assertIsNone(cache.redis_client) + self.assertIsNotNone(cache.memory_cache) + + def test_generate_cache_key(self): + """Test cache key generation.""" + prompt = "Test prompt" + reference = "Reference data" + + key = self.cache.generate_cache_key(prompt, reference) + + # Should be prefixed and hashed + self.assertTrue(key.startswith("lookup:llm:")) + self.assertEqual(len(key), len("lookup:llm:") + 64) # SHA256 hex length + + def test_set_and_get(self): + """Test setting and getting cache values.""" + # Test set + key = "lookup:llm:test-key" + value = '{"result": "test"}' + + result = self.cache.set(key, value) + self.assertTrue(result) + + # Verify Redis setex called + self.mock_redis.setex.assert_called_once_with( + name=key, + time=24 * 3600, + value=value + ) + + # Test get + self.mock_redis.get.return_value = value + retrieved = self.cache.get(key) + + self.assertEqual(retrieved, value) + self.mock_redis.get.assert_called_once_with(key) + + def test_get_cache_miss(self): + """Test cache miss.""" + self.mock_redis.get.return_value = None + + result = self.cache.get("lookup:llm:nonexistent") + + self.assertIsNone(result) + + def test_delete(self): + """Test deleting cache entries.""" + key = "lookup:llm:test-key" + self.mock_redis.delete.return_value = 1 + + result = self.cache.delete(key) + + self.assertTrue(result) + self.mock_redis.delete.assert_called_once_with(key) + + def test_delete_nonexistent(self): + """Test deleting non-existent entry.""" + self.mock_redis.delete.return_value = 0 + + result = self.cache.delete("lookup:llm:nonexistent") + + # Redis returns 0 for non-existent keys + self.assertFalse(result) + + def test_clear_pattern(self): + """Test clearing entries by pattern.""" + # Mock SCAN operation + self.mock_redis.scan.side_effect = [ + (100, ["lookup:llm:project1:key1", "lookup:llm:project1:key2"]), + (0, ["lookup:llm:project1:key3"]) + ] + self.mock_redis.delete.return_value = 3 + + count = self.cache.clear_pattern("lookup:llm:project1:*") + + self.assertEqual(count, 3) + self.mock_redis.delete.assert_called() + + def test_fallback_to_memory_cache(self): + """Test fallback to memory cache when Redis fails.""" + # Make Redis operations fail + from redis.exceptions import RedisError + self.mock_redis.get.side_effect = RedisError("Connection lost") + self.mock_redis.setex.side_effect = RedisError("Connection lost") + + # Cache should have memory fallback + self.assertIsNotNone(self.cache.memory_cache) + + # Test set with fallback + key = "lookup:llm:test" + value = "test-value" + + result = self.cache.set(key, value) + self.assertTrue(result) # Should succeed with memory cache + + # Test get with fallback + retrieved = self.cache.get(key) + # Memory cache uses key without prefix + self.assertEqual(retrieved, value) + + def test_get_stats(self): + """Test getting cache statistics.""" + # Mock Redis info + self.mock_redis.info.side_effect = [ + { # stats info + 'total_connections_received': 100, + 'keyspace_hits': 80, + 'keyspace_misses': 20 + }, + { # keyspace info + 'db1': {'keys': 50, 'expires': 45} + } + ] + + stats = self.cache.get_stats() + + # Verify stats structure + self.assertEqual(stats['backend'], 'redis') + self.assertEqual(stats['ttl_hours'], 24) + self.assertIn('redis', stats) + self.assertEqual(stats['redis']['keyspace_hits'], 80) + self.assertEqual(stats['redis']['hit_rate'], 0.8) + + def test_cleanup_expired(self): + """Test cleanup of expired entries.""" + # Redis handles expiration automatically + count = self.cache.cleanup_expired() + + # Should return 0 for Redis (automatic expiry) + self.assertEqual(count, 0) + + def test_warmup(self): + """Test cache warmup.""" + project_id = "test-project" + preload_data = { + "prompt1": "response1", + "prompt2": "response2" + } + + count = self.cache.warmup(project_id, preload_data) + + self.assertEqual(count, 2) + self.assertEqual(self.mock_redis.setex.call_count, 2) + + def test_custom_ttl(self): + """Test setting cache with custom TTL.""" + key = "lookup:llm:test" + value = "test-value" + custom_ttl = 3600 # 1 hour + + self.cache.set(key, value, ttl=custom_ttl) + + # Verify custom TTL was used + self.mock_redis.setex.assert_called_with( + name=key, + time=custom_ttl, + value=value + ) + + def test_hit_rate_calculation(self): + """Test cache hit rate calculation.""" + # Test with hits and misses + rate = self.cache._calculate_hit_rate(80, 20) + self.assertEqual(rate, 0.8) + + # Test with no data + rate = self.cache._calculate_hit_rate(0, 0) + self.assertEqual(rate, 0.0) + + def test_pattern_matching(self): + """Test pattern matching for memory cache.""" + # Test exact match + self.assertTrue(self.cache._match_pattern("key1", "key1")) + + # Test wildcard match + self.assertTrue(self.cache._match_pattern("project:key1", "project:*")) + self.assertTrue(self.cache._match_pattern("project:subkey:value", "project:*:value")) + + # Test non-match + self.assertFalse(self.cache._match_pattern("other:key", "project:*")) + + +class RedisUnavailableTest(TestCase): + """Test cases when Redis is not available.""" + + def setUp(self): + """Set up test without Redis.""" + # Mock Redis as unavailable + self.redis_patcher = patch('lookup.integrations.redis_cache.REDIS_AVAILABLE', False) + self.redis_patcher.start() + + def tearDown(self): + """Clean up patches.""" + self.redis_patcher.stop() + + def test_initialization_without_redis_package(self): + """Test when redis package is not installed.""" + cache = RedisLLMCache() + + # Should initialize without Redis + self.assertIsNone(cache.redis_client) + self.assertIsNotNone(cache.memory_cache) + + # Should still be functional with memory cache + key = cache.generate_cache_key("test", "data") + cache.set(key, "value") + self.assertEqual(cache.get(key), "value") diff --git a/backend/lookup/tests/test_integrations/test_storage_integration.py b/backend/lookup/tests/test_integrations/test_storage_integration.py new file mode 100644 index 0000000000..5939a2fc7e --- /dev/null +++ b/backend/lookup/tests/test_integrations/test_storage_integration.py @@ -0,0 +1,247 @@ +""" +Tests for object storage integration. +""" + +import uuid +from unittest.mock import MagicMock, patch + +from django.test import TestCase + +from ...integrations.storage_client import RemoteStorageClient + + +class RemoteStorageClientTest(TestCase): + """Test cases for RemoteStorageClient.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock the FileStorage instance + self.mock_fs_patcher = patch('lookup.integrations.storage_client.EnvHelper.get_storage') + self.mock_get_storage = self.mock_fs_patcher.start() + + # Create mock file storage + self.mock_fs = MagicMock() + self.mock_get_storage.return_value = self.mock_fs + + # Initialize client + self.client = RemoteStorageClient(base_path="test/lookup") + self.project_id = uuid.uuid4() + + def tearDown(self): + """Clean up patches.""" + self.mock_fs_patcher.stop() + + def test_upload_success(self): + """Test successful file upload.""" + # Setup mock + self.mock_fs.write.return_value = None # Successful write + + # Test upload + content = b"test content" + path = "test/file.txt" + + result = self.client.upload(path, content) + + # Verify + self.assertTrue(result) + self.mock_fs.mkdir.assert_called_once() + self.mock_fs.write.assert_called_once_with( + path=path, + mode="wb", + data=content + ) + + def test_upload_failure(self): + """Test file upload failure.""" + # Setup mock to raise exception + self.mock_fs.write.side_effect = Exception("Storage error") + + # Test upload + result = self.client.upload("test/file.txt", b"content") + + # Verify + self.assertFalse(result) + + def test_download_success(self): + """Test successful file download.""" + # Setup mock + expected_content = b"test content" + self.mock_fs.exists.return_value = True + self.mock_fs.read.return_value = expected_content + + # Test download + content = self.client.download("test/file.txt") + + # Verify + self.assertEqual(content, expected_content) + self.mock_fs.read.assert_called_once_with( + path="test/file.txt", + mode="rb" + ) + + def test_download_file_not_found(self): + """Test download when file doesn't exist.""" + # Setup mock + self.mock_fs.exists.return_value = False + + # Test download + content = self.client.download("nonexistent.txt") + + # Verify + self.assertIsNone(content) + self.mock_fs.read.assert_not_called() + + def test_delete_success(self): + """Test successful file deletion.""" + # Setup mock + self.mock_fs.exists.return_value = True + self.mock_fs.delete.return_value = None + + # Test delete + result = self.client.delete("test/file.txt") + + # Verify + self.assertTrue(result) + self.mock_fs.delete.assert_called_once_with("test/file.txt") + + def test_delete_file_not_found(self): + """Test delete when file doesn't exist.""" + # Setup mock + self.mock_fs.exists.return_value = False + + # Test delete + result = self.client.delete("nonexistent.txt") + + # Verify + self.assertFalse(result) + self.mock_fs.delete.assert_not_called() + + def test_exists_check(self): + """Test file existence check.""" + # Test existing file + self.mock_fs.exists.return_value = True + self.assertTrue(self.client.exists("test/file.txt")) + + # Test non-existing file + self.mock_fs.exists.return_value = False + self.assertFalse(self.client.exists("nonexistent.txt")) + + def test_list_files(self): + """Test listing files with prefix.""" + # Setup mock + self.mock_fs.listdir.return_value = ["file1.txt", "file2.txt", ".hidden"] + + # Test list + files = self.client.list_files("test/prefix") + + # Verify - hidden files should be excluded + self.assertEqual(len(files), 2) + self.assertIn("test/prefix/file1.txt", files) + self.assertIn("test/prefix/file2.txt", files) + self.assertNotIn("test/prefix/.hidden", files) + + def test_text_content_operations(self): + """Test text content save and retrieve.""" + # Test save + text = "Hello, World!" + self.mock_fs.write.return_value = None + + result = self.client.save_text_content("test.txt", text) + self.assertTrue(result) + self.mock_fs.write.assert_called_with( + path="test.txt", + mode="wb", + data=text.encode('utf-8') + ) + + # Test get + self.mock_fs.exists.return_value = True + self.mock_fs.read.return_value = text.encode('utf-8') + + retrieved = self.client.get_text_content("test.txt") + self.assertEqual(retrieved, text) + + def test_upload_reference_data(self): + """Test uploading reference data with metadata.""" + # Setup + content = b"reference data" + filename = "vendors.csv" + metadata = {"source": "manual", "version": 1} + + # Mock JSON encoding for metadata + import json + expected_meta = json.dumps(metadata, indent=2) + + # Test upload + path = self.client.upload_reference_data( + self.project_id, + filename, + content, + metadata + ) + + # Verify + expected_path = f"test/lookup/{self.project_id}/{filename}" + self.assertEqual(path, expected_path) + + # Check main file upload + call_args = [call for call in self.mock_fs.write.call_args_list + if call[1]['path'] == expected_path] + self.assertEqual(len(call_args), 1) + + # Check metadata upload + meta_path = f"{expected_path}.meta.json" + meta_calls = [call for call in self.mock_fs.write.call_args_list + if call[1]['path'] == meta_path] + self.assertEqual(len(meta_calls), 1) + + def test_get_reference_data(self): + """Test retrieving reference data.""" + # Setup + expected_data = "reference content" + self.mock_fs.exists.return_value = True + self.mock_fs.read.return_value = expected_data.encode('utf-8') + + # Test + data = self.client.get_reference_data(self.project_id, "data.txt") + + # Verify + self.assertEqual(data, expected_data) + expected_path = f"test/lookup/{self.project_id}/data.txt" + self.mock_fs.read.assert_called_with(path=expected_path, mode="rb") + + def test_list_project_files(self): + """Test listing all files for a project.""" + # Setup + self.mock_fs.listdir.return_value = ["file1.csv", "file2.json"] + + # Test + files = self.client.list_project_files(self.project_id) + + # Verify + expected_prefix = f"test/lookup/{self.project_id}" + self.assertEqual(len(files), 2) + self.assertIn(f"{expected_prefix}/file1.csv", files) + self.assertIn(f"{expected_prefix}/file2.json", files) + + def test_delete_project_data(self): + """Test deleting all project data.""" + # Setup + project_files = ["file1.csv", "file2.json"] + self.mock_fs.listdir.return_value = project_files + self.mock_fs.exists.return_value = True + + # Test + result = self.client.delete_project_data(self.project_id) + + # Verify + self.assertTrue(result) + + # Check files were deleted + expected_prefix = f"test/lookup/{self.project_id}" + for filename in project_files: + expected_path = f"{expected_prefix}/{filename}" + self.mock_fs.delete.assert_any_call(expected_path) + + # Check directory was removed + self.mock_fs.rmdir.assert_called_once_with(expected_prefix) diff --git a/backend/lookup/tests/test_migrations.py b/backend/lookup/tests/test_migrations.py new file mode 100644 index 0000000000..496891a816 --- /dev/null +++ b/backend/lookup/tests/test_migrations.py @@ -0,0 +1,286 @@ +"""Tests for Look-Up system database migrations.""" + +import pytest +from django.db import connection +from django.db.migrations.executor import MigrationExecutor +from django.test import TransactionTestCase + + +class TestLookupMigrations(TransactionTestCase): + """Test suite for Look-Up database migrations.""" + + @property + def app(self): + return "lookup" + + @property + def migrate_from(self): + return None # Initial migration + + @property + def migrate_to(self): + return [(self.app, "0001_initial")] + + def setUp(self): + """Set up test environment.""" + assert ( + self.migrate_from and self.migrate_to + ), "TestCase '{}' must define migrate_from and migrate_to properties".format( + type(self).__name__ + ) + self.migrate_from = [(self.app, self.migrate_from)] + self.migrate_to = [(self.app, self.migrate_to[0][1])] + executor = MigrationExecutor(connection) + old_apps = executor.loader.project_state(self.migrate_from).apps + + # Reverse to the original migration + executor.migrate(self.migrate_from) + + self.apps_before = old_apps + + def test_migration_can_be_applied(self): + """Test that the migration can be applied successfully.""" + executor = MigrationExecutor(connection) + + # Apply migrations + executor.migrate(self.migrate_to) + + # Check tables exist + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT table_name FROM information_schema.tables + WHERE table_name IN ( + 'lookup_projects', + 'lookup_data_sources', + 'lookup_prompt_templates', + 'prompt_studio_lookup_links' + ) + """ + ) + tables = [row[0] for row in cursor.fetchall()] + + assert len(tables) == 4, f"Expected 4 tables, found {len(tables)}: {tables}" + assert "lookup_projects" in tables + assert "lookup_data_sources" in tables + assert "lookup_prompt_templates" in tables + assert "prompt_studio_lookup_links" in tables + + def test_migration_can_be_reversed(self): + """Test that the migration can be reversed successfully.""" + executor = MigrationExecutor(connection) + + # Apply migration + executor.migrate(self.migrate_to) + + # Reverse migration + executor.migrate(self.migrate_from) + + # Check tables don't exist + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT table_name FROM information_schema.tables + WHERE table_name IN ( + 'lookup_projects', + 'lookup_data_sources', + 'lookup_prompt_templates', + 'prompt_studio_lookup_links' + ) + """ + ) + tables = [row[0] for row in cursor.fetchall()] + + assert ( + len(tables) == 0 + ), f"Tables should not exist after reversal, found: {tables}" + + def test_constraints_are_enforced(self): + """Test that all database constraints are properly enforced.""" + from django.contrib.auth import get_user_model + from account.models import Organization + from lookup.models import LookupProject, LookupDataSource + + User = get_user_model() + + # Create test user and organization + org = Organization.objects.create(name="Test Org") + user = User.objects.create_user( + username="testuser", email="test@example.com", password="password" + ) + + # Test CHECK constraint on lookup_type + with pytest.raises(Exception): + LookupProject.objects.create( + name="Invalid Type Project", + lookup_type="invalid_type", + llm_provider="openai", + llm_model="gpt-4", + organization=org, + created_by=user, + ) + + # Test CHECK constraint on extraction_status + project = LookupProject.objects.create( + name="Valid Project", + lookup_type="static_data", + llm_provider="openai", + llm_model="gpt-4", + organization=org, + created_by=user, + ) + + with pytest.raises(Exception): + LookupDataSource.objects.create( + project=project, + file_name="test.pdf", + file_path="/path/to/file.pdf", + file_size=1024, + file_type="pdf", + extraction_status="invalid_status", + uploaded_by=user, + ) + + def test_version_trigger_functionality(self): + """Test that the version management trigger works correctly.""" + from django.contrib.auth import get_user_model + from account.models import Organization + from lookup.models import LookupProject, LookupDataSource + + User = get_user_model() + + # Create test data + org = Organization.objects.create(name="Test Org") + user = User.objects.create_user( + username="testuser", email="test@example.com", password="password" + ) + + project = LookupProject.objects.create( + name="Version Test Project", + lookup_type="static_data", + llm_provider="openai", + llm_model="gpt-4", + organization=org, + created_by=user, + ) + + # Create first data source + ds1 = LookupDataSource.objects.create( + project=project, + file_name="v1.pdf", + file_path="/path/v1.pdf", + file_size=1024, + file_type="pdf", + uploaded_by=user, + ) + + # Verify first version + ds1.refresh_from_db() + assert ds1.version_number == 1 + assert ds1.is_latest is True + + # Create second data source + ds2 = LookupDataSource.objects.create( + project=project, + file_name="v2.pdf", + file_path="/path/v2.pdf", + file_size=2048, + file_type="pdf", + uploaded_by=user, + ) + + # Verify second version + ds2.refresh_from_db() + assert ds2.version_number == 2 + assert ds2.is_latest is True + + # Verify first version is no longer latest + ds1.refresh_from_db() + assert ds1.is_latest is False + + def test_indexes_are_created(self): + """Test that all required indexes are created.""" + with connection.cursor() as cursor: + # Check for index existence on lookup_projects + cursor.execute( + """ + SELECT indexname FROM pg_indexes + WHERE tablename = 'lookup_projects' + AND indexname IN ( + 'idx_lookup_proj_org', + 'idx_lookup_proj_created_by', + 'idx_lookup_proj_updated' + ) + """ + ) + project_indexes = [row[0] for row in cursor.fetchall()] + + assert len(project_indexes) == 3 + + with connection.cursor() as cursor: + # Check for index existence on lookup_data_sources + cursor.execute( + """ + SELECT indexname FROM pg_indexes + WHERE tablename = 'lookup_data_sources' + AND indexname IN ( + 'idx_lookup_ds_project', + 'idx_lookup_ds_latest', + 'idx_lookup_ds_created', + 'idx_lookup_ds_status' + ) + """ + ) + ds_indexes = [row[0] for row in cursor.fetchall()] + + assert len(ds_indexes) == 4 + + def test_foreign_key_constraints(self): + """Test that foreign key relationships work correctly.""" + from django.contrib.auth import get_user_model + from account.models import Organization + from lookup.models import ( + LookupProject, + LookupDataSource, + LookupPromptTemplate, + ) + + User = get_user_model() + + # Create test data + org = Organization.objects.create(name="Test Org") + user = User.objects.create_user( + username="testuser", email="test@example.com", password="password" + ) + + project = LookupProject.objects.create( + name="FK Test Project", + lookup_type="static_data", + llm_provider="openai", + llm_model="gpt-4", + organization=org, + created_by=user, + ) + + # Test cascade delete + data_source = LookupDataSource.objects.create( + project=project, + file_name="test.pdf", + file_path="/path/test.pdf", + file_size=1024, + file_type="pdf", + uploaded_by=user, + ) + + template = LookupPromptTemplate.objects.create( + project=project, + template_text="Test: {{input_data}}", + ) + + # Delete project should cascade + project_id = project.id + project.delete() + + # Verify cascaded deletes + assert not LookupDataSource.objects.filter(id=data_source.id).exists() + assert not LookupPromptTemplate.objects.filter(id=template.id).exists() diff --git a/backend/lookup/tests/test_services/test_audit_logger.py b/backend/lookup/tests/test_services/test_audit_logger.py new file mode 100644 index 0000000000..f7f138b193 --- /dev/null +++ b/backend/lookup/tests/test_services/test_audit_logger.py @@ -0,0 +1,455 @@ +""" +Tests for Audit Logger implementation. + +This module tests the AuditLogger class including logging executions, +convenience methods, and statistics retrieval. +""" + +import uuid +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + +from lookup.services.audit_logger import AuditLogger + + +class TestAuditLogger: + """Test cases for AuditLogger class.""" + + @pytest.fixture + def audit_logger(self): + """Create an AuditLogger instance.""" + return AuditLogger() + + @pytest.fixture + def mock_project(self): + """Create a mock LookupProject.""" + project = MagicMock() + project.id = uuid.uuid4() + project.name = "Test Look-Up" + return project + + @pytest.fixture + def execution_params(self, mock_project): + """Create standard execution parameters.""" + return { + 'execution_id': str(uuid.uuid4()), + 'lookup_project_id': mock_project.id, + 'prompt_studio_project_id': uuid.uuid4(), + 'input_data': {'vendor': 'Slack Technologies'}, + 'reference_data_version': 2, + 'llm_provider': 'openai', + 'llm_model': 'gpt-4', + 'llm_prompt': 'Match vendor Slack Technologies...', + 'llm_response': '{"canonical_vendor": "Slack", "confidence": 0.92}', + 'enriched_output': {'canonical_vendor': 'Slack'}, + 'status': 'success', + 'confidence_score': 0.92, + 'execution_time_ms': 1234, + 'llm_call_time_ms': 890, + 'llm_response_cached': False, + 'error_message': None + } + + # ========== Basic Logging Tests ========== + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_successful_logging( + self, mock_project_model, mock_audit_model, + audit_logger, execution_params, mock_project + ): + """Test successful execution logging.""" + # Setup mocks + mock_project_model.objects.get.return_value = mock_project + mock_audit_instance = MagicMock() + mock_audit_instance.id = uuid.uuid4() + mock_audit_model.objects.create.return_value = mock_audit_instance + + # Log execution + result = audit_logger.log_execution(**execution_params) + + # Verify project was fetched + mock_project_model.objects.get.assert_called_once_with( + id=execution_params['lookup_project_id'] + ) + + # Verify audit was created with correct params + mock_audit_model.objects.create.assert_called_once() + create_call = mock_audit_model.objects.create.call_args + kwargs = create_call.kwargs + + assert kwargs['lookup_project'] == mock_project + assert kwargs['execution_id'] == execution_params['execution_id'] + assert kwargs['status'] == 'success' + assert kwargs['llm_provider'] == 'openai' + assert kwargs['llm_model'] == 'gpt-4' + assert kwargs['confidence_score'] == Decimal('0.92') + + # Verify return value + assert result == mock_audit_instance + + @patch('lookup.services.audit_logger.LookupProject') + def test_project_not_found( + self, mock_project_model, audit_logger, execution_params + ): + """Test handling when Look-Up project doesn't exist.""" + # Make project lookup fail + mock_project_model.objects.get.side_effect = mock_project_model.DoesNotExist() + + # Log execution + result = audit_logger.log_execution(**execution_params) + + # Should return None and not raise exception + assert result is None + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_database_error_handling( + self, mock_project_model, mock_audit_model, + audit_logger, execution_params, mock_project + ): + """Test handling of database errors during logging.""" + # Setup mocks + mock_project_model.objects.get.return_value = mock_project + mock_audit_model.objects.create.side_effect = Exception("Database error") + + # Log execution - should not raise exception + result = audit_logger.log_execution(**execution_params) + + # Should return None + assert result is None + + # ========== Convenience Method Tests ========== + + @patch('lookup.services.audit_logger.AuditLogger.log_execution') + def test_log_success(self, mock_log_execution, audit_logger): + """Test log_success convenience method.""" + execution_id = str(uuid.uuid4()) + project_id = uuid.uuid4() + + audit_logger.log_success( + execution_id=execution_id, + project_id=project_id, + input_data={'test': 'data'}, + confidence_score=0.85 + ) + + mock_log_execution.assert_called_once_with( + execution_id=execution_id, + lookup_project_id=project_id, + status='success', + input_data={'test': 'data'}, + confidence_score=0.85 + ) + + @patch('lookup.services.audit_logger.AuditLogger.log_execution') + def test_log_failure(self, mock_log_execution, audit_logger): + """Test log_failure convenience method.""" + execution_id = str(uuid.uuid4()) + project_id = uuid.uuid4() + error_msg = "LLM timeout" + + audit_logger.log_failure( + execution_id=execution_id, + project_id=project_id, + error=error_msg, + input_data={'test': 'data'} + ) + + mock_log_execution.assert_called_once_with( + execution_id=execution_id, + lookup_project_id=project_id, + status='failed', + error_message=error_msg, + input_data={'test': 'data'} + ) + + @patch('lookup.services.audit_logger.AuditLogger.log_execution') + def test_log_partial(self, mock_log_execution, audit_logger): + """Test log_partial convenience method.""" + execution_id = str(uuid.uuid4()) + project_id = uuid.uuid4() + + audit_logger.log_partial( + execution_id=execution_id, + project_id=project_id, + confidence_score=0.35, + error_message='Low confidence' + ) + + mock_log_execution.assert_called_once_with( + execution_id=execution_id, + lookup_project_id=project_id, + status='partial', + confidence_score=0.35, + error_message='Low confidence' + ) + + # ========== Data Validation Tests ========== + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_confidence_score_conversion( + self, mock_project_model, mock_audit_model, + audit_logger, execution_params, mock_project + ): + """Test that confidence score is properly converted to Decimal.""" + mock_project_model.objects.get.return_value = mock_project + mock_audit_model.objects.create.return_value = MagicMock() + + # Test with float confidence + execution_params['confidence_score'] = 0.456789 + audit_logger.log_execution(**execution_params) + + create_call = mock_audit_model.objects.create.call_args + assert create_call.kwargs['confidence_score'] == Decimal('0.456789') + + # Test with None confidence + execution_params['confidence_score'] = None + audit_logger.log_execution(**execution_params) + + create_call = mock_audit_model.objects.create.call_args + assert create_call.kwargs['confidence_score'] is None + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_optional_fields( + self, mock_project_model, mock_audit_model, + audit_logger, mock_project + ): + """Test logging with minimal required fields.""" + mock_project_model.objects.get.return_value = mock_project + mock_audit_model.objects.create.return_value = MagicMock() + + # Minimal parameters + minimal_params = { + 'execution_id': str(uuid.uuid4()), + 'lookup_project_id': mock_project.id, + 'prompt_studio_project_id': None, + 'input_data': {}, + 'reference_data_version': 1, + 'llm_provider': 'openai', + 'llm_model': 'gpt-4', + 'llm_prompt': 'test prompt', + 'llm_response': None, + 'enriched_output': None, + 'status': 'failed' + } + + result = audit_logger.log_execution(**minimal_params) + + # Should still create audit record + mock_audit_model.objects.create.assert_called_once() + + # ========== History Retrieval Tests ========== + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_execution_history(self, mock_audit_model, audit_logger): + """Test retrieving execution history.""" + execution_id = str(uuid.uuid4()) + + # Create mock audit records + mock_audits = [] + for i in range(3): + audit = MagicMock() + audit.lookup_project.name = f"Look-Up {i+1}" + audit.status = 'success' if i < 2 else 'failed' + mock_audits.append(audit) + + # Setup mock query + mock_queryset = MagicMock() + mock_queryset.select_related.return_value = mock_queryset + mock_queryset.order_by.return_value = mock_queryset + mock_queryset.__getitem__.return_value = mock_audits + mock_audit_model.objects.filter.return_value = mock_queryset + + # Get history + result = audit_logger.get_execution_history(execution_id, limit=10) + + # Verify query + mock_audit_model.objects.filter.assert_called_once_with( + execution_id=execution_id + ) + mock_queryset.select_related.assert_called_once_with('lookup_project') + mock_queryset.order_by.assert_called_once_with('executed_at') + + # Check result + assert len(result) == 3 + assert result[0].status == 'success' + assert result[2].status == 'failed' + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_execution_history_error_handling( + self, mock_audit_model, audit_logger + ): + """Test error handling in get_execution_history.""" + mock_audit_model.objects.filter.side_effect = Exception("Database error") + + result = audit_logger.get_execution_history('test-id') + + # Should return empty list on error + assert result == [] + + # ========== Statistics Tests ========== + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_project_stats(self, mock_audit_model, audit_logger): + """Test getting project statistics.""" + project_id = uuid.uuid4() + + # Create mock audit records + mock_audits = [] + + # 3 successful executions + for i in range(3): + audit = MagicMock() + audit.status = 'success' + audit.execution_time_ms = 1000 + i * 100 + audit.llm_response_cached = (i == 0) # First one cached + audit.confidence_score = Decimal(f'0.{80 + i}') + mock_audits.append(audit) + + # 1 failed execution + audit = MagicMock() + audit.status = 'failed' + audit.execution_time_ms = 500 + audit.llm_response_cached = False + audit.confidence_score = None + mock_audits.append(audit) + + # 1 partial execution + audit = MagicMock() + audit.status = 'partial' + audit.execution_time_ms = 800 + audit.llm_response_cached = False + audit.confidence_score = Decimal('0.40') + mock_audits.append(audit) + + # Setup mock query + mock_queryset = MagicMock() + mock_queryset.order_by.return_value = mock_queryset + mock_queryset.__getitem__.return_value = mock_audits + mock_audit_model.objects.filter.return_value = mock_queryset + + # Get stats + stats = audit_logger.get_project_stats(project_id, limit=100) + + # Verify stats + assert stats['total_executions'] == 5 + assert stats['successful'] == 3 + assert stats['failed'] == 1 + assert stats['partial'] == 1 + assert stats['success_rate'] == 0.6 # 3/5 + assert stats['cache_hit_rate'] == 0.2 # 1/5 + assert stats['avg_execution_time_ms'] == 880 # (1000+1100+1200+500+800)/5 + # avg_confidence = (0.80 + 0.81 + 0.82 + 0.40) / 4 = 0.7075 + assert abs(stats['avg_confidence'] - 0.7075) < 0.001 + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_project_stats_empty(self, mock_audit_model, audit_logger): + """Test getting stats for project with no executions.""" + mock_queryset = MagicMock() + mock_queryset.order_by.return_value = mock_queryset + mock_queryset.__getitem__.return_value = [] + mock_audit_model.objects.filter.return_value = mock_queryset + + stats = audit_logger.get_project_stats(uuid.uuid4()) + + assert stats['total_executions'] == 0 + assert stats['success_rate'] == 0.0 + assert stats['avg_execution_time_ms'] == 0 + assert stats['cache_hit_rate'] == 0.0 + assert stats['avg_confidence'] == 0.0 + + @patch('lookup.services.audit_logger.LookupExecutionAudit') + def test_get_project_stats_error_handling( + self, mock_audit_model, audit_logger + ): + """Test error handling in get_project_stats.""" + mock_audit_model.objects.filter.side_effect = Exception("Database error") + + stats = audit_logger.get_project_stats(uuid.uuid4()) + + # Should return zero stats on error + assert stats['total_executions'] == 0 + assert stats['success_rate'] == 0.0 + + # ========== Integration Tests ========== + + @patch('lookup.services.audit_logger.logger') + @patch('lookup.services.audit_logger.LookupExecutionAudit') + @patch('lookup.services.audit_logger.LookupProject') + def test_logging_messages( + self, mock_project_model, mock_audit_model, mock_logger, + audit_logger, execution_params, mock_project + ): + """Test that appropriate log messages are generated.""" + mock_project_model.objects.get.return_value = mock_project + mock_audit_instance = MagicMock() + mock_audit_instance.id = uuid.uuid4() + mock_audit_model.objects.create.return_value = mock_audit_instance + + audit_logger.log_execution(**execution_params) + + # Should log debug message on success + mock_logger.debug.assert_called() + debug_message = mock_logger.debug.call_args[0][0] + assert 'Logged execution audit' in debug_message + assert mock_project.name in debug_message + + @patch('lookup.services.audit_logger.logger') + @patch('lookup.services.audit_logger.LookupProject') + def test_error_logging( + self, mock_project_model, mock_logger, + audit_logger, execution_params + ): + """Test that errors are properly logged.""" + mock_project_model.objects.get.side_effect = Exception("Database connection lost") + + audit_logger.log_execution(**execution_params) + + # Should log exception + mock_logger.exception.assert_called() + error_message = mock_logger.exception.call_args[0][0] + assert 'Failed to log execution audit' in error_message + + def test_real_world_scenario(self, audit_logger): + """Test realistic usage scenario with mock objects.""" + # This would normally require Django test database + # For now, just verify the interface works correctly + + execution_id = str(uuid.uuid4()) + project_id = uuid.uuid4() + + # Log various execution types + with patch('lookup.services.audit_logger.AuditLogger.log_execution') as mock_log: + mock_log.return_value = MagicMock() + + # Success + audit_logger.log_success( + execution_id=execution_id, + project_id=project_id, + input_data={'vendor': 'Slack'}, + enriched_output={'canonical': 'Slack'}, + confidence_score=0.95 + ) + + # Failure + audit_logger.log_failure( + execution_id=execution_id, + project_id=project_id, + error='Timeout', + input_data={'vendor': 'Unknown'} + ) + + # Partial + audit_logger.log_partial( + execution_id=execution_id, + project_id=project_id, + confidence_score=0.30 + ) + + # Verify all three logs were attempted + assert mock_log.call_count == 3 diff --git a/backend/lookup/tests/test_services/test_enrichment_merger.py b/backend/lookup/tests/test_services/test_enrichment_merger.py new file mode 100644 index 0000000000..bed77e2bf7 --- /dev/null +++ b/backend/lookup/tests/test_services/test_enrichment_merger.py @@ -0,0 +1,547 @@ +""" +Tests for Enrichment Merger implementation. + +This module tests the EnrichmentMerger class including merging logic, +conflict resolution, and metadata tracking. +""" + +import uuid + +import pytest + +from lookup.services.enrichment_merger import EnrichmentMerger + + +class TestEnrichmentMerger: + """Test cases for EnrichmentMerger class.""" + + @pytest.fixture + def merger(self): + """Create an EnrichmentMerger instance.""" + return EnrichmentMerger() + + @pytest.fixture + def sample_enrichments(self): + """Create sample enrichments for testing.""" + return [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Vendor Matcher', + 'data': { + 'canonical_vendor': 'Slack', + 'vendor_category': 'SaaS' + }, + 'confidence': 0.95, + 'execution_time_ms': 1234, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Product Classifier', + 'data': { + 'product_type': 'Software', + 'license_model': 'Subscription' + }, + 'confidence': 0.88, + 'execution_time_ms': 567, + 'cached': True + } + ] + + # ========== No Conflicts Tests ========== + + def test_merge_no_conflicts(self, merger, sample_enrichments): + """Test merging enrichments with no overlapping fields.""" + result = merger.merge(sample_enrichments) + + # Check merged data has all fields + assert 'canonical_vendor' in result['data'] + assert 'vendor_category' in result['data'] + assert 'product_type' in result['data'] + assert 'license_model' in result['data'] + + # Check values are correct + assert result['data']['canonical_vendor'] == 'Slack' + assert result['data']['vendor_category'] == 'SaaS' + assert result['data']['product_type'] == 'Software' + assert result['data']['license_model'] == 'Subscription' + + # Check no conflicts were resolved + assert result['conflicts_resolved'] == 0 + + # Check enrichment details + assert len(result['enrichment_details']) == 2 + assert result['enrichment_details'][0]['lookup_project_name'] == 'Vendor Matcher' + assert result['enrichment_details'][0]['fields_added'] == ['canonical_vendor', 'vendor_category'] + assert result['enrichment_details'][1]['lookup_project_name'] == 'Product Classifier' + assert result['enrichment_details'][1]['fields_added'] == ['product_type', 'license_model'] + + def test_merge_empty_enrichments(self, merger): + """Test merging empty list of enrichments.""" + result = merger.merge([]) + + assert result['data'] == {} + assert result['conflicts_resolved'] == 0 + assert result['enrichment_details'] == [] + + def test_merge_single_enrichment(self, merger): + """Test merging with only one enrichment.""" + enrichment = { + 'project_id': uuid.uuid4(), + 'project_name': 'Solo Lookup', + 'data': {'field1': 'value1', 'field2': 'value2'}, + 'confidence': 0.9, + 'execution_time_ms': 100, + 'cached': False + } + + result = merger.merge([enrichment]) + + assert result['data'] == {'field1': 'value1', 'field2': 'value2'} + assert result['conflicts_resolved'] == 0 + assert len(result['enrichment_details']) == 1 + assert result['enrichment_details'][0]['fields_added'] == ['field1', 'field2'] + + # ========== Confidence-Based Resolution Tests ========== + + def test_higher_confidence_wins(self, merger): + """Test that higher confidence value wins in conflict.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Low Confidence', + 'data': {'category': 'Communication'}, + 'confidence': 0.80, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'High Confidence', + 'data': {'category': 'Collaboration'}, + 'confidence': 0.95, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # Higher confidence should win + assert result['data']['category'] == 'Collaboration' + assert result['conflicts_resolved'] == 1 + + # Check which lookup contributed the field + details = result['enrichment_details'] + assert details[0]['fields_added'] == [] # Lost the conflict + assert details[1]['fields_added'] == ['category'] # Won the conflict + + def test_equal_confidence_first_wins(self, merger): + """Test that first-complete wins when confidence is equal.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'First', + 'data': {'status': 'active'}, + 'confidence': 0.90, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Second', + 'data': {'status': 'inactive'}, + 'confidence': 0.90, # Same confidence + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # First should win (first-complete-wins) + assert result['data']['status'] == 'active' + assert result['conflicts_resolved'] == 0 # No resolution needed, kept existing + + def test_confidence_beats_no_confidence(self, merger): + """Test that enrichment with confidence beats one without.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'No Confidence', + 'data': {'vendor': 'Microsoft'}, + 'confidence': None, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Has Confidence', + 'data': {'vendor': 'Slack'}, + 'confidence': 0.75, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # Confidence should win + assert result['data']['vendor'] == 'Slack' + assert result['conflicts_resolved'] == 1 + + def test_no_confidence_first_wins(self, merger): + """Test first-complete wins when neither has confidence.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'First No Conf', + 'data': {'region': 'US'}, + 'confidence': None, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Second No Conf', + 'data': {'region': 'EU'}, + 'confidence': None, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # First should win + assert result['data']['region'] == 'US' + assert result['conflicts_resolved'] == 0 + + # ========== Multiple Conflicts Tests ========== + + def test_multiple_conflicts_same_enrichments(self, merger): + """Test resolving multiple field conflicts between same enrichments.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Enrichment A', + 'data': { + 'field1': 'A1', + 'field2': 'A2', + 'field3': 'A3' + }, + 'confidence': 0.70, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Enrichment B', + 'data': { + 'field1': 'B1', # Conflict + 'field2': 'B2', # Conflict + 'field4': 'B4' # No conflict + }, + 'confidence': 0.85, + 'execution_time_ms': 200, + 'cached': True + } + ] + + result = merger.merge(enrichments) + + # Higher confidence (B) should win conflicts + assert result['data']['field1'] == 'B1' + assert result['data']['field2'] == 'B2' + assert result['data']['field3'] == 'A3' # Only from A + assert result['data']['field4'] == 'B4' # Only from B + + assert result['conflicts_resolved'] == 2 + + def test_three_way_conflicts(self, merger): + """Test resolving conflicts among three enrichments.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'First', + 'data': {'category': 'Cat1', 'type': 'Type1'}, + 'confidence': 0.60, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Second', + 'data': {'category': 'Cat2', 'vendor': 'Vendor2'}, + 'confidence': 0.80, + 'execution_time_ms': 200, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Third', + 'data': {'category': 'Cat3', 'type': 'Type3'}, + 'confidence': 0.75, + 'execution_time_ms': 300, + 'cached': True + } + ] + + result = merger.merge(enrichments) + + # Second should win category (0.80 confidence) + assert result['data']['category'] == 'Cat2' + # Third should win type (0.75 > 0.60) + assert result['data']['type'] == 'Type3' + # Vendor only from Second + assert result['data']['vendor'] == 'Vendor2' + + assert result['conflicts_resolved'] == 2 # category and type conflicts + + # ========== Metadata Tracking Tests ========== + + def test_enrichment_details_tracking(self, merger): + """Test that enrichment details are correctly tracked.""" + project_id1 = uuid.uuid4() + project_id2 = uuid.uuid4() + + enrichments = [ + { + 'project_id': project_id1, + 'project_name': 'Lookup 1', + 'data': {'field1': 'value1'}, + 'confidence': 0.9, + 'execution_time_ms': 1500, + 'cached': True + }, + { + 'project_id': project_id2, + 'project_name': 'Lookup 2', + 'data': {'field2': 'value2'}, + 'confidence': 0.85, + 'execution_time_ms': 800, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + details = result['enrichment_details'] + assert len(details) == 2 + + # Check first enrichment details + assert details[0]['lookup_project_id'] == str(project_id1) + assert details[0]['lookup_project_name'] == 'Lookup 1' + assert details[0]['confidence'] == 0.9 + assert details[0]['cached'] is True + assert details[0]['execution_time_ms'] == 1500 + assert details[0]['fields_added'] == ['field1'] + + # Check second enrichment details + assert details[1]['lookup_project_id'] == str(project_id2) + assert details[1]['lookup_project_name'] == 'Lookup 2' + assert details[1]['confidence'] == 0.85 + assert details[1]['cached'] is False + assert details[1]['execution_time_ms'] == 800 + assert details[1]['fields_added'] == ['field2'] + + def test_fields_added_with_conflicts(self, merger): + """Test fields_added tracking when conflicts are resolved.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Winner', + 'data': {'shared': 'win_value', 'unique1': 'value1'}, + 'confidence': 0.95, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Loser', + 'data': {'shared': 'lose_value', 'unique2': 'value2'}, + 'confidence': 0.50, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + details = result['enrichment_details'] + # Winner should have both its fields + assert 'shared' in details[0]['fields_added'] + assert 'unique1' in details[0]['fields_added'] + + # Loser should only have its unique field + assert 'shared' not in details[1]['fields_added'] + assert 'unique2' in details[1]['fields_added'] + + # ========== Edge Cases Tests ========== + + def test_missing_optional_fields(self, merger): + """Test handling of enrichments with missing optional fields.""" + enrichments = [ + { + 'project_id': None, # Missing ID + 'project_name': 'No ID Lookup', + 'data': {'field1': 'value1'}, + # Missing confidence + # Missing execution_time_ms + # Missing cached + }, + { + # Minimal valid enrichment + 'data': {'field2': 'value2'} + } + ] + + result = merger.merge(enrichments) + + assert result['data']['field1'] == 'value1' + assert result['data']['field2'] == 'value2' + assert result['conflicts_resolved'] == 0 + + # Check defaults are handled + details = result['enrichment_details'] + assert details[0]['lookup_project_id'] is None + assert details[0]['confidence'] is None + assert details[0]['execution_time_ms'] == 0 # Default + assert details[0]['cached'] is False # Default + + assert details[1]['lookup_project_name'] == 'Unknown' # Default + + def test_empty_data_fields(self, merger): + """Test handling enrichments with empty data.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Empty Data', + 'data': {}, # Empty + 'confidence': 0.9, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Has Data', + 'data': {'field': 'value'}, + 'confidence': 0.8, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + assert result['data'] == {'field': 'value'} + assert result['conflicts_resolved'] == 0 + assert result['enrichment_details'][0]['fields_added'] == [] + assert result['enrichment_details'][1]['fields_added'] == ['field'] + + def test_complex_value_types(self, merger): + """Test merging with complex value types (lists, dicts).""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Complex Types', + 'data': { + 'tags': ['tag1', 'tag2'], + 'metadata': {'key': 'value'}, + 'count': 42 + }, + 'confidence': 0.9, + 'execution_time_ms': 100, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'More Complex', + 'data': { + 'tags': ['tag3', 'tag4'], # Conflict - different list + 'settings': {'option': True} + }, + 'confidence': 0.95, + 'execution_time_ms': 200, + 'cached': False + } + ] + + result = merger.merge(enrichments) + + # Higher confidence wins for tags + assert result['data']['tags'] == ['tag3', 'tag4'] + assert result['data']['metadata'] == {'key': 'value'} + assert result['data']['count'] == 42 + assert result['data']['settings'] == {'option': True} + assert result['conflicts_resolved'] == 1 + + # ========== Integration Tests ========== + + def test_real_world_scenario(self, merger): + """Test a realistic scenario with multiple lookups.""" + enrichments = [ + { + 'project_id': uuid.uuid4(), + 'project_name': 'Vendor Standardization', + 'data': { + 'canonical_vendor': 'Slack Technologies', + 'vendor_id': 'SLACK-001', + 'vendor_category': 'Communication' + }, + 'confidence': 0.92, + 'execution_time_ms': 1200, + 'cached': False + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Product Mapping', + 'data': { + 'product_name': 'Slack Workspace', + 'product_sku': 'SLK-WS-ENT', + 'vendor_category': 'Collaboration', # Conflict + 'license_type': 'Per User' + }, + 'confidence': 0.88, + 'execution_time_ms': 850, + 'cached': True + }, + { + 'project_id': uuid.uuid4(), + 'project_name': 'Cost Center Assignment', + 'data': { + 'cost_center': 'CC-IT-001', + 'department': 'Information Technology', + 'budget_category': 'Software' + }, + 'confidence': None, # No confidence score + 'execution_time_ms': 450, + 'cached': True + } + ] + + result = merger.merge(enrichments) + + # Check all fields are present + expected_fields = [ + 'canonical_vendor', 'vendor_id', 'vendor_category', + 'product_name', 'product_sku', 'license_type', + 'cost_center', 'department', 'budget_category' + ] + for field in expected_fields: + assert field in result['data'] + + # vendor_category conflict: 0.92 > 0.88, first wins + assert result['data']['vendor_category'] == 'Communication' + + # Should have 1 conflict resolved + assert result['conflicts_resolved'] == 1 + + # Check enrichment details + assert len(result['enrichment_details']) == 3 + assert result['enrichment_details'][0]['fields_added'] == [ + 'canonical_vendor', 'vendor_id', 'vendor_category' + ] + # Product mapping lost vendor_category conflict + assert 'vendor_category' not in result['enrichment_details'][1]['fields_added'] + assert 'product_name' in result['enrichment_details'][1]['fields_added'] diff --git a/backend/lookup/tests/test_services/test_llm_cache.py b/backend/lookup/tests/test_services/test_llm_cache.py new file mode 100644 index 0000000000..422a72dad4 --- /dev/null +++ b/backend/lookup/tests/test_services/test_llm_cache.py @@ -0,0 +1,322 @@ +""" +Tests for LLM Response Cache implementation. + +This module tests the LLMResponseCache class including basic operations, +TTL expiration, cache key generation, and cache management functionality. +""" + +import time +from unittest.mock import patch + +import pytest + +from lookup.services.llm_cache import LLMResponseCache + + +class TestLLMResponseCache: + """Test cases for LLMResponseCache class.""" + + @pytest.fixture + def cache(self): + """Create a fresh cache instance for each test.""" + return LLMResponseCache(ttl_hours=1) + + @pytest.fixture + def short_ttl_cache(self): + """Create cache with very short TTL for expiration testing.""" + # 0.001 hours = 3.6 seconds + return LLMResponseCache(ttl_hours=0.001) + + # ========== Basic Operations Tests ========== + + def test_set_stores_value(self, cache): + """Test that set() correctly stores a value.""" + cache.set("test_key", "test_response") + assert "test_key" in cache.cache + stored_value, _ = cache.cache["test_key"] + assert stored_value == "test_response" + + def test_get_retrieves_stored_value(self, cache): + """Test that get() retrieves a stored value.""" + cache.set("test_key", "test_response") + result = cache.get("test_key") + assert result == "test_response" + + def test_get_returns_none_for_missing_key(self, cache): + """Test that get() returns None for non-existent key.""" + result = cache.get("nonexistent_key") + assert result is None + + def test_set_overwrites_existing_value(self, cache): + """Test that set() overwrites existing values.""" + cache.set("test_key", "first_response") + cache.set("test_key", "second_response") + result = cache.get("test_key") + assert result == "second_response" + + # ========== TTL Expiration Tests ========== + + def test_get_returns_value_before_expiration(self, cache): + """Test that get() returns value before TTL expiration.""" + cache.set("test_key", "test_response") + # Should still be valid after immediate retrieval + result = cache.get("test_key") + assert result == "test_response" + + def test_get_returns_none_after_expiration(self, short_ttl_cache): + """Test that get() returns None after TTL expiration.""" + short_ttl_cache.set("test_key", "test_response") + # Wait for expiration (TTL is 3.6 seconds) + time.sleep(4) + result = short_ttl_cache.get("test_key") + assert result is None + + def test_expired_entry_removed_on_access(self, short_ttl_cache): + """Test that expired entries are lazily removed on access.""" + short_ttl_cache.set("test_key", "test_response") + assert "test_key" in short_ttl_cache.cache + + # Wait for expiration + time.sleep(4) + result = short_ttl_cache.get("test_key") + + assert result is None + assert "test_key" not in short_ttl_cache.cache + + @patch('time.time') + def test_ttl_calculation(self, mock_time, cache): + """Test correct TTL calculation using mocked time.""" + # Set initial time + mock_time.return_value = 1000.0 + cache.set("test_key", "test_response") + + # Verify expiry is set correctly (1 hour = 3600 seconds) + _, expiry = cache.cache["test_key"] + assert expiry == 4600.0 # 1000 + 3600 + + # Move time forward but not past expiry + mock_time.return_value = 4599.0 + assert cache.get("test_key") == "test_response" + + # Move time past expiry + mock_time.return_value = 4601.0 + assert cache.get("test_key") is None + + # ========== Cache Key Generation Tests ========== + + def test_cache_key_generation_deterministic(self, cache): + """Test that same inputs generate same cache key.""" + prompt = "Match vendor {{input_data.vendor}}" + ref_data = "Slack\nGoogle\nMicrosoft" + + key1 = cache.generate_cache_key(prompt, ref_data) + key2 = cache.generate_cache_key(prompt, ref_data) + + assert key1 == key2 + + def test_different_prompt_different_key(self, cache): + """Test that different prompts generate different keys.""" + ref_data = "Slack\nGoogle\nMicrosoft" + + key1 = cache.generate_cache_key("Prompt 1", ref_data) + key2 = cache.generate_cache_key("Prompt 2", ref_data) + + assert key1 != key2 + + def test_different_ref_data_different_key(self, cache): + """Test that different reference data generates different keys.""" + prompt = "Match vendor {{input_data.vendor}}" + + key1 = cache.generate_cache_key(prompt, "Slack\nGoogle") + key2 = cache.generate_cache_key(prompt, "Slack\nMicrosoft") + + assert key1 != key2 + + def test_cache_key_is_valid_sha256(self, cache): + """Test that cache key is valid SHA256 hex (64 characters).""" + key = cache.generate_cache_key("test prompt", "test ref data") + + assert len(key) == 64 + assert all(c in '0123456789abcdef' for c in key) + + def test_cache_key_handles_unicode(self, cache): + """Test cache key generation with Unicode characters.""" + prompt = "Match vendor: Müller GmbH" + ref_data = "Zürich AG\n北京公司\n東京株式会社" + + key = cache.generate_cache_key(prompt, ref_data) + assert len(key) == 64 + + # ========== Cache Management Tests ========== + + def test_invalidate_removes_specific_key(self, cache): + """Test that invalidate() removes specific key.""" + cache.set("key1", "response1") + cache.set("key2", "response2") + + result = cache.invalidate("key1") + + assert result is True + assert cache.get("key1") is None + assert cache.get("key2") == "response2" + + def test_invalidate_returns_false_for_missing_key(self, cache): + """Test that invalidate() returns False for non-existent key.""" + result = cache.invalidate("nonexistent_key") + assert result is False + + def test_invalidate_all_clears_cache(self, cache): + """Test that invalidate_all() clears entire cache.""" + cache.set("key1", "response1") + cache.set("key2", "response2") + cache.set("key3", "response3") + + count = cache.invalidate_all() + + assert count == 3 + assert len(cache.cache) == 0 + assert cache.get("key1") is None + assert cache.get("key2") is None + assert cache.get("key3") is None + + def test_invalidate_all_returns_zero_for_empty_cache(self, cache): + """Test that invalidate_all() returns 0 for empty cache.""" + count = cache.invalidate_all() + assert count == 0 + + # ========== Statistics Tests ========== + + def test_get_stats_with_valid_entries(self, cache): + """Test get_stats() with all valid entries.""" + cache.set("key1", "response1") + cache.set("key2", "response2") + + stats = cache.get_stats() + + assert stats['total'] == 2 + assert stats['valid'] == 2 + assert stats['expired'] == 0 + + @patch('time.time') + def test_get_stats_with_mixed_entries(self, mock_time, cache): + """Test get_stats() with mix of valid and expired entries.""" + # Set initial time + mock_time.return_value = 1000.0 + + cache.set("key1", "response1") + cache.set("key2", "response2") + + # Manually expire one entry + cache.cache["key1"] = ("response1", 999.0) # Already expired + + stats = cache.get_stats() + + assert stats['total'] == 2 + assert stats['valid'] == 1 + assert stats['expired'] == 1 + + def test_get_stats_empty_cache(self, cache): + """Test get_stats() with empty cache.""" + stats = cache.get_stats() + + assert stats['total'] == 0 + assert stats['valid'] == 0 + assert stats['expired'] == 0 + + # ========== Cleanup Tests ========== + + @patch('time.time') + def test_cleanup_expired_removes_expired_entries(self, mock_time, cache): + """Test cleanup_expired() removes only expired entries.""" + # Set initial time + mock_time.return_value = 1000.0 + + cache.set("key1", "response1") + cache.set("key2", "response2") + cache.set("key3", "response3") + + # Manually expire two entries + cache.cache["key1"] = ("response1", 999.0) # Expired + cache.cache["key2"] = ("response2", 999.0) # Expired + # key3 remains valid (expiry at 4600.0) + + removed_count = cache.cleanup_expired() + + assert removed_count == 2 + assert len(cache.cache) == 1 + assert "key3" in cache.cache + assert "key1" not in cache.cache + assert "key2" not in cache.cache + + def test_cleanup_expired_empty_cache(self, cache): + """Test cleanup_expired() with empty cache.""" + removed_count = cache.cleanup_expired() + assert removed_count == 0 + + def test_cleanup_expired_no_expired_entries(self, cache): + """Test cleanup_expired() when no entries are expired.""" + cache.set("key1", "response1") + cache.set("key2", "response2") + + removed_count = cache.cleanup_expired() + + assert removed_count == 0 + assert len(cache.cache) == 2 + + # ========== Integration Tests ========== + + def test_end_to_end_caching_workflow(self, cache): + """Test complete caching workflow.""" + # Generate cache key + prompt = "Match vendor Slack India" + ref_data = "Slack\nGoogle\nMicrosoft" + cache_key = cache.generate_cache_key(prompt, ref_data) + + # Initially no cached value + assert cache.get(cache_key) is None + + # Store response + llm_response = '{"canonical_vendor": "Slack", "confidence": 0.92}' + cache.set(cache_key, llm_response) + + # Retrieve cached value + cached = cache.get(cache_key) + assert cached == llm_response + + # Check stats + stats = cache.get_stats() + assert stats['valid'] == 1 + + # Invalidate + removed = cache.invalidate(cache_key) + assert removed is True + assert cache.get(cache_key) is None + + def test_concurrent_operations(self, cache): + """Test cache handles multiple operations correctly.""" + # Add multiple entries + for i in range(10): + key = f"key_{i}" + response = f"response_{i}" + cache.set(key, response) + + # Verify all entries exist + for i in range(10): + key = f"key_{i}" + assert cache.get(key) == f"response_{i}" + + # Invalidate some entries + for i in range(0, 10, 2): # Even indices + cache.invalidate(f"key_{i}") + + # Verify correct entries remain + for i in range(10): + key = f"key_{i}" + if i % 2 == 0: + assert cache.get(key) is None + else: + assert cache.get(key) == f"response_{i}" + + # Clear all + count = cache.invalidate_all() + assert count == 5 # Only odd indices remained diff --git a/backend/lookup/tests/test_services/test_lookup_executor.py b/backend/lookup/tests/test_services/test_lookup_executor.py new file mode 100644 index 0000000000..a8b663caba --- /dev/null +++ b/backend/lookup/tests/test_services/test_lookup_executor.py @@ -0,0 +1,418 @@ +""" +Tests for Look-Up Executor implementation. + +This module tests the LookUpExecutor class including execution flow, +caching, error handling, and response parsing. +""" + +import json +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from lookup.exceptions import ( + ExtractionNotCompleteError, + ParseError, +) +from lookup.services.lookup_executor import LookUpExecutor + + +class TestLookUpExecutor: + """Test cases for LookUpExecutor class.""" + + @pytest.fixture + def mock_variable_resolver(self): + """Create a mock VariableResolver class.""" + mock_class = Mock() + mock_instance = Mock() + mock_instance.resolve.return_value = "Resolved prompt text" + mock_class.return_value = mock_instance + return mock_class + + @pytest.fixture + def mock_cache(self): + """Create a mock LLMResponseCache.""" + cache = MagicMock() + cache.generate_cache_key.return_value = "cache_key_123" + cache.get.return_value = None # Default to cache miss + return cache + + @pytest.fixture + def mock_ref_loader(self): + """Create a mock ReferenceDataLoader.""" + loader = MagicMock() + loader.load_latest_for_project.return_value = { + 'version': 1, + 'content': "Reference data content", + 'files': [], + 'total_size': 1000 + } + return loader + + @pytest.fixture + def mock_llm_client(self): + """Create a mock LLM client.""" + client = MagicMock() + client.generate.return_value = '{"canonical_vendor": "Slack", "confidence": 0.92}' + return client + + @pytest.fixture + def mock_project(self): + """Create a mock LookupProject.""" + project = MagicMock() + project.id = uuid.uuid4() + project.name = "Test Look-Up" + project.llm_config = {'temperature': 0.7} + + # Create mock template + template = MagicMock() + template.template_text = "Match {{input_data.vendor}} with {{reference_data}}" + project.template = template + + return project + + @pytest.fixture + def executor(self, mock_variable_resolver, mock_cache, mock_ref_loader, mock_llm_client): + """Create a LookUpExecutor instance with mocked dependencies.""" + return LookUpExecutor( + variable_resolver=mock_variable_resolver, + cache_manager=mock_cache, + reference_loader=mock_ref_loader, + llm_client=mock_llm_client + ) + + @pytest.fixture + def sample_input_data(self): + """Create sample input data.""" + return { + 'vendor': 'Slack India Pvt Ltd', + 'invoice_amount': 5000 + } + + # ========== Successful Execution Tests ========== + + def test_successful_execution(self, executor, mock_project, sample_input_data, + mock_variable_resolver, mock_cache, mock_llm_client): + """Test complete successful execution flow.""" + result = executor.execute(mock_project, sample_input_data) + + # Check success + assert result['status'] == 'success' + assert result['project_id'] == mock_project.id + assert result['project_name'] == 'Test Look-Up' + + # Check enrichment data + assert result['data'] == {'canonical_vendor': 'Slack'} + assert result['confidence'] == 0.92 + assert result['cached'] is False + assert result['execution_time_ms'] > 0 + + # Verify variable resolver was called + mock_variable_resolver.assert_called_once_with( + sample_input_data, "Reference data content" + ) + mock_variable_resolver.return_value.resolve.assert_called_once_with( + "Match {{input_data.vendor}} with {{reference_data}}" + ) + + # Verify cache was checked and set + mock_cache.get.assert_called_once_with("cache_key_123") + mock_cache.set.assert_called_once_with( + "cache_key_123", + '{"canonical_vendor": "Slack", "confidence": 0.92}' + ) + + # Verify LLM was called + mock_llm_client.generate.assert_called_once_with( + "Resolved prompt text", + {'temperature': 0.7} + ) + + def test_confidence_extraction(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test confidence score extraction from response.""" + mock_llm_client.generate.return_value = '{"field": "value", "confidence": 0.85}' + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['confidence'] == 0.85 + assert result['data'] == {'field': 'value'} + + def test_no_confidence_in_response(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling response without confidence score.""" + mock_llm_client.generate.return_value = '{"field": "value"}' + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['confidence'] is None + assert result['data'] == {'field': 'value'} + + # ========== Cache Tests ========== + + def test_cache_hit(self, executor, mock_project, sample_input_data, + mock_cache, mock_llm_client): + """Test execution with cache hit.""" + # Set up cache hit + mock_cache.get.return_value = '{"cached_field": "cached_value", "confidence": 0.88}' + + result = executor.execute(mock_project, sample_input_data) + + # Check result + assert result['status'] == 'success' + assert result['data'] == {'cached_field': 'cached_value'} + assert result['confidence'] == 0.88 + assert result['cached'] is True + assert result['execution_time_ms'] == 0 + + # Verify LLM was NOT called + mock_llm_client.generate.assert_not_called() + + # Verify cache.set was NOT called + mock_cache.set.assert_not_called() + + def test_cache_miss(self, executor, mock_project, sample_input_data, + mock_cache, mock_llm_client): + """Test execution with cache miss.""" + # Cache miss (default setup) + mock_cache.get.return_value = None + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['cached'] is False + + # Verify LLM was called + mock_llm_client.generate.assert_called_once() + + # Verify result was cached + mock_cache.set.assert_called_once() + + # ========== Error Handling Tests ========== + + def test_reference_data_not_ready(self, executor, mock_project, sample_input_data, + mock_ref_loader): + """Test handling of incomplete extraction.""" + mock_ref_loader.load_latest_for_project.side_effect = ExtractionNotCompleteError( + ['file1.csv', 'file2.txt'] + ) + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Reference data extraction not complete' in result['error'] + assert 'file1.csv' in result['error'] + assert 'file2.txt' in result['error'] + + def test_missing_template(self, executor, mock_project, sample_input_data): + """Test handling of missing template.""" + mock_project.template = None + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Missing prompt template' in result['error'] + + def test_llm_timeout(self, executor, mock_project, sample_input_data, mock_llm_client): + """Test handling of LLM timeout.""" + mock_llm_client.generate.side_effect = TimeoutError("Request timed out after 30s") + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'LLM request timed out' in result['error'] + assert '30s' in result['error'] + + def test_llm_generic_error(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling of generic LLM errors.""" + mock_llm_client.generate.side_effect = Exception("API key invalid") + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'LLM request failed' in result['error'] + assert 'API key invalid' in result['error'] + + def test_parse_error_invalid_json(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling of invalid JSON response.""" + mock_llm_client.generate.return_value = "Not valid JSON" + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Failed to parse LLM response' in result['error'] + + def test_parse_error_not_object(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling of non-object JSON response.""" + mock_llm_client.generate.return_value = '["array", "not", "object"]' + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Failed to parse LLM response' in result['error'] + + def test_reference_loader_error(self, executor, mock_project, sample_input_data, + mock_ref_loader): + """Test handling of reference loader errors.""" + mock_ref_loader.load_latest_for_project.side_effect = Exception("Storage unavailable") + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Failed to load reference data' in result['error'] + assert 'Storage unavailable' in result['error'] + + # ========== Variable Resolution Tests ========== + + def test_variable_resolution(self, executor, mock_project, sample_input_data, + mock_variable_resolver): + """Test that variables are correctly resolved.""" + result = executor.execute(mock_project, sample_input_data) + + # Verify resolver was instantiated with correct data + mock_variable_resolver.assert_called_once_with( + sample_input_data, + "Reference data content" + ) + + # Verify resolve was called with template + mock_variable_resolver.return_value.resolve.assert_called_once_with( + "Match {{input_data.vendor}} with {{reference_data}}" + ) + + # ========== Response Parsing Tests ========== + + def test_parse_llm_response_with_confidence(self, executor): + """Test parsing response with confidence score.""" + response = '{"field1": "value1", "field2": "value2", "confidence": 0.95}' + + data, confidence = executor._parse_llm_response(response) + + assert data == {'field1': 'value1', 'field2': 'value2'} + assert confidence == 0.95 + + def test_parse_llm_response_without_confidence(self, executor): + """Test parsing response without confidence score.""" + response = '{"field1": "value1", "field2": "value2"}' + + data, confidence = executor._parse_llm_response(response) + + assert data == {'field1': 'value1', 'field2': 'value2'} + assert confidence is None + + def test_parse_llm_response_invalid_confidence(self, executor): + """Test handling of invalid confidence values.""" + # Confidence outside range (should be clamped) + response = '{"field": "value", "confidence": 1.5}' + + with patch('lookup.services.lookup_executor.logger') as mock_logger: + data, confidence = executor._parse_llm_response(response) + + assert data == {'field': 'value'} + assert confidence == 1.0 # Clamped to max + mock_logger.warning.assert_called() + + def test_parse_llm_response_non_numeric_confidence(self, executor): + """Test handling of non-numeric confidence.""" + response = '{"field": "value", "confidence": "high"}' + + with patch('lookup.services.lookup_executor.logger') as mock_logger: + data, confidence = executor._parse_llm_response(response) + + assert data == {'field': 'value'} + assert confidence is None + mock_logger.warning.assert_called() + + def test_parse_llm_response_invalid_json(self, executor): + """Test parsing invalid JSON raises ParseError.""" + response = "This is not JSON" + + with pytest.raises(ParseError) as exc_info: + executor._parse_llm_response(response) + + assert "Invalid JSON response" in str(exc_info.value) + + # ========== Integration Tests ========== + + def test_end_to_end_execution(self, executor, mock_project, sample_input_data): + """Test complete execution flow with realistic data.""" + result = executor.execute(mock_project, sample_input_data) + + # Basic assertions + assert result['status'] == 'success' + assert 'data' in result + assert 'confidence' in result + assert 'cached' in result + assert 'execution_time_ms' in result + + def test_execution_with_empty_input(self, executor, mock_project): + """Test execution with empty input data.""" + result = executor.execute(mock_project, {}) + + # Should still execute successfully + assert result['status'] == 'success' + + def test_execution_time_tracking(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test that execution time is properly tracked.""" + # Add a small delay to LLM call + def delayed_generate(*args, **kwargs): + import time + time.sleep(0.01) # 10ms delay + return '{"result": "data"}' + + mock_llm_client.generate.side_effect = delayed_generate + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['execution_time_ms'] >= 10 # At least 10ms + + def test_failed_execution_time_tracking(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test that execution time is tracked even on failure.""" + mock_llm_client.generate.side_effect = Exception("Error") + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert result['execution_time_ms'] >= 0 + + @patch('lookup.services.lookup_executor.logger') + def test_unexpected_error_logging(self, mock_logger, executor, mock_project, + sample_input_data): + """Test that unexpected errors are logged.""" + # Create an error that will trigger the catch-all + mock_project.template.template_text = None # Will cause AttributeError + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'failed' + assert 'Unexpected error' in result['error'] + mock_logger.exception.assert_called() + + def test_complex_llm_response(self, executor, mock_project, sample_input_data, + mock_llm_client): + """Test handling of complex nested LLM response.""" + complex_response = json.dumps({ + "vendor": { + "canonical_name": "Slack Technologies", + "id": "SLACK-001" + }, + "categories": ["Communication", "SaaS"], + "confidence": 0.88 + }) + mock_llm_client.generate.return_value = complex_response + + result = executor.execute(mock_project, sample_input_data) + + assert result['status'] == 'success' + assert result['data']['vendor']['canonical_name'] == 'Slack Technologies' + assert result['data']['categories'] == ['Communication', 'SaaS'] + assert result['confidence'] == 0.88 diff --git a/backend/lookup/tests/test_services/test_lookup_orchestrator.py b/backend/lookup/tests/test_services/test_lookup_orchestrator.py new file mode 100644 index 0000000000..f1031ef4bc --- /dev/null +++ b/backend/lookup/tests/test_services/test_lookup_orchestrator.py @@ -0,0 +1,521 @@ +""" +Tests for Look-Up Orchestrator implementation. + +This module tests the LookUpOrchestrator class including parallel execution, +timeout handling, result merging, and error recovery. +""" + +import time +import uuid +from concurrent.futures import TimeoutError as FutureTimeoutError +from unittest.mock import MagicMock, patch + +import pytest + +from lookup.services.lookup_orchestrator import LookUpOrchestrator + + +class TestLookUpOrchestrator: + """Test cases for LookUpOrchestrator class.""" + + @pytest.fixture + def mock_executor(self): + """Create a mock LookUpExecutor.""" + executor = MagicMock() + # Default to successful execution + executor.execute.return_value = { + 'status': 'success', + 'project_id': uuid.uuid4(), + 'project_name': 'Test Look-Up', + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + return executor + + @pytest.fixture + def mock_merger(self): + """Create a mock EnrichmentMerger.""" + merger = MagicMock() + merger.merge.return_value = { + 'data': {'merged_field': 'merged_value'}, + 'conflicts_resolved': 0, + 'enrichment_details': [] + } + return merger + + @pytest.fixture + def config(self): + """Create test configuration.""" + return { + 'max_concurrent_executions': 5, + 'queue_timeout_seconds': 10, + 'execution_timeout_seconds': 2 + } + + @pytest.fixture + def orchestrator(self, mock_executor, mock_merger, config): + """Create a LookUpOrchestrator with mocked dependencies.""" + return LookUpOrchestrator( + executor=mock_executor, + merger=mock_merger, + config=config + ) + + @pytest.fixture + def sample_input_data(self): + """Create sample input data.""" + return { + 'vendor': 'Slack Technologies', + 'amount': 5000 + } + + @pytest.fixture + def mock_projects(self): + """Create mock Look-Up projects.""" + projects = [] + for i in range(3): + project = MagicMock() + project.id = uuid.uuid4() + project.name = f"Look-Up {i+1}" + projects.append(project) + return projects + + # ========== Basic Execution Tests ========== + + def test_successful_parallel_execution( + self, orchestrator, sample_input_data, mock_projects, + mock_executor, mock_merger + ): + """Test successful parallel execution of multiple Look-Ups.""" + # Setup executor to return different data for each project + def execute_side_effect(project, input_data): + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {f'field_{project.name}': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + mock_executor.execute.side_effect = execute_side_effect + + # Execute + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Verify execution + assert mock_executor.execute.call_count == 3 + assert mock_merger.merge.call_count == 1 + + # Check metadata + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 3 + assert metadata['lookups_successful'] == 3 + assert metadata['lookups_failed'] == 0 + assert 'execution_id' in metadata + assert 'executed_at' in metadata + assert metadata['total_execution_time_ms'] > 0 + + def test_empty_projects_list(self, orchestrator, sample_input_data): + """Test execution with empty projects list.""" + result = orchestrator.execute_lookups(sample_input_data, []) + + assert result['lookup_enrichment'] == {} + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 0 + assert metadata['lookups_successful'] == 0 + assert metadata['lookups_failed'] == 0 + assert metadata['enrichments'] == [] + + def test_single_project_execution( + self, orchestrator, sample_input_data, mock_executor + ): + """Test execution with single Look-Up project.""" + project = MagicMock() + project.id = uuid.uuid4() + project.name = "Single Look-Up" + + result = orchestrator.execute_lookups(sample_input_data, [project]) + + assert mock_executor.execute.call_count == 1 + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 1 + assert metadata['lookups_successful'] == 1 + + # ========== Failure Handling Tests ========== + + def test_partial_failures( + self, orchestrator, sample_input_data, mock_projects, + mock_executor, mock_merger + ): + """Test handling of partial failures.""" + # Setup: First succeeds, second fails, third succeeds + def execute_side_effect(project, input_data): + if project.name == "Look-Up 2": + return { + 'status': 'failed', + 'project_id': project.id, + 'project_name': project.name, + 'error': 'Test error', + 'execution_time_ms': 50, + 'cached': False + } + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + mock_executor.execute.side_effect = execute_side_effect + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Check results + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 3 + assert metadata['lookups_successful'] == 2 + assert metadata['lookups_failed'] == 1 + + # Verify merger was called with only successful enrichments + merge_call_args = mock_merger.merge.call_args[0][0] + assert len(merge_call_args) == 2 # Only successful ones + + def test_all_failures( + self, orchestrator, sample_input_data, mock_projects, + mock_executor, mock_merger + ): + """Test handling when all Look-Ups fail.""" + mock_executor.execute.return_value = { + 'status': 'failed', + 'project_id': uuid.uuid4(), + 'project_name': 'Failed Look-Up', + 'error': 'Test error', + 'execution_time_ms': 50, + 'cached': False + } + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Check results + assert result['lookup_enrichment'] == {} # Empty merged data + metadata = result['_lookup_metadata'] + assert metadata['lookups_successful'] == 0 + assert metadata['lookups_failed'] == 3 + assert metadata['conflicts_resolved'] == 0 + + # ========== Timeout Tests ========== + + @patch('lookup.services.lookup_orchestrator.ThreadPoolExecutor') + def test_individual_execution_timeout( + self, mock_executor_class, orchestrator, sample_input_data, + mock_projects + ): + """Test handling of individual execution timeouts.""" + # Setup mock executor + mock_thread_executor = MagicMock() + mock_executor_class.return_value.__enter__.return_value = mock_thread_executor + + # Create futures that will timeout + future1 = MagicMock() + future1.result.side_effect = FutureTimeoutError() + + future2 = MagicMock() + future2.result.return_value = { + 'status': 'success', + 'project_id': mock_projects[1].id, + 'project_name': mock_projects[1].name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + # Setup as_completed to return futures + with patch('lookup.services.lookup_orchestrator.as_completed') as mock_as_completed: + mock_as_completed.return_value = [future1, future2] + + # Setup submit to return futures + mock_thread_executor.submit.side_effect = [future1, future2, MagicMock()] + + result = orchestrator.execute_lookups(sample_input_data, mock_projects[:2]) + + metadata = result['_lookup_metadata'] + # One timeout (failed), one success + assert metadata['lookups_failed'] >= 1 + assert metadata['lookups_successful'] >= 1 + + @patch('lookup.services.lookup_orchestrator.as_completed') + def test_queue_timeout( + self, mock_as_completed, orchestrator, sample_input_data, + mock_projects + ): + """Test handling of overall queue timeout.""" + # Make as_completed raise TimeoutError + mock_as_completed.side_effect = FutureTimeoutError() + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 3 + assert metadata['lookups_successful'] == 0 + assert metadata['lookups_failed'] == 3 # All marked as failed due to queue timeout + + # ========== Concurrency Tests ========== + + def test_max_concurrent_limit( + self, mock_executor, mock_merger, sample_input_data + ): + """Test that max concurrent executions limit is respected.""" + # Create orchestrator with low concurrency limit + config = {'max_concurrent_executions': 2} + orchestrator = LookUpOrchestrator(mock_executor, mock_merger, config) + + # Create many projects + projects = [] + for i in range(10): + project = MagicMock() + project.id = uuid.uuid4() + project.name = f"Look-Up {i+1}" + projects.append(project) + + # Add small delay to executor to simulate work + def slow_execute(project, input_data): + time.sleep(0.01) # Small delay + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 10 + } + + mock_executor.execute.side_effect = slow_execute + + # Execute + result = orchestrator.execute_lookups(sample_input_data, projects) + + # Should complete successfully despite concurrency limit + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 10 + assert metadata['lookups_successful'] == 10 + + # ========== Result Merging Tests ========== + + def test_successful_merge( + self, orchestrator, sample_input_data, mock_projects, + mock_executor, mock_merger + ): + """Test that successful enrichments are properly merged.""" + # Setup merger to return specific merged data + mock_merger.merge.return_value = { + 'data': { + 'vendor': 'Slack', + 'category': 'SaaS', + 'type': 'Communication' + }, + 'conflicts_resolved': 2, + 'enrichment_details': [ + {'lookup_project_name': 'Look-Up 1', 'fields_added': ['vendor']}, + {'lookup_project_name': 'Look-Up 2', 'fields_added': ['category']}, + {'lookup_project_name': 'Look-Up 3', 'fields_added': ['type']} + ] + } + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Check merged enrichment + assert result['lookup_enrichment'] == { + 'vendor': 'Slack', + 'category': 'SaaS', + 'type': 'Communication' + } + + # Check metadata + metadata = result['_lookup_metadata'] + assert metadata['conflicts_resolved'] == 2 + + # ========== Error Recovery Tests ========== + + def test_executor_exception_handling( + self, orchestrator, sample_input_data, mock_projects, + mock_executor + ): + """Test handling of unexpected exceptions from executor.""" + # Make executor raise exception for one project + def execute_with_exception(project, input_data): + if project.name == "Look-Up 2": + raise ValueError("Unexpected error in executor") + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + mock_executor.execute.side_effect = execute_with_exception + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Should handle exception gracefully + metadata = result['_lookup_metadata'] + assert metadata['lookups_executed'] == 3 + assert metadata['lookups_successful'] == 2 + assert metadata['lookups_failed'] == 1 + + # Check that error is captured + failed_enrichment = next( + e for e in metadata['enrichments'] + if e['status'] == 'failed' and 'Unexpected error' in e['error'] + ) + assert failed_enrichment is not None + + # ========== Metadata Tests ========== + + def test_execution_metadata( + self, orchestrator, sample_input_data, mock_projects + ): + """Test that execution metadata is properly populated.""" + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + metadata = result['_lookup_metadata'] + + # Check all required metadata fields + assert 'execution_id' in metadata + assert isinstance(metadata['execution_id'], str) + assert len(metadata['execution_id']) == 36 # UUID format + + assert 'executed_at' in metadata + assert 'T' in metadata['executed_at'] # ISO8601 format + + assert 'total_execution_time_ms' in metadata + assert metadata['total_execution_time_ms'] >= 0 + + assert 'enrichments' in metadata + assert len(metadata['enrichments']) >= metadata['lookups_successful'] + + def test_enrichments_list_includes_all_results( + self, orchestrator, sample_input_data, mock_projects, + mock_executor + ): + """Test that enrichments list includes both successful and failed results.""" + # Setup mixed results + def execute_side_effect(project, input_data): + if project.name == "Look-Up 2": + return { + 'status': 'failed', + 'project_id': project.id, + 'project_name': project.name, + 'error': 'Test failure', + 'execution_time_ms': 50, + 'cached': False + } + return { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + + mock_executor.execute.side_effect = execute_side_effect + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + metadata = result['_lookup_metadata'] + enrichments = metadata['enrichments'] + + # Should have all 3 enrichments (2 success + 1 failed) + assert len(enrichments) == 3 + + # Check statuses + statuses = [e['status'] for e in enrichments] + assert statuses.count('success') == 2 + assert statuses.count('failed') == 1 + + # ========== Configuration Tests ========== + + def test_default_configuration(self, mock_executor, mock_merger): + """Test orchestrator with default configuration.""" + orchestrator = LookUpOrchestrator(mock_executor, mock_merger) + + assert orchestrator.max_concurrent == 10 + assert orchestrator.queue_timeout == 120 + assert orchestrator.execution_timeout == 30 + + def test_custom_configuration(self, mock_executor, mock_merger): + """Test orchestrator with custom configuration.""" + config = { + 'max_concurrent_executions': 20, + 'queue_timeout_seconds': 300, + 'execution_timeout_seconds': 60 + } + + orchestrator = LookUpOrchestrator(mock_executor, mock_merger, config) + + assert orchestrator.max_concurrent == 20 + assert orchestrator.queue_timeout == 300 + assert orchestrator.execution_timeout == 60 + + # ========== Integration Tests ========== + + @patch('lookup.services.lookup_orchestrator.logger') + def test_logging( + self, mock_logger, orchestrator, sample_input_data, + mock_projects + ): + """Test that appropriate logging is performed.""" + orchestrator.execute_lookups(sample_input_data, mock_projects) + + # Should log start and completion + assert mock_logger.info.call_count >= 2 + start_log = mock_logger.info.call_args_list[0][0][0] + assert 'Starting orchestrated execution' in start_log + + completion_log = mock_logger.info.call_args_list[-1][0][0] + assert 'completed' in completion_log + + def test_execution_id_propagation( + self, orchestrator, sample_input_data, mock_projects, + mock_executor + ): + """Test that execution ID is propagated to individual executions.""" + # Capture execution results + captured_results = [] + + def capture_execute(project, input_data): + result = { + 'status': 'success', + 'project_id': project.id, + 'project_name': project.name, + 'data': {'field': 'value'}, + 'confidence': 0.9, + 'cached': False, + 'execution_time_ms': 100 + } + captured_results.append(result) + return result + + mock_executor.execute.side_effect = capture_execute + + result = orchestrator.execute_lookups(sample_input_data, mock_projects) + + execution_id = result['_lookup_metadata']['execution_id'] + + # Check that execution_id is added to enrichments + for enrichment in result['_lookup_metadata']['enrichments']: + if 'execution_id' in enrichment: + assert enrichment['execution_id'] == execution_id diff --git a/backend/lookup/tests/test_services/test_reference_data_loader.py b/backend/lookup/tests/test_services/test_reference_data_loader.py new file mode 100644 index 0000000000..6e3e84490d --- /dev/null +++ b/backend/lookup/tests/test_services/test_reference_data_loader.py @@ -0,0 +1,565 @@ +""" +Tests for Reference Data Loader implementation. + +This module tests the ReferenceDataLoader class including loading latest/specific +versions, concatenation, and extraction validation. +""" + +import uuid +from datetime import datetime, timezone +from unittest.mock import MagicMock, Mock + +import pytest +from django.contrib.auth import get_user_model + +from lookup.exceptions import ExtractionNotCompleteError +from lookup.models import LookupDataSource, LookupProject +from lookup.services.reference_data_loader import ReferenceDataLoader + +User = get_user_model() + + +@pytest.mark.django_db +class TestReferenceDataLoader: + """Test cases for ReferenceDataLoader class.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock storage client.""" + storage = MagicMock() + storage.get = Mock(return_value="Default content") + return storage + + @pytest.fixture + def loader(self, mock_storage): + """Create a ReferenceDataLoader instance with mock storage.""" + return ReferenceDataLoader(mock_storage) + + @pytest.fixture + def test_user(self): + """Create a test user.""" + return User.objects.create_user( + username='testuser', + email='test@example.com', + password='testpass123' + ) + + @pytest.fixture + def test_project(self, test_user): + """Create a test Look-Up project.""" + return LookupProject.objects.create( + name="Test Project", + description="Test project for loader", + lookup_type='static_data', + llm_provider='openai', + llm_model='gpt-4', + llm_config={'temperature': 0.7}, + organization_id=uuid.uuid4(), + created_by=test_user + ) + + @pytest.fixture + def data_sources_completed(self, test_project, test_user): + """Create 3 completed data sources for testing.""" + sources = [] + for i in range(3): + source = LookupDataSource.objects.create( + project=test_project, + file_name=f"file{i+1}.csv", + file_path=f"uploads/file{i+1}.csv", + file_size=1000 * (i + 1), + file_type="text/csv", + extracted_content_path=f"extracted/file{i+1}.txt", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + sources.append(source) + return sources + + # ========== Load Latest Tests ========== + + def test_load_latest_success(self, loader, mock_storage, test_project, + data_sources_completed): + """Test successful loading of latest reference data.""" + # Setup mock storage responses + def storage_get(path): + if 'file1' in path: + return "Content from file 1" + elif 'file2' in path: + return "Content from file 2" + elif 'file3' in path: + return "Content from file 3" + return "Unknown file" + + mock_storage.get.side_effect = storage_get + + # Load latest data + result = loader.load_latest_for_project(test_project.id) + + # Verify result structure + assert 'version' in result + assert 'content' in result + assert 'files' in result + assert 'total_size' in result + + # Check version + assert result['version'] == 1 + + # Check concatenated content + expected_content = ( + "=== File: file1.csv ===\n\n" + "Content from file 1\n\n" + "=== File: file2.csv ===\n\n" + "Content from file 2\n\n" + "=== File: file3.csv ===\n\n" + "Content from file 3" + ) + assert result['content'] == expected_content + + # Check files metadata + assert len(result['files']) == 3 + assert result['files'][0]['name'] == 'file1.csv' + assert result['files'][1]['name'] == 'file2.csv' + assert result['files'][2]['name'] == 'file3.csv' + + # Check total size + assert result['total_size'] == 6000 # 1000 + 2000 + 3000 + + # Verify storage was called for each file + assert mock_storage.get.call_count == 3 + + def test_load_latest_incomplete_extraction(self, loader, test_project, + test_user): + """Test loading fails when extraction is incomplete.""" + # Create sources with mixed status + LookupDataSource.objects.create( + project=test_project, + file_name="completed.csv", + file_path="uploads/completed.csv", + file_size=1000, + file_type="text/csv", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + + LookupDataSource.objects.create( + project=test_project, + file_name="pending.csv", + file_path="uploads/pending.csv", + file_size=2000, + file_type="text/csv", + extraction_status='pending', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + + LookupDataSource.objects.create( + project=test_project, + file_name="failed.csv", + file_path="uploads/failed.csv", + file_size=3000, + file_type="text/csv", + extraction_status='failed', + extraction_error='Parse error', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + + # Attempt to load should raise exception + with pytest.raises(ExtractionNotCompleteError) as exc_info: + loader.load_latest_for_project(test_project.id) + + # Check error message includes failed files + assert 'pending.csv' in str(exc_info.value) + assert 'failed.csv' in str(exc_info.value) + assert exc_info.value.failed_files == ['pending.csv', 'failed.csv'] + + def test_load_latest_no_data_sources(self, loader): + """Test loading when no data sources exist.""" + non_existent_id = uuid.uuid4() + + with pytest.raises(LookupDataSource.DoesNotExist) as exc_info: + loader.load_latest_for_project(non_existent_id) + + assert str(non_existent_id) in str(exc_info.value) + + def test_load_latest_order_by_upload(self, loader, mock_storage, test_project, + test_user): + """Test that files are concatenated in upload order.""" + # Create sources with specific creation times + source1 = LookupDataSource.objects.create( + project=test_project, + file_name="third.csv", # Named third but uploaded first + file_path="uploads/third.csv", + file_size=1000, + file_type="text/csv", + extracted_content_path="extracted/third.txt", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + source1.created_at = datetime(2024, 1, 1, 10, 0, tzinfo=timezone.utc) + source1.save() + + source2 = LookupDataSource.objects.create( + project=test_project, + file_name="first.csv", # Named first but uploaded second + file_path="uploads/first.csv", + file_size=2000, + file_type="text/csv", + extracted_content_path="extracted/first.txt", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + source2.created_at = datetime(2024, 1, 1, 11, 0, tzinfo=timezone.utc) + source2.save() + + def storage_get(path): + if 'third' in path: + return "Third content" + elif 'first' in path: + return "First content" + return "Unknown" + + mock_storage.get.side_effect = storage_get + + result = loader.load_latest_for_project(test_project.id) + + # Verify order is by upload time, not name + assert "=== File: third.csv ===" in result['content'] + assert "=== File: first.csv ===" in result['content'] + assert result['content'].index('third.csv') < result['content'].index('first.csv') + + # ========== Load Specific Version Tests ========== + + def test_load_specific_version(self, loader, mock_storage, test_project, + test_user): + """Test loading a specific version of reference data.""" + # Create multiple versions + # Version 1 + LookupDataSource.objects.create( + project=test_project, + file_name="v1_file.csv", + file_path="uploads/v1_file.csv", + file_size=1000, + file_type="text/csv", + extracted_content_path="extracted/v1_file.txt", + extraction_status='completed', + version_number=1, + is_latest=False, + uploaded_by=test_user + ) + + # Version 2 + LookupDataSource.objects.create( + project=test_project, + file_name="v2_file.csv", + file_path="uploads/v2_file.csv", + file_size=2000, + file_type="text/csv", + extracted_content_path="extracted/v2_file.txt", + extraction_status='completed', + version_number=2, + is_latest=False, + uploaded_by=test_user + ) + + # Version 3 (latest) + LookupDataSource.objects.create( + project=test_project, + file_name="v3_file.csv", + file_path="uploads/v3_file.csv", + file_size=3000, + file_type="text/csv", + extracted_content_path="extracted/v3_file.txt", + extraction_status='completed', + version_number=3, + is_latest=True, + uploaded_by=test_user + ) + + def storage_get(path): + if 'v1' in path: + return "Version 1 content" + elif 'v2' in path: + return "Version 2 content" + elif 'v3' in path: + return "Version 3 content" + return "Unknown" + + mock_storage.get.side_effect = storage_get + + # Load version 2 + result = loader.load_specific_version(test_project.id, 2) + + assert result['version'] == 2 + assert "Version 2 content" in result['content'] + assert "v2_file.csv" in result['content'] + assert result['files'][0]['name'] == 'v2_file.csv' + + # Load version 1 + result = loader.load_specific_version(test_project.id, 1) + + assert result['version'] == 1 + assert "Version 1 content" in result['content'] + assert "v1_file.csv" in result['content'] + + def test_load_specific_version_not_found(self, loader, test_project): + """Test loading non-existent version.""" + with pytest.raises(LookupDataSource.DoesNotExist) as exc_info: + loader.load_specific_version(test_project.id, 999) + + assert "Version 999 not found" in str(exc_info.value) + + # ========== Concatenation Tests ========== + + def test_concatenate_sources(self, loader, mock_storage, data_sources_completed): + """Test concatenation of multiple sources.""" + def storage_get(path): + if 'file1' in path: + return "Alpha content" + elif 'file2' in path: + return "Beta content" + elif 'file3' in path: + return "Gamma content" + return "Unknown" + + mock_storage.get.side_effect = storage_get + + result = loader.concatenate_sources(data_sources_completed) + + # Check all files are included with headers + assert "=== File: file1.csv ===" in result + assert "=== File: file2.csv ===" in result + assert "=== File: file3.csv ===" in result + + # Check all content is included + assert "Alpha content" in result + assert "Beta content" in result + assert "Gamma content" in result + + # Check double newline separators + assert "\n\n" in result + + def test_concatenate_with_missing_path(self, loader, test_project, test_user): + """Test concatenation when extracted_content_path is missing.""" + source = LookupDataSource.objects.create( + project=test_project, + file_name="no_path.csv", + file_path="uploads/no_path.csv", + file_size=1000, + file_type="text/csv", + extracted_content_path=None, # No extracted path + extraction_status='completed', # But marked as completed + version_number=1, + is_latest=True, + uploaded_by=test_user + ) + + result = loader.concatenate_sources([source]) + + assert "=== File: no_path.csv ===" in result + assert "[No extracted content path]" in result + + def test_concatenate_with_storage_error(self, loader, mock_storage, + data_sources_completed): + """Test concatenation handles storage errors gracefully.""" + mock_storage.get.side_effect = Exception("Storage unavailable") + + result = loader.concatenate_sources(data_sources_completed) + + # Should include error messages instead of content + assert "[Error loading file: Storage unavailable]" in result + # But should still have file headers + assert "=== File: file1.csv ===" in result + + # ========== Validation Tests ========== + + def test_validate_extraction_complete_all_success(self, loader, + data_sources_completed): + """Test validation when all extractions are complete.""" + is_complete, failed_files = loader.validate_extraction_complete( + data_sources_completed + ) + + assert is_complete is True + assert failed_files == [] + + def test_validate_extraction_with_failures(self, loader, test_project, test_user): + """Test validation identifies incomplete extractions.""" + sources = [] + + # Completed + sources.append(LookupDataSource.objects.create( + project=test_project, + file_name="good.csv", + file_path="uploads/good.csv", + file_size=1000, + file_type="text/csv", + extraction_status='completed', + version_number=1, + is_latest=True, + uploaded_by=test_user + )) + + # Pending + sources.append(LookupDataSource.objects.create( + project=test_project, + file_name="pending.csv", + file_path="uploads/pending.csv", + file_size=2000, + file_type="text/csv", + extraction_status='pending', + version_number=1, + is_latest=True, + uploaded_by=test_user + )) + + # Processing + sources.append(LookupDataSource.objects.create( + project=test_project, + file_name="processing.csv", + file_path="uploads/processing.csv", + file_size=3000, + file_type="text/csv", + extraction_status='processing', + version_number=1, + is_latest=True, + uploaded_by=test_user + )) + + # Failed + sources.append(LookupDataSource.objects.create( + project=test_project, + file_name="failed.csv", + file_path="uploads/failed.csv", + file_size=4000, + file_type="text/csv", + extraction_status='failed', + version_number=1, + is_latest=True, + uploaded_by=test_user + )) + + is_complete, failed_files = loader.validate_extraction_complete(sources) + + assert is_complete is False + assert len(failed_files) == 3 + assert 'pending.csv' in failed_files + assert 'processing.csv' in failed_files + assert 'failed.csv' in failed_files + assert 'good.csv' not in failed_files + + # ========== Integration Tests ========== + + def test_end_to_end_multi_file_loading(self, loader, mock_storage, test_project, + test_user): + """Test complete workflow with multiple files and versions.""" + # Create version 1 with 2 files + v1_file1 = LookupDataSource.objects.create( + project=test_project, + file_name="vendors_v1.csv", + file_path="uploads/vendors_v1.csv", + file_size=5000, + file_type="text/csv", + extracted_content_path="extracted/vendors_v1.txt", + extraction_status='completed', + version_number=1, + is_latest=False, + uploaded_by=test_user + ) + + v1_file2 = LookupDataSource.objects.create( + project=test_project, + file_name="products_v1.csv", + file_path="uploads/products_v1.csv", + file_size=3000, + file_type="text/csv", + extracted_content_path="extracted/products_v1.txt", + extraction_status='completed', + version_number=1, + is_latest=False, + uploaded_by=test_user + ) + + # Create version 2 with 3 files (latest) + v2_file1 = LookupDataSource.objects.create( + project=test_project, + file_name="vendors_v2.csv", + file_path="uploads/vendors_v2.csv", + file_size=6000, + file_type="text/csv", + extracted_content_path="extracted/vendors_v2.txt", + extraction_status='completed', + version_number=2, + is_latest=True, + uploaded_by=test_user + ) + + v2_file2 = LookupDataSource.objects.create( + project=test_project, + file_name="products_v2.csv", + file_path="uploads/products_v2.csv", + file_size=4000, + file_type="text/csv", + extracted_content_path="extracted/products_v2.txt", + extraction_status='completed', + version_number=2, + is_latest=True, + uploaded_by=test_user + ) + + v2_file3 = LookupDataSource.objects.create( + project=test_project, + file_name="categories.csv", + file_path="uploads/categories.csv", + file_size=2000, + file_type="text/csv", + extracted_content_path="extracted/categories.txt", + extraction_status='completed', + version_number=2, + is_latest=True, + uploaded_by=test_user + ) + + def storage_get(path): + content_map = { + 'vendors_v1': "Slack\nMicrosoft", + 'products_v1': "Slack Workspace\nTeams", + 'vendors_v2': "Slack\nMicrosoft\nGoogle", + 'products_v2': "Slack Workspace\nTeams\nGoogle Workspace", + 'categories': "Communication\nProductivity" + } + for key, value in content_map.items(): + if key in path: + return value + return "Unknown content" + + mock_storage.get.side_effect = storage_get + + # Load latest (v2) + latest_result = loader.load_latest_for_project(test_project.id) + + assert latest_result['version'] == 2 + assert len(latest_result['files']) == 3 + assert latest_result['total_size'] == 12000 # 6000 + 4000 + 2000 + assert "Google" in latest_result['content'] + assert "categories.csv" in latest_result['content'] + + # Load specific version (v1) + v1_result = loader.load_specific_version(test_project.id, 1) + + assert v1_result['version'] == 1 + assert len(v1_result['files']) == 2 + assert v1_result['total_size'] == 8000 # 5000 + 3000 + assert "Google" not in v1_result['content'] # v1 doesn't have Google + assert "categories" not in v1_result['content'] # v1 doesn't have categories diff --git a/backend/lookup/tests/test_variable_resolver.py b/backend/lookup/tests/test_variable_resolver.py new file mode 100644 index 0000000000..826f961b10 --- /dev/null +++ b/backend/lookup/tests/test_variable_resolver.py @@ -0,0 +1,289 @@ +"""Tests for the VariableResolver class.""" + +import json +import pytest +from lookup.services.variable_resolver import VariableResolver + + +class TestVariableResolver: + """Test suite for VariableResolver.""" + + @pytest.fixture + def sample_input_data(self): + """Sample input data for testing.""" + return { + "vendor_name": "Slack India Pvt Ltd", + "contract_value": 50000, + "contract_date": "2024-01-15", + "line_items": [ + {"product": "Slack Pro", "quantity": 100, "price": 500} + ], + "metadata": { + "region": "APAC", + "currency": "INR" + }, + "none_field": None + } + + @pytest.fixture + def sample_reference_data(self): + """Sample reference data for testing.""" + return """Canonical Vendors: +- Slack (variations: Slack Inc, Slack India, Slack Singapore) +- Microsoft (variations: Microsoft Corp, MSFT) +- Google (variations: Google LLC, Google India)""" + + @pytest.fixture + def resolver(self, sample_input_data, sample_reference_data): + """Create a VariableResolver instance for testing.""" + return VariableResolver(sample_input_data, sample_reference_data) + + # ========================================================================== + # Basic Resolution Tests + # ========================================================================== + + def test_simple_variable_replacement(self, resolver): + """Test simple variable replacement.""" + template = "Vendor: {{input_data.vendor_name}}" + result = resolver.resolve(template) + assert result == "Vendor: Slack India Pvt Ltd" + + def test_reference_data_replacement(self, resolver): + """Test reference data replacement.""" + template = "Database: {{reference_data}}" + result = resolver.resolve(template) + assert "Canonical Vendors" in result + assert "Slack" in result + + def test_multiple_variables(self, resolver): + """Test multiple variable replacements in one template.""" + template = "Match {{input_data.vendor_name}} from {{reference_data}}" + result = resolver.resolve(template) + assert "Slack India Pvt Ltd" in result + assert "Canonical Vendors" in result + + # ========================================================================== + # Dot Notation Tests + # ========================================================================== + + def test_one_level_dot_notation(self, resolver): + """Test one level dot notation.""" + template = "Value: {{input_data.contract_value}}" + result = resolver.resolve(template) + assert result == "Value: 50000" + + def test_two_level_dot_notation(self, resolver): + """Test two level dot notation.""" + template = "Region: {{input_data.metadata.region}}" + result = resolver.resolve(template) + assert result == "Region: APAC" + + def test_array_indexing(self, resolver): + """Test array indexing with dot notation.""" + template = "Product: {{input_data.line_items.0.product}}" + result = resolver.resolve(template) + assert result == "Product: Slack Pro" + + def test_deep_nesting(self, resolver): + """Test deep nested path resolution.""" + template = "Price: {{input_data.line_items.0.price}}" + result = resolver.resolve(template) + assert result == "Price: 500" + + # ========================================================================== + # Complex Object Serialization Tests + # ========================================================================== + + def test_dict_serialization(self, resolver): + """Test that dicts are serialized to JSON.""" + template = "Metadata: {{input_data.metadata}}" + result = resolver.resolve(template) + # Parse the JSON portion to verify it's valid + json_str = result.replace("Metadata: ", "") + parsed = json.loads(json_str) + assert parsed["region"] == "APAC" + assert parsed["currency"] == "INR" + + def test_list_serialization(self, resolver): + """Test that lists are serialized to JSON.""" + template = "Items: {{input_data.line_items}}" + result = resolver.resolve(template) + # Parse the JSON portion to verify it's valid + json_str = result.replace("Items: ", "") + parsed = json.loads(json_str) + assert len(parsed) == 1 + assert parsed[0]["product"] == "Slack Pro" + + def test_full_input_serialization(self, resolver): + """Test serializing the entire input_data object.""" + template = "Full data: {{input_data}}" + result = resolver.resolve(template) + # Should contain JSON representation + assert '"vendor_name"' in result + assert '"contract_value"' in result + + # ========================================================================== + # Missing Value Tests + # ========================================================================== + + def test_missing_root_variable(self, resolver): + """Test missing root variable returns empty string.""" + template = "Missing: {{missing_data}}" + result = resolver.resolve(template) + assert result == "Missing: " + + def test_missing_nested_field(self, resolver): + """Test missing nested field returns empty string.""" + template = "Missing: {{input_data.missing.field}}" + result = resolver.resolve(template) + assert result == "Missing: " + + def test_partial_path_missing(self, resolver): + """Test partially missing path returns empty string.""" + template = "Missing: {{input_data.metadata.missing}}" + result = resolver.resolve(template) + assert result == "Missing: " + + def test_none_value(self, resolver): + """Test that None values are converted to empty string.""" + template = "None field: {{input_data.none_field}}" + result = resolver.resolve(template) + assert result == "None field: " + + def test_out_of_bounds_array_index(self, resolver): + """Test out of bounds array index returns empty string.""" + template = "Missing: {{input_data.line_items.999.product}}" + result = resolver.resolve(template) + assert result == "Missing: " + + # ========================================================================== + # Edge Cases Tests + # ========================================================================== + + def test_empty_template(self, resolver): + """Test empty template returns empty string.""" + assert resolver.resolve("") == "" + + def test_no_variables(self, resolver): + """Test template without variables returns unchanged.""" + template = "Plain text without variables" + assert resolver.resolve(template) == template + + def test_malformed_variable(self, resolver): + """Test malformed variable syntax is left unchanged.""" + template = "Malformed: {{unclosed" + result = resolver.resolve(template) + assert result == "Malformed: {{unclosed" + + def test_whitespace_in_variable(self, resolver): + """Test variables with whitespace are handled correctly.""" + template = "Vendor: {{ input_data.vendor_name }}" + result = resolver.resolve(template) + assert result == "Vendor: Slack India Pvt Ltd" + + def test_empty_braces(self, resolver): + """Test empty braces are handled.""" + template = "Empty: {{}}" + result = resolver.resolve(template) + assert result == "Empty: " + + # ========================================================================== + # Variable Detection Tests + # ========================================================================== + + def test_detect_single_variable(self, resolver): + """Test detecting a single variable.""" + template = "{{input_data.vendor_name}}" + variables = resolver.detect_variables(template) + assert variables == ["input_data.vendor_name"] + + def test_detect_multiple_variables(self, resolver): + """Test detecting multiple variables.""" + template = "{{input_data.vendor_name}} and {{reference_data}}" + variables = resolver.detect_variables(template) + assert set(variables) == {"input_data.vendor_name", "reference_data"} + + def test_detect_duplicate_variables(self, resolver): + """Test that duplicate variables are deduplicated.""" + template = "{{var}} and {{var}} and {{var}}" + variables = resolver.detect_variables(template) + assert variables == ["var"] + + def test_detect_no_variables(self, resolver): + """Test detecting no variables in plain text.""" + template = "Plain text without variables" + variables = resolver.detect_variables(template) + assert variables == [] + + def test_detect_variables_with_whitespace(self, resolver): + """Test detecting variables with whitespace.""" + template = "{{ input_data.vendor }} and {{reference_data}}" + variables = resolver.detect_variables(template) + assert set(variables) == {"input_data.vendor", "reference_data"} + + # ========================================================================== + # Variable Validation Tests + # ========================================================================== + + def test_validate_existing_variables(self, resolver): + """Test validation of existing variables.""" + template = "{{input_data.vendor_name}} and {{reference_data}}" + validation = resolver.validate_variables(template) + assert validation["input_data.vendor_name"] is True + assert validation["reference_data"] is True + + def test_validate_missing_variables(self, resolver): + """Test validation of missing variables.""" + template = "{{input_data.missing}} and {{nonexistent}}" + validation = resolver.validate_variables(template) + assert validation["input_data.missing"] is False + assert validation["nonexistent"] is False + + def test_get_missing_variables(self, resolver): + """Test getting list of missing variables.""" + template = "{{input_data.vendor_name}} {{missing}} {{input_data.nope}}" + missing = resolver.get_missing_variables(template) + assert set(missing) == {"missing", "input_data.nope"} + + # ========================================================================== + # Integration Tests + # ========================================================================== + + def test_complete_vendor_matching_template(self, resolver): + """Test complete vendor matching template from spec.""" + template = """Match vendor "{{input_data.vendor_name}}" from: +{{reference_data}} + +Contract value: {{input_data.contract_value}} +Region: {{input_data.metadata.region}}""" + + result = resolver.resolve(template) + + # Verify all variables were replaced correctly + assert "Slack India Pvt Ltd" in result + assert "Canonical Vendors" in result + assert "50000" in result + assert "APAC" in result + assert "{{" not in result # No unresolved variables + + def test_complex_nested_template(self, resolver): + """Test complex template with various variable types.""" + template = """Input Summary: +- Vendor: {{input_data.vendor_name}} +- Date: {{input_data.contract_date}} +- First Item: {{input_data.line_items.0.product}} +- Metadata: {{input_data.metadata}} +- Missing: {{input_data.nonexistent}} + +Reference Database: +{{reference_data}}""" + + result = resolver.resolve(template) + + # Verify each part + assert "Slack India Pvt Ltd" in result + assert "2024-01-15" in result + assert "Slack Pro" in result + assert '"region": "APAC"' in result # JSON serialized + assert "Missing: \n" in result # Empty for missing field + assert "Canonical Vendors" in result diff --git a/backend/lookup/urls.py b/backend/lookup/urls.py new file mode 100644 index 0000000000..bc30dfa21d --- /dev/null +++ b/backend/lookup/urls.py @@ -0,0 +1,34 @@ +"""URL configuration for Look-Up API endpoints.""" + +from django.urls import include, path +from rest_framework.routers import DefaultRouter + +from .views import ( + LookupDataSourceViewSet, + LookupDebugView, + LookupExecutionAuditViewSet, + LookupProfileManagerViewSet, + LookupProjectViewSet, + LookupPromptTemplateViewSet, + PromptStudioLookupLinkViewSet, +) + +# Create router for viewsets +router = DefaultRouter() +router.register(r"lookup-projects", LookupProjectViewSet, basename="lookupproject") +router.register( + r"lookup-templates", LookupPromptTemplateViewSet, basename="lookuptemplate" +) +router.register(r"lookup-profiles", LookupProfileManagerViewSet, basename="lookupprofile") +router.register(r"data-sources", LookupDataSourceViewSet, basename="lookupdatasource") +router.register(r"lookup-links", PromptStudioLookupLinkViewSet, basename="lookuplink") +router.register( + r"execution-audits", LookupExecutionAuditViewSet, basename="executionaudit" +) +router.register(r"lookup-debug", LookupDebugView, basename="lookupdebug") + +app_name = "lookup" + +urlpatterns = [ + path("", include(router.urls)), +] diff --git a/backend/lookup/views.py b/backend/lookup/views.py new file mode 100644 index 0000000000..9416acc689 --- /dev/null +++ b/backend/lookup/views.py @@ -0,0 +1,1489 @@ +"""Django REST Framework views for Look-Up API. + +This module provides RESTful API endpoints for managing Look-Up projects, +templates, reference data, and executing Look-Ups. +""" + +import logging +import uuid + +from account_v2.custom_exceptions import DuplicateData +from django.db import IntegrityError, transaction +from permissions.permission import IsOwner +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.versioning import URLPathVersioning +from utils.filtering import FilterHelper +from utils.pagination import CustomPagination + +from .constants import LookupProfileManagerErrors, LookupProfileManagerKeys +from .exceptions import ExtractionNotCompleteError +from .models import ( + LookupDataSource, + LookupExecutionAudit, + LookupProfileManager, + LookupProject, + LookupPromptTemplate, + PromptStudioLookupLink, +) +from .serializers import ( + BulkLinkSerializer, + LookupDataSourceSerializer, + LookupExecutionAuditSerializer, + LookupExecutionRequestSerializer, + LookupExecutionResponseSerializer, + LookupProfileManagerSerializer, + LookupProjectSerializer, + LookupPromptTemplateSerializer, + PromptStudioLookupLinkSerializer, + ReferenceDataUploadSerializer, + TemplateValidationSerializer, +) +from .services import ( + AuditLogger, + EnrichmentMerger, + LLMResponseCache, + LookUpExecutor, + LookUpOrchestrator, + ReferenceDataLoader, + VariableResolver, +) + +logger = logging.getLogger(__name__) + + +class LookupProjectViewSet(viewsets.ModelViewSet): + """ViewSet for managing Look-Up projects. + + Provides CRUD operations and additional actions for + executing Look-Ups and managing reference data. + """ + + queryset = LookupProject.objects.all() + serializer_class = LookupProjectSerializer + permission_classes = [IsAuthenticated] + pagination_class = CustomPagination + + def get_queryset(self): + """Filter projects by organization and active status.""" + # Note: Organization filtering is handled automatically by + # DefaultOrganizationMixin's save() method and queryset filtering + # should be handled by a custom manager if needed (like Prompt Studio) + queryset = super().get_queryset() + + # Filter by active status if requested + is_active = self.request.query_params.get("is_active") + if is_active is not None: + queryset = queryset.filter(is_active=is_active.lower() == "true") + + return queryset.select_related("template") + + def perform_create(self, serializer): + """Set created_by from request.""" + # Note: organization is set automatically by DefaultOrganizationMixin's save() method + serializer.save(created_by=self.request.user) + + def destroy(self, request, *args, **kwargs): + """Delete a Look-Up project. + + Prevents deletion if the project is linked to any Prompt Studio projects. + """ + instance = self.get_object() + + # Check if the project is linked to any Prompt Studio projects + linked_ps_projects = instance.ps_links.all() + if linked_ps_projects.exists(): + # Get linked project IDs for the error message + linked_ids = list( + linked_ps_projects.values_list("prompt_studio_project_id", flat=True) + ) + return Response( + { + "error": "Cannot delete Look-Up project that is linked to " + "Prompt Studio projects", + "detail": f"This Look-Up project is linked to {len(linked_ids)} " + f"Prompt Studio project(s). Please unlink it from all Prompt " + f"Studio projects before deleting.", + "linked_prompt_studio_projects": [str(id) for id in linked_ids], + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Proceed with deletion + return super().destroy(request, *args, **kwargs) + + @action(detail=True, methods=["post"]) + def execute(self, request, pk=None): + """Execute a Look-Up project with provided input data. + + POST /api/v1/lookup-projects/{id}/execute/ + """ + project = self.get_object() + + # Validate request + request_serializer = LookupExecutionRequestSerializer(data=request.data) + request_serializer.is_valid(raise_exception=True) + + input_data = request_serializer.validated_data["input_data"] + use_cache = request_serializer.validated_data["use_cache"] + timeout = request_serializer.validated_data["timeout_seconds"] + + try: + # Get the LLM adapter from Lookup profile + from utils.user_context import UserContext + + from .integrations.file_storage_client import FileStorageClient + from .integrations.unstract_llm_client import UnstractLLMClient + + # Get organization ID for RAG retrieval (must match what was used during indexing) + org_id = UserContext.get_organization_identifier() + + # Get profile for this project + profile = LookupProfileManager.objects.filter(lookup_project=project).first() + + if not profile or not profile.llm: + return Response( + {"error": "No LLM profile configured for this Look-Up project"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Create real LLM client using the profile's adapter + llm_client = UnstractLLMClient(profile.llm) + storage_client = FileStorageClient() + cache = LLMResponseCache() if use_cache else None + ref_loader = ReferenceDataLoader(storage_client) + merger = EnrichmentMerger() + + executor = LookUpExecutor( + variable_resolver=VariableResolver, + cache_manager=cache, + reference_loader=ref_loader, + llm_client=llm_client, + org_id=org_id, + ) + + orchestrator = LookUpOrchestrator( + executor=executor, + merger=merger, + config={"execution_timeout_seconds": timeout}, + ) + + # Execute Look-Up + result = orchestrator.execute_lookups( + input_data=input_data, lookup_projects=[project] + ) + + # Check if any lookups failed - return error response if so + metadata = result.get("_lookup_metadata", {}) + enrichments = metadata.get("enrichments", []) + failed_enrichments = [e for e in enrichments if e.get("status") == "failed"] + + if failed_enrichments: + # Find context window errors first (more specific) + context_window_error = next( + ( + e + for e in failed_enrichments + if e.get("error_type") == "context_window_exceeded" + ), + None, + ) + + if context_window_error: + return Response( + { + "error": context_window_error.get("error"), + "error_type": "context_window_exceeded", + "token_count": context_window_error.get("token_count"), + "context_limit": context_window_error.get("context_limit"), + "model": context_window_error.get("model"), + "_lookup_metadata": metadata, + }, + status=status.HTTP_400_BAD_REQUEST, + ) + else: + # Other failure - return first error + first_error = failed_enrichments[0] + return Response( + { + "error": first_error.get("error", "Look-Up execution failed"), + "_lookup_metadata": metadata, + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Serialize response + response_serializer = LookupExecutionResponseSerializer(data=result) + response_serializer.is_valid(raise_exception=True) + + return Response(response_serializer.validated_data, status=status.HTTP_200_OK) + + except ExtractionNotCompleteError as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception: + logger.exception(f"Error executing Look-Up project {project.id}") + return Response( + {"error": "Internal error during execution"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + @action(detail=True, methods=["post"], parser_classes=[MultiPartParser]) + def upload_reference_data(self, request, pk=None): + """Upload reference data for a Look-Up project. + + POST /api/v1/lookup-projects/{id}/upload_reference_data/ + """ + project = self.get_object() + + upload_serializer = ReferenceDataUploadSerializer(data=request.data) + upload_serializer.is_valid(raise_exception=True) + + file = upload_serializer.validated_data["file"] + extract_text = upload_serializer.validated_data["extract_text"] + + try: + from utils.file_storage.helpers.prompt_studio_file_helper import ( + PromptStudioFileHelper, + ) + from utils.user_context import UserContext + + # Determine file type from extension + file_ext = file.name.split(".")[-1].lower() + file_type = ( + file_ext + if file_ext in ["pdf", "xlsx", "csv", "docx", "txt", "json"] + else "txt" + ) + + org_id = UserContext.get_organization_identifier() + # Use a fixed user_id for lookup uploads to match PS path structure + # Path: {base_path}/{org_id}/{user_id}/{tool_id}/{filename} + user_id = "lookup" + tool_id = str(project.id) + + logger.info( + f"Upload via PromptStudioFileHelper: org_id={org_id}, " + f"user_id={user_id}, tool_id={tool_id}, file={file.name}" + ) + + # Use PromptStudioFileHelper - exact same code path as working PS upload + PromptStudioFileHelper.upload_for_ide( + org_id=org_id, + user_id=user_id, + tool_id=tool_id, + file_name=file.name, + file_data=file, + ) + + # Build file_path for database record (matching PS helper's path structure) + from pathlib import Path + + from utils.file_storage.constants import FileStorageConstants + + from unstract.core.utilities import UnstractUtils + + base_path = UnstractUtils.get_env( + env_key=FileStorageConstants.REMOTE_PROMPT_STUDIO_FILE_PATH + ) + file_path = str(Path(base_path) / org_id / user_id / tool_id / file.name) + + logger.info(f"Uploaded file to storage: {file_path}") + + # Create a data source record + data_source = LookupDataSource.objects.create( + project=project, + file_name=file.name, + file_path=file_path, + file_size=file.size, + file_type=file_type, + extraction_status="pending" if extract_text else "completed", + uploaded_by=request.user, + ) + + serializer = LookupDataSourceSerializer(data_source) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + except Exception as e: + logger.exception( + f"Error uploading reference data for project {project.id}: {e}" + ) + return Response( + {"error": f"Failed to upload file: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + @action(detail=True, methods=["get"]) + def data_sources(self, request, pk=None): + """List all data sources for a Look-Up project. + + GET /api/v1/lookup-projects/{id}/data_sources/ + """ + project = self.get_object() + data_sources = project.data_sources.all().order_by("-version_number") + + # Filter by is_latest if requested + is_latest = request.query_params.get("is_latest") + if is_latest is not None: + data_sources = data_sources.filter(is_latest=is_latest.lower() == "true") + + serializer = LookupDataSourceSerializer(data_sources, many=True) + return Response(serializer.data) + + @action(detail=True, methods=["post"]) + def cleanup_stale_indexes(self, request, pk=None): + """Manually trigger cleanup of stale vector DB indexes for a project. + + POST /api/v1/lookup-projects/{id}/cleanup_stale_indexes/ + + Cleans up vector DB nodes that are marked as stale or no longer needed. + This is useful for reclaiming storage and ensuring data consistency. + + Returns: + Summary of cleanup operations performed. + """ + project = self.get_object() + + try: + from lookup.models import LookupIndexManager + from lookup.services.vector_db_cleanup_service import ( + VectorDBCleanupService, + ) + + cleanup_service = VectorDBCleanupService() + total_deleted = 0 + total_failed = 0 + errors = [] + + # Get all index managers for this project that need cleanup + index_managers = LookupIndexManager.objects.filter( + data_source__project=project + ).select_related("profile_manager", "data_source") + + for index_manager in index_managers: + if not index_manager.profile_manager: + continue + + # Clean up stale indexes (keeping only the current one) + result = cleanup_service.cleanup_stale_indexes( + index_manager=index_manager, keep_current=True + ) + total_deleted += result.get("deleted", 0) + total_failed += result.get("failed", 0) + if result.get("errors"): + errors.extend(result["errors"]) + + # Reset reindex_required flag if cleanup was successful + if result.get("success"): + index_manager.reindex_required = False + index_manager.save(update_fields=["reindex_required"]) + + return Response( + { + "message": "Cleanup completed", + "project_id": str(project.id), + "indexes_deleted": total_deleted, + "indexes_failed": total_failed, + "errors": errors[:10] if errors else [], # Limit errors shown + }, + status=status.HTTP_200_OK, + ) + + except Exception as e: + logger.exception(f"Error during cleanup for project {project.id}") + return Response( + {"error": f"Cleanup failed: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + @action(detail=True, methods=["post"]) + def index_all(self, request, pk=None): + """Index all reference data using the project's default profile. + + POST /api/v1/lookup-projects/{id}/index_all/ + + Triggers indexing of all completed data sources using the + configured default profile's adapters and settings. + + This calls external extraction and indexing services via PromptTool SDK. + """ + project = self.get_object() + + try: + from utils.user_context import UserContext + + from .exceptions import DefaultProfileError + from .services import IndexingService + + # Get organization and user context + org_id = UserContext.get_organization_identifier() + user_id = str(request.user.user_id) if request.user else None + + logger.info( + f"Starting indexing for project {project.id} " + f"(org: {org_id}, user: {user_id})" + ) + + # Index all using default profile + results = IndexingService.index_with_default_profile( + project_id=str(project.id), org_id=org_id, user_id=user_id + ) + + logger.info( + f"Indexing completed for project {project.id}: " + f"{results['success']} successful, {results['failed']} failed" + ) + + return Response( + {"message": "Indexing completed", "results": results}, + status=status.HTTP_200_OK, + ) + + except DefaultProfileError as e: + logger.error(f"Default profile error for project {project.id}: {e}") + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception as e: + logger.exception(f"Error indexing reference data for project {project.id}") + return Response( + {"error": f"Failed to index reference data: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class LookupPromptTemplateViewSet(viewsets.ModelViewSet): + """ViewSet for managing Look-Up prompt templates. + + Provides CRUD operations and template validation. + """ + + queryset = LookupPromptTemplate.objects.all() + serializer_class = LookupPromptTemplateSerializer + permission_classes = [IsAuthenticated] + pagination_class = CustomPagination + + def get_queryset(self): + """Filter templates by active status if requested.""" + queryset = super().get_queryset() + is_active = self.request.query_params.get("is_active") + + if is_active is not None: + queryset = queryset.filter(is_active=is_active.lower() == "true") + + return queryset + + def perform_create(self, serializer): + """Set created_by from request and update project's template reference.""" + template = serializer.save(created_by=self.request.user) + # Update the project's template field to point to this template + if template.project: + template.project.template = template + template.project.save(update_fields=["template"]) + + def perform_update(self, serializer): + """Update template and ensure project reference is maintained.""" + template = serializer.save() + # Ensure the project's template field points to this template + if template.project and template.project.template != template: + template.project.template = template + template.project.save(update_fields=["template"]) + + @action(detail=False, methods=["post"]) + def validate(self, request): + """Validate a template with optional sample data. + + POST /api/v1/lookup-templates/validate/ + """ + validator = TemplateValidationSerializer(data=request.data) + validator.is_valid(raise_exception=True) + + template_text = validator.validated_data["template_text"] + sample_data = validator.validated_data.get("sample_data", {}) + sample_reference = validator.validated_data.get("sample_reference", "") + + try: + # Test variable resolution + resolver = VariableResolver(sample_data, sample_reference) + resolved = resolver.resolve(template_text) + + return Response( + { + "valid": True, + "resolved_template": resolved[:1000], # First 1000 chars + "variables_found": list(resolver.get_all_variables(template_text)), + } + ) + + except Exception as e: + return Response( + {"valid": False, "error": str(e)}, status=status.HTTP_400_BAD_REQUEST + ) + + +class PromptStudioLookupLinkViewSet(viewsets.ModelViewSet): + """ViewSet for managing links between Prompt Studio projects and Look-Ups.""" + + queryset = PromptStudioLookupLink.objects.all() + serializer_class = PromptStudioLookupLinkSerializer + permission_classes = [IsAuthenticated] + pagination_class = CustomPagination + + def get_queryset(self): + """Filter links by PS project if requested.""" + queryset = super().get_queryset() + + ps_project_id = self.request.query_params.get("prompt_studio_project_id") + if ps_project_id: + queryset = queryset.filter(prompt_studio_project_id=ps_project_id) + + lookup_project_id = self.request.query_params.get("lookup_project_id") + if lookup_project_id: + queryset = queryset.filter(lookup_project_id=lookup_project_id) + + return queryset.select_related("lookup_project") + + @action(detail=False, methods=["post"]) + def bulk_link(self, request): + """Create or remove multiple links at once. + + POST /api/v1/lookup-links/bulk_link/ + """ + serializer = BulkLinkSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + + ps_project_id = serializer.validated_data["prompt_studio_project_id"] + lookup_project_ids = serializer.validated_data["lookup_project_ids"] + unlink = serializer.validated_data["unlink"] + + results = [] + + with transaction.atomic(): + for lookup_id in lookup_project_ids: + if unlink: + # Remove link + deleted_count, _ = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=ps_project_id, + lookup_project_id=lookup_id, + ).delete() + results.append( + { + "lookup_project_id": str(lookup_id), + "unlinked": deleted_count > 0, + } + ) + else: + # Create link + link, created = PromptStudioLookupLink.objects.get_or_create( + prompt_studio_project_id=ps_project_id, + lookup_project_id=lookup_id, + ) + results.append( + { + "lookup_project_id": str(lookup_id), + "linked": created, + "link_id": str(link.id) if created else None, + } + ) + + return Response({"results": results, "total_processed": len(results)}) + + +class LookupExecutionAuditViewSet(viewsets.ReadOnlyModelViewSet): + """ViewSet for viewing execution audit records. + + Read-only access to execution history and statistics. + """ + + queryset = LookupExecutionAudit.objects.all() + serializer_class = LookupExecutionAuditSerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + """Filter audit records by various parameters.""" + queryset = super().get_queryset() + + # Filter by Look-Up project + lookup_project_id = self.request.query_params.get("lookup_project_id") + if lookup_project_id: + queryset = queryset.filter(lookup_project_id=lookup_project_id) + + # Filter by PS project + ps_project_id = self.request.query_params.get("prompt_studio_project_id") + if ps_project_id: + queryset = queryset.filter(prompt_studio_project_id=ps_project_id) + + # Filter by execution ID + execution_id = self.request.query_params.get("execution_id") + if execution_id: + queryset = queryset.filter(execution_id=execution_id) + + # Filter by status + status_filter = self.request.query_params.get("status") + if status_filter: + queryset = queryset.filter(status=status_filter) + + return queryset.select_related("lookup_project").order_by("-executed_at") + + @action(detail=False, methods=["get"]) + def statistics(self, request): + """Get execution statistics for a project. + + GET /api/v1/execution-audits/statistics/?lookup_project_id={id} + """ + lookup_project_id = request.query_params.get("lookup_project_id") + if not lookup_project_id: + return Response( + {"error": "lookup_project_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + audit_logger = AuditLogger() + stats = audit_logger.get_project_stats( + project_id=uuid.UUID(lookup_project_id), limit=1000 + ) + return Response(stats) + + except ValueError: + return Response( + {"error": "Invalid UUID format"}, status=status.HTTP_400_BAD_REQUEST + ) + + @action(detail=False, methods=["get"]) + def by_file_execution(self, request): + """Get Look-up audits for a specific file execution. + + GET /api/v1/execution-audits/by_file_execution/?file_execution_id={id} + + This endpoint is used by the Nav Bar Logs page to show Look-up + enrichment details for a specific file processed in ETL/Workflow/API. + """ + file_execution_id = request.query_params.get("file_execution_id") + if not file_execution_id: + return Response( + {"error": "file_execution_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + audits = self.get_queryset().filter(file_execution_id=file_execution_id) + serializer = self.get_serializer(audits, many=True) + return Response(serializer.data) + except ValueError: + return Response( + {"error": "Invalid UUID format"}, status=status.HTTP_400_BAD_REQUEST + ) + + @action(detail=False, methods=["get"]) + def by_workflow_execution(self, request): + """Get Look-up audits for an entire workflow execution. + + GET /api/v1/execution-audits/by_workflow_execution/?workflow_execution_id={id} + + This endpoint returns all Look-up audits across all files + processed in a workflow execution. + """ + workflow_execution_id = request.query_params.get("workflow_execution_id") + if not workflow_execution_id: + return Response( + {"error": "workflow_execution_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + from workflow_manager.file_execution.models import WorkflowFileExecution + + # Get all file execution IDs for this workflow + file_execution_ids = WorkflowFileExecution.objects.filter( + workflow_execution_id=workflow_execution_id + ).values_list("id", flat=True) + + audits = self.get_queryset().filter(file_execution_id__in=file_execution_ids) + serializer = self.get_serializer(audits, many=True) + return Response(serializer.data) + except ValueError: + return Response( + {"error": "Invalid UUID format"}, status=status.HTTP_400_BAD_REQUEST + ) + + +class LookupDebugView(viewsets.ViewSet): + """Debug endpoints for testing Look-Up execution with Prompt Studio.""" + + permission_classes = [IsAuthenticated] + + @action(detail=False, methods=["post"]) + def enrich_ps_output(self, request): + """Enrich Prompt Studio extracted output with linked Look-Ups. + + Uses real LLM clients configured in each Look-Up project's profile. + + POST /api/v1/lookup-debug/enrich_ps_output/ + + Request body: + { + "prompt_studio_project_id": "uuid", + "extracted_data": {"vendor_name": "Amzn Web Services Inc", ...} + } + + Response: + { + "original_data": {...}, + "enriched_data": {...}, + "lookup_enrichment": {...}, + "_lookup_metadata": {...} + } + """ + ps_project_id = request.data.get("prompt_studio_project_id") + extracted_data = request.data.get("extracted_data", {}) + + if not ps_project_id: + return Response( + {"error": "prompt_studio_project_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if not extracted_data: + return Response( + {"error": "extracted_data is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + # Get linked Look-Ups + links = ( + PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=ps_project_id + ) + .select_related("lookup_project") + .order_by("execution_order") + ) + + if not links: + return Response( + { + "original_data": extracted_data, + "enriched_data": extracted_data, + "lookup_enrichment": {}, + "_lookup_metadata": { + "lookups_executed": 0, + "message": "No Look-Ups linked to this Prompt Studio project", + }, + } + ) + + # Get Look-Up projects (already ordered by execution_order from query) + lookup_projects = [link.lookup_project for link in links] + + # Initialize services with real clients + from utils.user_context import UserContext + + from .integrations.file_storage_client import FileStorageClient + from .integrations.unstract_llm_client import UnstractLLMClient + + # Get organization ID for RAG retrieval + org_id = UserContext.get_organization_identifier() + + storage_client = FileStorageClient() + cache = LLMResponseCache() + ref_loader = ReferenceDataLoader(storage_client) + merger = EnrichmentMerger() + + # Build project order mapping for sorting results later + # This ensures enrichments are merged in execution_order priority + project_order = { + str(project.id): idx for idx, project in enumerate(lookup_projects) + } + + # Execute each Look-Up with its own LLM profile + # Collect results with project IDs for proper ordering + enrichment_results = [] + all_metadata = {"lookups_executed": 0, "lookup_details": []} + + for project in lookup_projects: + # Get profile for this project + profile = LookupProfileManager.objects.filter( + lookup_project=project + ).first() + + if not profile or not profile.llm: + all_metadata["lookup_details"].append( + { + "project_id": str(project.id), + "project_name": project.name, + "status": "skipped", + "reason": "No LLM profile configured", + } + ) + continue + + try: + # Create LLM client for this project's profile + llm_client = UnstractLLMClient(profile.llm) + + executor = LookUpExecutor( + variable_resolver=VariableResolver, + cache_manager=cache, + reference_loader=ref_loader, + llm_client=llm_client, + org_id=org_id, + ) + + orchestrator = LookUpOrchestrator(executor=executor, merger=merger) + + # Execute Look-Up with extracted data as input + result = orchestrator.execute_lookups( + input_data=extracted_data, lookup_projects=[project] + ) + + # Collect results with project ID for ordering + if result.get("lookup_enrichment"): + enrichment_results.append( + { + "project_id": str(project.id), + "enrichment": result["lookup_enrichment"], + } + ) + + all_metadata["lookups_executed"] += 1 + all_metadata["lookup_details"].append( + { + "project_id": str(project.id), + "project_name": project.name, + "status": "success", + "enrichment_keys": list( + result.get("lookup_enrichment", {}).keys() + ), + } + ) + + except Exception as e: + logger.exception(f"Error executing Look-Up {project.id}") + all_metadata["lookup_details"].append( + { + "project_id": str(project.id), + "project_name": project.name, + "status": "error", + "error": str(e), + } + ) + + # Sort enrichment results by execution order (first lookup has priority) + enrichment_results.sort( + key=lambda x: project_order.get(x.get("project_id"), 999) + ) + + # Merge enrichments in REVERSE order so first lookup wins + # (later updates overwrite, so process highest priority last) + all_enrichment = {} + for result in reversed(enrichment_results): + all_enrichment.update(result["enrichment"]) + + # Merge enrichment into extracted data + enriched_data = {**extracted_data, **all_enrichment} + + return Response( + { + "original_data": extracted_data, + "enriched_data": enriched_data, + "lookup_enrichment": all_enrichment, + "_lookup_metadata": all_metadata, + } + ) + + except Exception as e: + logger.exception(f"Error enriching PS output for project {ps_project_id}") + return Response( + {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + + @action(detail=False, methods=["post"]) + def test_with_ps_project(self, request): + """Test Look-Up execution with a Prompt Studio project context. + + POST /api/v1/lookup-debug/test_with_ps_project/ + """ + ps_project_id = request.data.get("prompt_studio_project_id") + input_data = request.data.get("input_data", {}) + + if not ps_project_id: + return Response( + {"error": "prompt_studio_project_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + # Get linked Look-Ups + links = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=ps_project_id + ).select_related("lookup_project") + + if not links: + return Response( + { + "message": "No Look-Ups linked to this Prompt Studio project", + "lookup_enrichment": {}, + "_lookup_metadata": {"lookups_executed": 0}, + } + ) + + # Get Look-Up projects + lookup_projects = [link.lookup_project for link in links] + + # Initialize services + from utils.user_context import UserContext + + from .services.mock_clients import MockLLMClient, MockStorageClient + + # Get organization ID for RAG retrieval + org_id = UserContext.get_organization_identifier() + + llm_client = MockLLMClient() + storage_client = MockStorageClient() + cache = LLMResponseCache() + ref_loader = ReferenceDataLoader(storage_client) + merger = EnrichmentMerger() + + executor = LookUpExecutor( + variable_resolver=VariableResolver, + cache_manager=cache, + reference_loader=ref_loader, + llm_client=llm_client, + org_id=org_id, + ) + + orchestrator = LookUpOrchestrator(executor=executor, merger=merger) + + # Execute Look-Ups + result = orchestrator.execute_lookups( + input_data=input_data, lookup_projects=lookup_projects + ) + + return Response(result) + + except Exception as e: + logger.exception(f"Error in debug execution for PS project {ps_project_id}") + return Response( + {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + + @action(detail=False, methods=["post"]) + def check_indexing_status(self, request): + """Check indexing status and optionally test vector DB retrieval. + + POST /api/v1/lookup-debug/check_indexing_status/ + + Request body: + { + "project_id": "uuid", + "test_query": "optional query to test retrieval" + } + + Returns detailed status of data sources, index managers, and vector DB. + """ + from lookup.models import LookupIndexManager + + project_id = request.data.get("project_id") + test_query = request.data.get("test_query") + + if not project_id: + return Response( + {"error": "project_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + project = LookupProject.objects.get(id=project_id) + except LookupProject.DoesNotExist: + return Response( + {"error": f"Project not found: {project_id}"}, + status=status.HTTP_404_NOT_FOUND, + ) + + # Get profile + try: + profile = LookupProfileManager.get_default_profile(project) + except Exception as e: + return Response( + {"error": f"No default profile: {e}"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + result = { + "project": { + "id": str(project.id), + "name": project.name, + }, + "profile": { + "name": profile.profile_name, + "chunk_size": profile.chunk_size, + "chunk_overlap": profile.chunk_overlap, + "similarity_top_k": profile.similarity_top_k, + "vector_store_id": str(profile.vector_store.id), + "embedding_model_id": str(profile.embedding_model.id), + "rag_enabled": profile.chunk_size > 0, + }, + "data_sources": [], + "index_managers": [], + "retrieval_test": None, + } + + # Check data sources + data_sources = LookupDataSource.objects.filter(project_id=project_id).order_by( + "-created_at" + ) + + for ds in data_sources: + result["data_sources"].append( + { + "id": str(ds.id), + "file_name": ds.file_name, + "extraction_status": ds.extraction_status, + "is_latest": ds.is_latest, + "file_path": ds.file_path, + } + ) + + # Check index managers + index_managers = LookupIndexManager.objects.filter( + data_source__project_id=project_id, + profile_manager=profile, + ).select_related("data_source") + + for im in index_managers: + result["index_managers"].append( + { + "data_source": im.data_source.file_name, + "raw_index_id": im.raw_index_id, + "has_index": im.raw_index_id is not None, + "extraction_status": im.extraction_status, + "index_ids_history": im.index_ids_history, + } + ) + + # Test retrieval if query provided + if test_query and profile.chunk_size > 0: + try: + from utils.user_context import UserContext + + from lookup.services.lookup_retrieval_service import ( + LookupRetrievalService, + ) + + org_id = UserContext.get_organization_identifier() + service = LookupRetrievalService(profile, org_id=org_id) + context = service.retrieve_context(test_query, str(project.id)) + + result["retrieval_test"] = { + "query": test_query, + "success": bool(context), + "context_length": len(context) if context else 0, + "context_preview": context[:500] if context else None, + } + except Exception as e: + logger.exception("Retrieval test failed") + result["retrieval_test"] = { + "query": test_query, + "success": False, + "error": str(e), + } + + return Response(result) + + @action(detail=False, methods=["post"]) + def force_reindex(self, request): + """Force re-indexing of all data sources for a project. + + POST /api/v1/lookup-debug/force_reindex/ + + Request body: + { + "project_id": "uuid" + } + """ + from utils.user_context import UserContext + + from lookup.services.indexing_service import IndexingService + + project_id = request.data.get("project_id") + + if not project_id: + return Response( + {"error": "project_id is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + org_id = UserContext.get_organization_identifier() + user_id = str(request.user.id) if request.user else None + + result = IndexingService.index_with_default_profile( + project_id=project_id, + org_id=org_id, + user_id=user_id, + ) + + return Response( + { + "status": "success", + "total": result["total"], + "success": result["success"], + "failed": result["failed"], + "errors": result.get("errors", []), + } + ) + + except Exception as e: + logger.exception(f"Re-indexing failed for project {project_id}") + return Response( + {"error": str(e)}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class LookupDataSourceViewSet(viewsets.ModelViewSet): + """ViewSet for managing LookupDataSource instances. + + Provides CRUD operations for Look-Up data sources (reference data files). + Supports listing, retrieving, and deleting data sources. + """ + + queryset = LookupDataSource.objects.all() + serializer_class = LookupDataSourceSerializer + permission_classes = [IsAuthenticated] + pagination_class = CustomPagination + + @action(detail=True, methods=["post"]) + def cleanup_indexes(self, request, pk=None): + """Clean up vector DB indexes for a specific data source. + + POST /api/v1/data-sources/{id}/cleanup_indexes/ + + Cleans up all vector DB nodes associated with this data source. + Use this before re-uploading a data source or when indexes are corrupted. + + Returns: + Summary of cleanup operations performed. + """ + data_source = self.get_object() + + try: + from lookup.models import LookupIndexManager + from lookup.services.vector_db_cleanup_service import ( + VectorDBCleanupService, + ) + + cleanup_service = VectorDBCleanupService() + total_deleted = 0 + total_failed = 0 + errors = [] + + # Get all index managers for this data source + index_managers = LookupIndexManager.objects.filter( + data_source=data_source + ).select_related("profile_manager") + + for index_manager in index_managers: + if not index_manager.profile_manager: + continue + + if index_manager.index_ids_history: + result = cleanup_service.cleanup_index_ids( + index_ids=index_manager.index_ids_history, + vector_db_instance_id=str( + index_manager.profile_manager.vector_store_id + ), + ) + total_deleted += result.get("deleted", 0) + total_failed += result.get("failed", 0) + if result.get("errors"): + errors.extend(result["errors"]) + + # Clear history if cleanup was successful + if result.get("success"): + index_manager.index_ids_history = [] + index_manager.raw_index_id = None + index_manager.reindex_required = True + index_manager.save( + update_fields=[ + "index_ids_history", + "raw_index_id", + "reindex_required", + ] + ) + + return Response( + { + "message": "Cleanup completed", + "data_source_id": str(data_source.id), + "file_name": data_source.file_name, + "indexes_deleted": total_deleted, + "indexes_failed": total_failed, + "errors": errors[:10] if errors else [], + }, + status=status.HTTP_200_OK, + ) + + except Exception as e: + logger.exception(f"Error during cleanup for data source {data_source.id}") + return Response( + {"error": f"Cleanup failed: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + def get_queryset(self): + """Filter data sources by project if specified.""" + queryset = super().get_queryset() + + # Filter by project if provided + project_id = self.request.query_params.get("project") + if project_id: + queryset = queryset.filter(project_id=project_id) + + # Filter by is_latest if requested + is_latest = self.request.query_params.get("is_latest") + if is_latest is not None: + queryset = queryset.filter(is_latest=is_latest.lower() == "true") + + return queryset.select_related("project", "uploaded_by").order_by( + "-version_number" + ) + + def destroy(self, request, *args, **kwargs): + """Delete a data source and its associated files from storage. + + DELETE /api/v1/unstract/{org_id}/data-sources/{id}/ + """ + instance = self.get_object() + + try: + from utils.file_storage.constants import FileStorageKeys + + from unstract.sdk1.file_storage.constants import StorageType + from unstract.sdk1.file_storage.env_helper import EnvHelper + + # Get file storage instance + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + # Delete the file from storage if it exists + if instance.file_path: + try: + if fs_instance.exists(instance.file_path): + fs_instance.rm(instance.file_path) + logger.info(f"Deleted file from storage: {instance.file_path}") + except Exception as e: + logger.warning(f"Failed to delete file from storage: {e}") + + # Delete extracted content if it exists + if instance.extracted_content_path: + try: + if fs_instance.exists(instance.extracted_content_path): + fs_instance.rm(instance.extracted_content_path) + logger.info( + f"Deleted extracted content: {instance.extracted_content_path}" + ) + except Exception as e: + logger.warning(f"Failed to delete extracted content: {e}") + + # Delete associated index manager if exists + try: + from .models import LookupIndexManager + + LookupIndexManager.objects.filter(data_source=instance).delete() + logger.info(f"Deleted index manager for data source: {instance.id}") + except Exception as e: + logger.warning(f"Failed to delete index manager: {e}") + + # Delete the database record + instance.delete() + + logger.info(f"Successfully deleted data source: {instance.id}") + return Response(status=status.HTTP_204_NO_CONTENT) + + except Exception as e: + logger.exception(f"Error deleting data source {instance.id}") + return Response( + {"error": f"Failed to delete data source: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class LookupProfileManagerViewSet(viewsets.ModelViewSet): + """ViewSet for managing LookupProfileManager instances. + + Provides CRUD operations for Look-Up project profiles. + Each profile defines the set of adapters to use for a Look-Up project. + + Follows the same pattern as Prompt Studio's ProfileManagerView. + """ + + versioning_class = URLPathVersioning + permission_classes = [IsOwner] + serializer_class = LookupProfileManagerSerializer + pagination_class = CustomPagination + + def get_queryset(self): + """Filter queryset by created_by if specified in query params. + Otherwise return all profiles the user has access to. + """ + filter_args = FilterHelper.build_filter_args( + self.request, + LookupProfileManagerKeys.CREATED_BY, + ) + if filter_args: + queryset = LookupProfileManager.objects.filter(**filter_args) + else: + queryset = LookupProfileManager.objects.all() + + # Filter by lookup_project if provided in query params + lookup_project_id = self.request.query_params.get("lookup_project") + if lookup_project_id: + queryset = queryset.filter(lookup_project_id=lookup_project_id) + + return queryset.order_by("-created_at") + + def create(self, request, *args, **kwargs): + """Create a new profile. + + Handles IntegrityError for duplicate profile names within the same project. + """ + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + try: + self.perform_create(serializer) + except IntegrityError: + raise DuplicateData(LookupProfileManagerErrors.PROFILE_NAME_EXISTS) + + return Response(serializer.data, status=status.HTTP_201_CREATED) + + @action(detail=False, methods=["get"], url_path="default") + def get_default(self, request): + """Get the default profile for a lookup project. + + Query params: + - lookup_project: UUID of the lookup project (required) + + Returns: + Profile data or 404 if no default profile exists + """ + lookup_project_id = request.query_params.get("lookup_project") + + if not lookup_project_id: + return Response( + {"error": "lookup_project query parameter is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + project = LookupProject.objects.get(id=lookup_project_id) + profile = LookupProfileManager.get_default_profile(project) + serializer = self.get_serializer(profile) + return Response(serializer.data) + except LookupProject.DoesNotExist: + return Response( + {"error": "Lookup project not found"}, status=status.HTTP_404_NOT_FOUND + ) + except Exception as e: + return Response({"error": str(e)}, status=status.HTTP_404_NOT_FOUND) + + @action(detail=True, methods=["post"], url_path="set-default") + def set_default(self, request, pk=None): + """Set a profile as the default for its project. + + Unsets any existing default profile for the same project. + """ + profile = self.get_object() + + with transaction.atomic(): + # Unset existing default for this project + LookupProfileManager.objects.filter( + lookup_project=profile.lookup_project, is_default=True + ).update(is_default=False) + + # Set this profile as default + profile.is_default = True + profile.save() + + serializer = self.get_serializer(profile) + return Response(serializer.data) + + def partial_update(self, request, *args, **kwargs): + """Update a profile and mark indexes as stale if RAG settings changed. + + When chunk_size, chunk_overlap, embedding_model, or vector_store + are changed, existing indexes become stale and need re-indexing. + """ + profile = self.get_object() + + # Track original values for RAG-relevant fields + original_values = { + "chunk_size": profile.chunk_size, + "chunk_overlap": profile.chunk_overlap, + "embedding_model": str(profile.embedding_model_id) + if profile.embedding_model + else None, + "vector_store": str(profile.vector_store_id) + if profile.vector_store + else None, + } + + # Perform the update + response = super().partial_update(request, *args, **kwargs) + + # Check if any RAG-relevant fields changed + if response.status_code == status.HTTP_200_OK: + profile.refresh_from_db() + + new_values = { + "chunk_size": profile.chunk_size, + "chunk_overlap": profile.chunk_overlap, + "embedding_model": str(profile.embedding_model_id) + if profile.embedding_model + else None, + "vector_store": str(profile.vector_store_id) + if profile.vector_store + else None, + } + + # Determine if re-indexing is needed + rag_settings_changed = original_values != new_values + was_rag_mode = ( + original_values["chunk_size"] and original_values["chunk_size"] > 0 + ) + is_rag_mode = new_values["chunk_size"] and new_values["chunk_size"] > 0 + + if rag_settings_changed and (was_rag_mode or is_rag_mode): + # Mark all indexes for this profile as requiring re-index + from lookup.models import LookupIndexManager + + updated_count = LookupIndexManager.objects.filter( + profile_manager=profile + ).update(reindex_required=True) + + if updated_count > 0: + logger.info( + f"Marked {updated_count} index(es) as requiring re-index " + f"for profile {profile.profile_name}" + ) + + # If switching from RAG to full context mode, clean up old indexes + if was_rag_mode and not is_rag_mode: + from lookup.services.vector_db_cleanup_service import ( + VectorDBCleanupService, + ) + + cleanup_service = VectorDBCleanupService() + cleanup_result = cleanup_service.cleanup_for_profile( + str(profile.profile_id) + ) + logger.info( + f"Cleaned up {cleanup_result['deleted']} index(es) " + f"after switching profile {profile.profile_name} to full context mode" + ) + + return response diff --git a/backend/prompt_studio/prompt_studio_core_v2/views.py b/backend/prompt_studio/prompt_studio_core_v2/views.py index 5e1f0d2a3f..0083e1ea9a 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/views.py +++ b/backend/prompt_studio/prompt_studio_core_v2/views.py @@ -399,6 +399,10 @@ def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response: Returns: Response """ + from prompt_studio.prompt_studio_output_manager_v2.output_manager_helper import ( + LookupEnrichmentError, + ) + custom_tool = self.get_object() tool_id: str = str(custom_tool.tool_id) document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID) @@ -408,16 +412,25 @@ def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response: if not run_id: # Generate a run_id run_id = CommonUtils.generate_uuid() - response: dict[str, Any] = PromptStudioHelper.prompt_responder( - id=id, - tool_id=tool_id, - org_id=UserSessionUtils.get_organization_id(request), - user_id=custom_tool.created_by.user_id, - document_id=document_id, - run_id=run_id, - profile_manager_id=profile_manager, - ) - return Response(response, status=status.HTTP_200_OK) + try: + response: dict[str, Any] = PromptStudioHelper.prompt_responder( + id=id, + tool_id=tool_id, + org_id=UserSessionUtils.get_organization_id(request), + user_id=custom_tool.created_by.user_id, + document_id=document_id, + run_id=run_id, + profile_manager_id=profile_manager, + ) + return Response(response, status=status.HTTP_200_OK) + except LookupEnrichmentError as e: + # Return error response for critical lookup failures + error_response = { + "error": str(e), + "error_type": e.error_type, + **e.details, + } + return Response(error_response, status=status.HTTP_400_BAD_REQUEST) @action(detail=True, methods=["post"]) def single_pass_extraction(self, request: HttpRequest, pk: uuid) -> Response: diff --git a/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_helper.py b/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_helper.py index 405b91e00f..d244a015f3 100644 --- a/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_helper.py +++ b/backend/prompt_studio/prompt_studio_output_manager_v2/output_manager_helper.py @@ -2,7 +2,9 @@ import logging from typing import Any +from account_v2.constants import Common from django.core.exceptions import ObjectDoesNotExist +from utils.local_context import StateStore from prompt_studio.prompt_profile_manager_v2.models import ProfileManager from prompt_studio.prompt_studio_core_v2.exceptions import ( @@ -25,6 +27,166 @@ logger = logging.getLogger(__name__) +def _build_prompt_lookup_map( + prompts: list[ToolStudioPrompt], +) -> dict[str, str]: + """Build mapping of prompt_key to lookup_project_id for prompts with lookups. + + This function determines which prompts have lookup enrichment enabled. + Lookup is enabled for a prompt when `lookup_project` is assigned. + Prompts without `lookup_project` will be skipped (no enrichment). + + Args: + prompts: List of ToolStudioPrompt instances + + Returns: + Dict mapping prompt_key to lookup_project_id (as string) for prompts + that have a lookup_project assigned (lookup enabled). + """ + prompt_lookup_map: dict[str, str] = {} + skipped_prompts: list[str] = [] + + for prompt in prompts: + if prompt.lookup_project_id: + prompt_lookup_map[prompt.prompt_key] = str(prompt.lookup_project_id) + else: + skipped_prompts.append(prompt.prompt_key) + + if skipped_prompts: + logger.debug( + f"Prompts without lookup enabled (no lookup_project): {skipped_prompts}" + ) + if prompt_lookup_map: + logger.info(f"Prompts with lookup enabled: {list(prompt_lookup_map.keys())}") + + return prompt_lookup_map + + +class LookupEnrichmentError(Exception): + """Exception raised when lookup enrichment fails critically. + + This exception is raised for errors that should stop prompt execution + and be displayed to the user, such as context window exceeded errors. + """ + + def __init__( + self, + message: str, + error_type: str | None = None, + details: dict[str, Any] | None = None, + ): + super().__init__(message) + self.error_type = error_type + self.details = details or {} + + +def _try_lookup_enrichment( + tool_id: str, + extracted_data: dict[str, Any], + run_id: str | None = None, + session_id: str | None = None, + doc_name: str | None = None, + prompt_lookup_map: dict[str, str] | None = None, +) -> dict[str, Any]: + """Attempt Lookup enrichment if available. + + This function safely attempts to enrich extracted data using linked + Lookup projects. Returns empty dict if Lookup app is not available. + + Supports prompt-level lookups: if a field has a specific lookup assigned + via prompt_lookup_map, only that lookup will enrich it. Fields without + specific lookups will be SKIPPED (no enrichment applied). + + Args: + tool_id: Prompt Studio project (CustomTool) UUID + extracted_data: Dict of extracted field values from prompts + run_id: Optional execution run ID for tracking + session_id: Optional WebSocket session ID for real-time log emission + doc_name: Optional document name being processed + prompt_lookup_map: Optional mapping of field names (prompt_key) to + specific lookup_project_id for prompt-level lookup support + + Returns: + Dict with 'lookup_enrichment' and '_lookup_metadata' keys, + or empty dict if Lookup is not available or no links exist. + + Raises: + LookupEnrichmentError: When lookup fails with a critical error that + should stop execution (e.g., context window exceeded). + """ + try: + from utils.user_context import UserContext + + from lookup.services.lookup_integration_service import ( + LookupIntegrationService, + ) + + # Get organization ID from user context for RAG retrieval + organization_id = UserContext.get_organization_identifier() + + result = LookupIntegrationService.enrich_if_linked( + prompt_studio_project_id=tool_id, + extracted_data=extracted_data, + run_id=run_id, + session_id=session_id, + doc_name=doc_name, + organization_id=organization_id, + prompt_lookup_map=prompt_lookup_map, + ) + + # Check if any lookups failed with critical errors + metadata = result.get("_lookup_metadata", {}) + enrichments = metadata.get("enrichments", []) + + logger.info(f"Checking {len(enrichments)} enrichments for critical errors") + + for enrichment in enrichments: + logger.info( + f"Enrichment status={enrichment.get('status')}, " + f"error_type={enrichment.get('error_type')}, " + f"error={enrichment.get('error', '')[:100]}" + ) + if enrichment.get("status") == "failed": + error_type = enrichment.get("error_type") + error_msg = enrichment.get("error", "Unknown lookup error") + + # Context window exceeded is a critical error - raise it + if error_type == "context_window_exceeded": + logger.error( + f"Context window exceeded error detected! " + f"Raising LookupEnrichmentError: {error_msg}" + ) + raise LookupEnrichmentError( + message=error_msg, + error_type=error_type, + details={ + "token_count": enrichment.get("token_count"), + "context_limit": enrichment.get("context_limit"), + "model": enrichment.get("model"), + "project_name": enrichment.get("project_name"), + }, + ) + + return result + except ImportError: + # Lookup app not installed + logger.debug("Lookup app not available, skipping enrichment") + return {} + except LookupEnrichmentError: + # Re-raise critical lookup errors to be handled by caller + raise + except Exception as e: + # Don't let non-critical Lookup errors break PS execution + logger.warning(f"Lookup enrichment failed (non-fatal): {e}") + return { + "lookup_enrichment": {}, + "_lookup_metadata": { + "status": "error", + "message": str(e), + }, + } + + class OutputManagerHelper: @staticmethod def handle_prompt_output_update( @@ -39,6 +201,17 @@ def handle_prompt_output_update( """Handles updating prompt outputs in the database and returns serialized data. + This method processes extraction outputs, saves them to the database, + and applies lookup enrichment as a post-processing step. + + Lookup Enrichment Behavior: + - Lookups run as POST-PROCESSING after extraction completes + - Only prompts with `lookup_project` assigned will be enriched + - Prompts without `lookup_project` are skipped (no enrichment) + - Each prompt can have a different lookup project assigned + - Response includes `_lookup_status` for each prompt indicating + whether lookup was enabled and if enrichment was applied + Args: run_id (str): ID of the run. prompts (list[ToolStudioPrompt]): List of prompts to update. @@ -46,11 +219,15 @@ def handle_prompt_output_update( document_id (str): ID of the document. profile_manager_id (Optional[str]): UUID of the profile manager. is_single_pass_extract (bool): Flag indicating if single pass - extract is active. + extract is active. metadata (dict[str, Any]): Metadata for the update. Returns: list[dict[str, Any]]: List of serialized prompt output data. + Each item includes `_lookup_status` with: + - enabled: Whether lookup was configured for this prompt + - lookup_project_id: The assigned lookup project UUID (or None) + - was_enriched: Whether the output was actually enriched """ def update_or_create_prompt_output( @@ -193,6 +370,124 @@ def update_or_create_prompt_output( serializer = PromptStudioOutputSerializer(prompt_output) serialized_data.append(serializer.data) + # Post-processing: Lookup enrichment integration + # Build extracted data dict from all prompt outputs for enrichment + extracted_data_for_lookup: dict[str, Any] = {} + logger.info( + f"Building extracted_data_for_lookup from {len(prompts)} prompts, " + f"outputs keys: {list(outputs.keys())}" + ) + for prompt in prompts: + if prompt.prompt_type == PSOMKeys.NOTES: + logger.debug(f"Skipping NOTES prompt: {prompt.prompt_key}") + continue + output_value = outputs.get(prompt.prompt_key) + logger.info( + f"Prompt {prompt.prompt_key}: output_value={output_value!r} " + f"(type={type(output_value).__name__})" + ) + if output_value is not None: + extracted_data_for_lookup[prompt.prompt_key] = output_value + + logger.info(f"extracted_data_for_lookup: {extracted_data_for_lookup}") + + # Initialize lookup_result for later status tracking + lookup_result: dict[str, Any] = {} + + # Execute Lookup enrichment if linked projects exist + if extracted_data_for_lookup: + tool_id_str = str(tool.tool_id) + logger.info( + f"Calling Lookup enrichment for tool {tool_id_str} " + f"with data: {extracted_data_for_lookup}" + ) + # Get session_id for WebSocket log emission + session_id = StateStore.get(Common.LOG_EVENTS_ID) + doc_name = metadata.get("file_name") or document_manager.document_name + + # Build prompt-level lookup mapping for per-prompt lookup support + prompt_lookup_map = _build_prompt_lookup_map(prompts) + if prompt_lookup_map: + logger.info(f"Using prompt-level lookups: {prompt_lookup_map}") + + lookup_result = _try_lookup_enrichment( + tool_id=tool_id_str, + extracted_data=extracted_data_for_lookup, + run_id=run_id, + session_id=session_id, + doc_name=doc_name, + prompt_lookup_map=prompt_lookup_map, + ) + logger.info(f"Lookup enrichment result: {lookup_result}") + + # Replace output values with enriched values where applicable + if lookup_result: + lookup_enrichment = lookup_result.get("lookup_enrichment", {}) + lookup_metadata = lookup_result.get("_lookup_metadata", {}) + logger.info( + f"Applying lookup_enrichment={lookup_enrichment} " + f"to {len(serialized_data)} items" + ) + + for item in serialized_data: + prompt_key = item.get("prompt_key") + # If this prompt's field was enriched, replace the output value + if prompt_key and prompt_key in lookup_enrichment: + enriched_value = lookup_enrichment[prompt_key] + if enriched_value is not None: + original_value = item.get("output") + logger.info( + f"Replacing {prompt_key} output: " + f"'{original_value}' -> '{enriched_value}'" + ) + # Store original value and enriched value for UI display + item["lookup_replacement"] = { + "original_value": original_value, + "enriched_value": enriched_value, + "field_name": prompt_key, + } + item["output"] = enriched_value + + # Update the database record with the enriched value + # so combined output also shows the correct lookup data + output_manager_id = item.get("prompt_output_id") + if output_manager_id: + try: + PromptStudioOutputManager.objects.filter( + prompt_output_id=output_manager_id + ).update(output=enriched_value) + logger.info( + f"Updated DB record {output_manager_id} " + f"with enriched value for {prompt_key}" + ) + except Exception as db_err: + logger.warning( + f"Failed to update DB with enriched value " + f"for {prompt_key}: {db_err}" + ) + # Add metadata for tracking + item["_lookup_metadata"] = lookup_metadata + + # Add lookup status to each serialized item for debugging/UI + # This indicates whether lookup was enabled and if enrichment was applied + prompt_by_key = {p.prompt_key: p for p in prompts} + lookup_enrichment_keys = set(lookup_result.get("lookup_enrichment", {}).keys()) + for item in serialized_data: + prompt_key = item.get("prompt_key") + prompt = prompt_by_key.get(prompt_key) + if prompt: + lookup_enabled = prompt.lookup_project_id is not None + was_enriched = prompt_key in lookup_enrichment_keys + item["_lookup_status"] = { + "enabled": lookup_enabled, + "lookup_project_id": ( + str(prompt.lookup_project_id) + if prompt.lookup_project_id + else None + ), + "was_enriched": was_enriched, + } + return serialized_data @staticmethod diff --git a/backend/prompt_studio/prompt_studio_output_manager_v2/serializers.py b/backend/prompt_studio/prompt_studio_output_manager_v2/serializers.py index 275e4a0956..c53313712f 100644 --- a/backend/prompt_studio/prompt_studio_output_manager_v2/serializers.py +++ b/backend/prompt_studio/prompt_studio_output_manager_v2/serializers.py @@ -18,6 +18,9 @@ class Meta: def to_representation(self, instance): data = super().to_representation(instance) + # Include prompt_key for frontend to match lookup enrichment + if instance.prompt_id: + data["prompt_key"] = instance.prompt_id.prompt_key try: token_usage = UsageHelper.get_aggregated_token_count(instance.run_id) except Exception as e: diff --git a/backend/prompt_studio/prompt_studio_v2/migrations/0014_toolstudioprompt_lookup_project.py b/backend/prompt_studio/prompt_studio_v2/migrations/0014_toolstudioprompt_lookup_project.py new file mode 100644 index 0000000000..0317e76183 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_v2/migrations/0014_toolstudioprompt_lookup_project.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.1 on 2025-01-21 09:15 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("lookup", "0001_initial"), + ( + "prompt_studio_v2", + "0013_toolstudioprompt_enable_postprocessing_webhook_and_more", + ), + ] + + operations = [ + migrations.AddField( + model_name="toolstudioprompt", + name="lookup_project", + field=models.ForeignKey( + blank=True, + db_comment="Lookup project for this prompt's enrichment. " + "Must be linked at project level via PromptStudioLookupLink.", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="linked_prompts", + to="lookup.lookupproject", + ), + ), + ] diff --git a/backend/prompt_studio/prompt_studio_v2/models.py b/backend/prompt_studio/prompt_studio_v2/models.py index 74739bfe5a..ceebacd8a3 100644 --- a/backend/prompt_studio/prompt_studio_v2/models.py +++ b/backend/prompt_studio/prompt_studio_v2/models.py @@ -139,6 +139,16 @@ class RequiredType(models.TextChoices): postprocessing_webhook_url = models.TextField( blank=True, null=True, db_comment="URL endpoint for postprocessing webhook" ) + # Prompt-level lookup association + lookup_project = models.ForeignKey( + "lookup.LookupProject", + on_delete=models.SET_NULL, + related_name="linked_prompts", + null=True, + blank=True, + db_comment="Lookup project for this prompt's enrichment. " + "Must be linked at project level via PromptStudioLookupLink.", + ) # Eval settings for the prompt # NOTE: # - Field name format is eval__ diff --git a/backend/prompt_studio/prompt_studio_v2/serializers.py b/backend/prompt_studio/prompt_studio_v2/serializers.py index e1adddc33c..9ade56e2a8 100644 --- a/backend/prompt_studio/prompt_studio_v2/serializers.py +++ b/backend/prompt_studio/prompt_studio_v2/serializers.py @@ -6,10 +6,59 @@ class ToolStudioPromptSerializer(AuditSerializer): + """Serializer for ToolStudioPrompt model with lookup project validation.""" + + lookup_project_details = serializers.SerializerMethodField(read_only=True) + class Meta: model = ToolStudioPrompt fields = "__all__" + def get_lookup_project_details(self, obj: ToolStudioPrompt) -> dict | None: + """Return lookup project name and id if set.""" + if obj.lookup_project: + return { + "id": str(obj.lookup_project.id), + "name": obj.lookup_project.name, + } + return None + + def validate_lookup_project(self, value): + """Validate that the lookup project is linked to the PS project. + + The selected lookup project must be linked at the project level + via PromptStudioLookupLink before it can be assigned to a prompt. + """ + if value is None: + return value + + # Get tool_id from instance (update) or initial data (create) + tool_id = None + if self.instance: + tool_id = self.instance.tool_id_id + elif "tool_id" in self.initial_data: + tool_id = self.initial_data["tool_id"] + + if tool_id: + try: + from lookup.models import PromptStudioLookupLink + + link_exists = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=tool_id, + lookup_project=value, + ).exists() + + if not link_exists: + raise serializers.ValidationError( + "Selected lookup project must be linked to this " + "Prompt Studio project at the project level first." + ) + except ImportError: + # Lookup app not installed, skip validation + pass + + return value + class ToolStudioIndexSerializer(serializers.Serializer): file_name = serializers.CharField() diff --git a/backend/prompt_studio/prompt_studio_v2/urls.py b/backend/prompt_studio/prompt_studio_v2/urls.py index 23e5f02438..767c7f209e 100644 --- a/backend/prompt_studio/prompt_studio_v2/urls.py +++ b/backend/prompt_studio/prompt_studio_v2/urls.py @@ -13,6 +13,7 @@ ) reorder_prompts = ToolStudioPromptView.as_view({"post": "reorder_prompts"}) +available_lookups = ToolStudioPromptView.as_view({"get": "available_lookups"}) urlpatterns = format_suffix_patterns( [ @@ -26,5 +27,10 @@ reorder_prompts, name="reorder_prompts", ), + path( + "prompt/available_lookups/", + available_lookups, + name="available_lookups", + ), ] ) diff --git a/backend/prompt_studio/prompt_studio_v2/views.py b/backend/prompt_studio/prompt_studio_v2/views.py index f120baf15f..95483a1fd4 100644 --- a/backend/prompt_studio/prompt_studio_v2/views.py +++ b/backend/prompt_studio/prompt_studio_v2/views.py @@ -53,3 +53,50 @@ def reorder_prompts(self, request: Request) -> Response: """ prompt_studio_controller = PromptStudioController() return prompt_studio_controller.reorder_prompts(request, ToolStudioPrompt) + + @action(detail=False, methods=["get"]) + def available_lookups(self, request: Request) -> Response: + """Get lookup projects linked to a Prompt Studio project. + + Returns the list of lookup projects that are linked at the project level + and can be assigned to individual prompts for enrichment. + + Query Parameters: + tool_id: UUID of the Prompt Studio project (CustomTool) + + Returns: + Response: List of available lookup projects with id, name, and is_ready status + """ + tool_id = request.query_params.get("tool_id") + if not tool_id: + return Response( + {"error": "tool_id query parameter is required"}, + status=400, + ) + + try: + from lookup.models import PromptStudioLookupLink + + links = PromptStudioLookupLink.objects.filter( + prompt_studio_project_id=tool_id + ).select_related("lookup_project") + + available_lookups = [ + { + "id": str(link.lookup_project.id), + "name": link.lookup_project.name, + "is_ready": link.lookup_project.is_ready, + } + for link in links + if link.lookup_project.is_active + ] + + return Response(available_lookups) + except ImportError: + # Lookup app not installed + return Response([]) + except Exception as e: + return Response( + {"error": f"Failed to fetch available lookups: {str(e)}"}, + status=500, + ) diff --git a/backend/workflow_manager/workflow_v2/file_execution_tasks.py b/backend/workflow_manager/workflow_v2/file_execution_tasks.py index a9b3ab8717..d9e7833579 100644 --- a/backend/workflow_manager/workflow_v2/file_execution_tasks.py +++ b/backend/workflow_manager/workflow_v2/file_execution_tasks.py @@ -1033,6 +1033,14 @@ def _process_final_output( error=processing_error, ) + # Execute Look-up enrichment if configured and no processing error + if output_result and not processing_error: + output_result = cls._try_lookup_enrichment( + workflow=workflow, + output_result=output_result, + file_execution_id=file_execution_id, + ) + if destination.is_api: execution_metadata = destination.get_metadata(file_history) if cls._should_create_file_history( @@ -1058,6 +1066,84 @@ def _process_final_output( output=output_result, metadata=execution_metadata, error=None ) + @classmethod + def _try_lookup_enrichment( + cls, + workflow: Workflow, + output_result: str | None, + file_execution_id: str, + ) -> str | None: + """Attempt Look-up enrichment for the extraction output. + + This method integrates with the Look-up system to enrich extracted + data with reference data matching. It gracefully degrades if + Look-up is not configured or fails. + + Args: + workflow: The workflow being executed + output_result: The extraction result from the tool + file_execution_id: File execution ID for tracking + + Returns: + Enriched output result if Look-up was successful, + original output_result otherwise. + """ + logger.info( + f"[LOOKUP] _try_lookup_enrichment called for workflow {workflow.id}, " + f"file_execution {file_execution_id}, output_result type: {type(output_result)}" + ) + + if not output_result: + logger.info("[LOOKUP] No output_result, skipping enrichment") + return output_result + + try: + from lookup.services.workflow_integration import LookupWorkflowIntegration + + logger.info( + "[LOOKUP] Calling LookupWorkflowIntegration.process_workflow_enrichment" + ) + enriched_output, was_enriched = ( + LookupWorkflowIntegration.process_workflow_enrichment( + workflow_id=str(workflow.id), + original_output=output_result, + file_execution_id=file_execution_id, + ) + ) + + logger.info( + f"[LOOKUP] process_workflow_enrichment returned: was_enriched={was_enriched}" + ) + + if was_enriched: + logger.info( + f"[LOOKUP] Look-up enrichment applied for workflow {workflow.id}, " + f"file execution {file_execution_id}" + ) + # Convert back to string if needed for storage + if isinstance(enriched_output, dict): + import json + + return json.dumps(enriched_output) + return enriched_output + + return output_result + + except ImportError as e: + # Look-up module not available - gracefully skip + logger.info( + f"[LOOKUP] Look-up module not available, skipping enrichment: {e}" + ) + return output_result + except Exception as e: + # Log error but don't fail the workflow + logger.warning( + f"[LOOKUP] Look-up enrichment failed for workflow {workflow.id}, " + f"file execution {file_execution_id}: {e}", + exc_info=True, + ) + return output_result + @classmethod def _should_create_file_history( cls, diff --git a/docs/LOOKUP_ARCHITECTURE.md b/docs/LOOKUP_ARCHITECTURE.md new file mode 100644 index 0000000000..711f5e61fe --- /dev/null +++ b/docs/LOOKUP_ARCHITECTURE.md @@ -0,0 +1,650 @@ +# Lookup System - Architecture Documentation + +## Profile-Based Adapter System (Aligned with Prompt Studio) + +**Last Updated**: 2025-02-05 +**Status**: Implemented +**Based On**: Prompt Studio's ProfileManager pattern + +--- + +## Executive Summary + +The Lookup system enables **reference data enrichment** in document extraction workflows. Users upload reference data (CSV, JSON, PDF), which is indexed into a vector database. During prompt execution, extracted values are matched against reference data to provide standardized/enriched values. + +### Key Features + +1. **LookupProject** - Container for lookup configurations and reference data +2. **LookupProfileManager** - Adapter configurations (X2Text, Embedding, VectorDB, LLM) +3. **LookupDataSource** - Reference data file storage with version management +4. **LookupIndexManager** - Vector DB index tracking with reindex capabilities +5. **PromptStudioLookupLink** - Links Prompt Studio prompts to Lookup projects +6. **LookupExecutionAudit** - Comprehensive execution logging + +--- + +## Data Models + +### LookupProject + +Container for a lookup configuration with LLM settings and organization association. + +```python +class LookupProject(DefaultOrganizationMixin, BaseModel): + """Represents a Look-Up project for static data-based enrichment.""" + + LOOKUP_TYPE_CHOICES = [("static_data", "Static Data")] + LLM_PROVIDER_CHOICES = [ + ("openai", "OpenAI"), + ("anthropic", "Anthropic"), + ("azure", "Azure OpenAI"), + ("custom", "Custom Provider"), + ] + + id = UUIDField(primary_key=True) + name = CharField(max_length=255) + description = TextField(blank=True, null=True) + lookup_type = CharField(choices=LOOKUP_TYPE_CHOICES, default="static_data") + + # Template and status + template = ForeignKey("LookupPromptTemplate", SET_NULL, null=True) + is_active = BooleanField(default=True) + metadata = JSONField(default=dict) + + # LLM Configuration + llm_provider = CharField(choices=LLM_PROVIDER_CHOICES, null=True) + llm_model = CharField(max_length=100, null=True) + llm_config = JSONField(default=dict) + + # Ownership + created_by = ForeignKey(User, RESTRICT) + + class Meta: + db_table = "lookup_projects" + + @property + def is_ready(self) -> bool: + """Check if project has completed reference data.""" + ... +``` + +### LookupProfileManager + +Profile manager for adapter configurations - mirrors Prompt Studio's ProfileManager. + +```python +class LookupProfileManager(BaseModel): + """Model to store adapter configuration profiles for Look-Up projects.""" + + profile_id = UUIDField(primary_key=True) + profile_name = TextField(blank=False, null=False) + + # Foreign key to LookupProject + lookup_project = ForeignKey("LookupProject", CASCADE, related_name="profiles") + + # Required Adapters - All must be configured + vector_store = ForeignKey(AdapterInstance, PROTECT, related_name="lookup_profiles_vector_store") + embedding_model = ForeignKey(AdapterInstance, PROTECT, related_name="lookup_profiles_embedding_model") + llm = ForeignKey(AdapterInstance, PROTECT, related_name="lookup_profiles_llm") + x2text = ForeignKey(AdapterInstance, PROTECT, related_name="lookup_profiles_x2text") + + # Configuration fields + chunk_size = IntegerField(default=1000) + chunk_overlap = IntegerField(default=200) + similarity_top_k = IntegerField(default=5) + + # Flags + is_default = BooleanField(default=False) + reindex = BooleanField(default=False) + + # Audit + created_by = ForeignKey(User, SET_NULL, null=True) + modified_by = ForeignKey(User, SET_NULL, null=True) + + class Meta: + db_table = "lookup_profile_manager" + constraints = [ + UniqueConstraint(fields=["lookup_project", "profile_name"]) + ] + + @staticmethod + def get_default_profile(project) -> "LookupProfileManager": + """Get default profile for a Look-Up project.""" + ... +``` + +### LookupDataSource + +Reference data file storage with automatic version management. + +```python +class LookupDataSource(BaseModel): + """Represents a reference data source with version management.""" + + EXTRACTION_STATUS_CHOICES = [ + ("pending", "Pending"), + ("processing", "Processing"), + ("completed", "Completed"), + ("failed", "Failed"), + ] + FILE_TYPE_CHOICES = [ + ("pdf", "PDF"), ("xlsx", "Excel"), ("csv", "CSV"), + ("docx", "Word"), ("txt", "Text"), ("json", "JSON"), + ] + + id = UUIDField(primary_key=True) + project = ForeignKey("LookupProject", CASCADE, related_name="data_sources") + + # File Information + file_name = CharField(max_length=255) + file_path = TextField() # Path in object storage (MinIO) + file_size = BigIntegerField() + file_type = CharField(choices=FILE_TYPE_CHOICES) + + # Extracted Content + extracted_content_path = TextField(blank=True, null=True) + extraction_status = CharField(choices=EXTRACTION_STATUS_CHOICES, default="pending") + extraction_error = TextField(blank=True, null=True) + + # Version Management (auto-managed via signals) + version_number = IntegerField(default=1) + is_latest = BooleanField(default=True) + + # Upload Information + uploaded_by = ForeignKey(User, RESTRICT) + + class Meta: + db_table = "lookup_data_sources" + unique_together = [["project", "version_number"]] +``` + +**Version Management Signals:** +- `pre_save`: Auto-increments version number, marks previous versions as not latest +- `post_delete`: Promotes previous version to latest when current latest is deleted + +### LookupIndexManager + +Tracks indexed reference data in Vector DB. + +```python +class LookupIndexManager(BaseModel): + """Model to store indexing details for Look-Up reference data.""" + + index_manager_id = UUIDField(primary_key=True) + + # References + data_source = ForeignKey("LookupDataSource", CASCADE, related_name="index_managers") + profile_manager = ForeignKey("LookupProfileManager", SET_NULL, null=True, related_name="index_managers") + + # Vector DB index ID + raw_index_id = CharField(max_length=255, null=True) + index_ids_history = JSONField(default=list) # For cleanup on deletion + + # Status tracking + extraction_status = JSONField(default=dict) # Per X2Text config + status = JSONField(default=dict) # Legacy: {extracted, indexed, error} + reindex_required = BooleanField(default=False) + + # Audit + created_by = ForeignKey(User, SET_NULL, null=True) + modified_by = ForeignKey(User, SET_NULL, null=True) + + class Meta: + db_table = "lookup_index_manager" + constraints = [ + UniqueConstraint(fields=["data_source", "profile_manager"]) + ] +``` + +**Cleanup Signal:** +- `pre_delete`: Cleans up vector DB entries when index manager is deleted + +### PromptStudioLookupLink + +Many-to-many relationship between Prompt Studio projects and Look-Up projects. + +```python +class PromptStudioLookupLink(Model): + """Links Prompt Studio projects with Look-Up projects.""" + + id = UUIDField(primary_key=True) + prompt_studio_project_id = UUIDField() # PS project reference + lookup_project = ForeignKey("LookupProject", CASCADE, related_name="ps_links") + execution_order = PositiveIntegerField(default=0) + created_at = DateTimeField(auto_now_add=True) + + class Meta: + db_table = "prompt_studio_lookup_links" + unique_together = [["prompt_studio_project_id", "lookup_project"]] +``` + +### LookupPromptTemplate + +Prompt template with variable detection and validation. + +```python +class LookupPromptTemplate(BaseModel): + """Represents a prompt template with {{variable}} placeholders.""" + + VARIABLE_PATTERN = r"\{\{([^}]+)\}\}" + + id = UUIDField(primary_key=True) + project = OneToOneField("LookupProject", CASCADE, related_name="prompt_template_link") + + name = CharField(max_length=255) + template_text = TextField() # Contains {{variable}} placeholders + llm_config = JSONField(default=dict) + is_active = BooleanField(default=True) + created_by = ForeignKey(User, RESTRICT) + variable_mappings = JSONField(default=dict) + + class Meta: + db_table = "lookup_prompt_templates" + + def detect_variables(self) -> list[str]: + """Extract all {{variable}} references from template.""" + ... + + def validate_syntax(self) -> bool: + """Validate matching braces and no nested placeholders.""" + ... +``` + +### LookupExecutionAudit + +Comprehensive audit log for Look-Up executions. + +```python +class LookupExecutionAudit(Model): + """Audit log for Look-Up executions.""" + + STATUS_CHOICES = [ + ("success", "Success"), + ("partial", "Partial Success"), + ("failed", "Failed"), + ] + + id = UUIDField(primary_key=True) + + # Execution Context + lookup_project = ForeignKey("LookupProject", CASCADE, related_name="execution_audits") + prompt_studio_project_id = UUIDField(null=True) + execution_id = UUIDField() # Groups all Look-Ups in a batch + file_execution_id = UUIDField(null=True) # Workflow tracking for API/ETL + + # Input/Output + input_data = JSONField() + reference_data_version = IntegerField() + enriched_output = JSONField(null=True) + + # LLM Details + llm_provider = CharField(max_length=50) + llm_model = CharField(max_length=100) + llm_prompt = TextField() + llm_response = TextField(null=True) + llm_response_cached = BooleanField(default=False) + + # Performance Metrics + execution_time_ms = IntegerField(null=True) + llm_call_time_ms = IntegerField(null=True) + + # Status & Errors + status = CharField(choices=STATUS_CHOICES) + error_message = TextField(null=True) + confidence_score = DecimalField(max_digits=3, decimal_places=2, null=True) + + executed_at = DateTimeField(auto_now_add=True) + + class Meta: + db_table = "lookup_execution_audit" +``` + +--- + +## Service Layer + +### Core Services + +| Service | Purpose | +|---------|---------| +| `IndexingService` | Document indexing with chunking and vector embeddings | +| `LookUpExecutor` | Executes lookups using RAG retrieval | +| `LookUpOrchestrator` | Coordinates lookup workflow orchestration | +| `LookupRetrievalService` | Vector DB search and retrieval | +| `VectorDBCleanupService` | Manages vector DB lifecycle and cleanup | + +### Supporting Services + +| Service | Purpose | +|---------|---------| +| `AuditLogger` | Execution logging and audit trail | +| `LLMResponseCache` | Caches LLM responses for performance | +| `ReferenceDataLoader` | Loads and parses reference data files | +| `VariableResolver` | Resolves template variables with actual values | +| `EnrichmentMerger` | Merges lookup results into extraction output | +| `LookupIndexHelper` | Helper functions for index operations | +| `LogEmitter` | Emits logs for execution tracking | + +### Integration Services + +| Service | Purpose | +|---------|---------| +| `LookupDocumentIndexingService` | High-level document indexing orchestration | +| `LookupIntegrationService` | Integration with external systems | +| `WorkflowIntegration` | Integration with workflow execution | + +--- + +## API Endpoints + +Base URL: `/api/v2/unstract/{org_id}/lookup/` + +### Project Management + +``` +GET /lookup-projects/ # List all projects +POST /lookup-projects/ # Create new project +GET /lookup-projects/{id}/ # Get project details +PUT /lookup-projects/{id}/ # Update project +DELETE /lookup-projects/{id}/ # Delete project +``` + +### Profile Management + +``` +GET /lookup-profiles/ # List profiles +POST /lookup-profiles/ # Create profile +GET /lookup-profiles/{id}/ # Get profile details +PUT /lookup-profiles/{id}/ # Update profile +DELETE /lookup-profiles/{id}/ # Delete profile +POST /lookup-profiles/{id}/set-default/ # Set as default profile +``` + +### Data Source Management + +``` +GET /data-sources/ # List data sources +POST /data-sources/ # Upload new reference data +GET /data-sources/{id}/ # Get data source details +DELETE /data-sources/{id}/ # Delete data source +POST /data-sources/{id}/reindex/ # Trigger reindexing +``` + +### Template Management + +``` +GET /lookup-templates/ # List templates +POST /lookup-templates/ # Create template +GET /lookup-templates/{id}/ # Get template details +PUT /lookup-templates/{id}/ # Update template +DELETE /lookup-templates/{id}/ # Delete template +``` + +### Linking & Execution + +``` +GET /lookup-links/ # List PS project links +POST /lookup-links/ # Create link +DELETE /lookup-links/{id}/ # Remove link + +GET /execution-audits/ # List execution history +GET /execution-audits/{id}/ # Get execution details + +POST /lookup-debug/test/ # Test lookup execution +``` + +--- + +## Frontend Components + +### Page Structure + +``` +/lookups → LookUpProjectList +/lookups/:projectId → LookUpProjectDetail + ├── Reference Data Tab → ReferenceDataTab + ├── Templates Tab → TemplateTab + ├── Profiles Tab → ProfileManagementTab + │ └── Profile Modal → ProfileFormModal + ├── Linked Projects Tab → LinkedProjectsTab + ├── Execution History Tab → ExecutionHistoryTab + └── Debug Tab → DebugTab +``` + +### Component Descriptions + +| Component | Purpose | +|-----------|---------| +| `LookUpProjectList` | Lists all lookup projects with create/delete actions | +| `LookUpProjectDetail` | Project detail view with tabbed navigation | +| `CreateProjectModal` | Modal for creating new lookup projects | +| `ReferenceDataTab` | Upload and manage reference data files | +| `TemplateTab` | Configure prompt templates with variables | +| `ProfileManagementTab` | Manage adapter profiles | +| `ProfileFormModal` | Create/edit profiles with adapter dropdowns | +| `LinkedProjectsTab` | Link/unlink Prompt Studio projects | +| `ExecutionHistoryTab` | View execution audit logs | +| `DebugTab` | Test lookup execution manually | + +--- + +## System Workflows + +### 1. Reference Data Indexing Workflow + +``` +User uploads reference file (CSV/JSON/PDF) + ↓ +Create LookupDataSource (version auto-incremented) + ↓ +Get default LookupProfileManager for project + ↓ +Extract text using profile.x2text adapter + ↓ +Store extracted text in MinIO + ↓ +Chunk text (profile.chunk_size, profile.chunk_overlap) + ↓ +Generate embeddings using profile.embedding_model + ↓ +Store vectors in VectorDB using profile.vector_store + ↓ +Create/update LookupIndexManager entry + ↓ +Update data source status to 'completed' +``` + +### 2. Lookup Execution Workflow + +``` +Prompt Studio executes with lookup variable + ↓ +Get linked LookupProject via PromptStudioLookupLink + ↓ +Get default profile (LookupProfileManager.get_default_profile) + ↓ +Generate query embedding using profile.embedding_model + ↓ +Search VectorDB using profile.vector_store + (returns top_k similar results based on profile.similarity_top_k) + ↓ +Optional: Use profile.llm for best match selection + ↓ +Create LookupExecutionAudit record + ↓ +Return standardized value to Prompt Studio +``` + +### 3. Profile Change & Reindexing Workflow + +``` +User updates profile settings (chunk_size, adapters, etc.) + ↓ +Set reindex_required=True on associated LookupIndexManager entries + ↓ +User triggers reindex (or automatic reindex on next execution) + ↓ +Delete old vector DB indexes (using index_ids_history) + ↓ +Re-run indexing workflow with new profile settings + ↓ +Update LookupIndexManager with new index IDs + ↓ +Set reindex_required=False +``` + +### 4. Cleanup Workflows + +**On LookupDataSource deletion:** +``` +Cascade delete LookupIndexManager entries + ↓ +pre_delete signal on LookupIndexManager + ↓ +VectorDBCleanupService.cleanup_index_ids() + ↓ +Remove vectors from VectorDB +``` + +**On LookupProfileManager deletion:** +``` +pre_delete signal on LookupProfileManager + ↓ +For each associated LookupIndexManager: + ↓ +VectorDBCleanupService.cleanup_index_ids() + ↓ +Remove all vectors indexed with this profile +``` + +--- + +## Integration with Prompt Studio + +### Prompt-Level Lookup Configuration + +The `ToolStudioPrompt` model has been extended with a `lookup_project` field: + +```python +class ToolStudioPrompt(BaseModel): + # ... existing fields ... + + lookup_project = ForeignKey( + "lookup.LookupProject", + on_delete=SET_NULL, + null=True, + blank=True, + related_name="linked_prompts", + ) +``` + +### Frontend Integration + +- **Lookup Replacement Indicator**: Visual indicator on prompts with lookup configured +- **Prompt Card Header**: Shows lookup project linkage +- **Output Display**: Shows lookup-enriched values in combined output + +--- + +## Database Tables + +| Table | Description | +|-------|-------------| +| `lookup_projects` | Lookup project configurations | +| `lookup_data_sources` | Reference data file metadata | +| `lookup_profile_manager` | Adapter profile configurations | +| `lookup_index_manager` | Vector DB index tracking | +| `lookup_prompt_templates` | Prompt templates with variables | +| `prompt_studio_lookup_links` | PS-to-Lookup project links | +| `lookup_execution_audit` | Execution history and metrics | + +--- + +## Key Design Principles + +### 1. Consistency with Prompt Studio +- Same model structure (FK to project, unique constraint on name) +- Same adapter fields (x2text, embedding_model, vector_store, llm) +- Same naming conventions (ProfileManager, is_default, reindex) +- Same API patterns (ViewSet, Serializer, permissions) + +### 2. Profile Ownership +- Profiles belong to projects (not standalone) +- Each project can have multiple profiles +- One profile must be marked as default +- Unique profile names within a project + +### 3. Adapter Protection +- All 4 adapter types required for completeness +- Adapters protected from deletion if in use (PROTECT) +- Users select from configured adapters via dropdowns + +### 4. Separation of Concerns +- **LookupProfileManager**: Adapter configuration storage +- **LookupIndexManager**: Indexing state tracking +- **LookupDataSource**: Reference data file metadata +- **LookUpExecutor**: Runtime execution logic + +### 5. Automatic Cleanup +- Vector DB cleanup on index/profile deletion via signals +- Index history tracking for complete cleanup +- Version promotion on data source deletion + +### 6. Comprehensive Auditing +- LookupExecutionAudit captures all execution details +- file_execution_id for workflow tracking in API/ETL +- LLM prompts, responses, and performance metrics logged + +--- + +## Environment Configuration + +The lookup system uses existing adapter configurations. No new environment variables required. + +**Used Configuration:** +- MinIO/S3 for file storage (via existing filesystem configuration) +- Redis for caching (via existing Redis configuration) +- Existing adapter instances for X2Text, Embedding, VectorDB, LLM + +--- + +## Testing + +### Unit Tests Location + +``` +backend/lookup/tests/ +├── test_api/ +│ ├── test_execution_api.py +│ ├── test_linking_api.py +│ ├── test_profile_manager_api.py +│ ├── test_project_api.py +│ └── test_template_api.py +├── test_integrations/ +│ ├── test_llm_integration.py +│ ├── test_llmwhisperer_integration.py +│ ├── test_redis_cache_integration.py +│ └── test_storage_integration.py +├── test_services/ +│ ├── test_audit_logger.py +│ ├── test_enrichment_merger.py +│ ├── test_llm_cache.py +│ ├── test_lookup_executor.py +│ ├── test_lookup_orchestrator.py +│ └── test_reference_data_loader.py +├── test_migrations.py +└── test_variable_resolver.py +``` + +### Test Coverage Areas +- Model CRUD operations +- API endpoint functionality +- Service layer logic +- Integration with adapters +- Vector DB cleanup signals +- Version management +- Cache operations + +--- + +**END OF ARCHITECTURE DOCUMENT** diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 9954af83a7..0feb2e872a 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -9,6 +9,8 @@ "version": "0.1.0", "dependencies": { "@ant-design/icons": "^5.1.4", + "@codemirror/lang-json": "^6.0.2", + "@codemirror/theme-one-dark": "^6.1.3", "@monaco-editor/react": "^4.7.0", "@react-awesome-query-builder/antd": "^6.6.10", "@react-pdf-viewer/core": "^3.12.0", @@ -23,6 +25,7 @@ "@testing-library/jest-dom": "^5.16.5", "@testing-library/react": "^13.4.0", "@testing-library/user-event": "^13.5.0", + "@uiw/react-codemirror": "^4.25.4", "antd": "^5.5.1", "axios": "^1.4.0", "cron-validator": "^1.3.1", @@ -2143,6 +2146,109 @@ "resolved": "https://registry.npmjs.org/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz", "integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==" }, + "node_modules/@codemirror/autocomplete": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.20.0.tgz", + "integrity": "sha512-bOwvTOIJcG5FVo5gUUupiwYh8MioPLQ4UcqbcRf7UQ98X90tCa9E1kZ3Z7tqwpZxYyOvh1YTYbmZE9RTfTp5hg==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.17.0", + "@lezer/common": "^1.0.0" + } + }, + "node_modules/@codemirror/commands": { + "version": "6.10.0", + "resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.10.0.tgz", + "integrity": "sha512-2xUIc5mHXQzT16JnyOFkh8PvfeXuIut3pslWGfsGOhxP/lpgRm9HOl/mpzLErgt5mXDovqA0d11P21gofRLb9w==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.4.0", + "@codemirror/view": "^6.27.0", + "@lezer/common": "^1.1.0" + } + }, + "node_modules/@codemirror/lang-json": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/@codemirror/lang-json/-/lang-json-6.0.2.tgz", + "integrity": "sha512-x2OtO+AvwEHrEwR0FyyPtfDUiloG3rnVTSZV1W8UteaLL8/MajQd8DpvUb2YVzC+/T18aSDv0H9mu+xw0EStoQ==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@lezer/json": "^1.0.0" + } + }, + "node_modules/@codemirror/language": { + "version": "6.11.3", + "resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.11.3.tgz", + "integrity": "sha512-9HBM2XnwDj7fnu0551HkGdrUrrqmYq/WC5iv6nbY2WdicXdGbhR/gfbZOH73Aqj4351alY1+aoG9rCNfiwS1RA==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.23.0", + "@lezer/common": "^1.1.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0", + "style-mod": "^4.0.0" + } + }, + "node_modules/@codemirror/lint": { + "version": "6.9.2", + "resolved": "https://registry.npmjs.org/@codemirror/lint/-/lint-6.9.2.tgz", + "integrity": "sha512-sv3DylBiIyi+xKwRCJAAsBZZZWo82shJ/RTMymLabAdtbkV5cSKwWDeCgtUq3v8flTaXS2y1kKkICuRYtUswyQ==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.35.0", + "crelt": "^1.0.5" + } + }, + "node_modules/@codemirror/search": { + "version": "6.5.11", + "resolved": "https://registry.npmjs.org/@codemirror/search/-/search-6.5.11.tgz", + "integrity": "sha512-KmWepDE6jUdL6n8cAAqIpRmLPBZ5ZKnicE8oGU/s3QrAVID+0VhLFrzUucVKHG5035/BSykhExDL/Xm7dHthiA==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0", + "crelt": "^1.0.5" + } + }, + "node_modules/@codemirror/state": { + "version": "6.5.2", + "resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.5.2.tgz", + "integrity": "sha512-FVqsPqtPWKVVL3dPSxy8wEF/ymIEuVzF1PK3VbUgrxXpJUSHQWWZz4JMToquRxnkw+36LTamCZG2iua2Ptq0fA==", + "license": "MIT", + "dependencies": { + "@marijn/find-cluster-break": "^1.0.0" + } + }, + "node_modules/@codemirror/theme-one-dark": { + "version": "6.1.3", + "resolved": "https://registry.npmjs.org/@codemirror/theme-one-dark/-/theme-one-dark-6.1.3.tgz", + "integrity": "sha512-NzBdIvEJmx6fjeremiGp3t/okrLPYT0d9orIc7AFun8oZcRk58aejkqhv6spnz4MLAevrKNPMQYXEWMg4s+sKA==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0", + "@lezer/highlight": "^1.0.0" + } + }, + "node_modules/@codemirror/view": { + "version": "6.39.4", + "resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.39.4.tgz", + "integrity": "sha512-xMF6OfEAUVY5Waega4juo1QGACfNkNF+aJLqpd8oUJz96ms2zbfQ9Gh35/tI3y8akEV31FruKfj7hBnIU/nkqA==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.5.0", + "crelt": "^1.0.6", + "style-mod": "^4.1.0", + "w3c-keyname": "^2.2.4" + } + }, "node_modules/@csstools/normalize.css": { "version": "12.0.0", "resolved": "https://registry.npmjs.org/@csstools/normalize.css/-/normalize.css-12.0.0.tgz", @@ -3484,6 +3590,41 @@ "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==" }, + "node_modules/@lezer/common": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.4.0.tgz", + "integrity": "sha512-DVeMRoGrgn/k45oQNu189BoW4SZwgZFzJ1+1TV5j2NJ/KFC83oa/enRqZSGshyeMk5cPWMhsKs9nx+8o0unwGg==", + "license": "MIT" + }, + "node_modules/@lezer/highlight": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@lezer/highlight/-/highlight-1.2.3.tgz", + "integrity": "sha512-qXdH7UqTvGfdVBINrgKhDsVTJTxactNNxLk7+UMwZhU13lMHaOBlJe9Vqp907ya56Y3+ed2tlqzys7jDkTmW0g==", + "license": "MIT", + "dependencies": { + "@lezer/common": "^1.3.0" + } + }, + "node_modules/@lezer/json": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@lezer/json/-/json-1.0.3.tgz", + "integrity": "sha512-BP9KzdF9Y35PDpv04r0VeSTKDeox5vVr3efE7eBbx3r4s3oNLfunchejZhjArmeieBH+nVOpgIiBJpEAv8ilqQ==", + "license": "MIT", + "dependencies": { + "@lezer/common": "^1.2.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0" + } + }, + "node_modules/@lezer/lr": { + "version": "1.4.5", + "resolved": "https://registry.npmjs.org/@lezer/lr/-/lr-1.4.5.tgz", + "integrity": "sha512-/YTRKP5yPPSo1xImYQk7AZZMAgap0kegzqCSYHjAL9x1AZ0ZQW+IpcEzMKagCsbTsLnVeWkxYrCNeXG8xEPrjg==", + "license": "MIT", + "dependencies": { + "@lezer/common": "^1.0.0" + } + }, "node_modules/@mapbox/node-pre-gyp": { "version": "1.0.11", "resolved": "https://registry.npmjs.org/@mapbox/node-pre-gyp/-/node-pre-gyp-1.0.11.tgz", @@ -3537,6 +3678,12 @@ "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", "optional": true }, + "node_modules/@marijn/find-cluster-break": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@marijn/find-cluster-break/-/find-cluster-break-1.0.2.tgz", + "integrity": "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==", + "license": "MIT" + }, "node_modules/@monaco-editor/loader": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/@monaco-editor/loader/-/loader-1.5.0.tgz", @@ -5510,6 +5657,59 @@ "url": "https://opencollective.com/typescript-eslint" } }, + "node_modules/@uiw/codemirror-extensions-basic-setup": { + "version": "4.25.4", + "resolved": "https://registry.npmjs.org/@uiw/codemirror-extensions-basic-setup/-/codemirror-extensions-basic-setup-4.25.4.tgz", + "integrity": "sha512-YzNwkm0AbPv1EXhCHYR5v0nqfemG2jEB0Z3Att4rBYqKrlG7AA9Rhjc3IyBaOzsBu18wtrp9/+uhTyu7TXSRng==", + "license": "MIT", + "dependencies": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/commands": "^6.0.0", + "@codemirror/language": "^6.0.0", + "@codemirror/lint": "^6.0.0", + "@codemirror/search": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0" + }, + "funding": { + "url": "https://jaywcjlove.github.io/#/sponsor" + }, + "peerDependencies": { + "@codemirror/autocomplete": ">=6.0.0", + "@codemirror/commands": ">=6.0.0", + "@codemirror/language": ">=6.0.0", + "@codemirror/lint": ">=6.0.0", + "@codemirror/search": ">=6.0.0", + "@codemirror/state": ">=6.0.0", + "@codemirror/view": ">=6.0.0" + } + }, + "node_modules/@uiw/react-codemirror": { + "version": "4.25.4", + "resolved": "https://registry.npmjs.org/@uiw/react-codemirror/-/react-codemirror-4.25.4.tgz", + "integrity": "sha512-ipO067oyfUw+DVaXhQCxkB0ZD9b7RnY+ByrprSYSKCHaULvJ3sqWYC/Zen6zVQ8/XC4o5EPBfatGiX20kC7XGA==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.18.6", + "@codemirror/commands": "^6.1.0", + "@codemirror/state": "^6.1.1", + "@codemirror/theme-one-dark": "^6.0.0", + "@uiw/codemirror-extensions-basic-setup": "4.25.4", + "codemirror": "^6.0.0" + }, + "funding": { + "url": "https://jaywcjlove.github.io/#/sponsor" + }, + "peerDependencies": { + "@babel/runtime": ">=7.11.0", + "@codemirror/state": ">=6.0.0", + "@codemirror/theme-one-dark": ">=6.0.0", + "@codemirror/view": ">=6.0.0", + "codemirror": ">=6.0.0", + "react": ">=17.0.0", + "react-dom": ">=17.0.0" + } + }, "node_modules/@webassemblyjs/ast": { "version": "1.14.1", "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.14.1.tgz", @@ -7167,6 +7367,21 @@ "node": ">=4" } }, + "node_modules/codemirror": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.2.tgz", + "integrity": "sha512-VhydHotNW5w1UGK0Qj96BwSk/Zqbp9WbnyK2W/eVMv4QyF41INRGpjUhFJY7/uDNuudSc33a/PKr4iDqRduvHw==", + "license": "MIT", + "dependencies": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/commands": "^6.0.0", + "@codemirror/language": "^6.0.0", + "@codemirror/lint": "^6.0.0", + "@codemirror/search": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0" + } + }, "node_modules/collect-v8-coverage": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/collect-v8-coverage/-/collect-v8-coverage-1.0.1.tgz", @@ -7450,6 +7665,12 @@ "node": ">=10" } }, + "node_modules/crelt": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", + "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==", + "license": "MIT" + }, "node_modules/cron-validator": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/cron-validator/-/cron-validator-1.3.1.tgz", @@ -19913,6 +20134,12 @@ "webpack": "^5.0.0" } }, + "node_modules/style-mod": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.3.tgz", + "integrity": "sha512-i/n8VsZydrugj3Iuzll8+x/00GH2vnYsk1eomD8QiRrSAeW6ItbCQDtfXCeJHd0iwiNagqjQkvpvREEPtW3IoQ==", + "license": "MIT" + }, "node_modules/style-to-object": { "version": "0.4.4", "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-0.4.4.tgz", @@ -21156,6 +21383,12 @@ "browser-process-hrtime": "^1.0.0" } }, + "node_modules/w3c-keyname": { + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", + "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==", + "license": "MIT" + }, "node_modules/w3c-xmlserializer": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-2.0.0.tgz", @@ -23423,6 +23656,100 @@ "resolved": "https://registry.npmjs.org/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz", "integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==" }, + "@codemirror/autocomplete": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.20.0.tgz", + "integrity": "sha512-bOwvTOIJcG5FVo5gUUupiwYh8MioPLQ4UcqbcRf7UQ98X90tCa9E1kZ3Z7tqwpZxYyOvh1YTYbmZE9RTfTp5hg==", + "requires": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.17.0", + "@lezer/common": "^1.0.0" + } + }, + "@codemirror/commands": { + "version": "6.10.0", + "resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.10.0.tgz", + "integrity": "sha512-2xUIc5mHXQzT16JnyOFkh8PvfeXuIut3pslWGfsGOhxP/lpgRm9HOl/mpzLErgt5mXDovqA0d11P21gofRLb9w==", + "requires": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.4.0", + "@codemirror/view": "^6.27.0", + "@lezer/common": "^1.1.0" + } + }, + "@codemirror/lang-json": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/@codemirror/lang-json/-/lang-json-6.0.2.tgz", + "integrity": "sha512-x2OtO+AvwEHrEwR0FyyPtfDUiloG3rnVTSZV1W8UteaLL8/MajQd8DpvUb2YVzC+/T18aSDv0H9mu+xw0EStoQ==", + "requires": { + "@codemirror/language": "^6.0.0", + "@lezer/json": "^1.0.0" + } + }, + "@codemirror/language": { + "version": "6.11.3", + "resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.11.3.tgz", + "integrity": "sha512-9HBM2XnwDj7fnu0551HkGdrUrrqmYq/WC5iv6nbY2WdicXdGbhR/gfbZOH73Aqj4351alY1+aoG9rCNfiwS1RA==", + "requires": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.23.0", + "@lezer/common": "^1.1.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0", + "style-mod": "^4.0.0" + } + }, + "@codemirror/lint": { + "version": "6.9.2", + "resolved": "https://registry.npmjs.org/@codemirror/lint/-/lint-6.9.2.tgz", + "integrity": "sha512-sv3DylBiIyi+xKwRCJAAsBZZZWo82shJ/RTMymLabAdtbkV5cSKwWDeCgtUq3v8flTaXS2y1kKkICuRYtUswyQ==", + "requires": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.35.0", + "crelt": "^1.0.5" + } + }, + "@codemirror/search": { + "version": "6.5.11", + "resolved": "https://registry.npmjs.org/@codemirror/search/-/search-6.5.11.tgz", + "integrity": "sha512-KmWepDE6jUdL6n8cAAqIpRmLPBZ5ZKnicE8oGU/s3QrAVID+0VhLFrzUucVKHG5035/BSykhExDL/Xm7dHthiA==", + "requires": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0", + "crelt": "^1.0.5" + } + }, + "@codemirror/state": { + "version": "6.5.2", + "resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.5.2.tgz", + "integrity": "sha512-FVqsPqtPWKVVL3dPSxy8wEF/ymIEuVzF1PK3VbUgrxXpJUSHQWWZz4JMToquRxnkw+36LTamCZG2iua2Ptq0fA==", + "requires": { + "@marijn/find-cluster-break": "^1.0.0" + } + }, + "@codemirror/theme-one-dark": { + "version": "6.1.3", + "resolved": "https://registry.npmjs.org/@codemirror/theme-one-dark/-/theme-one-dark-6.1.3.tgz", + "integrity": "sha512-NzBdIvEJmx6fjeremiGp3t/okrLPYT0d9orIc7AFun8oZcRk58aejkqhv6spnz4MLAevrKNPMQYXEWMg4s+sKA==", + "requires": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0", + "@lezer/highlight": "^1.0.0" + } + }, + "@codemirror/view": { + "version": "6.39.4", + "resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.39.4.tgz", + "integrity": "sha512-xMF6OfEAUVY5Waega4juo1QGACfNkNF+aJLqpd8oUJz96ms2zbfQ9Gh35/tI3y8akEV31FruKfj7hBnIU/nkqA==", + "requires": { + "@codemirror/state": "^6.5.0", + "crelt": "^1.0.6", + "style-mod": "^4.1.0", + "w3c-keyname": "^2.2.4" + } + }, "@csstools/normalize.css": { "version": "12.0.0", "resolved": "https://registry.npmjs.org/@csstools/normalize.css/-/normalize.css-12.0.0.tgz", @@ -24429,6 +24756,37 @@ "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==" }, + "@lezer/common": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.4.0.tgz", + "integrity": "sha512-DVeMRoGrgn/k45oQNu189BoW4SZwgZFzJ1+1TV5j2NJ/KFC83oa/enRqZSGshyeMk5cPWMhsKs9nx+8o0unwGg==" + }, + "@lezer/highlight": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@lezer/highlight/-/highlight-1.2.3.tgz", + "integrity": "sha512-qXdH7UqTvGfdVBINrgKhDsVTJTxactNNxLk7+UMwZhU13lMHaOBlJe9Vqp907ya56Y3+ed2tlqzys7jDkTmW0g==", + "requires": { + "@lezer/common": "^1.3.0" + } + }, + "@lezer/json": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@lezer/json/-/json-1.0.3.tgz", + "integrity": "sha512-BP9KzdF9Y35PDpv04r0VeSTKDeox5vVr3efE7eBbx3r4s3oNLfunchejZhjArmeieBH+nVOpgIiBJpEAv8ilqQ==", + "requires": { + "@lezer/common": "^1.2.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0" + } + }, + "@lezer/lr": { + "version": "1.4.5", + "resolved": "https://registry.npmjs.org/@lezer/lr/-/lr-1.4.5.tgz", + "integrity": "sha512-/YTRKP5yPPSo1xImYQk7AZZMAgap0kegzqCSYHjAL9x1AZ0ZQW+IpcEzMKagCsbTsLnVeWkxYrCNeXG8xEPrjg==", + "requires": { + "@lezer/common": "^1.0.0" + } + }, "@mapbox/node-pre-gyp": { "version": "1.0.11", "resolved": "https://registry.npmjs.org/@mapbox/node-pre-gyp/-/node-pre-gyp-1.0.11.tgz", @@ -24472,6 +24830,11 @@ } } }, + "@marijn/find-cluster-break": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@marijn/find-cluster-break/-/find-cluster-break-1.0.2.tgz", + "integrity": "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==" + }, "@monaco-editor/loader": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/@monaco-editor/loader/-/loader-1.5.0.tgz", @@ -25869,6 +26232,33 @@ "eslint-visitor-keys": "^3.3.0" } }, + "@uiw/codemirror-extensions-basic-setup": { + "version": "4.25.4", + "resolved": "https://registry.npmjs.org/@uiw/codemirror-extensions-basic-setup/-/codemirror-extensions-basic-setup-4.25.4.tgz", + "integrity": "sha512-YzNwkm0AbPv1EXhCHYR5v0nqfemG2jEB0Z3Att4rBYqKrlG7AA9Rhjc3IyBaOzsBu18wtrp9/+uhTyu7TXSRng==", + "requires": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/commands": "^6.0.0", + "@codemirror/language": "^6.0.0", + "@codemirror/lint": "^6.0.0", + "@codemirror/search": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0" + } + }, + "@uiw/react-codemirror": { + "version": "4.25.4", + "resolved": "https://registry.npmjs.org/@uiw/react-codemirror/-/react-codemirror-4.25.4.tgz", + "integrity": "sha512-ipO067oyfUw+DVaXhQCxkB0ZD9b7RnY+ByrprSYSKCHaULvJ3sqWYC/Zen6zVQ8/XC4o5EPBfatGiX20kC7XGA==", + "requires": { + "@babel/runtime": "^7.18.6", + "@codemirror/commands": "^6.1.0", + "@codemirror/state": "^6.1.1", + "@codemirror/theme-one-dark": "^6.0.0", + "@uiw/codemirror-extensions-basic-setup": "4.25.4", + "codemirror": "^6.0.0" + } + }, "@webassemblyjs/ast": { "version": "1.14.1", "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.14.1.tgz", @@ -27074,6 +27464,20 @@ } } }, + "codemirror": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.2.tgz", + "integrity": "sha512-VhydHotNW5w1UGK0Qj96BwSk/Zqbp9WbnyK2W/eVMv4QyF41INRGpjUhFJY7/uDNuudSc33a/PKr4iDqRduvHw==", + "requires": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/commands": "^6.0.0", + "@codemirror/language": "^6.0.0", + "@codemirror/lint": "^6.0.0", + "@codemirror/search": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0" + } + }, "collect-v8-coverage": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/collect-v8-coverage/-/collect-v8-coverage-1.0.1.tgz", @@ -27296,6 +27700,11 @@ "yaml": "^1.10.0" } }, + "crelt": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", + "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==" + }, "cron-validator": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/cron-validator/-/cron-validator-1.3.1.tgz", @@ -36010,6 +36419,11 @@ "integrity": "sha512-53BiGLXAcll9maCYtZi2RCQZKa8NQQai5C4horqKyRmHj9H7QmcUyucrH+4KW/gBQbXM2AsB0axoEcFZPlfPcw==", "requires": {} }, + "style-mod": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.3.tgz", + "integrity": "sha512-i/n8VsZydrugj3Iuzll8+x/00GH2vnYsk1eomD8QiRrSAeW6ItbCQDtfXCeJHd0iwiNagqjQkvpvREEPtW3IoQ==" + }, "style-to-object": { "version": "0.4.4", "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-0.4.4.tgz", @@ -36917,6 +37331,11 @@ "browser-process-hrtime": "^1.0.0" } }, + "w3c-keyname": { + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", + "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==" + }, "w3c-xmlserializer": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-2.0.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index d8bd36af34..0abab9ba85 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -4,6 +4,8 @@ "private": true, "dependencies": { "@ant-design/icons": "^5.1.4", + "@codemirror/lang-json": "^6.0.2", + "@codemirror/theme-one-dark": "^6.1.3", "@monaco-editor/react": "^4.7.0", "@react-awesome-query-builder/antd": "^6.6.10", "@react-pdf-viewer/core": "^3.12.0", @@ -18,6 +20,7 @@ "@testing-library/jest-dom": "^5.16.5", "@testing-library/react": "^13.4.0", "@testing-library/user-event": "^13.5.0", + "@uiw/react-codemirror": "^4.25.4", "antd": "^5.5.1", "axios": "^1.4.0", "cron-validator": "^1.3.1", diff --git a/frontend/src/assets/lookups.svg b/frontend/src/assets/lookups.svg new file mode 100644 index 0000000000..db41980604 --- /dev/null +++ b/frontend/src/assets/lookups.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/frontend/src/components/custom-tools/combined-output/CombinedOutput.css b/frontend/src/components/custom-tools/combined-output/CombinedOutput.css index 13302b7e92..24c43a1104 100644 --- a/frontend/src/components/custom-tools/combined-output/CombinedOutput.css +++ b/frontend/src/components/custom-tools/combined-output/CombinedOutput.css @@ -1,39 +1,54 @@ /* Styles for CombinedOutput */ .combined-op-layout { - display: flex; - flex-direction: column; - height: 100%; - overflow-y: hidden; + display: flex; + flex-direction: column; + height: 100%; + overflow-y: hidden; } .combined-op-header { - display: flex; - margin-top: 7px; + display: flex; + margin-top: 7px; } .combined-op-segment { - margin-left: auto; + margin-left: auto; } .combined-op-body { - flex: 1; - overflow-y: auto; + flex: 1; + overflow-y: auto; } .combined-op-divider { - margin-bottom: 10px; + margin-bottom: 10px; } .code-snippet { - border: 1px solid #ECEFF3; + border: 1px solid #eceff3; } .code-snippet > .language-javascript { - margin: 0px !important; - height: 100%; + margin: 0px !important; + height: 100%; } .combined-op-layout .gap { - margin-bottom: 12px; + margin-bottom: 12px; +} + +/* Enrichment info bar styling */ +.enrichment-info-bar { + padding: 8px 12px; + background-color: #f6ffed; + border: 1px solid #b7eb8f; + border-radius: 4px; + margin-bottom: 10px; +} + +.combined-op-segment { + display: flex; + align-items: center; + gap: 8px; } diff --git a/frontend/src/components/custom-tools/combined-output/JsonView.jsx b/frontend/src/components/custom-tools/combined-output/JsonView.jsx index aee6c297a5..8ef0caf885 100644 --- a/frontend/src/components/custom-tools/combined-output/JsonView.jsx +++ b/frontend/src/components/custom-tools/combined-output/JsonView.jsx @@ -34,7 +34,6 @@ function JsonView({ /> ))} -
{ + switch (stage) { + case "LOOKUP": + return { + color: "purple", + className: "display-logs-stage-lookup", + }; + case "RUN": + return { + color: "blue", + className: "display-logs-stage-run", + }; + case "TOOL": + return { + color: "cyan", + className: "display-logs-stage-tool", + }; + default: + return { + color: "default", + className: "display-logs-stage-default", + }; + } +}; + function DisplayLogs() { const bottomRef = useRef(null); const { messages } = useSocketCustomToolStore(); @@ -20,10 +51,14 @@ function DisplayLogs() { return (
{messages.map((message) => { + const stageStyle = getStageStyle(message?.stage); + const isLookupLog = message?.stage === "LOOKUP"; + const rowClassName = isLookupLog ? "display-logs-row-lookup" : ""; + return ( -
+
- + {getDateTimeString(message?.timestamp)} @@ -34,21 +69,35 @@ function DisplayLogs() { - - {message?.state} - + {message?.stage ? ( + + {isLookupLog && ( + + )} + {message?.stage} + + ) : ( + + {message?.state} + + )} - {message?.component?.prompt_key} + {message?.component?.prompt_key || + message?.component?.lookup_project || + ""} - {message?.component?.doc_name} + {message?.component?.doc_name || ""} - + { @@ -151,13 +163,84 @@ function Header({ setWebhookUrl ); }; + + // Fetch available lookup projects for this PS project (lazy load on dropdown open) + const fetchAvailableLookups = useCallback(async () => { + if (!details?.tool_id || !sessionDetails?.orgId) return; + if (lookupsFetched) return; // Already fetched, skip + + setLookupLoading(true); + try { + const response = await axiosPrivate.get( + `/api/v1/unstract/${sessionDetails.orgId}/prompt-studio/prompt/available_lookups/`, + { params: { tool_id: details.tool_id } } + ); + setAvailableLookups(response?.data || []); + setLookupsFetched(true); + } catch (error) { + console.error("Failed to fetch available lookups:", error); + setAvailableLookups([]); + } finally { + setLookupLoading(false); + } + }, [details?.tool_id, sessionDetails?.orgId, axiosPrivate, lookupsFetched]); + + // Handle lookup enabled checkbox change + const handleLookupEnabledChange = (e) => { + const newValue = e.target.checked; + setLookupEnabled(newValue); + if (!newValue) { + // When disabling, clear the selected lookup + setSelectedLookup(null); + handleChange( + null, + promptDetails?.prompt_id, + "lookup_project", + true, + true + ).catch(() => { + // Rollback on error + setLookupEnabled(true); + setSelectedLookup(promptDetails?.lookup_project || null); + }); + } + }; + + // Handle lookup project selection change + const handleLookupChange = (value) => { + const newValue = value || null; + setSelectedLookup(newValue); + if (!newValue) { + // If clearing the selection, also disable lookup + setLookupEnabled(false); + } + handleChange( + newValue, + promptDetails?.prompt_id, + "lookup_project", + true, + true + ).catch(() => { + // Rollback on error + setLookupEnabled(!!promptDetails?.lookup_project); + setSelectedLookup(promptDetails?.lookup_project || null); + }); + }; + useEffect(() => { setIsDisablePrompt(promptDetails?.active); setRequired(promptDetails?.required); setWebhookEnabled(promptDetails?.enable_postprocessing_webhook || false); setWebhookUrl(promptDetails?.postprocessing_webhook_url || ""); + setLookupEnabled(!!promptDetails?.lookup_project); + setSelectedLookup(promptDetails?.lookup_project || null); }, [promptDetails, details]); + // Reset lookupsFetched when tool changes so we refetch on next dropdown open + useEffect(() => { + setLookupsFetched(false); + }, [details?.tool_id]); + useEffect(() => { const dropdownItems = [ { @@ -248,6 +331,65 @@ function Header({ ), key: "required", }, + { + label: ( +
+ e.stopPropagation()} + disabled={availableLookups.length === 0 && !lookupEnabled} + > + Enable Look-up{" "} + + + + + {lookupEnabled && availableLookups.length > 0 && ( +
+ +
+ )} + {availableLookups.length === 0 && !lookupLoading && ( +
+ No lookup projects linked at project level. +
+ Link lookups in Settings → Lookups first. +
+ )} +
+ ), + key: "lookup", + }, + { + type: "divider", + }, { label: ( @@ -400,7 +552,16 @@ function Header({ promptDetails={promptDetails} /> )} - + { + if (open) { + fetchAvailableLookups(); + } + }} + >
+ ); + + return ( + + + + ); +} + +LookupReplacementIndicator.propTypes = { + lookupReplacement: PropTypes.shape({ + original_value: PropTypes.oneOfType([PropTypes.string, PropTypes.any]), + enriched_value: PropTypes.oneOfType([PropTypes.string, PropTypes.any]), + field_name: PropTypes.string, + }), +}; + +export { LookupReplacementIndicator }; diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCard.css b/frontend/src/components/custom-tools/prompt-card/PromptCard.css index 9b58a9b7ea..49aa39a0e2 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCard.css +++ b/frontend/src/components/custom-tools/prompt-card/PromptCard.css @@ -310,7 +310,7 @@ } .json-value.clickable { - color: #0097D8; + color: #0097d8; cursor: pointer; } @@ -319,9 +319,52 @@ } .json-value.selected { - color: #5A8300; + color: #5a8300; } -.prompt-output-result{ +.prompt-output-result { font-size: 12px; } + +/* Lookup Enrichment Styles */ +.lookup-enrichment-container { + margin-top: 8px; + border-top: 1px dashed #d9d9d9; + padding-top: 8px; +} + +.lookup-enrichment-container .ant-collapse { + background-color: #f0f7ff; + border: 1px solid #91caff; +} + +.lookup-enrichment-container .ant-collapse-header { + padding: 6px 12px !important; +} + +.lookup-enrichment-container .ant-collapse-content-box { + padding: 8px !important; + background-color: #ffffff; +} + +.lookup-enrichment-table .ant-table-thead > tr > th { + padding: 4px 8px; + font-size: 11px; + background-color: #fafafa; +} + +.lookup-enrichment-table .ant-table-tbody > tr > td { + padding: 4px 8px; + font-size: 12px; +} + +.lookup-info-icon { + color: #1890ff; + cursor: pointer; + font-size: 12px; +} + +.lookup-enrichment-error { + margin-top: 8px; + padding: 4px 0; +} diff --git a/frontend/src/components/custom-tools/prompt-card/PromptOutput.jsx b/frontend/src/components/custom-tools/prompt-card/PromptOutput.jsx index 3286db0799..1dfc0cc6b7 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptOutput.jsx +++ b/frontend/src/components/custom-tools/prompt-card/PromptOutput.jsx @@ -35,6 +35,7 @@ import { DisplayPromptResult } from "./DisplayPromptResult"; import usePromptOutput from "../../../hooks/usePromptOutput"; import { PromptRunTimer } from "./PromptRunTimer"; import { PromptRunCost } from "./PromptRunCost"; +import { LookupReplacementIndicator } from "./LookupReplacementIndicator"; let TableOutput; try { @@ -112,6 +113,9 @@ function PromptOutput({ setOpenExpandModal={setOpenExpandModal} />
+ copyOutputToClipboard( @@ -204,6 +208,9 @@ function PromptOutput({ } />
+ @@ -439,6 +446,11 @@ function PromptOutput({ promptDetails={promptDetails} />
+ diff --git a/frontend/src/components/custom-tools/settings-modal/SettingsModal.jsx b/frontend/src/components/custom-tools/settings-modal/SettingsModal.jsx index 4c3a1eac8a..209d49f96b 100644 --- a/frontend/src/components/custom-tools/settings-modal/SettingsModal.jsx +++ b/frontend/src/components/custom-tools/settings-modal/SettingsModal.jsx @@ -7,6 +7,7 @@ import { DiffOutlined, FileTextOutlined, MessageOutlined, + SearchOutlined, } from "@ant-design/icons"; import SpaceWrapper from "../../widgets/space-wrapper/SpaceWrapper"; @@ -22,6 +23,7 @@ let SummarizeManager = null; const EvaluationManager = null; let ChallengeManager = null; let HighlightManager = null; +let LookupManager = null; try { SummarizeManager = require("../../../plugins/summarize-manager/SummarizeManager").SummarizeManager; @@ -29,6 +31,8 @@ try { require("../../../plugins/challenge-manager/ChallengeManager").ChallengeManager; HighlightManager = require("../../../plugins/highlight-manager/HighlightManager").HighlightManager; + LookupManager = + require("../../../plugins/lookup-manager/LookupManager").LookupManager; } catch { // Component will remain null if it is not present. } @@ -112,6 +116,12 @@ function SettingsModal({ open, setOpen, handleUpdateTool }) { /> ); } + if (LookupManager) { + items.push(getMenuItem("Lookups", 9, )); + listOfComponents[9] = ( + + ); + } setMenuItems(items); setComponents(listOfComponents); }, []); diff --git a/frontend/src/components/logging/log-modal/LogModal.jsx b/frontend/src/components/logging/log-modal/LogModal.jsx index 18de897bc6..261685d7ce 100644 --- a/frontend/src/components/logging/log-modal/LogModal.jsx +++ b/frontend/src/components/logging/log-modal/LogModal.jsx @@ -1,7 +1,7 @@ -import { Button, Modal, Table, Tooltip } from "antd"; +import { Button, Modal, Table, Tag, Tooltip } from "antd"; import { useEffect, useState } from "react"; import PropTypes from "prop-types"; -import { CopyOutlined } from "@ant-design/icons"; +import { CopyOutlined, SearchOutlined } from "@ant-design/icons"; import "./LogModal.css"; import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; @@ -40,6 +40,22 @@ function LogModal({ }); const filterOptions = ["INFO", "WARN", "DEBUG", "ERROR"]; + // Get stage-specific styling for visual differentiation + const getStageStyle = (stage) => { + switch (stage) { + case "LOOKUP": + return { color: "purple", icon: }; + case "RUN": + return { color: "blue", icon: null }; + case "EXTRACT": + return { color: "green", icon: null }; + case "INDEX": + return { color: "orange", icon: null }; + default: + return { color: "default", icon: null }; + } + }; + const fetchExecutionFileLogs = async (executionId, fileId, page) => { try { const url = getUrl(`/execution/${executionId}/logs/`); @@ -100,6 +116,14 @@ function LogModal({ title: "Event Stage", dataIndex: "eventStage", key: "stage", + render: (stage) => { + const stageStyle = getStageStyle(stage); + return ( + + {stage} + + ); + }, }, { title: "Log Level", @@ -186,6 +210,12 @@ function LogModal({ loading={loading} onChange={handleTableChange} sortDirections={["ascend", "descend", "ascend"]} + rowClassName={(record) => { + if (record.eventStage === "LOOKUP") { + return "log-modal-row-lookup"; + } + return ""; + }} /> ); diff --git a/frontend/src/components/lookups/create-project-modal/CreateProjectModal.jsx b/frontend/src/components/lookups/create-project-modal/CreateProjectModal.jsx new file mode 100644 index 0000000000..4edd19c484 --- /dev/null +++ b/frontend/src/components/lookups/create-project-modal/CreateProjectModal.jsx @@ -0,0 +1,70 @@ +import { Form, Input, Modal } from "antd"; +import PropTypes from "prop-types"; + +const { TextArea } = Input; + +export function CreateProjectModal({ open, onCancel, onCreate }) { + const [form] = Form.useForm(); + + const handleSubmit = async () => { + try { + const values = await form.validateFields(); + await onCreate(values); + form.resetFields(); + } catch (error) { + console.error("Validation failed:", error); + } + }; + + return ( + +
+ + + + + +