From 4c063eca4c81e37755a6dc71c1af105cc9683b53 Mon Sep 17 00:00:00 2001 From: Charith Nuwan Bimsara <59943919+nuwangeek@users.noreply.github.com> Date: Mon, 16 Mar 2026 22:00:21 +0530 Subject: [PATCH] Context based response generation workflow (#327) * remove unwanted file * updated changes * fixed requested changes * fixed issue * service workflow implementation without calling service endpoints * fixed requested changes * fixed issues * protocol related requested changes * fixed requested changes * update time tracking * added time tracking and reloacate input guardrail before toolclassifiier * fixed issue * fixed issue * added hybrid search for the service detection * update tool classifier * fixing merge conflicts * fixed issue * optimize first user query response generation time * fixed pr reviewed issues * context based response generation flow * fixed pr review suggested issues * fixed issues --------- Co-authored-by: Thiru Dinesh <56014038+Thirunayan22@users.noreply.github.com> --- docs/CONTEXT_WORKFLOW_GREETING_DETECTION.md | 323 +++++ src/llm_orchestration_service.py | 5 + src/llm_orchestration_service_api.py | 2 +- src/llm_orchestrator_config/stream_config.py | 7 +- src/tool_classifier/classifier.py | 15 +- src/tool_classifier/constants.py | 6 +- src/tool_classifier/context_analyzer.py | 1038 +++++++++++++++++ src/tool_classifier/greeting_constants.py | 40 + .../workflows/context_workflow.py | 387 +++++- src/tool_classifier/workflows/rag_workflow.py | 48 +- .../workflows/service_workflow.py | 17 + src/utils/rate_limiter.py | 138 ++- tests/conftest.py | 9 + tests/test_context_analyzer.py | 979 ++++++++++++++++ tests/test_context_workflow.py | 698 +++++++++++ tests/test_context_workflow_integration.py | 851 ++++++++++++++ 16 files changed, 4427 insertions(+), 136 deletions(-) create mode 100644 docs/CONTEXT_WORKFLOW_GREETING_DETECTION.md create mode 100644 src/tool_classifier/context_analyzer.py create mode 100644 src/tool_classifier/greeting_constants.py create mode 100644 tests/test_context_analyzer.py create mode 100644 tests/test_context_workflow.py create mode 100644 tests/test_context_workflow_integration.py diff --git a/docs/CONTEXT_WORKFLOW_GREETING_DETECTION.md b/docs/CONTEXT_WORKFLOW_GREETING_DETECTION.md new file mode 100644 index 00000000..8a67e841 --- /dev/null +++ b/docs/CONTEXT_WORKFLOW_GREETING_DETECTION.md @@ -0,0 +1,323 @@ +# Context Workflow: Greeting Detection and Conversation History Analysis + +## Overview + +The **Context Workflow (Layer 2)** intercepts user queries that can be answered without searching the knowledge base. It handles two categories: + +1. **Greetings** — Detects and responds to social exchanges (hello, goodbye, thanks) in multiple languages +2. **Conversation history references** — Answers follow-up questions that refer to information already discussed in the session + +When the context workflow can answer, a response is returned immediately, bypassing the RAG pipeline entirely. When it cannot answer, the query falls through to the RAG workflow (Layer 3). + +--- + +## Architecture + +### Position in the Classifier Chain + +``` +User Query + ↓ +Layer 1: SERVICE → External API calls + ↓ (cannot handle) +Layer 2: CONTEXT → Greetings + conversation history ←── This document + ↓ (cannot handle) +Layer 3: RAG → Knowledge base retrieval + ↓ (cannot handle) +Layer 4: OOD → Out-of-domain fallback +``` + +### Key Components + +| Component | File | Responsibility | +|-----------|------|----------------| +| `ContextAnalyzer` | `src/tool_classifier/context_analyzer.py` | LLM-based greeting detection and context analysis | +| `ContextWorkflowExecutor` | `src/tool_classifier/workflows/context_workflow.py` | Orchestrates the workflow, handles streaming/non-streaming | +| `ToolClassifier` | `src/tool_classifier/classifier.py` | Invokes `ContextAnalyzer` during classification and routes to `ContextWorkflowExecutor` | +| `greeting_constants.py` | `src/tool_classifier/greeting_constants.py` | Fallback greeting responses for Estonian and English | + +--- + +## Full Request Flow + +``` +User Query + Conversation History + ↓ +ToolClassifier.classify() + ├─ Layer 1 (SERVICE): Embedding-based intent routing + │ └─ If no service tool matches → route to CONTEXT workflow + │ + └─ ClassificationResult(workflow=CONTEXT) + +ToolClassifier.route_to_workflow() + ├─ Non-streaming → ContextWorkflowExecutor.execute_async() + │ ├─ Phase 1: _detect() → context_analyzer.detect_context() [classification only] + │ ├─ If greeting → return greeting OrchestrationResponse + │ ├─ If can_answer → _generate_response_async() → context_analyzer.generate_context_response() + │ └─ Otherwise → return None (RAG fallback) + │ + └─ Streaming → ContextWorkflowExecutor.execute_streaming() + ├─ Phase 1: _detect() → context_analyzer.detect_context() [classification only] + ├─ If greeting → _stream_greeting() async generator + ├─ If can_answer → _create_history_stream() → context_analyzer.stream_context_response() + └─ Otherwise → return None (RAG fallback) +``` + +--- + +## Phase 1: Detection (Classify Only) + +### LLM Task + +Every query is checked against the **most recent 10 conversation turns** using a single LLM call (`detect_context()`). This phase **does not generate an answer** — it only classifies the query and extracts a relevant context snippet for Phase 2. + +The `ContextDetectionSignature` DSPy signature instructs the LLM to: + +1. Detect if the query is a greeting in any supported language +2. Check if the query references something discussed in the last 10 turns +3. If the query can be answered from history, extract the relevant snippet +4. Do **not** generate the final answer here — detection only + +### LLM Output Format + +The LLM returns a JSON object parsed into `ContextDetectionResult`: + +```json +{ + "is_greeting": false, + "can_answer_from_context": true, + "reasoning": "User is asking about tax rate discussed earlier", + "context_snippet": "Bot confirmed the flat rate is 20%, applying equally to all income brackets." +} +``` + +| Field | Type | Description | +|-------|------|-------------| +| `is_greeting` | `bool` | Whether the query is a greeting | +| `can_answer_from_context` | `bool` | Whether the query can be answered from conversation history | +| `reasoning` | `str` | Brief explanation of the detection decision | +| `context_snippet` | `str \| null` | Relevant excerpt from history for use in Phase 2, or `null` | + +> **Internal field**: `answered_from_summary` (bool, default `False`) is reserved for future summary-based detection paths. + +### Decision After Phase 1 + +``` +is_greeting=True → Phase 2: return greeting response (no LLM call) +can_answer_from_context=True AND snippet set → Phase 2: generate answer from snippet +Otherwise → Fall back to RAG +``` + +--- + +## Phase 2: Response Generation + +### Non-Streaming (`_generate_response_async`) + +Calls `generate_context_response(query, context_snippet)` which uses `ContextResponseGenerationSignature` to produce a complete answer in a single LLM call. Output guardrails are applied before returning the `OrchestrationResponse`. + +### Streaming (`_create_history_stream` → `stream_context_response`) + +Calls `stream_context_response(query, context_snippet)` which uses DSPy native streaming (`dspy.streamify`) with `ContextResponseGenerationSignature`. Tokens are yielded in real time and passed through NeMo Guardrails before being SSE-formatted. + +--- + +--- + +## Greeting Detection + +### Supported Languages + +| Language | Code | +|----------|------| +| Estonian | `et` | +| English | `en` | + +### Supported Greeting Types + +| Type | Estonian Examples | English Examples | +|------|-------------------|-----------------| +| `hello` | Tere, Hei, Tervist, Moi | Hello, Hi, Hey, Good morning | +| `goodbye` | Nägemist, Tšau | Bye, Goodbye, See you, Good night | +| `thanks` | Tänan, Aitäh, Tänud | Thank you, Thanks | +| `casual` | Tere, Tervist | Hey | + +### Greeting Response Generation + +Greeting detection is handled in **Phase 1 (`detect_context`)**, where the LLM classifies whether the query is a greeting and, if so, identifies the language and greeting type. This phase does **not** generate the final natural-language reply. +In **Phase 2**, `ContextWorkflowExecutor` calls `get_greeting_response(...)`, which returns a response based on predefined static templates in `greeting_constants.py`, ensuring the reply is in the detected language. If greeting detection fails or the greeting type is unsupported, the query falls through to the next workflow layer instead of attempting LLM-based greeting generation. +**Greeting response templates (`greeting_constants.py`):** + +```python +GREETINGS_ET = { + "hello": "Tere! Kuidas ma saan sind aidata?", + "goodbye": "Nägemist! Head päeva!", + "thanks": "Palun! Kui on veel küsimusi, küsi julgelt.", + "casual": "Tere! Mida ma saan sinu jaoks teha?", +} + +GREETINGS_EN = { + "hello": "Hello! How can I help you?", + "goodbye": "Goodbye! Have a great day!", + "thanks": "You're welcome! Feel free to ask if you have more questions.", + "casual": "Hey! What can I do for you?", +} +``` + +The fallback greeting type is determined by keyword matching in `_detect_greeting_type()` — checking for `thank/tänan/aitäh`, `bye/goodbye/nägemist/tšau`, before defaulting to `hello`. + +--- + +## Streaming Support + +The context workflow supports both response modes: + +### Non-Streaming (`execute_async`) + +Returns a complete `OrchestrationResponse` object with the answer as a single string. Output guardrails are applied before the response is returned. + +### Streaming (`execute_streaming`) + +Returns an `AsyncIterator[str]` that yields SSE (Server-Sent Events) chunks. + +**Greeting responses** are yielded as a single SSE chunk followed by `END`. + +**History responses** use DSPy native streaming (`dspy.streamify`) with `ContextResponseGenerationSignature`. Tokens are emitted in real time as they arrive from the LLM, then passed through NeMo Guardrails (`stream_with_guardrails`) before being SSE-formatted. If a guardrail violation is detected in a chunk, streaming stops and the violation message is sent instead. + +**SSE Format:** +``` +data: {"chatId": "abc123", "payload": {"content": "Tere! Kuidas ma"}, "timestamp": "...", "sentTo": []} + +data: {"chatId": "abc123", "payload": {"content": " saan sind aidata?"}, "timestamp": "...", "sentTo": []} + +data: {"chatId": "abc123", "payload": {"content": "END"}, "timestamp": "...", "sentTo": []} +``` + +--- + +## Cost Tracking + +LLM token usage and cost is tracked via `get_lm_usage_since()` and stored in `costs_metric` within the workflow executor. Costs are logged via `orchestration_service.log_costs()` at the end of each execution path. + +Two cost keys are tracked separately: + +```python +costs_metric = { + "context_detection": { + # Phase 1: detect_context() — single LLM call + "total_cost": 0.0012, + "total_tokens": 180, + "total_prompt_tokens": 150, + "total_completion_tokens": 30, + "num_calls": 1, + }, + "context_response": { + # Phase 2: generate_context_response() or stream_context_response() + "total_cost": 0.003, + "total_tokens": 140, + "total_prompt_tokens": 100, + "total_completion_tokens": 40, + "num_calls": 1, + }, +} +``` + +Greeting responses skip Phase 2, so only `"context_detection"` cost is populated. + +--- + +--- + +## Error Handling and Fallback + +| Failure Point | Behaviour | +|---------------|-----------| +| Phase 1 LLM call raises exception | `can_answer_from_context=False` → falls back to RAG | +| Phase 1 returns invalid JSON | Logged as warning, all flags default to `False` → falls back to RAG | +| Phase 2 LLM call raises exception | Logged as error, `_generate_response_async` returns `None` → falls back to RAG | +| Phase 2 returns empty answer | Logged as warning → falls back to RAG | +| Output guardrails fail | Logged as warning, response returned without guardrail check | +| Guardrail violation in streaming | `OUTPUT_GUARDRAIL_VIOLATION_MESSAGE` sent, stream terminated | +| `orchestration_service` unavailable | History streaming skipped → `None` returned → RAG fallback | +| `guardrails_adapter` not a `NeMoRailsAdapter` | Logged as warning → cannot stream → RAG fallback | +| Any unhandled exception in executor | Error logged, `execute_async/execute_streaming` returns `None` → RAG fallback via classifier | + +--- + +## Logging + +Key log entries emitted during a request: + +| Level | Message | When | +|-------|---------|------| +| `INFO` | `CONTEXT WORKFLOW (NON-STREAMING) \| Query: '...'` | `execute_async()` entry | +| `INFO` | `CONTEXT WORKFLOW (STREAMING) \| Query: '...'` | `execute_streaming()` entry | +| `INFO` | `CONTEXT DETECTOR: Phase 1 \| Query: '...' \| History: N turns` | `detect_context()` entry | +| `INFO` | `DETECTION RESULT \| Greeting: ... \| Can Answer: ... \| Has snippet: ...` | Phase 1 LLM response parsed | +| `INFO` | `Detection cost \| Total: $... \| Tokens: N` | After Phase 1 cost tracked | +| `INFO` | `Detection: greeting=... can_answer=...` | After `_detect()` returns in executor | +| `INFO` | `CONTEXT GENERATOR: Phase 2 non-streaming \| Query: '...'` | `generate_context_response()` entry | +| `INFO` | `CONTEXT GENERATOR: Phase 2 streaming \| Query: '...'` | `stream_context_response()` entry | +| `INFO` | `Context response streaming complete (final Prediction received)` | DSPy streaming finished | +| `WARNING` | `[chatId] Phase 2 empty answer — fallback to RAG` | Phase 2 returned no content | +| `WARNING` | `[chatId] Guardrails violation in context streaming` | Violation detected mid-stream | +| `WARNING` | `[chatId] Cannot answer from context — falling back to RAG` | Neither phase could answer | + +--- + +## Data Models + +### `ContextDetectionResult` (Phase 1 output) + +```python +class ContextDetectionResult(BaseModel): + is_greeting: bool # True if query is a greeting + can_answer_from_context: bool # True if query can be answered from last 10 turns + reasoning: str # LLM's brief explanation + answered_from_summary: bool # Reserved; always False in current workflow + context_snippet: Optional[str] # Relevant excerpt for Phase 2 generation, or None +``` + +### `ContextDetectionSignature` (DSPy — Phase 1) + +| Field | Type | Description | +|-------|------|-------------| +| `conversation_history` | Input | Last 10 turns formatted as JSON | +| `user_query` | Input | Current user query | +| `detection_result` | Output | JSON with `is_greeting`, `can_answer_from_context`, `reasoning`, `context_snippet` | + +> Detection only — **no answer generated here**. + +### `ContextResponseGenerationSignature` (DSPy — Phase 2) + +| Field | Type | Description | +|-------|------|-------------| +| `context_snippet` | Input | Relevant excerpt from Phase 1 | +| `user_query` | Input | Current user query | +| `answer` | Output | Natural language response in the same language as the query | + +--- + +## Decision Summary Table + +| Scenario | Phase 1 LLM Calls | Phase 2 LLM Calls | Outcome | +|----------|--------------------|--------------------|---------| +| Greeting detected | 1 (`detect_context`) | 0 (static response) | Context responds (greeting) | +| Follow-up answerable from last 10 turns | 1 (`detect_context`) | 1 (`generate_context_response` or `stream_context_response`) | Context responds | +| Cannot answer from last 10 turns | 1 (`detect_context`) | 0 | Falls back to RAG | +| Phase 1 LLM error / JSON parse failure | — | 0 | Falls back to RAG | +| Phase 2 LLM error or empty answer | 1 | — | Falls back to RAG | + +--- + +## File Reference + +| File | Purpose | +|------|---------| +| `src/tool_classifier/context_analyzer.py` | Core LLM analysis logic (all three steps) | +| `src/tool_classifier/workflows/context_workflow.py` | Workflow executor (streaming + non-streaming) | +| `src/tool_classifier/classifier.py` | Classification layer that invokes context analysis | +| `src/tool_classifier/greeting_constants.py` | Static fallback greeting responses (ET/EN) | +| `tests/test_context_analyzer.py` | Unit tests for `ContextAnalyzer` | +| `tests/test_context_workflow.py` | Unit tests for `ContextWorkflowExecutor` | +| `tests/test_context_workflow_integration.py` | Integration tests for the full classify → route → execute chain | \ No newline at end of file diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index 7f7432fc..78899870 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -639,11 +639,13 @@ async def stream_orchestration_response( ) # Classify query to determine workflow + start_time = time.time() classification = await self.tool_classifier.classify( query=request.message, conversation_history=request.conversationHistory, language=detected_language, ) + time_metric["classifier.classify"] = time.time() - start_time logger.info( f"[{request.chatId}] [{stream_ctx.stream_id}] Classification: {classification.workflow.value} " @@ -652,11 +654,14 @@ async def stream_orchestration_response( # Route to appropriate workflow (streaming) # route_to_workflow returns AsyncIterator[str] when is_streaming=True + start_time = time.time() stream_result = await self.tool_classifier.route_to_workflow( classification=classification, request=request, is_streaming=True, + time_metric=time_metric, ) + time_metric["classifier.route"] = time.time() - start_time async for sse_chunk in stream_result: yield sse_chunk diff --git a/src/llm_orchestration_service_api.py b/src/llm_orchestration_service_api.py index 0e9b1273..110c2991 100644 --- a/src/llm_orchestration_service_api.py +++ b/src/llm_orchestration_service_api.py @@ -71,7 +71,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: if StreamConfig.RATE_LIMIT_ENABLED: app.state.rate_limiter = RateLimiter( requests_per_minute=StreamConfig.RATE_LIMIT_REQUESTS_PER_MINUTE, - tokens_per_second=StreamConfig.RATE_LIMIT_TOKENS_PER_SECOND, + tokens_per_minute=StreamConfig.RATE_LIMIT_TOKENS_PER_MINUTE, ) logger.info("Rate limiter initialized successfully") else: diff --git a/src/llm_orchestrator_config/stream_config.py b/src/llm_orchestrator_config/stream_config.py index ad193387..84e5edd5 100644 --- a/src/llm_orchestrator_config/stream_config.py +++ b/src/llm_orchestrator_config/stream_config.py @@ -21,8 +21,7 @@ class StreamConfig: # Rate Limiting Configuration RATE_LIMIT_ENABLED: bool = True # Enable/disable rate limiting - RATE_LIMIT_REQUESTS_PER_MINUTE: int = 10 # Max requests per user per minute - RATE_LIMIT_TOKENS_PER_SECOND: int = ( - 100 # Max tokens per user per second (burst control) - ) + RATE_LIMIT_REQUESTS_PER_MINUTE: int = 20 # Max requests per user per minute + RATE_LIMIT_TOKENS_PER_MINUTE: int = 40_000 # Max tokens per user per minute RATE_LIMIT_CLEANUP_INTERVAL: int = 300 # Cleanup old entries every 5 minutes + RATE_LIMIT_TOKEN_WINDOW_SECONDS: int = 60 # Sliding window size for token tracking diff --git a/src/tool_classifier/classifier.py b/src/tool_classifier/classifier.py index f18ef3ec..1ada8940 100644 --- a/src/tool_classifier/classifier.py +++ b/src/tool_classifier/classifier.py @@ -57,9 +57,9 @@ class ToolClassifier: def __init__( self, - llm_manager: Any, - orchestration_service: Any, - ): + llm_manager: Any, # noqa: ANN401 + orchestration_service: Any, # noqa: ANN401 + ) -> None: """ Initialize tool classifier with required dependencies. @@ -88,6 +88,7 @@ def __init__( ) self.context_workflow = ContextWorkflowExecutor( llm_manager=llm_manager, + orchestration_service=orchestration_service, ) self.rag_workflow = RAGWorkflowExecutor( orchestration_service=orchestration_service, @@ -622,7 +623,7 @@ def _get_workflow_executor(self, workflow_type: WorkflowType) -> Any: async def _execute_with_fallback_async( self, - workflow: Any, + workflow: Any, # noqa: ANN401 request: OrchestrationRequest, context: Dict[str, Any], start_layer: WorkflowType, @@ -696,11 +697,11 @@ async def _execute_with_fallback_async( if rag_result is not None: return rag_result else: - raise RuntimeError("RAG workflow returned None unexpectedly") + raise RuntimeError("RAG workflow returned None unexpectedly") from e async def _execute_with_fallback_streaming( self, - workflow: Any, + workflow: Any, # noqa: ANN401 request: OrchestrationRequest, context: Dict[str, Any], start_layer: WorkflowType, @@ -782,4 +783,4 @@ async def _execute_with_fallback_streaming( async for chunk in streaming_result: yield chunk else: - raise RuntimeError("RAG workflow returned None unexpectedly") + raise RuntimeError("RAG workflow returned None unexpectedly") from e diff --git a/src/tool_classifier/constants.py b/src/tool_classifier/constants.py index 65f30332..d839e2cf 100644 --- a/src/tool_classifier/constants.py +++ b/src/tool_classifier/constants.py @@ -70,13 +70,15 @@ DENSE_SEARCH_TOP_K = 3 """Number of top results from dense-only search for relevance scoring.""" -DENSE_MIN_THRESHOLD = 0.38 +# DENSE_MIN_THRESHOLD = 0.38 +DENSE_MIN_THRESHOLD = 0.5 """Minimum dense cosine similarity to consider a result as a potential match. Below this → skip SERVICE entirely, go to CONTEXT/RAG. Note: Multilingual embeddings (Estonian/short queries) typically yield lower cosine scores (0.25-0.40) than English. Tune based on observed scores.""" -DENSE_HIGH_CONFIDENCE_THRESHOLD = 0.40 +# DENSE_HIGH_CONFIDENCE_THRESHOLD = 0.40 +DENSE_HIGH_CONFIDENCE_THRESHOLD = 0.55 """Dense cosine similarity for high-confidence service classification. Above this AND score gap is large → SERVICE without LLM confirmation.""" diff --git a/src/tool_classifier/context_analyzer.py b/src/tool_classifier/context_analyzer.py new file mode 100644 index 00000000..3584683a --- /dev/null +++ b/src/tool_classifier/context_analyzer.py @@ -0,0 +1,1038 @@ +"""Context analyzer for greeting detection and conversation history analysis.""" + +from __future__ import annotations + +from typing import Any, AsyncIterator, Dict, List, Optional +import json +import dspy +import dspy.streaming +from dspy.streaming import StreamListener +from loguru import logger +from pydantic import BaseModel, Field + +from src.utils.cost_utils import get_lm_usage_since +from src.tool_classifier.greeting_constants import get_greeting_response + + +class ContextAnalysisResult(BaseModel): + """Result of context analysis.""" + + is_greeting: bool = Field( + ..., description="Whether the query is a greeting (hello, goodbye, thanks)" + ) + can_answer_from_context: bool = Field( + ..., description="Whether the query can be answered from conversation history" + ) + answer: Optional[str] = Field( + None, description="Generated response (greeting or context-based answer)" + ) + reasoning: str = Field(..., description="Brief explanation of the analysis") + answered_from_summary: bool = Field( + default=False, + description="Whether the answer was derived from a conversation summary (older turns beyond recent 10)", + ) + + +class ContextAnalysisSignature(dspy.Signature): + """Analyze user query for greeting detection and conversation history references. + + This signature instructs the LLM to: + 1. Detect greetings in multiple languages (Estonian, English) + 2. Check if query references conversation history + 3. Generate appropriate responses or extract answers from history + + Supported greeting types: + - hello: Tere, Hello, Hi, Hei, Hey, Moi, Good morning, Good afternoon, Good evening + - goodbye: Nägemist, Bye, Goodbye, See you, Good night + - thanks: Tänan, Aitäh, Thank you, Thanks, Much appreciated + - casual: Tervist, Tšau, Moikka + + The LLM should respond in the SAME language as the user's query. + """ + + conversation_history: str = dspy.InputField( + desc="Recent conversation history (last 10 turns) formatted as JSON" + ) + user_query: str = dspy.InputField( + desc="Current user query to analyze for greetings or context references" + ) + analysis_result: str = dspy.OutputField( + desc='JSON object with: {"is_greeting": bool, "can_answer_from_context": bool, "answer": str|null, "reasoning": str}. ' + "For greetings, generate a friendly response in the same language. " + "For context references, extract the answer from conversation history if available." + ) + + +class ConversationSummarySignature(dspy.Signature): + """Generate a concise summary of conversation history. + + Summarize the key topics, facts, decisions, and information discussed + in the conversation. Preserve specific details like numbers, names, + dates, and other factual information that might be referenced later. + + The summary should be in the SAME language as the conversation. + """ + + conversation_history: str = dspy.InputField( + desc="Conversation history formatted as JSON to summarize" + ) + summary: str = dspy.OutputField( + desc="Concise summary capturing key topics, facts, and information discussed. " + "Preserve specific details (numbers, names, dates) that could be referenced later." + ) + + +class SummaryAnalysisSignature(dspy.Signature): + """Analyze if a user query can be answered from a conversation summary. + + Given a summary of earlier conversation and the current user query, + determine if the query references information from the summarized conversation. + If yes, generate an appropriate answer based on the summary. + + The response should be in the SAME language as the user's query. + """ + + conversation_summary: str = dspy.InputField( + desc="Summary of earlier conversation history" + ) + user_query: str = dspy.InputField( + desc="Current user query to check against the conversation summary" + ) + analysis_result: str = dspy.OutputField( + desc='JSON object with: {"can_answer_from_context": bool, "answer": str|null, "reasoning": str}. ' + "If the query references information from the summary, extract/generate the answer. " + "If the summary does not contain relevant information, set can_answer_from_context to false." + ) + + +class ContextDetectionResult(BaseModel): + """Result of Phase 1 context detection (classify only, no answer generation).""" + + is_greeting: bool = Field(..., description="Whether the query is a greeting") + greeting_type: str = Field( + default="hello", + description="Type of greeting: hello, goodbye, thanks, or casual", + ) + can_answer_from_context: bool = Field( + ..., description="Whether the query can be answered from conversation history" + ) + reasoning: str = Field(..., description="Brief explanation of the detection") + answered_from_summary: bool = Field( + default=False, + description="Whether summary analysis was used for detection", + ) + # Relevant context snippet extracted for use in Phase 2 generation + context_snippet: Optional[str] = Field( + default=None, + description="The relevant part of history/summary to answer from, for Phase 2", + ) + + +class ContextDetectionSignature(dspy.Signature): + """Detect if a user query is a greeting or can be answered from conversation history. + + Phase 1 (detection only): classify the query WITHOUT generating the answer. + + Supported greeting types: + - hello: Tere, Hello, Hi, Hei, Hey, Moi, Good morning/afternoon/evening + - goodbye: Nägemist, Bye, Goodbye, See you, Good night + - thanks: Tänan, Aitäh, Thank you, Thanks, Much appreciated + - casual: Tervist, Tšau, Moikka + + Do NOT generate the answer here — only detect and extract a relevant context snippet. + """ + + conversation_history: str = dspy.InputField( + desc="Recent conversation history (last 10 turns) formatted as JSON" + ) + user_query: str = dspy.InputField(desc="Current user query to classify") + detection_result: str = dspy.OutputField( + desc='JSON object with: {"is_greeting": bool, "greeting_type": str, "can_answer_from_context": bool, ' + '"reasoning": str, "context_snippet": str|null}. ' + 'greeting_type must be one of: "hello", "goodbye", "thanks", "casual" — ' + 'set it only when is_greeting is true, defaulting to "hello" otherwise. ' + "context_snippet should contain the relevant excerpt from history if can_answer_from_context is true, " + "or null otherwise. Do NOT generate the final answer — only detect and extract." + ) + + +class ContextResponseGenerationSignature(dspy.Signature): + """Generate a response to a user query based on conversation history context. + + Phase 2 (generation): given the user query and relevant context, generate a helpful answer. + Respond in the SAME language as the user query. + """ + + context_snippet: str = dspy.InputField( + desc="Relevant excerpt from conversation history or summary that contains the answer" + ) + user_query: str = dspy.InputField(desc="Current user query to answer") + answer: str = dspy.OutputField( + desc="A helpful, natural response to the user query based on the provided context. " + "Respond in the same language as the user query." + ) + + +class ContextAnalyzer: + """ + Analyzer for greeting detection and context-based question answering. + + This class uses an LLM to intelligently detect: + - Greetings in multiple languages (Estonian, English) + - Questions that reference conversation history + - Generate appropriate responses based on context + + Example Usage: + analyzer = ContextAnalyzer(llm_manager) + result = await analyzer.analyze_context( + query="Tere!", + conversation_history=[], + language="et" + ) + # result.is_greeting = True + # result.answer = "Tere! Kuidas ma saan sind aidata?" + """ + + def __init__(self, llm_manager: Any) -> None: # noqa: ANN401 + """ + Initialize the context analyzer. + + Args: + llm_manager: LLM manager instance for making LLM calls + """ + self.llm_manager = llm_manager + self._module: Optional[dspy.Module] = None + self._summary_module: Optional[dspy.Module] = None + self._summary_analysis_module: Optional[dspy.Module] = None + # Phase 1 & 2 modules for two-phase detection+generation flow + self._detection_module: Optional[dspy.Module] = None + self._response_generation_module: Optional[dspy.Module] = None + logger.info("Context analyzer initialized") + + def _format_conversation_history( + self, conversation_history: List[Dict[str, Any]], max_turns: int = 10 + ) -> str: + """ + Format conversation history for LLM consumption. + + Args: + conversation_history: List of conversation items with authorRole, message, timestamp + max_turns: Maximum number of turns to include (default: 10) + + Returns: + Formatted conversation history as JSON string + """ + # Take last N turns + recent_history = ( + conversation_history[-max_turns:] if conversation_history else [] + ) + + # Format as readable JSON + formatted_history = [ + { + "role": item.get("authorRole", "unknown"), + "message": item.get("message", ""), + "timestamp": item.get("timestamp", ""), + } + for item in recent_history + ] + + if not formatted_history: + return "[]" + + return json.dumps(formatted_history, ensure_ascii=False, indent=2) + + @staticmethod + def _merge_cost_dicts( + cost1: Dict[str, Any], cost2: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Merge two cost dictionaries by summing numeric values. + + Args: + cost1: First cost dictionary + cost2: Second cost dictionary + + Returns: + Merged cost dictionary with summed values + """ + return { + "total_cost": cost1.get("total_cost", 0) + cost2.get("total_cost", 0), + "total_tokens": cost1.get("total_tokens", 0) + cost2.get("total_tokens", 0), + "total_prompt_tokens": cost1.get("total_prompt_tokens", 0) + + cost2.get("total_prompt_tokens", 0), + "total_completion_tokens": cost1.get("total_completion_tokens", 0) + + cost2.get("total_completion_tokens", 0), + "num_calls": cost1.get("num_calls", 0) + cost2.get("num_calls", 0), + } + + async def detect_context( + self, + query: str, + conversation_history: List[Dict[str, Any]], + ) -> tuple[ContextDetectionResult, Dict[str, Any]]: + """ + Phase 1: Detect if query is a greeting or can be answered from history. + + Classify-only — no answer generated here. Returns a ContextDetectionResult + with is_greeting/can_answer_from_context flags and a context_snippet for + Phase 2 generation. + + Args: + query: User query to classify + conversation_history: Full conversation history + + Returns: + Tuple of (ContextDetectionResult, cost_dict) + """ + total_turns = len(conversation_history) + logger.info( + f"CONTEXT DETECTOR: Phase 1 | Query: '{query[:100]}' | " + f"History: {total_turns} turns" + ) + + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning(f"Failed to get LM history length for detection: {e}") + + formatted_history = self._format_conversation_history(conversation_history) + + self.llm_manager.ensure_global_config() + try: + with self.llm_manager.use_task_local(): + if self._detection_module is None: + self._detection_module = dspy.ChainOfThought( + ContextDetectionSignature + ) + response = self._detection_module( + conversation_history=formatted_history, + user_query=query, + ) + + try: + detection_data = json.loads(response.detection_result) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse detection response: {response.detection_result[:100]}" + ) + detection_data = { + "is_greeting": False, + "can_answer_from_context": False, + "reasoning": "Failed to parse detection response", + "context_snippet": None, + } + + result = ContextDetectionResult( + is_greeting=detection_data.get("is_greeting", False), + greeting_type=detection_data.get("greeting_type", "hello"), + can_answer_from_context=detection_data.get( + "can_answer_from_context", False + ), + reasoning=detection_data.get("reasoning", "Detection completed"), + context_snippet=detection_data.get("context_snippet"), + ) + logger.info( + f"DETECTION RESULT | Greeting: {result.is_greeting} | " + f"Can Answer: {result.can_answer_from_context} | " + f"Has snippet: {result.context_snippet is not None}" + ) + + except Exception as e: + logger.error(f"Context detection failed: {e}", exc_info=True) + result = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=False, + reasoning=f"Detection error: {str(e)}", + ) + + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Detection cost | Total: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)}" + ) + return result, cost_dict + + async def detect_context_with_summary_fallback( + self, + query: str, + conversation_history: List[Dict[str, Any]], + ) -> tuple[ContextDetectionResult, Dict[str, Any]]: + """ + Phase 1 with summary fallback: detect if query can be answered from history. + + Implements a 3-step flow: + 1. Check the last 10 turns via detect_context(). + 2. If cannot answer AND total history > 10 turns: + - Generate a concise summary of the older turns (everything before the last 10). + - Check whether the query can be answered from that summary. + 3. If still cannot answer, return can_answer=False (workflow falls back to RAG). + + When the summary path succeeds, the returned ContextDetectionResult has: + - can_answer_from_context=True + - answered_from_summary=True + - context_snippet set to the answer extracted from the summary, so that + Phase 2 (stream_context_response / generate_context_response) can use it + directly as the context for response generation. + + Args: + query: User query to classify + conversation_history: Full conversation history + + Returns: + Tuple of (ContextDetectionResult, cost_dict) + """ + total_turns = len(conversation_history) + + # Step 1: check the most recent 10 turns + result, cost_dict = await self.detect_context( + query=query, conversation_history=conversation_history + ) + + # If already answered or it's a greeting, return immediately + if result.is_greeting or result.can_answer_from_context: + return result, cost_dict + + # Step 2 & 3: if history exceeds 10 turns, try summary-based detection + if total_turns > 10: + logger.info( + f"History has {total_turns} turns (> 10) | " + f"Cannot answer from recent 10 | Attempting summary-based detection" + ) + older_history = conversation_history[:-10] + logger.info(f"Summarizing {len(older_history)} older turns") + + try: + summary, summary_cost = await self._generate_conversation_summary( + older_history + ) + cost_dict = self._merge_cost_dicts(cost_dict, summary_cost) + + if summary: + summary_result, analysis_cost = await self._analyze_from_summary( + query=query, summary=summary + ) + cost_dict = self._merge_cost_dicts(cost_dict, analysis_cost) + + if summary_result.can_answer_from_context and summary_result.answer: + logger.info( + f"DETECTION: Can answer from summary | " + f"Reasoning: {summary_result.reasoning}" + ) + # Surface the summary-derived answer as context_snippet so + # Phase 2 can generate a polished response from it. + return ContextDetectionResult( + is_greeting=False, + can_answer_from_context=True, + reasoning=summary_result.reasoning, + context_snippet=summary_result.answer, + answered_from_summary=True, + ), cost_dict + + logger.info( + "Cannot answer from summary either | Falling back to RAG" + ) + else: + logger.warning( + "Summary generation returned empty | Falling back to RAG" + ) + + except Exception as e: + logger.error(f"Summary-based detection failed: {e}", exc_info=True) + else: + logger.info( + f"History has {total_turns} turns (<= 10) | " + f"No summary needed | Falling back to RAG" + ) + + return result, cost_dict + + @staticmethod + def _yield_in_chunks(text: str, chunk_size: int = 5) -> list[str]: + """Split text into word-group chunks for simulated streaming.""" + words = text.split() + chunks = [] + for i in range(0, len(words), chunk_size): + group = words[i : i + chunk_size] + trailing = " " if i + chunk_size < len(words) else "" + chunks.append(" ".join(group) + trailing) + return chunks + + async def stream_context_response( + self, + query: str, + context_snippet: str, + ) -> AsyncIterator[str]: + """ + Phase 2 (streaming): Stream a generated answer using DSPy native streaming. + + Creates a fresh streamify predictor per call (avoids stale StreamListener + issues that occur when the cached predictor is reused across calls). + + Fallback chain: + 1. DSPy streamify → yield StreamResponse tokens as they arrive. + 2. If no stream tokens received but final Prediction has an answer, + yield it in word-group chunks. + 3. If that is also empty, call generate_context_response() directly + and yield the result in word-group chunks. + + Args: + query: The user query to answer + context_snippet: Relevant context extracted during Phase 1 detection + + Yields: + Token strings as they arrive from the LLM (or simulated chunks) + """ + logger.info(f"CONTEXT GENERATOR: Phase 2 streaming | Query: '{query[:100]}'") + + self.llm_manager.ensure_global_config() + output_stream = None + stream_started = False + prediction_answer: Optional[str] = None + try: + with self.llm_manager.use_task_local(): + # Always create a fresh StreamListener + streamified predictor so that + # the listener's internal state is clean for this call. + answer_listener = StreamListener(signature_field_name="answer") + stream_predictor: Any = dspy.streamify( + dspy.Predict(ContextResponseGenerationSignature), + stream_listeners=[answer_listener], + ) + output_stream = stream_predictor( + context_snippet=context_snippet, + user_query=query, + ) + + async for chunk in output_stream: + if isinstance(chunk, dspy.streaming.StreamResponse): + if chunk.signature_field_name == "answer": + stream_started = True + yield chunk.chunk + elif isinstance(chunk, dspy.Prediction): + logger.info( + "Context response streaming complete (final Prediction received)" + ) + if not stream_started: + # Tokens didn't stream — extract answer from the Prediction + # directly as first fallback before leaving the LM context. + prediction_answer = getattr(chunk, "answer", "") or "" + + except GeneratorExit: + raise + except Exception as e: + logger.error(f"Error during context response streaming: {e}") + raise + finally: + if output_stream is not None: + try: + await output_stream.aclose() + except Exception as cleanup_error: + logger.debug( + f"Error during context stream cleanup: {cleanup_error}" + ) + + if stream_started: + return + + # Fallback 1: answer was in the final Prediction but didn't stream as tokens + if prediction_answer: + logger.warning( + "Stream tokens not received — yielding answer from final Prediction in chunks." + ) + for text_chunk in self._yield_in_chunks(prediction_answer): + yield text_chunk + return + + # Fallback 2: Prediction had no answer either — call generate_context_response + logger.warning( + "No answer from streamify — falling back to generate_context_response." + ) + fallback_answer, _ = await self.generate_context_response( + query=query, context_snippet=context_snippet + ) + if fallback_answer: + for text_chunk in self._yield_in_chunks(fallback_answer): + yield text_chunk + else: + logger.error("All Phase 2 streaming fallbacks exhausted — empty response.") + + async def generate_context_response( + self, + query: str, + context_snippet: str, + ) -> tuple[str, Dict[str, Any]]: + """ + Phase 2 (non-streaming): Generate a complete answer from context snippet. + + Used for non-streaming mode after Phase 1 detection confirms context can answer. + + Args: + query: The user query to answer + context_snippet: Relevant context extracted during Phase 1 detection + + Returns: + Tuple of (answer_text, cost_dict) + """ + logger.info( + f"CONTEXT GENERATOR: Phase 2 non-streaming | Query: '{query[:100]}'" + ) + + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning(f"Failed to get LM history length for generation: {e}") + + self.llm_manager.ensure_global_config() + answer = "" + try: + with self.llm_manager.use_task_local(): + if self._response_generation_module is None: + self._response_generation_module = dspy.ChainOfThought( + ContextResponseGenerationSignature + ) + response = self._response_generation_module( + context_snippet=context_snippet, + user_query=query, + ) + answer = getattr(response, "answer", "") or "" + logger.info( + f"Context response generated: {len(answer)} chars | " + f"Preview: '{answer[:150]}'" + ) + except Exception as e: + logger.error(f"Context response generation failed: {e}", exc_info=True) + + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Generation cost | Total: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)}" + ) + return answer, cost_dict + + async def _generate_conversation_summary( + self, + older_history: List[Dict[str, Any]], + ) -> tuple[str, Dict[str, Any]]: + """ + Generate a concise summary of older conversation turns. + + Args: + older_history: Conversation turns older than the recent 10 + + Returns: + Tuple of (summary_text, cost_dict) + """ + logger.info(f"SUMMARY GENERATION: Summarizing {len(older_history)} older turns") + + # Track costs + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning(f"Failed to get LM history length for summary: {e}") + + # Format older history + formatted_history = self._format_conversation_history( + older_history, max_turns=len(older_history) + ) + + # Initialize and run summary module within task-local LLM config + try: + self.llm_manager.ensure_global_config() + with self.llm_manager.use_task_local(): + if self._summary_module is None: + self._summary_module = dspy.ChainOfThought( + ConversationSummarySignature + ) + response = self._summary_module( + conversation_history=formatted_history, + ) + summary = response.summary + logger.info( + f"Summary generated: {len(summary)} chars | " + f"Preview: '{summary[:150]}...'" + ) + except Exception as e: + logger.error(f"Summary generation failed: {e}", exc_info=True) + summary = "" + + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Summary cost | Total: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)}" + ) + + return summary, cost_dict + + async def _analyze_from_summary( + self, + query: str, + summary: str, + ) -> tuple[ContextAnalysisResult, Dict[str, Any]]: + """ + Check if a query can be answered from a conversation summary. + + Args: + query: User query to check + summary: Summary of older conversation turns + + Returns: + Tuple of (ContextAnalysisResult, cost_dict) + """ + logger.info( + f"SUMMARY ANALYSIS: Checking query against summary | Query: '{query[:100]}'" + ) + + # Ensure DSPy is configured and run analysis in a task-local LM context + self.llm_manager.ensure_global_config() + history_length_before = 0 + with self.llm_manager.use_task_local(): + # Track costs + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning( + f"Failed to get LM history length for summary analysis: {e}" + ) + # Initialize summary analysis module if needed + if self._summary_analysis_module is None: + self._summary_analysis_module = dspy.ChainOfThought( + SummaryAnalysisSignature + ) + try: + response = self._summary_analysis_module( + conversation_summary=summary, + user_query=query, + ) + # Parse JSON response + try: + analysis_data = json.loads(response.analysis_result) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse summary analysis response: " + f"{response.analysis_result[:100]}" + ) + analysis_data = { + "can_answer_from_context": False, + "answer": None, + "reasoning": "Failed to parse summary analysis response", + } + can_answer = analysis_data.get("can_answer_from_context", False) + answer = analysis_data.get("answer") + reasoning = analysis_data.get("reasoning", "Summary analysis completed") + logger.debug( + f"Raw summary analysis parsed | " + f"can_answer_from_context={can_answer} | " + f"has_answer={answer is not None}" + ) + # Only mark as answerable when both the LLM flag is True AND an answer exists + can_answer_from_context = bool(can_answer and answer) + result = ContextAnalysisResult( + is_greeting=False, + can_answer_from_context=can_answer_from_context, + answer=answer, + reasoning=reasoning, + answered_from_summary=can_answer_from_context, + ) + logger.info( + "SUMMARY ANALYSIS RESULT | " + f"Can answer from summary: {can_answer} | " + f"Can answer from context: {can_answer_from_context} | " + f"Has answer: {answer is not None} | Reasoning: {reasoning}" + ) + except Exception as e: + logger.error(f"Summary analysis failed: {e}", exc_info=True) + result = ContextAnalysisResult( + is_greeting=False, + can_answer_from_context=False, + answer=None, + reasoning=f"Summary analysis error: {str(e)}", + ) + + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Summary analysis cost | Total: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)}" + ) + + return result, cost_dict + + async def analyze_context( + self, + query: str, + conversation_history: List[Dict[str, Any]], + language: str = "et", + ) -> tuple[ContextAnalysisResult, Dict[str, Any]]: + """ + Analyze if query is a greeting or can be answered from conversation history. + + Implements a 3-step flow: + 1. Analyze recent 10 turns for greetings and history-answerable queries + 2. If cannot answer and total history > 10 turns, generate a summary of older turns + 3. Check if the query can be answered from the summary + 4. If still cannot answer, return cannot-answer result (falls through to RAG) + + Args: + query: User query to analyze + conversation_history: List of conversation items + language: Language code (et, en) for response generation + + Returns: + Tuple of (ContextAnalysisResult, cost_dict) + """ + total_turns = len(conversation_history) + logger.info( + f"CONTEXT ANALYZER: Starting analysis | Query: '{query[:100]}' | " + f"History: {total_turns} turns | Language: {language}" + ) + + # STEP 1: Analyze recent 10 turns (existing behavior) + result, cost_dict = await self._analyze_recent_history( + query=query, + conversation_history=conversation_history, + language=language, + ) + + # If greeting or can answer from recent history, return immediately + if (result.is_greeting or result.can_answer_from_context) and result.answer: + logger.info( + f"Answered from recent history | " + f"Greeting: {result.is_greeting} | From context: {result.can_answer_from_context}" + ) + return result, cost_dict + + # STEP 2 & 3: If history > 10 turns and couldn't answer from recent, try summary + if total_turns > 10: + logger.info( + f"History exceeds 10 turns ({total_turns} total) | " + f"Cannot answer from recent 10 | Attempting summary-based analysis" + ) + + # Get older turns (everything before the last 10) + older_history = conversation_history[:-10] + logger.info(f"Older history: {len(older_history)} turns to summarize") + + try: + # Generate summary of older turns + summary, summary_cost = await self._generate_conversation_summary( + older_history + ) + cost_dict = self._merge_cost_dicts(cost_dict, summary_cost) + + if summary: + # Analyze query against summary + summary_result, analysis_cost = await self._analyze_from_summary( + query=query, + summary=summary, + ) + cost_dict = self._merge_cost_dicts(cost_dict, analysis_cost) + + if summary_result.can_answer_from_context and summary_result.answer: + logger.info( + f"Answered from conversation summary | " + f"Reasoning: {summary_result.reasoning}" + ) + return summary_result, cost_dict + + logger.info( + "Cannot answer from summary either | Falling back to RAG" + ) + else: + logger.warning( + "Summary generation returned empty | Falling back to RAG" + ) + + except Exception as e: + logger.error(f"Summary-based analysis failed: {e}", exc_info=True) + else: + logger.info( + f"History has {total_turns} turns (<= 10) | " + f"No summary needed | Falling back to RAG" + ) + + # Cannot answer from context at all + logger.info( + f"CONTEXT ANALYZER FINAL DECISION | " + f"can_answer_from_context={result.can_answer_from_context} | " + f"is_greeting={result.is_greeting} | " + f"answered_from_summary={result.answered_from_summary} | " + f"has_answer={result.answer is not None} | " + f"action={'RESPOND' if (result.can_answer_from_context or result.is_greeting) and result.answer else 'FALLBACK_TO_RAG'}" + ) + return result, cost_dict + + async def _analyze_recent_history( + self, + query: str, + conversation_history: List[Dict[str, Any]], + language: str = "et", + ) -> tuple[ContextAnalysisResult, Dict[str, Any]]: + """ + Analyze the query against the most recent conversation turns. + + This is the original analysis logic extracted into its own method. + Checks for greetings and history-answerable queries in the last 10 turns. + + Args: + query: User query to analyze + conversation_history: Full conversation history (last 10 will be used) + language: Language code for response generation + + Returns: + Tuple of (ContextAnalysisResult, cost_dict) + """ + logger.info("STEP 1: Analyzing recent history (last 10 turns)") + + # Track LLM history for cost calculation + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning(f"Failed to get LM history length: {e}") + + # Format conversation history (last 10 turns) + formatted_history = self._format_conversation_history(conversation_history) + + # Ensure LM is configured and use task-local context for DSPy operations + self.llm_manager.ensure_global_config() + try: + with self.llm_manager.use_task_local(): + # Initialize DSPy module if not already done + if self._module is None: + self._module = dspy.ChainOfThought(ContextAnalysisSignature) + # Call LLM for analysis + logger.info( + "Calling LLM for context analysis (greeting/history check)..." + ) + response = self._module( + conversation_history=formatted_history, + user_query=query, + ) + + # Parse the analysis result + analysis_json = response.analysis_result + + # Try to parse JSON response + try: + analysis_data = json.loads(analysis_json) + logger.debug( + f"Raw LLM response parsed | " + f"can_answer_from_context={analysis_data.get('can_answer_from_context')} | " + f"is_greeting={analysis_data.get('is_greeting')} | " + f"has_answer={analysis_data.get('answer') is not None}" + ) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse LLM response as JSON: {analysis_json[:100]}" + ) + # Fallback: treat as cannot answer + analysis_data = { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Failed to parse LLM response", + } + + # Create result object + result = ContextAnalysisResult( + is_greeting=analysis_data.get("is_greeting", False), + can_answer_from_context=analysis_data.get( + "can_answer_from_context", False + ), + answer=analysis_data.get("answer"), + reasoning=analysis_data.get("reasoning", "Analysis completed"), + ) + + logger.info( + f"ANALYSIS RESULT | Greeting: {result.is_greeting} | " + f"Can Answer from Context: {result.can_answer_from_context} | " + f"Answer: {result.answer[:100] if result.answer else None} | " + f"Reasoning: {result.reasoning}" + ) + + # If greeting detected but LLM didn't generate an answer, use fallback + if result.is_greeting and result.answer is None: + greeting_type = self._detect_greeting_type(query) + fallback_answer = get_greeting_response(greeting_type, language) + result = ContextAnalysisResult( + is_greeting=result.is_greeting, + can_answer_from_context=result.can_answer_from_context, + answer=fallback_answer, + reasoning=result.reasoning, + ) + + except Exception as e: + logger.error(f"Context analysis failed: {e}", exc_info=True) + # Fallback result + result = ContextAnalysisResult( + is_greeting=False, + can_answer_from_context=False, + answer=None, + reasoning=f"Analysis error: {str(e)}", + ) + + # Calculate costs + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Cost tracking | Total cost: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)} | " + f"Calls: {cost_dict.get('num_calls', 0)}" + ) + + return result, cost_dict + + def _detect_greeting_type(self, query: str) -> str: + """ + Detect the type of greeting from the query text. + + Args: + query: User query string + + Returns: + Greeting type: 'thanks', 'goodbye', 'casual', or 'hello' (default) + """ + query_lower = query.lower().strip() + thanks_keywords = ["thank", "thanks", "tänan", "aitäh", "tänud"] + goodbye_keywords = ["bye", "goodbye", "nägemist", "tsau", "tšau", "head aega"] + casual_keywords = ["hei", "hey", "moi", "moikka"] + for kw in thanks_keywords: + if kw in query_lower: + return "thanks" + for kw in goodbye_keywords: + if kw in query_lower: + return "goodbye" + for kw in casual_keywords: + if kw in query_lower: + return "casual" + return "hello" + + def get_fallback_greeting_response(self, language: str = "et") -> str: + """ + Get a fallback greeting response without LLM call. + + Used when LLM-based greeting detection fails but we still want + to provide a friendly response. + + Args: + language: Language code (et, en) + + Returns: + Greeting message in the specified language + """ + greetings = { + "et": "Tere! Kuidas ma saan sind aidata?", + "en": "Hello! How can I help you?", + } + return greetings.get(language, greetings["et"]) diff --git a/src/tool_classifier/greeting_constants.py b/src/tool_classifier/greeting_constants.py new file mode 100644 index 00000000..272d6a4c --- /dev/null +++ b/src/tool_classifier/greeting_constants.py @@ -0,0 +1,40 @@ +"""Constants for greeting responses in multiple languages.""" + +from typing import Dict + +# Estonian greeting responses +GREETINGS_ET: Dict[str, str] = { + "hello": "Tere! Kuidas ma saan sind aidata?", + "goodbye": "Nägemist! Head päeva!", + "thanks": "Palun! Kui on veel küsimusi, küsi julgelt.", + "casual": "Tere! Mida ma saan sinu jaoks teha?", +} + +# English greeting responses +GREETINGS_EN: Dict[str, str] = { + "hello": "Hello! How can I help you?", + "goodbye": "Goodbye! Have a great day!", + "thanks": "You're welcome! Feel free to ask if you have more questions.", + "casual": "Hey! What can I do for you?", +} + +# Language-specific greeting mappings +GREETINGS_BY_LANGUAGE: Dict[str, Dict[str, str]] = { + "et": GREETINGS_ET, + "en": GREETINGS_EN, +} + + +def get_greeting_response(greeting_type: str = "hello", language: str = "et") -> str: + """ + Get a greeting response for a specific type and language. + + Args: + greeting_type: Type of greeting (hello, goodbye, thanks, casual) + language: Language code (et, en) + + Returns: + Greeting message in the specified language + """ + language_greetings = GREETINGS_BY_LANGUAGE.get(language, GREETINGS_EN) + return language_greetings.get(greeting_type, language_greetings["hello"]) diff --git a/src/tool_classifier/workflows/context_workflow.py b/src/tool_classifier/workflows/context_workflow.py index dc23e8bf..0aa7fb20 100644 --- a/src/tool_classifier/workflows/context_workflow.py +++ b/src/tool_classifier/workflows/context_workflow.py @@ -1,10 +1,22 @@ """Context workflow executor - Layer 2: Conversation history and greetings.""" -from typing import Any, AsyncIterator, Dict, Optional +from typing import Any, AsyncIterator, Dict, Optional, cast +import time +import dspy from loguru import logger from models.request_models import OrchestrationRequest, OrchestrationResponse from tool_classifier.base_workflow import BaseWorkflow +from tool_classifier.context_analyzer import ContextAnalyzer, ContextDetectionResult +from tool_classifier.workflows.service_workflow import LLMServiceProtocol +from src.guardrails.nemo_rails_adapter import NeMoRailsAdapter +from src.llm_orchestrator_config.llm_manager import LLMManager +from src.utils.cost_utils import get_lm_usage_since +from src.utils.language_detector import detect_language +from src.llm_orchestrator_config.llm_ochestrator_constants import ( + GUARDRAILS_BLOCKED_PHRASES, + OUTPUT_GUARDRAIL_VIOLATION_MESSAGE, +) class ContextWorkflowExecutor(BaseWorkflow): @@ -12,24 +24,231 @@ class ContextWorkflowExecutor(BaseWorkflow): Handles greetings and conversation history queries (Layer 2). Detects: - - Greetings: "Hello", "Thanks", "Goodbye" + - Greetings: "Hello", "Thanks", "Goodbye" (multilingual: Estonian, English) - History references: "What did you say earlier?", "Can you repeat that?" Uses LLM for semantic detection (multilingual), no regex patterns. - Status: SKELETON - Returns None (fallback to RAG) - TODO: Implement greeting/context detection, answer extraction, guardrails + Implementation Strategy: + 1. Detect language from user query + 2. Use ContextAnalyzer (LLM-based) to check if: + - Query is a greeting -> generate friendly response + - Query references conversation history -> extract answer + 3. If can answer -> return response + 4. Otherwise -> return None (fallback to RAG) + + Cost Tracking: + - Tracks LLM costs for context analysis + - Logs via orchestration_service.log_costs() (same as service/RAG workflows) """ - def __init__(self, llm_manager: Any): + def __init__( + self, + llm_manager: LLMManager, + orchestration_service: Optional[LLMServiceProtocol] = None, + ) -> None: """ Initialize context workflow executor. Args: llm_manager: LLM manager for context analysis + orchestration_service: Reference to LLMOrchestrationService for cost logging """ self.llm_manager = llm_manager - logger.info("Context workflow executor initialized (skeleton)") + self.orchestration_service = orchestration_service + self.context_analyzer = ContextAnalyzer(llm_manager) + logger.info("Context workflow executor initialized") + + @staticmethod + def _build_history(request: OrchestrationRequest) -> list[Dict[str, Any]]: + return [ + { + "authorRole": item.authorRole, + "message": item.message, + "timestamp": item.timestamp, + } + for item in request.conversationHistory + ] + + async def _detect( + self, + message: str, + history: list[Dict[str, Any]], + time_metric: Dict[str, float], + costs_metric: Dict[str, Dict[str, Any]], + ) -> Optional[ContextDetectionResult]: + """Phase 1: run context detection with summary fallback. + + Checks the last 10 conversation turns first. If the query cannot be + answered from those and the history exceeds 10 turns, falls back to a + summary-based check over the older turns. Returns None on error so the + caller falls through to RAG. + """ + try: + start = time.time() + ( + result, + cost, + ) = await self.context_analyzer.detect_context_with_summary_fallback( + query=message, conversation_history=history + ) + time_metric["context.detection"] = time.time() - start + costs_metric["context_detection"] = cost + return result + except Exception as e: + logger.error(f"Phase 1 detection failed: {e}", exc_info=True) + return None + + def _log_costs(self, costs_metric: Dict[str, Dict[str, Any]]) -> None: + if self.orchestration_service: + self.orchestration_service.log_costs(costs_metric) + + @staticmethod + def _is_guardrail_violation(chunk: str) -> bool: + """Return True if the chunk matches a known guardrail blocked phrase.""" + chunk_lower = chunk.strip().lower() + return any( + phrase.lower() in chunk_lower + and len(chunk_lower) <= len(phrase.lower()) + 20 + for phrase in GUARDRAILS_BLOCKED_PHRASES + ) + + async def _generate_response_async( + self, + request: OrchestrationRequest, + context_snippet: str, + time_metric: Dict[str, float], + costs_metric: Dict[str, Dict[str, Any]], + ) -> Optional[OrchestrationResponse]: + """Non-streaming: Generate response + apply output guardrails.""" + try: + start = time.time() + answer, cost = await self.context_analyzer.generate_context_response( + query=request.message, context_snippet=context_snippet + ) + time_metric["context.generation"] = time.time() - start + costs_metric["context_response"] = cost + except Exception as e: + logger.error(f"Phase 2 generation failed: {e}", exc_info=True) + self._log_costs(costs_metric) + return None + + if not answer: + logger.warning(f"[{request.chatId}] Phase 2 empty answer — fallback to RAG") + self._log_costs(costs_metric) + return None + + response = OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=answer, + ) + if self.orchestration_service: + try: + components = self.orchestration_service._initialize_service_components( + request + ) + response = await self.orchestration_service.handle_output_guardrails( + guardrails_adapter=components.get("guardrails_adapter"), + generated_response=response, + request=request, + costs_metric=costs_metric, + ) + except Exception as e: + logger.warning( + f"[{request.chatId}] Output guardrails check failed: {e}" + ) + self._log_costs(costs_metric) + return response + + async def _stream_history_generator( + self, + chat_id: str, + query: str, + context_snippet: str, + history_length_before: int, + guardrails_adapter: NeMoRailsAdapter, + costs_metric: Dict[str, Dict[str, Any]], + ) -> AsyncIterator[str]: + """Async generator: stream history answer through NeMo Guardrails.""" + bot_generator = self.context_analyzer.stream_context_response( + query=query, context_snippet=context_snippet + ) + orchestration_service = self.orchestration_service + if orchestration_service is None: + return + async for validated_chunk in guardrails_adapter.stream_with_guardrails( + user_message=query, bot_message_generator=bot_generator + ): + if isinstance(validated_chunk, str) and self._is_guardrail_violation( + validated_chunk + ): + logger.warning(f"[{chat_id}] Guardrails violation in context streaming") + yield orchestration_service.format_sse( + chat_id, OUTPUT_GUARDRAIL_VIOLATION_MESSAGE + ) + yield orchestration_service.format_sse(chat_id, "END") + costs_metric["context_response"] = get_lm_usage_since( + history_length_before + ) + orchestration_service.log_costs(costs_metric) + return + yield orchestration_service.format_sse(chat_id, validated_chunk) + yield orchestration_service.format_sse(chat_id, "END") + logger.info(f"[{chat_id}] Context streaming complete") + costs_metric["context_response"] = get_lm_usage_since(history_length_before) + orchestration_service.log_costs(costs_metric) + + async def _create_history_stream( + self, + request: OrchestrationRequest, + context_snippet: str, + costs_metric: Dict[str, Dict[str, Any]], + ) -> Optional[AsyncIterator[str]]: + """Set up guardrails adapter and return the history streaming generator.""" + if not self.orchestration_service: + logger.warning( + f"[{request.chatId}] No orchestration_service — cannot stream with guardrails" + ) + return None + try: + components = self.orchestration_service._initialize_service_components( + request + ) + guardrails_adapter = components.get("guardrails_adapter") + except Exception as e: + logger.error( + f"[{request.chatId}] Failed to initialize components: {e}", + exc_info=True, + ) + self._log_costs(costs_metric) + return None + + if not isinstance(guardrails_adapter, NeMoRailsAdapter): + logger.warning( + f"[{request.chatId}] guardrails_adapter unavailable — cannot stream" + ) + self._log_costs(costs_metric) + return None + + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception: + pass + + return self._stream_history_generator( + chat_id=request.chatId, + query=request.message, + context_snippet=context_snippet, + history_length_before=history_length_before, + guardrails_adapter=guardrails_adapter, + costs_metric=costs_metric, + ) async def execute_async( self, @@ -38,26 +257,84 @@ async def execute_async( time_metric: Optional[Dict[str, float]] = None, ) -> Optional[OrchestrationResponse]: """ - Execute context workflow in non-streaming mode. + Execute context workflow in non-streaming mode (two-phase). - TODO: Check greeting (LLM) → generate response, OR check history (last 10 turns) - → extract answer → validate with guardrails. Return None if cannot answer. - - Args: - request: Orchestration request with user query and history - context: Metadata with is_greeting, can_answer_from_history flags - time_metric: Optional timing dictionary for future timing tracking + Phase 1: Detect if query is a greeting or can be answered from history. + Phase 2: Generate response (greetings: pre-built; history: LLM + guardrails). Returns: - OrchestrationResponse with context-based answer or None to fallback + OrchestrationResponse or None to fallback to RAG """ - logger.debug( - f"[{request.chatId}] Context workflow execute_async called " - f"(not implemented - returning None)" + logger.info( + f"[{request.chatId}] CONTEXT WORKFLOW (NON-STREAMING) | " + f"Query: '{request.message[:100]}'" ) + costs_metric: Dict[str, Dict[str, Any]] = {} + if time_metric is None: + time_metric = {} - # TODO: Implement context workflow logic here - # For now, return None to trigger fallback to next layer (RAG) + language = detect_language(request.message) + history = self._build_history(request) + + # Check if analysis is pre-computed (e.g. from classifier classify step) + pre_computed = context.get("analysis_result") + if ( + pre_computed is not None + and hasattr(pre_computed, "is_greeting") + and hasattr(pre_computed, "can_answer_from_context") + ): + detection_result: ContextDetectionResult = cast( + ContextDetectionResult, pre_computed + ) + costs_metric.setdefault( + "context_detection", + {"total_cost": 0.0, "total_tokens": 0, "num_calls": 0}, + ) + else: + _detected = await self._detect( + request.message, history, time_metric, costs_metric + ) + if _detected is None: + self._log_costs(costs_metric) + context["costs_dict"] = costs_metric + return None + detection_result = _detected + + logger.info( + f"[{request.chatId}] Detection: greeting={detection_result.is_greeting} " + f"can_answer={detection_result.can_answer_from_context}" + ) + + if detection_result.is_greeting: + from src.tool_classifier.greeting_constants import get_greeting_response + + greeting = get_greeting_response( + greeting_type=detection_result.greeting_type, language=language + ) + self._log_costs(costs_metric) + context["costs_dict"] = costs_metric + return OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=greeting, + ) + + if ( + detection_result.can_answer_from_context + and detection_result.context_snippet + ): + context["costs_dict"] = costs_metric + return await self._generate_response_async( + request, detection_result.context_snippet, time_metric, costs_metric + ) + + logger.warning( + f"[{request.chatId}] Cannot answer from context — falling back to RAG" + ) + self._log_costs(costs_metric) + context["costs_dict"] = costs_metric return None async def execute_streaming( @@ -67,24 +344,66 @@ async def execute_streaming( time_metric: Optional[Dict[str, float]] = None, ) -> Optional[AsyncIterator[str]]: """ - Execute context workflow in streaming mode. + Execute context workflow in streaming mode (two-phase). - TODO: Get answer (greeting/history) → validate BEFORE streaming → chunk and - yield as SSE. Return None if cannot answer. - - Args: - request: Orchestration request with user query and history - context: Metadata with is_greeting, can_answer_from_history flags - time_metric: Optional timing dictionary for future timing tracking + Phase 1: Detect context (blocking, fast — classification only). + Phase 2: Stream answer through NeMo Guardrails (same pipeline as RAG). Returns: - AsyncIterator yielding SSE strings or None to fallback + AsyncIterator yielding SSE strings or None to fallback to RAG """ - logger.debug( - f"[{request.chatId}] Context workflow execute_streaming called " - f"(not implemented - returning None)" + logger.info( + f"[{request.chatId}] CONTEXT WORKFLOW (STREAMING) | " + f"Query: '{request.message[:100]}'" + ) + costs_metric: Dict[str, Dict[str, Any]] = {} + if time_metric is None: + time_metric = {} + + language = detect_language(request.message) + history = self._build_history(request) + + detection_result = await self._detect( + request.message, history, time_metric, costs_metric + ) + if detection_result is None: + self._log_costs(costs_metric) + return None + + logger.info( + f"[{request.chatId}] Detection: greeting={detection_result.is_greeting} " + f"can_answer={detection_result.can_answer_from_context}" ) - # TODO: Implement context streaming logic here - # For now, return None to trigger fallback to next layer (RAG) + if detection_result.is_greeting: + from src.tool_classifier.greeting_constants import get_greeting_response + + greeting = get_greeting_response( + greeting_type=detection_result.greeting_type, language=language + ) + orchestration_service = self.orchestration_service + if orchestration_service is None: + self._log_costs(costs_metric) + return None + chat_id = request.chatId + + async def _stream_greeting() -> AsyncIterator[str]: + yield orchestration_service.format_sse(chat_id, greeting) + yield orchestration_service.format_sse(chat_id, "END") + orchestration_service.log_costs(costs_metric) + + return _stream_greeting() + + if ( + detection_result.can_answer_from_context + and detection_result.context_snippet + ): + return await self._create_history_stream( + request, detection_result.context_snippet, costs_metric + ) + + logger.warning( + f"[{request.chatId}] Cannot answer from context — falling back to RAG" + ) + self._log_costs(costs_metric) return None diff --git a/src/tool_classifier/workflows/rag_workflow.py b/src/tool_classifier/workflows/rag_workflow.py index b5da35b1..9c983ced 100644 --- a/src/tool_classifier/workflows/rag_workflow.py +++ b/src/tool_classifier/workflows/rag_workflow.py @@ -64,7 +64,7 @@ async def execute_async( Args: request: Orchestration request with user query - context: Unused (RAG doesn't need classification metadata) + context: May contain pre-initialized "components" to avoid duplicate init time_metric: Optional timing dictionary from parent (for unified tracking) Returns: @@ -79,8 +79,12 @@ async def execute_async( if time_metric is None: time_metric = {} - # Initialize service components - components = self.orchestration_service._initialize_service_components(request) + # Reuse components from context if available, otherwise initialize + components = context.get("components") + if components is None: + components = self.orchestration_service._initialize_service_components( + request + ) # Call existing RAG pipeline with "rag" prefix for namespacing response = await self.orchestration_service._execute_orchestration_pipeline( @@ -105,6 +109,11 @@ async def execute_streaming( """ Execute RAG workflow in streaming mode. + Coroutine that returns an AsyncIterator so callers can safely use + ``await workflow.execute_streaming(...)`` and then iterate over the + returned stream without hitting a TypeError from awaiting an async + generator. + Delegates to existing streaming pipeline which handles: - Prompt refinement (blocking) - Chunk retrieval (blocking) @@ -118,7 +127,7 @@ async def execute_streaming( Args: request: Orchestration request with user query - context: Unused (RAG doesn't need classification metadata) + context: May contain pre-initialized "components" and "stream_ctx" time_metric: Optional timing dictionary from parent (for unified tracking) Returns: @@ -143,8 +152,7 @@ async def execute_streaming( # Get stream context from context if provided, otherwise create minimal tracking stream_ctx = context.get("stream_ctx") if stream_ctx is None: - # Create minimal stream context when called via tool classifier - # In production flow, this is provided by stream_orchestration_response + class MinimalStreamContext: """Minimal stream context for RAG workflow when called directly.""" @@ -154,25 +162,29 @@ def __init__(self, chat_id: str) -> None: self.bot_generator = None def mark_completed(self) -> None: - """No-op: Tracking handled by orchestration service.""" + # Intentionally empty: lifecycle tracking is handled by the orchestration service, not this minimal context pass def mark_cancelled(self) -> None: - """No-op: Tracking handled by orchestration service.""" + # Intentionally empty: lifecycle tracking is handled by the orchestration service, not this minimal context pass def mark_error(self, error_id: str) -> None: - """No-op: Tracking handled by orchestration service.""" + # Intentionally empty: lifecycle tracking is handled by the orchestration service, not this minimal context pass stream_ctx = MinimalStreamContext(request.chatId) - # Delegate to core RAG pipeline (bypasses classifier to avoid recursion) - async for sse_chunk in self.orchestration_service._stream_rag_pipeline( - request=request, - components=components, - stream_ctx=stream_ctx, - costs_metric=costs_metric, - time_metric=time_metric, - ): - yield sse_chunk + # Return an inner async generator so this method stays a coroutine. + # This avoids the TypeError when callers do ``await execute_streaming(...)``. + async def _stream() -> AsyncIterator[str]: + async for sse_chunk in self.orchestration_service._stream_rag_pipeline( + request=request, + components=components, + stream_ctx=stream_ctx, + costs_metric=costs_metric, + time_metric=time_metric, + ): + yield sse_chunk + + return _stream() diff --git a/src/tool_classifier/workflows/service_workflow.py b/src/tool_classifier/workflows/service_workflow.py index bb72f785..78825502 100644 --- a/src/tool_classifier/workflows/service_workflow.py +++ b/src/tool_classifier/workflows/service_workflow.py @@ -6,6 +6,7 @@ import httpx from loguru import logger +from src.guardrails.nemo_rails_adapter import NeMoRailsAdapter from src.utils.cost_utils import get_lm_usage_since from models.request_models import ( @@ -73,6 +74,22 @@ def log_costs(self, costs_metric: Dict[str, Dict[str, Any]]) -> None: """ ... + def _initialize_service_components( + self, request: OrchestrationRequest + ) -> Dict[str, Any]: + """Initialize and return service components dictionary.""" + ... + + async def handle_output_guardrails( + self, + guardrails_adapter: Optional[NeMoRailsAdapter], + generated_response: OrchestrationResponse, + request: OrchestrationRequest, + costs_metric: Dict[str, Dict[str, Any]], + ) -> OrchestrationResponse: + """Apply output guardrails to the generated response.""" + ... + class ServiceWorkflowExecutor(BaseWorkflow): """Executes external service calls via Ruuter endpoints (Layer 1).""" diff --git a/src/utils/rate_limiter.py b/src/utils/rate_limiter.py index 4b88d9d7..d86829f8 100644 --- a/src/utils/rate_limiter.py +++ b/src/utils/rate_limiter.py @@ -1,8 +1,8 @@ -"""Rate limiter for streaming endpoints with sliding window and token bucket algorithms.""" +"""Rate limiter for streaming endpoints with sliding window algorithms.""" import time from collections import defaultdict, deque -from typing import Dict, Deque, Tuple, Optional, Any +from typing import Dict, Deque, Optional, Any from threading import Lock from loguru import logger @@ -31,11 +31,11 @@ class RateLimitResult(BaseModel): class RateLimiter: """ - In-memory rate limiter with sliding window (requests/minute) and token bucket (tokens/second). + In-memory rate limiter using sliding windows for both requests and tokens. Features: - Sliding window for request rate limiting (e.g., 10 requests per minute) - - Token bucket for burst control (e.g., 100 tokens per second) + - Sliding window for token rate limiting (e.g., 40,000 tokens per minute) - Per-user tracking with authorId - Automatic cleanup of old entries to prevent memory leaks - Thread-safe operations @@ -43,7 +43,7 @@ class RateLimiter: Usage: rate_limiter = RateLimiter( requests_per_minute=10, - tokens_per_second=100 + tokens_per_minute=40_000, ) result = rate_limiter.check_rate_limit( @@ -59,28 +59,32 @@ class RateLimiter: def __init__( self, requests_per_minute: int = StreamConfig.RATE_LIMIT_REQUESTS_PER_MINUTE, - tokens_per_second: int = StreamConfig.RATE_LIMIT_TOKENS_PER_SECOND, + tokens_per_minute: int = StreamConfig.RATE_LIMIT_TOKENS_PER_MINUTE, cleanup_interval: int = StreamConfig.RATE_LIMIT_CLEANUP_INTERVAL, + token_window_seconds: int = StreamConfig.RATE_LIMIT_TOKEN_WINDOW_SECONDS, ): """ Initialize rate limiter. Args: requests_per_minute: Maximum requests per user per minute (sliding window) - tokens_per_second: Maximum tokens per user per second (token bucket) + tokens_per_minute: Maximum tokens per user per minute (sliding window) cleanup_interval: Seconds between automatic cleanup of old entries + token_window_seconds: Sliding window size in seconds for token tracking """ self.requests_per_minute = requests_per_minute - self.tokens_per_second = tokens_per_second + self.tokens_per_minute = tokens_per_minute self.cleanup_interval = cleanup_interval + self.token_window_seconds = token_window_seconds + # Scale the per-minute limit to the actual window size so the + # sliding-window comparison is consistent regardless of window length. + self.tokens_per_window = int(tokens_per_minute * token_window_seconds / 60) # Sliding window: Track request timestamps per user - # Format: {author_id: deque([timestamp1, timestamp2, ...])} self._request_history: Dict[str, Deque[float]] = defaultdict(deque) - # Token bucket: Track token consumption per user - # Format: {author_id: (last_refill_time, available_tokens)} - self._token_buckets: Dict[str, Tuple[float, float]] = {} + # Sliding window: Track token usage per user + self._token_history: Dict[str, Deque[tuple[float, int]]] = defaultdict(deque) # Thread safety self._lock = Lock() @@ -91,7 +95,7 @@ def __init__( logger.info( f"RateLimiter initialized - " f"requests_per_minute: {requests_per_minute}, " - f"tokens_per_second: {tokens_per_second}" + f"tokens_per_minute: {tokens_per_minute}" ) def check_rate_limit( @@ -121,7 +125,7 @@ def check_rate_limit( if not request_result.allowed: return request_result - # Check 2: Token bucket (tokens per second) + # Check 2: Sliding window (tokens per minute) if estimated_tokens > 0: token_result = self._check_token_limit( author_id, estimated_tokens, current_time @@ -186,12 +190,11 @@ def _check_token_limit( current_time: float, ) -> RateLimitResult: """ - Check token bucket limit. + Check sliding window token limit. - Token bucket algorithm: - - Bucket refills at constant rate (tokens_per_second) - - Burst allowed up to bucket capacity - - Request denied if insufficient tokens + Sliding window algorithm: + - Track cumulative tokens consumed within the window + - Reject if adding estimated tokens would exceed the limit Args: author_id: User identifier @@ -201,38 +204,42 @@ def _check_token_limit( Returns: RateLimitResult for token limit check """ - bucket_capacity = self.tokens_per_second - - # Get or initialize bucket for user - if author_id not in self._token_buckets: - # New user - start with full bucket - self._token_buckets[author_id] = (current_time, bucket_capacity) - - last_refill, available_tokens = self._token_buckets[author_id] - - # Refill tokens based on time elapsed - time_elapsed = current_time - last_refill - refill_amount = time_elapsed * self.tokens_per_second - available_tokens = min(bucket_capacity, available_tokens + refill_amount) - - # Check if enough tokens available - if available_tokens < estimated_tokens: - # Calculate time needed to refill enough tokens - tokens_needed = estimated_tokens - available_tokens - retry_after = int(tokens_needed / self.tokens_per_second) + 1 + token_history = self._token_history[author_id] + window_start = current_time - self.token_window_seconds + + # Remove entries outside the sliding window + while token_history and token_history[0][0] < window_start: + token_history.popleft() + + # Sum tokens consumed in the current window + current_token_usage = sum(tokens for _, tokens in token_history) + + # Check if adding this request would exceed the scaled window limit + if current_token_usage + estimated_tokens > self.tokens_per_window: + # Calculate retry_after based on oldest entry in window + if token_history: + oldest_timestamp = token_history[0][0] + retry_after = ( + int(oldest_timestamp + self.token_window_seconds - current_time) + 1 + ) + else: + retry_after = 1 logger.warning( f"Token rate limit exceeded for {author_id} - " - f"needed: {estimated_tokens}, available: {available_tokens:.0f} " - f"(retry after {retry_after}s)" + f"needed: {estimated_tokens}, " + f"current_usage: {current_token_usage}/{self.tokens_per_window} " + f"(window: {self.token_window_seconds}s, " + f"rate: {self.tokens_per_minute}/min, " + f"retry after {retry_after}s)" ) return RateLimitResult( allowed=False, retry_after=retry_after, limit_type="tokens", - current_usage=int(bucket_capacity - available_tokens), - limit=self.tokens_per_second, + current_usage=current_token_usage, + limit=self.tokens_per_window, ) return RateLimitResult(allowed=True) @@ -254,20 +261,9 @@ def _record_request( # Record request timestamp for sliding window self._request_history[author_id].append(current_time) - # Deduct tokens from bucket - if tokens_consumed > 0 and author_id in self._token_buckets: - last_refill, available_tokens = self._token_buckets[author_id] - - # Refill before deducting - time_elapsed = current_time - last_refill - refill_amount = time_elapsed * self.tokens_per_second - available_tokens = min( - self.tokens_per_second, available_tokens + refill_amount - ) - - # Deduct tokens - available_tokens -= tokens_consumed - self._token_buckets[author_id] = (current_time, available_tokens) + # Record token usage for sliding window + if tokens_consumed > 0: + self._token_history[author_id].append((current_time, tokens_consumed)) def _cleanup_old_entries(self, current_time: float) -> None: """ @@ -294,23 +290,25 @@ def _cleanup_old_entries(self, current_time: float) -> None: for author_id in users_to_remove: del self._request_history[author_id] - # Clean up token buckets (remove entries inactive for 5 minutes) - inactive_threshold = current_time - 300 - buckets_to_remove: list[str] = [] + # Clean up token history (remove entries outside window + inactive users) + token_window_start = current_time - self.token_window_seconds + token_users_to_remove: list[str] = [] - for author_id, (last_refill, _) in self._token_buckets.items(): - if last_refill < inactive_threshold: - buckets_to_remove.append(author_id) + for author_id, token_history in self._token_history.items(): + while token_history and token_history[0][0] < token_window_start: + token_history.popleft() + if not token_history: + token_users_to_remove.append(author_id) - for author_id in buckets_to_remove: - del self._token_buckets[author_id] + for author_id in token_users_to_remove: + del self._token_history[author_id] self._last_cleanup = current_time - if users_to_remove or buckets_to_remove: + if users_to_remove or token_users_to_remove: logger.debug( f"Cleaned up {len(users_to_remove)} request histories and " - f"{len(buckets_to_remove)} token buckets" + f"{len(token_users_to_remove)} token histories" ) def get_stats(self) -> Dict[str, Any]: @@ -323,9 +321,9 @@ def get_stats(self) -> Dict[str, Any]: with self._lock: return { "total_users_tracked": len(self._request_history), - "total_token_buckets": len(self._token_buckets), + "total_token_histories": len(self._token_history), "requests_per_minute_limit": self.requests_per_minute, - "tokens_per_second_limit": self.tokens_per_second, + "tokens_per_minute_limit": self.tokens_per_minute, "last_cleanup": self._last_cleanup, } @@ -339,7 +337,7 @@ def reset_user(self, author_id: str) -> None: with self._lock: if author_id in self._request_history: del self._request_history[author_id] - if author_id in self._token_buckets: - del self._token_buckets[author_id] + if author_id in self._token_history: + del self._token_history[author_id] logger.info(f"Reset rate limits for user: {author_id}") diff --git a/tests/conftest.py b/tests/conftest.py index d1633b76..e26acfc9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,3 +6,12 @@ # Add the project root to Python path so tests can import from src project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) + +# Add src directory to Python path for direct module imports +src_dir = project_root / "src" +sys.path.insert(0, str(src_dir)) + +# Add models directory (sibling to src) for backward compatibility +models_dir = project_root / "models" +if models_dir.exists(): + sys.path.insert(0, str(models_dir.parent)) diff --git a/tests/test_context_analyzer.py b/tests/test_context_analyzer.py new file mode 100644 index 00000000..094b8a47 --- /dev/null +++ b/tests/test_context_analyzer.py @@ -0,0 +1,979 @@ +"""Unit tests for context analyzer - greeting detection and context analysis.""" + +import pytest +from collections.abc import Generator +from unittest.mock import MagicMock, patch +import json +import dspy + +from src.tool_classifier.context_analyzer import ( + ContextAnalyzer, +) +from src.tool_classifier.greeting_constants import get_greeting_response + + +@pytest.fixture(autouse=True) +def mock_dspy_lm() -> Generator[MagicMock, None, None]: + """Mock DSPy LM to prevent 'No LM is loaded' errors.""" + mock_lm = MagicMock() + mock_lm.history = [] + with patch("dspy.settings") as mock_settings: + mock_settings.lm = mock_lm + # Configure DSPy with mock LM + dspy.configure(lm=mock_lm) + yield mock_lm + + +class TestContextAnalyzerInit: + """Test ContextAnalyzer initialization.""" + + def test_init_creates_analyzer(self) -> None: + """ContextAnalyzer should initialize with LLM manager.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + assert analyzer.llm_manager is llm_manager + assert analyzer._module is None + assert analyzer._summary_module is None + assert analyzer._summary_analysis_module is None + + +class TestConversationHistoryFormatting: + """Test conversation history formatting.""" + + def test_format_empty_history(self) -> None: + """Empty history should return empty JSON array.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + result = analyzer._format_conversation_history([]) + + assert result == "[]" + + def test_format_single_turn(self) -> None: + """Single conversation turn should be formatted correctly.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + history = [ + { + "authorRole": "user", + "message": "Hello", + "timestamp": "2024-01-01T12:00:00", + } + ] + + result = analyzer._format_conversation_history(history) + parsed = json.loads(result) + + assert len(parsed) == 1 + assert parsed[0]["role"] == "user" + assert parsed[0]["message"] == "Hello" + + def test_format_multiple_turns(self) -> None: + """Multiple conversation turns should be formatted correctly.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + history = [ + { + "authorRole": "user", + "message": "What is tax?", + "timestamp": "2024-01-01T12:00:00", + }, + { + "authorRole": "bot", + "message": "Tax is a mandatory financial charge.", + "timestamp": "2024-01-01T12:00:01", + }, + { + "authorRole": "user", + "message": "Thank you", + "timestamp": "2024-01-01T12:00:02", + }, + ] + + result = analyzer._format_conversation_history(history) + parsed = json.loads(result) + + assert len(parsed) == 3 + assert parsed[0]["role"] == "user" + assert parsed[1]["role"] == "bot" + assert parsed[2]["role"] == "user" + + def test_format_truncates_to_max_turns(self) -> None: + """History should be truncated to last 10 turns.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Create 15 turns + history = [ + { + "authorRole": "user" if i % 2 == 0 else "bot", + "message": f"Message {i}", + "timestamp": f"2024-01-01T12:00:{i:02d}", + } + for i in range(15) + ] + + result = analyzer._format_conversation_history(history, max_turns=10) + parsed = json.loads(result) + + assert len(parsed) == 10 + # Should have last 10 turns (indices 5-14) + assert parsed[0]["message"] == "Message 5" + assert parsed[-1]["message"] == "Message 14" + + +class TestGreetingDetection: + """Test greeting detection functionality.""" + + @pytest.mark.asyncio + async def test_detect_estonian_greeting(self) -> None: + """Should detect Estonian greeting 'Tere' and generate response.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Mock DSPy module response + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": "Tere! Kuidas ma saan sind aidata?", + "reasoning": "User said hello in Estonian", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, cost_dict = await analyzer.analyze_context( + query="Tere!", + conversation_history=[], + language="et", + ) + + assert result.is_greeting is True + assert result.can_answer_from_context is False + assert "Tere" in result.answer + assert cost_dict["total_cost"] == 0.001 + + @pytest.mark.asyncio + async def test_detect_english_greeting(self) -> None: + """Should detect English greeting 'Hello' and generate response.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Mock DSPy module response + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": "Hello! How can I help you?", + "reasoning": "User said hello in English", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, cost_dict = await analyzer.analyze_context( + query="Hello!", + conversation_history=[], + language="en", + ) + + assert result.is_greeting is True + assert "Hello" in result.answer or "hello" in result.answer.lower() + + @pytest.mark.asyncio + async def test_detect_goodbye(self) -> None: + """Should detect goodbye greeting.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": "Goodbye! Have a great day!", + "reasoning": "User said goodbye", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="Bye!", + conversation_history=[], + language="en", + ) + + assert result.is_greeting is True + + @pytest.mark.asyncio + async def test_detect_thanks(self) -> None: + """Should detect thank you greeting.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": "You're welcome! Feel free to ask if you have more questions.", + "reasoning": "User said thank you", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="Thank you!", + conversation_history=[], + language="en", + ) + + assert result.is_greeting is True + + +class TestContextBasedAnswering: + """Test context-based question answering.""" + + @pytest.mark.asyncio + async def test_answer_from_conversation_history(self) -> None: + """Should extract answer from conversation history when query references it.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + history = [ + { + "authorRole": "user", + "message": "What is the tax rate?", + "timestamp": "2024-01-01T12:00:00", + }, + { + "authorRole": "bot", + "message": "The tax rate is 20%.", + "timestamp": "2024-01-01T12:00:01", + }, + ] + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": True, + "answer": "I mentioned that the tax rate is 20%.", + "reasoning": "User is asking about previously mentioned tax rate", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.002, + "total_tokens": 100, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What was the rate you mentioned?", + conversation_history=history, + language="en", + ) + + assert result.is_greeting is False + assert result.can_answer_from_context is True + assert result.answer is not None + assert "20%" in result.answer + + @pytest.mark.asyncio + async def test_cannot_answer_from_context(self) -> None: + """Should return cannot answer when query doesn't reference history.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + history = [ + { + "authorRole": "user", + "message": "What is the weather?", + "timestamp": "2024-01-01T12:00:00", + }, + ] + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Query is about taxes, not previous weather discussion", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.002, + "total_tokens": 100, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What is the tax rate?", + conversation_history=history, + language="en", + ) + + assert result.is_greeting is False + assert result.can_answer_from_context is False + assert result.answer is None + + +class TestErrorHandling: + """Test error handling in context analyzer.""" + + @pytest.mark.asyncio + async def test_handles_llm_json_parse_error(self) -> None: + """Should handle invalid JSON response from LLM gracefully.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Mock DSPy module to return invalid JSON + mock_response = MagicMock() + mock_response.analysis_result = "Invalid JSON response" + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="Hello", + conversation_history=[], + language="en", + ) + + # Should fallback to safe default + assert result.is_greeting is False + assert result.can_answer_from_context is False + assert result.answer is None + assert "Failed to parse" in result.reasoning + + @pytest.mark.asyncio + async def test_handles_llm_exception(self) -> None: + """Should handle LLM call exceptions gracefully.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Mock DSPy module to raise exception + with patch.object( + dspy, + "ChainOfThought", + return_value=MagicMock(side_effect=Exception("LLM error")), + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.0, + "total_tokens": 0, + "num_calls": 0, + } + + result, _ = await analyzer.analyze_context( + query="Hello", + conversation_history=[], + language="en", + ) + + # Should fallback to safe default + assert result.is_greeting is False + assert result.can_answer_from_context is False + assert result.answer is None + assert "error" in result.reasoning.lower() + + +class TestFallbackGreeting: + """Test fallback greeting responses.""" + + def test_fallback_estonian_greeting(self) -> None: + """Should return Estonian fallback greeting.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + response = analyzer.get_fallback_greeting_response("et") + + assert "Tere" in response + + def test_fallback_english_greeting(self) -> None: + """Should return English fallback greeting.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + response = analyzer.get_fallback_greeting_response("en") + + assert "Hello" in response or "hello" in response + + def test_fallback_unknown_language_defaults_to_estonian(self) -> None: + """Should default to Estonian for unknown language codes.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + response = analyzer.get_fallback_greeting_response("xx") + + assert "Tere" in response or "tere" in response.lower() + + +class TestGreetingConstants: + """Test greeting constants and helper functions.""" + + def test_get_estonian_hello(self) -> None: + """Should return Estonian hello greeting.""" + response = get_greeting_response("hello", "et") + assert "Tere" in response + + def test_get_english_goodbye(self) -> None: + """Should return English goodbye greeting.""" + response = get_greeting_response("goodbye", "en") + assert "Goodbye" in response or "goodbye" in response + + def test_get_estonian_thanks(self) -> None: + """Should return Estonian thanks greeting.""" + response = get_greeting_response("thanks", "et") + assert "Palun" in response + + def test_unknown_greeting_type_defaults_to_hello(self) -> None: + """Should default to hello for unknown greeting types.""" + response = get_greeting_response("unknown", "en") + assert "Hello" in response or "hello" in response + + +def _make_history(num_turns: int) -> list[dict[str, str]]: + """Helper to create a conversation history with the specified number of turns.""" + return [ + { + "authorRole": "user" if i % 2 == 0 else "bot", + "message": f"Message {i}", + "timestamp": f"2024-01-01T12:00:{i:02d}", + } + for i in range(num_turns) + ] + + +class TestCostMerging: + """Test cost dictionary merging.""" + + def test_merge_cost_dicts(self) -> None: + """Should sum all numeric values from two cost dicts.""" + cost1 = { + "total_cost": 0.001, + "total_tokens": 50, + "total_prompt_tokens": 30, + "total_completion_tokens": 20, + "num_calls": 1, + } + cost2 = { + "total_cost": 0.002, + "total_tokens": 100, + "total_prompt_tokens": 60, + "total_completion_tokens": 40, + "num_calls": 1, + } + + merged = ContextAnalyzer._merge_cost_dicts(cost1, cost2) + + assert merged["total_cost"] == pytest.approx(0.003) + assert merged["total_tokens"] == 150 + assert merged["total_prompt_tokens"] == 90 + assert merged["total_completion_tokens"] == 60 + assert merged["num_calls"] == 2 + + def test_merge_cost_dicts_with_empty(self) -> None: + """Should handle merging with an empty cost dict.""" + cost1 = { + "total_cost": 0.001, + "total_tokens": 50, + "total_prompt_tokens": 30, + "total_completion_tokens": 20, + "num_calls": 1, + } + + merged = ContextAnalyzer._merge_cost_dicts(cost1, {}) + + assert merged["total_cost"] == 0.001 + assert merged["total_tokens"] == 50 + assert merged["num_calls"] == 1 + + +class TestConversationSummary: + """Test conversation summary generation.""" + + @pytest.mark.asyncio + async def test_generate_summary_from_older_turns(self) -> None: + """Should generate summary from older conversation turns.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + older_history = _make_history(6) + + mock_response = MagicMock() + mock_response.summary = "User discussed messages 0-5 about various topics." + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + summary, cost_dict = await analyzer._generate_conversation_summary( + older_history + ) + + assert summary == "User discussed messages 0-5 about various topics." + assert cost_dict["total_cost"] == 0.001 + + @pytest.mark.asyncio + async def test_generate_summary_handles_exception(self) -> None: + """Should return empty string when summary generation fails.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + with patch.object( + dspy, + "ChainOfThought", + return_value=MagicMock(side_effect=Exception("LLM error")), + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.0, + "total_tokens": 0, + "num_calls": 0, + } + + summary, _ = await analyzer._generate_conversation_summary( + _make_history(5) + ) + + assert summary == "" + + @pytest.mark.asyncio + async def test_analyze_from_summary_can_answer(self) -> None: + """Should answer from summary when information is available.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "can_answer_from_context": True, + "answer": "The tax rate discussed earlier was 20%.", + "reasoning": "Summary contains tax rate information", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.002, + "total_tokens": 100, + "num_calls": 1, + } + + result, cost_dict = await analyzer._analyze_from_summary( + query="What was the tax rate?", + summary="User asked about tax. Bot replied: tax rate is 20%.", + ) + + assert result.can_answer_from_context is True + assert result.answered_from_summary is True + assert result.answer is not None + assert "20%" in result.answer + + @pytest.mark.asyncio + async def test_analyze_from_summary_cannot_answer(self) -> None: + """Should return cannot answer when summary doesn't contain relevant info.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "can_answer_from_context": False, + "answer": None, + "reasoning": "Summary does not contain information about weather", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.002, + "total_tokens": 100, + "num_calls": 1, + } + + result, _ = await analyzer._analyze_from_summary( + query="What is the weather?", + summary="User discussed tax rates and filing.", + ) + + assert result.can_answer_from_context is False + assert result.answered_from_summary is False + assert result.answer is None + + @pytest.mark.asyncio + async def test_analyze_from_summary_handles_exception(self) -> None: + """Should return safe fallback when summary analysis fails.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + with patch.object( + dspy, + "ChainOfThought", + return_value=MagicMock(side_effect=Exception("LLM error")), + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.0, + "total_tokens": 0, + "num_calls": 0, + } + + result, _ = await analyzer._analyze_from_summary( + query="test", summary="test summary" + ) + + assert result.can_answer_from_context is False + assert result.answered_from_summary is False + assert result.answer is None + + +class TestSummaryFlow: + """Test the full analyze_context flow with summary logic.""" + + @pytest.mark.asyncio + async def test_short_history_skips_summary(self) -> None: + """With <= 10 turns, should use recent history only, no summary.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Cannot answer from recent history, but only 8 turns - should NOT trigger summary + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Cannot answer from context", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What is digital signature?", + conversation_history=_make_history(8), + language="en", + ) + + # Should not answer (no summary triggered for <= 10 turns) + assert result.can_answer_from_context is False + assert result.answered_from_summary is False + assert result.answer is None + + @pytest.mark.asyncio + async def test_long_history_answers_from_recent(self) -> None: + """With > 10 turns, if recent 10 can answer, should not trigger summary.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Can answer from recent history + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": True, + "answer": "The rate is 20%.", + "reasoning": "Found in recent history", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What was the rate?", + conversation_history=_make_history(15), + language="en", + ) + + assert result.can_answer_from_context is True + assert result.answered_from_summary is False + assert result.answer == "The rate is 20%." + + @pytest.mark.asyncio + async def test_long_history_answers_from_summary(self) -> None: + """With > 10 turns, if recent can't answer but summary can, should return summary answer.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Step 1: Recent history cannot answer + recent_response = MagicMock() + recent_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Not in recent history", + } + ) + + # Step 2: Summary generation + summary_response = MagicMock() + summary_response.summary = ( + "User asked about tax rates. Bot said the tax rate is 20%." + ) + + # Step 3: Summary analysis can answer + summary_analysis_response = MagicMock() + summary_analysis_response.analysis_result = json.dumps( + { + "can_answer_from_context": True, + "answer": "Based on our earlier discussion, the tax rate is 20%.", + "reasoning": "Found tax rate in conversation summary", + } + ) + + # Chain of Thought is called 3 times: recent analysis, summary gen, summary analysis + call_count = 0 + mock_modules = [ + MagicMock(return_value=recent_response), + MagicMock(return_value=summary_response), + MagicMock(return_value=summary_analysis_response), + ] + + def chain_of_thought_factory(*args: object, **kwargs: object) -> MagicMock: + nonlocal call_count + module = mock_modules[call_count] + call_count += 1 + return module + + with patch.object(dspy, "ChainOfThought", side_effect=chain_of_thought_factory): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, cost_dict = await analyzer.analyze_context( + query="What was the tax rate we discussed?", + conversation_history=_make_history(15), + language="en", + ) + + assert result.can_answer_from_context is True + assert result.answered_from_summary is True + assert result.answer is not None + assert "20%" in result.answer + # Costs should be merged from all 3 calls + assert cost_dict["num_calls"] == 3 + + @pytest.mark.asyncio + async def test_long_history_falls_to_rag(self) -> None: + """With > 10 turns, if neither recent nor summary can answer, should fall through.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Step 1: Recent history cannot answer + recent_response = MagicMock() + recent_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Not in recent history", + } + ) + + # Step 2: Summary generation + summary_response = MagicMock() + summary_response.summary = "User discussed weather and greetings." + + # Step 3: Summary analysis cannot answer + summary_analysis_response = MagicMock() + summary_analysis_response.analysis_result = json.dumps( + { + "can_answer_from_context": False, + "answer": None, + "reasoning": "Summary does not contain tax information", + } + ) + + call_count = 0 + mock_modules = [ + MagicMock(return_value=recent_response), + MagicMock(return_value=summary_response), + MagicMock(return_value=summary_analysis_response), + ] + + def chain_of_thought_factory(*args: object, **kwargs: object) -> MagicMock: + nonlocal call_count + module = mock_modules[call_count] + call_count += 1 + return module + + with patch.object(dspy, "ChainOfThought", side_effect=chain_of_thought_factory): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What is the tax rate?", + conversation_history=_make_history(15), + language="en", + ) + + # Should not be able to answer -> falls to RAG + assert result.can_answer_from_context is False + assert result.answered_from_summary is False + assert result.answer is None + + @pytest.mark.asyncio + async def test_answered_from_summary_flag_is_false_for_recent(self) -> None: + """The answered_from_summary flag should be False for recent history answers.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": True, + "answer": "The answer from recent history.", + "reasoning": "Found in recent conversation", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What did you say?", + conversation_history=_make_history(5), + language="en", + ) + + assert result.answered_from_summary is False diff --git a/tests/test_context_workflow.py b/tests/test_context_workflow.py new file mode 100644 index 00000000..1362a72d --- /dev/null +++ b/tests/test_context_workflow.py @@ -0,0 +1,698 @@ +"""Unit tests for context workflow executor.""" + +import pytest +from collections.abc import AsyncGenerator, Generator +from unittest.mock import AsyncMock, MagicMock, patch +import dspy + +from src.tool_classifier.workflows.context_workflow import ContextWorkflowExecutor +from src.tool_classifier.context_analyzer import ContextDetectionResult +from models.request_models import ( + OrchestrationRequest, + OrchestrationResponse, + ConversationItem, +) + + +@pytest.fixture +def mock_dspy_lm() -> Generator[MagicMock, None, None]: + """Mock DSPy LM to prevent 'No LM is loaded' errors.""" + mock_lm = MagicMock() + mock_lm.history = [] + with patch("dspy.settings") as mock_settings: + mock_settings.lm = mock_lm + # Configure DSPy with mock LM + dspy.configure(lm=mock_lm) + yield mock_lm + + +@pytest.fixture +def mock_orchestration_service() -> MagicMock: + """Create mock orchestration service for streaming tests.""" + import json as _json + import time as _time + + service = MagicMock() + + def _format_sse_impl(chat_id: str, content: str) -> str: + payload = { + "chatId": chat_id, + "payload": {"content": content}, + "timestamp": int(_time.time() * 1000), + } + return f"data: {_json.dumps(payload)}\n\n" + + service.format_sse = _format_sse_impl + service.log_costs = MagicMock() + return service + + +@pytest.fixture +def llm_manager() -> MagicMock: + """Create mock LLM manager.""" + return MagicMock() + + +@pytest.fixture +def context_workflow( + llm_manager: MagicMock, + mock_orchestration_service: MagicMock, + mock_dspy_lm: MagicMock, +) -> ContextWorkflowExecutor: + """Create ContextWorkflowExecutor instance.""" + return ContextWorkflowExecutor( + llm_manager, orchestration_service=mock_orchestration_service + ) + + +@pytest.fixture +def sample_request() -> OrchestrationRequest: + """Create sample orchestration request.""" + return OrchestrationRequest( + chatId="test-chat-123", + message="Hello!", + authorId="test-user", + conversationHistory=[], + url="https://example.com", + environment="testing", + connection_id="test-connection", + ) + + +class TestContextWorkflowInit: + """Test context workflow initialization.""" + + def test_init_creates_workflow(self, llm_manager: MagicMock) -> None: + """ContextWorkflowExecutor should initialize with LLM manager.""" + workflow = ContextWorkflowExecutor(llm_manager) + + assert workflow.llm_manager is llm_manager + assert workflow.context_analyzer is not None + + +class TestExecuteAsyncGreeting: + """Test execute_async with greeting queries.""" + + @pytest.mark.asyncio + async def test_execute_async_greeting_estonian( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should handle Estonian greeting and return response.""" + sample_request.message = "Tere!" + + # Mock context analyzer + mock_analysis = ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Greeting detected", + context_snippet=None, + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + context_dict = {} + response = await context_workflow.execute_async( + sample_request, context_dict + ) + + assert response is not None + assert isinstance(response, OrchestrationResponse) + assert response.chatId == "test-chat-123" + assert "Tere" in response.content + assert response.llmServiceActive is True + assert response.questionOutOfLLMScope is False + assert response.inputGuardFailed is False + + # Check cost tracking + assert "costs_dict" in context_dict + assert "context_detection" in context_dict["costs_dict"] + + @pytest.mark.asyncio + async def test_execute_async_greeting_english( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should handle English greeting and return response.""" + sample_request.message = "Hello!" + + mock_analysis = ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="English greeting detected", + context_snippet=None, + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is not None + assert "Hello" in response.content or "hello" in response.content.lower() + + +class TestExecuteAsyncContextBased: + """Test execute_async with context-based queries.""" + + @pytest.mark.asyncio + async def test_execute_async_context_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should answer from conversation history when possible.""" + # Add conversation history + sample_request.conversationHistory = [ + ConversationItem( + authorRole="user", + message="What is the tax rate?", + timestamp="2024-01-01T12:00:00", + ), + ConversationItem( + authorRole="bot", + message="The tax rate is 20%.", + timestamp="2024-01-01T12:00:01", + ), + ] + sample_request.message = "What was the rate you mentioned?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Referring to previous conversation about tax rate", + context_snippet="The tax rate is 20%.", + ) + + with ( + patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.002, "total_tokens": 100, "num_calls": 1}, + ), + ), + patch.object( + context_workflow.context_analyzer, + "generate_context_response", + new_callable=AsyncMock, + return_value=( + "The tax rate is 20%.", + {"total_cost": 0.003, "num_calls": 1}, + ), + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is not None + assert "20%" in response.content + + @pytest.mark.asyncio + async def test_execute_async_cannot_answer_from_context( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when cannot answer from context (fallback to RAG).""" + sample_request.message = "What is digital signature?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Query requires knowledge base search", + context_snippet=None, + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is None + + @pytest.mark.asyncio + async def test_execute_async_answer_is_none( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when can_answer_from_context=True but context_snippet is absent.""" + mock_analysis = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=True, + context_snippet=None, # No snippet → cannot generate answer + reasoning="No relevant snippet found in history", + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is None + + +class TestExecuteAsyncErrorHandling: + """Test error handling in execute_async.""" + + @pytest.mark.asyncio + async def test_execute_async_handles_analyzer_exception( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when context analyzer raises exception.""" + with patch.object( + context_workflow.context_analyzer, + "detect_context", + side_effect=Exception("Analysis failed"), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is None + + +class TestExecuteStreaming: + """Test execute_streaming functionality.""" + + @pytest.mark.asyncio + async def test_execute_streaming_greeting( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should stream greeting response.""" + sample_request.message = "Hello!" + + mock_analysis = ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Greeting detected", + context_snippet=None, + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is not None + + # Collect streamed chunks + chunks = [chunk async for chunk in stream_gen] + + # Should have multiple chunks + END marker + assert len(chunks) > 1 + + # Last chunk should be END marker + last_chunk = chunks[-1] + assert "END" in last_chunk + + # All chunks should be valid SSE format + for chunk in chunks: + assert chunk.startswith("data: ") + assert chunk.endswith("\n\n") + + @pytest.mark.asyncio + async def test_execute_streaming_context_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should stream context-based answer.""" + sample_request.message = "What did you say earlier?" + sample_request.conversationHistory = [ + ConversationItem( + authorRole="bot", + message="The rate is 20%.", + timestamp="2024-01-01T12:00:00", + ), + ] + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Referring to previous message", + context_snippet="I mentioned that the rate is 20%.", + ) + + async def _fake_history_stream( + *args: object, **kwargs: object + ) -> AsyncGenerator[str, None]: + yield context_workflow.orchestration_service.format_sse( + sample_request.chatId, "I mentioned that the rate is 20%." + ) + yield context_workflow.orchestration_service.format_sse( + sample_request.chatId, "END" + ) + + with ( + patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.002, "total_tokens": 100, "num_calls": 1}, + ), + ), + patch.object( + context_workflow, + "_create_history_stream", + new_callable=AsyncMock, + return_value=_fake_history_stream(), + ), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is not None + + chunks = [chunk async for chunk in stream_gen] + + assert len(chunks) > 0 + # Verify END marker + assert "END" in chunks[-1] + + @pytest.mark.asyncio + async def test_execute_streaming_cannot_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when cannot answer (fallback to RAG).""" + sample_request.message = "What is digital signature?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=False, + reasoning="Requires knowledge base", + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is None + + @pytest.mark.asyncio + async def test_execute_streaming_handles_exception( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when analyzer raises exception.""" + with patch.object( + context_workflow.context_analyzer, + "detect_context", + side_effect=Exception("Analysis failed"), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is None + + +class TestCostTracking: + """Test cost tracking functionality.""" + + @pytest.mark.asyncio + async def test_cost_tracking_in_context_dict( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should track costs in context dictionary.""" + mock_analysis = ContextDetectionResult( + is_greeting=True, + can_answer_from_context=False, + reasoning="Greeting", + ) + + cost_dict = { + "total_cost": 0.0015, + "total_tokens": 75, + "total_prompt_tokens": 50, + "total_completion_tokens": 25, + "num_calls": 1, + } + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=(mock_analysis, cost_dict), + ): + context_dict = {} + await context_workflow.execute_async(sample_request, context_dict) + + assert "costs_dict" in context_dict + assert "context_detection" in context_dict["costs_dict"] + assert context_dict["costs_dict"]["context_detection"]["total_cost"] == 0.0015 + assert context_dict["costs_dict"]["context_detection"]["total_tokens"] == 75 + + +class TestLanguageDetection: + """Test language detection integration.""" + + @pytest.mark.asyncio + async def test_detects_estonian_language( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should detect Estonian language from query.""" + sample_request.message = "Tere! Kuidas läheb?" + + mock_analysis = ContextDetectionResult( + is_greeting=True, + can_answer_from_context=False, + reasoning="Estonian greeting", + ) + + with ( + patch.object( + context_workflow.context_analyzer, "detect_context" + ) as mock_detect, + patch( + "src.tool_classifier.greeting_constants.get_greeting_response" + ) as mock_greeting, + ): + mock_detect.return_value = ( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ) + mock_greeting.return_value = "Tere! Kuidas ma saan sind aidata?" + + await context_workflow.execute_async(sample_request, {}) + + # Verify Estonian language was used for greeting response + mock_greeting.assert_called_with(greeting_type="hello", language="et") + + @pytest.mark.asyncio + async def test_detects_english_language( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should detect English language from query.""" + sample_request.message = "Hello! How are you?" + + mock_analysis = ContextDetectionResult( + is_greeting=True, + can_answer_from_context=False, + reasoning="English greeting", + ) + + with ( + patch.object( + context_workflow.context_analyzer, "detect_context" + ) as mock_detect, + patch( + "src.tool_classifier.greeting_constants.get_greeting_response" + ) as mock_greeting, + ): + mock_detect.return_value = ( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ) + mock_greeting.return_value = "Hello! How can I help you?" + + await context_workflow.execute_async(sample_request, {}) + + # Verify English language was used for greeting response + mock_greeting.assert_called_with(greeting_type="hello", language="en") + + +class TestExecuteAsyncSummaryBased: + """Test execute_async with summary-based answers.""" + + @pytest.mark.asyncio + async def test_execute_async_summary_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return response when answer comes from conversation summary.""" + sample_request.message = "What was the tax rate we discussed earlier?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Found in conversation summary", + context_snippet="Based on our earlier discussion, the tax rate is 20%.", + answered_from_summary=True, + ) + + with ( + patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.005, "total_tokens": 200, "num_calls": 3}, + ), + ), + patch.object( + context_workflow.context_analyzer, + "generate_context_response", + new_callable=AsyncMock, + return_value=( + "Based on our earlier discussion, the tax rate is 20%.", + {"total_cost": 0.003, "num_calls": 1}, + ), + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is not None + assert isinstance(response, OrchestrationResponse) + assert "20%" in response.content + assert response.llmServiceActive is True + + @pytest.mark.asyncio + async def test_execute_streaming_summary_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should stream summary-based answer correctly.""" + sample_request.message = "What was the tax rate we discussed earlier?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Found in conversation summary", + context_snippet="Based on our earlier discussion, the tax rate is 20%.", + answered_from_summary=True, + ) + + async def _fake_summary_stream( + *args: object, **kwargs: object + ) -> AsyncGenerator[str, None]: + yield context_workflow.orchestration_service.format_sse( + sample_request.chatId, "The tax rate is 20%." + ) + yield context_workflow.orchestration_service.format_sse( + sample_request.chatId, "END" + ) + + with ( + patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.005, "total_tokens": 200, "num_calls": 3}, + ), + ), + patch.object( + context_workflow, + "_create_history_stream", + new_callable=AsyncMock, + return_value=_fake_summary_stream(), + ), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is not None + + chunks = [chunk async for chunk in stream_gen] + + # Should have multiple chunks + END marker + assert len(chunks) > 1 + assert "END" in chunks[-1] + + @pytest.mark.asyncio + async def test_pre_computed_summary_analysis( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should use pre-computed summary analysis from classifier.""" + sample_request.message = "What was the tax rate?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Found in summary", + context_snippet="The tax rate is 20%.", + answered_from_summary=True, + ) + + # Pre-computed analysis (from classifier) + context = {"analysis_result": mock_analysis} + + with patch.object( + context_workflow.context_analyzer, + "generate_context_response", + new_callable=AsyncMock, + return_value=( + "The tax rate is 20%.", + {"total_cost": 0.003, "num_calls": 1}, + ), + ): + response = await context_workflow.execute_async(sample_request, context) + + assert response is not None + assert "20%" in response.content diff --git a/tests/test_context_workflow_integration.py b/tests/test_context_workflow_integration.py new file mode 100644 index 00000000..bca2af2e --- /dev/null +++ b/tests/test_context_workflow_integration.py @@ -0,0 +1,851 @@ +"""Integration tests for context workflow. + +Tests the full classify -> route -> execute chain with real component wiring. +Only the LLM layer (dspy) and RAG orchestration service are mocked. + +These tests verify: +- ToolClassifier.classify() correctly routes greetings to CONTEXT workflow +- ToolClassifier.route_to_workflow() executes the context workflow end-to-end +- Fallback from CONTEXT to RAG when context cannot answer +- Streaming mode for context workflow responses +- Cost tracking propagation through the classify -> execute chain +- Error resilience (LLM failures, JSON parse errors) +""" + +import pytest +from collections.abc import AsyncGenerator, Generator +from contextlib import AbstractContextManager +from unittest.mock import AsyncMock, MagicMock, patch +import json +import dspy + +from src.tool_classifier.classifier import ToolClassifier +from src.tool_classifier.context_analyzer import ContextDetectionResult +from src.tool_classifier.models import ClassificationResult +from src.models.request_models import ( + OrchestrationRequest, + OrchestrationResponse, + ConversationItem, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def mock_dspy_lm() -> Generator[MagicMock, None, None]: + """Mock DSPy LM to prevent 'No LM is loaded' errors.""" + mock_lm = MagicMock() + mock_lm.history = [] + with patch("dspy.settings") as mock_settings: + mock_settings.lm = mock_lm + # Configure DSPy with mock LM + dspy.configure(lm=mock_lm) + yield mock_lm + + +@pytest.fixture +def mock_orchestration_service() -> MagicMock: + """Create mock orchestration service for RAG workflow fallback.""" + import json as _json + import time as _time + + service = MagicMock() + + # Non-streaming RAG fallback returns a valid response + async def mock_execute_pipeline(**kwargs: object) -> OrchestrationResponse: + return OrchestrationResponse( + chatId=kwargs["request"].chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content="RAG fallback answer.", + ) + + service._execute_orchestration_pipeline = AsyncMock( + side_effect=mock_execute_pipeline + ) + service._initialize_service_components = MagicMock(return_value={}) + service._log_costs = MagicMock() + service.log_costs = MagicMock() + + def _format_sse_impl(chat_id: str, content: str) -> str: + payload = { + "chatId": chat_id, + "payload": {"content": content}, + "timestamp": int(_time.time() * 1000), + } + return f"data: {_json.dumps(payload)}\n\n" + + service.format_sse = _format_sse_impl + + # Streaming RAG fallback + async def mock_stream_pipeline(**kwargs: object) -> AsyncGenerator[str, None]: + yield 'data: {"chatId":"test","payload":{"content":"RAG stream"}}\n\n' + yield 'data: {"chatId":"test","payload":{"content":"END"}}\n\n' + + service._stream_rag_pipeline = mock_stream_pipeline + + return service + + +@pytest.fixture +def llm_manager() -> MagicMock: + """Create mock LLM manager.""" + return MagicMock() + + +@pytest.fixture +def classifier( + llm_manager: MagicMock, mock_orchestration_service: MagicMock +) -> ToolClassifier: + """Create a real ToolClassifier with real workflow executors.""" + return ToolClassifier( + llm_manager=llm_manager, + orchestration_service=mock_orchestration_service, + ) + + +def _make_request( + message: str, + chat_id: str = "integration-test-chat", + history: list | None = None, +) -> OrchestrationRequest: + """Helper to build an OrchestrationRequest.""" + return OrchestrationRequest( + chatId=chat_id, + message=message, + authorId="test-user", + conversationHistory=history or [], + url="https://example.com", + environment="testing", + connection_id="test-conn", + ) + + +def _mock_dspy_greeting(answer_text: str) -> AbstractContextManager[MagicMock]: + """Return a patch context manager that makes dspy return a greeting analysis.""" + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": answer_text, + "reasoning": "Greeting detected", + } + ) + return patch( + "dspy.ChainOfThought", + return_value=MagicMock(return_value=mock_response), + ) + + +def _mock_dspy_context_answer( + answer_text: str, reasoning: str = "History reference" +) -> AbstractContextManager[MagicMock]: + """Return a patch that makes dspy return a context-based answer.""" + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": True, + "answer": answer_text, + "reasoning": reasoning, + } + ) + return patch( + "dspy.ChainOfThought", + return_value=MagicMock(return_value=mock_response), + ) + + +def _mock_dspy_no_match() -> AbstractContextManager[MagicMock]: + """Return a patch that makes dspy indicate neither greeting nor context match.""" + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Requires knowledge base search", + } + ) + return patch( + "dspy.ChainOfThought", + return_value=MagicMock(return_value=mock_response), + ) + + +def _patch_cost_utils() -> AbstractContextManager[MagicMock]: + """Patch cost tracking to avoid dspy settings dependency. + + Patches at both possible module paths to handle Python's module identity + behaviour when src/ is on sys.path (module may be loaded as either + ``tool_classifier.context_analyzer`` or ``src.tool_classifier.context_analyzer``). + """ + cost_return = { + "total_cost": 0.001, + "total_tokens": 50, + "total_prompt_tokens": 30, + "total_completion_tokens": 20, + "num_calls": 1, + } + + import sys + + # Determine which module key is actually loaded at runtime + if "tool_classifier.context_analyzer" in sys.modules: + target = "tool_classifier.context_analyzer.get_lm_usage_since" + else: + target = "src.tool_classifier.context_analyzer.get_lm_usage_since" + + return patch(target, return_value=cost_return) + + +# --------------------------------------------------------------------------- +# Integration: classify -> route -> execute (non-streaming) +# --------------------------------------------------------------------------- + + +class TestClassifyAndRouteGreeting: + """Test full classify -> route chain for greeting queries.""" + + @pytest.mark.asyncio + async def test_greeting_classify_returns_context_workflow( + self, classifier: ToolClassifier + ) -> None: + """classify() should return CONTEXT workflow for greeting queries. + + With the hybrid-search classifier, classify() uses Qdrant to detect + service queries. When no service matches (or embedding fails in tests), + it falls back to CONTEXT. The analysis_result is produced later inside + the context workflow executor during route_to_workflow. + """ + with ( + _mock_dspy_greeting("Tere! Kuidas ma saan sind aidata?"), + _patch_cost_utils(), + ): + result = await classifier.classify( + query="Tere!", + conversation_history=[], + language="et", + ) + + # Hybrid classifier routes non-service queries to CONTEXT + assert result.workflow.value == "context" + # analysis_result is now populated during route_to_workflow, not classify + assert result.metadata is not None + + @pytest.mark.asyncio + async def test_greeting_end_to_end_non_streaming( + self, classifier: ToolClassifier + ) -> None: + """Full chain: classify greeting -> route to context workflow -> get response.""" + with _mock_dspy_greeting("Hello! How can I help you?"), _patch_cost_utils(): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + request = _make_request("Hello!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Greeting detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert response.chatId == "integration-test-chat" + assert "Hello" in response.content + assert response.llmServiceActive is True + assert response.questionOutOfLLMScope is False + + @pytest.mark.asyncio + async def test_estonian_greeting_end_to_end( + self, classifier: ToolClassifier + ) -> None: + """Full chain for Estonian greeting.""" + with ( + _mock_dspy_greeting("Tere! Kuidas ma saan sind aidata?"), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="Tere!", + conversation_history=[], + language="et", + ) + + request = _make_request("Tere!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Estonian greeting detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "Tere" in response.content + + @pytest.mark.asyncio + async def test_goodbye_end_to_end(self, classifier: ToolClassifier) -> None: + """Full chain for goodbye greeting.""" + with _mock_dspy_greeting("Goodbye! Have a great day!"), _patch_cost_utils(): + classification = await classifier.classify( + query="Goodbye!", + conversation_history=[], + language="en", + ) + + request = _make_request("Goodbye!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="goodbye", + can_answer_from_context=False, + reasoning="Goodbye detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "Goodbye" in response.content + + @pytest.mark.asyncio + async def test_thanks_end_to_end(self, classifier: ToolClassifier) -> None: + """Full chain for thanks greeting.""" + with ( + _mock_dspy_greeting("You're welcome! Feel free to ask more."), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="Thank you!", + conversation_history=[], + language="en", + ) + + request = _make_request("Thank you!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="thanks", + can_answer_from_context=False, + reasoning="Thanks detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "welcome" in response.content.lower() + + +class TestClassifyAndRouteContextAnswer: + """Test full classify -> route chain for context-based answers.""" + + @pytest.mark.asyncio + async def test_context_answer_end_to_end(self, classifier: ToolClassifier) -> None: + """Full chain: classify history query -> route to context -> get answer.""" + history = [ + ConversationItem( + authorRole="user", + message="What is the tax rate?", + timestamp="2024-01-01T12:00:00", + ), + ConversationItem( + authorRole="bot", + message="The tax rate is 20%.", + timestamp="2024-01-01T12:00:01", + ), + ] + + with ( + _mock_dspy_context_answer("I mentioned the tax rate is 20%."), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="What was the rate?", + conversation_history=history, + language="en", + ) + + request = _make_request("What was the rate?", history=history) + with ( + patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Tax rate referenced in history", + context_snippet="The tax rate is 20%.", + ), + {"total_cost": 0.002, "total_tokens": 100, "num_calls": 1}, + ), + ), + patch.object( + classifier.context_workflow.context_analyzer, + "generate_context_response", + new_callable=AsyncMock, + return_value=( + "I mentioned the tax rate is 20%.", + {"total_cost": 0.003, "num_calls": 1}, + ), + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert classification.workflow.value == "context" + assert isinstance(response, OrchestrationResponse) + assert "20%" in response.content + + @pytest.mark.asyncio + async def test_context_answer_with_long_history( + self, classifier: ToolClassifier + ) -> None: + """Should pass last 10 turns to the analyzer even with longer history.""" + history = [ + ConversationItem( + authorRole="user" if i % 2 == 0 else "bot", + message=f"Message {i}", + timestamp=f"2024-01-01T12:00:{i:02d}", + ) + for i in range(15) + ] + + with ( + _mock_dspy_context_answer("Based on our conversation, here's the answer."), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="What did we discuss?", + conversation_history=history, + language="en", + ) + + request = _make_request("What did we discuss?", history=history) + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert classification.workflow.value == "context" + assert isinstance(response, OrchestrationResponse) + assert response.content is not None + + +# --------------------------------------------------------------------------- +# Integration: fallback from CONTEXT to RAG +# --------------------------------------------------------------------------- + + +class TestContextToRAGFallback: + """Test that context workflow falls back to RAG when it cannot answer.""" + + @pytest.mark.asyncio + async def test_classify_defaults_to_rag_when_no_context_match( + self, classifier: ToolClassifier, mock_orchestration_service: MagicMock + ) -> None: + """When context analyzer can't answer, the full route chain ends at RAG. + + With the hybrid-search classifier, classify() returns CONTEXT for + non-service queries. The RAG fallback is triggered inside + route_to_workflow when the context workflow returns None. + """ + with _mock_dspy_no_match(), _patch_cost_utils(): + classification = await classifier.classify( + query="What is a digital signature?", + conversation_history=[], + language="en", + ) + + # Classifier routes non-service queries to CONTEXT first + assert classification.workflow.value == "context" + + # Full route: context can't answer → falls back to RAG + request = _make_request("What is a digital signature?") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "RAG" in response.content + + @pytest.mark.asyncio + async def test_fallback_to_rag_end_to_end( + self, classifier: ToolClassifier, mock_orchestration_service: MagicMock + ) -> None: + """Full chain: context can't answer -> falls back to RAG -> gets RAG response.""" + with _mock_dspy_no_match(), _patch_cost_utils(): + classification = await classifier.classify( + query="What is a digital signature?", + conversation_history=[], + language="en", + ) + + # Hybrid classifier routes to CONTEXT first; RAG is via fallback + assert classification.workflow.value == "context" + + request = _make_request("What is a digital signature?") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + # RAG mock returns "RAG fallback answer." + assert "RAG" in response.content + + @pytest.mark.asyncio + async def test_context_workflow_returns_none_triggers_rag_fallback( + self, classifier: ToolClassifier, mock_orchestration_service: MagicMock + ) -> None: + """When context workflow returns None during routing, RAG fallback is used.""" + # Force classification to CONTEXT but with an analysis that will produce None + no_answer_analysis = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=False, + answer=None, + reasoning="Cannot answer", + ) + + # Use the WorkflowType from the same module path the classifier uses + from tool_classifier.enums import WorkflowType as _WorkflowType + + forced_classification = ClassificationResult( + workflow=_WorkflowType.CONTEXT, + confidence=0.95, + metadata={"analysis_result": no_answer_analysis}, + reasoning="Forced for test", + ) + + request = _make_request("Something that context can't answer") + response = await classifier.route_to_workflow( + classification=forced_classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + # Should have fallen through to RAG + assert "RAG" in response.content + + +# --------------------------------------------------------------------------- +# Integration: streaming mode +# --------------------------------------------------------------------------- + + +class TestStreamingIntegration: + """Test the full classify -> route -> stream chain.""" + + @pytest.mark.asyncio + async def test_streaming_greeting_end_to_end( + self, classifier: ToolClassifier + ) -> None: + """Full chain: classify greeting -> route streaming -> collect SSE chunks.""" + with _mock_dspy_greeting("Hello! How can I help you?"), _patch_cost_utils(): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + request = _make_request("Hello!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Greeting detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + stream = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=True, + ) + + # Collect chunks inside the mock context so the dspy patch is active + # when the async generator body executes (lazy evaluation). + chunks = [chunk async for chunk in stream] + + # Should have content chunks + END marker + assert len(chunks) >= 2 + for chunk in chunks: + assert chunk.startswith("data: ") + assert chunk.endswith("\n\n") + + # Last chunk should contain END + last_payload = json.loads(chunks[-1][6:-2]) + assert last_payload["payload"]["content"] == "END" + + # Reconstruct content from non-END chunks + content_parts = [] + for chunk in chunks[:-1]: + payload = json.loads(chunk[6:-2]) + content_parts.append(payload["payload"]["content"]) + full_content = "".join(content_parts) + assert "Hello" in full_content + + @pytest.mark.asyncio + async def test_streaming_context_answer_end_to_end( + self, classifier: ToolClassifier + ) -> None: + """Full chain: classify history query -> route streaming -> collect answer.""" + history = [ + ConversationItem( + authorRole="bot", + message="The deadline is March 31st.", + timestamp="2024-01-01T12:00:00", + ), + ] + + with ( + _mock_dspy_context_answer("The deadline is March 31st."), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="When is the deadline?", + conversation_history=history, + language="en", + ) + + request = _make_request("When is the deadline?", history=history) + stream = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=True, + ) + + chunks = [chunk async for chunk in stream] + + assert len(chunks) >= 2 + last_payload = json.loads(chunks[-1][6:-2]) + assert last_payload["payload"]["content"] == "END" + + @pytest.mark.asyncio + async def test_streaming_fallback_to_rag( + self, classifier: ToolClassifier, mock_orchestration_service: MagicMock + ) -> None: + """Streaming: context can't answer -> falls back to RAG streaming.""" + # Force classification to CONTEXT with no answer + no_answer_analysis = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=False, + answer=None, + reasoning="Cannot answer", + ) + + from tool_classifier.enums import WorkflowType as _WorkflowType + + forced_classification = ClassificationResult( + workflow=_WorkflowType.CONTEXT, + confidence=0.95, + metadata={"analysis_result": no_answer_analysis}, + reasoning="Forced for test", + ) + + request = _make_request("Something needing RAG") + stream = await classifier.route_to_workflow( + classification=forced_classification, + request=request, + is_streaming=True, + ) + + chunks = [chunk async for chunk in stream] + + # Should have received RAG streaming output + assert len(chunks) >= 1 + + +# --------------------------------------------------------------------------- +# Integration: cost tracking across the chain +# --------------------------------------------------------------------------- + + +class TestCostTrackingIntegration: + """Test that cost data flows through the full classify -> execute chain.""" + + @pytest.mark.asyncio + async def test_costs_propagated_through_classification( + self, classifier: ToolClassifier + ) -> None: + """Cost dict from context analysis should be tracked during workflow execution. + + With the hybrid-search classifier, costs are tracked inside the context + workflow executor (execute_async/execute_streaming), not in classify(). + The cost dict is stored in the workflow's internal context dictionary. + """ + with _mock_dspy_greeting("Hello!"), _patch_cost_utils(): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + # Verify classify succeeded and routes to CONTEXT + assert classification.workflow.value == "context" + + # Execute the workflow to trigger cost tracking + request = _make_request("Hello!") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + # Verify workflow ran successfully (costs tracked internally) + assert isinstance(response, OrchestrationResponse) + assert response.chatId == "integration-test-chat" + + +# --------------------------------------------------------------------------- +# Integration: error resilience +# --------------------------------------------------------------------------- + + +class TestErrorResilience: + """Test that errors in context analysis gracefully fall back to RAG.""" + + @pytest.mark.asyncio + async def test_llm_exception_falls_back_to_rag( + self, classifier: ToolClassifier + ) -> None: + """If context analyzer LLM call raises, the route chain falls back to RAG. + + With the hybrid-search classifier, classify() returns CONTEXT for + non-service queries. When the context workflow LLM call raises, the + context workflow returns None and route_to_workflow falls back to RAG. + """ + with ( + patch( + "dspy.ChainOfThought", + return_value=MagicMock(side_effect=Exception("LLM unavailable")), + ), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + # classify() returns CONTEXT (non-service query) + assert classification.workflow.value == "context" + + # Full route: context LLM fails → falls back to RAG gracefully + request = _make_request("Hello!") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "RAG" in response.content + + @pytest.mark.asyncio + async def test_json_parse_error_falls_back_to_rag( + self, classifier: ToolClassifier + ) -> None: + """If LLM returns invalid JSON, the route chain falls back to RAG. + + JSON parse failure causes context analysis to return is_greeting=False, + answer=None. The context workflow then returns None and the fallback + chain routes to RAG. + """ + mock_response = MagicMock() + mock_response.analysis_result = "not valid json at all" + + with ( + patch( + "dspy.ChainOfThought", + return_value=MagicMock(return_value=mock_response), + ), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + # classify() returns CONTEXT (non-service query) + assert classification.workflow.value == "context" + + # Full route: JSON parse fails → context returns None → RAG fallback + request = _make_request("Hello!") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "RAG" in response.content