diff --git a/src/tool_classifier/context_analyzer.py b/src/tool_classifier/context_analyzer.py index 4572aef..3584683 100644 --- a/src/tool_classifier/context_analyzer.py +++ b/src/tool_classifier/context_analyzer.py @@ -207,7 +207,6 @@ def __init__(self, llm_manager: Any) -> None: # noqa: ANN401 # Phase 1 & 2 modules for two-phase detection+generation flow self._detection_module: Optional[dspy.Module] = None self._response_generation_module: Optional[dspy.Module] = None - self._stream_predictor: Optional[Any] = None logger.info("Context analyzer initialized") def _format_conversation_history( @@ -357,6 +356,111 @@ async def detect_context( ) return result, cost_dict + async def detect_context_with_summary_fallback( + self, + query: str, + conversation_history: List[Dict[str, Any]], + ) -> tuple[ContextDetectionResult, Dict[str, Any]]: + """ + Phase 1 with summary fallback: detect if query can be answered from history. + + Implements a 3-step flow: + 1. Check the last 10 turns via detect_context(). + 2. If cannot answer AND total history > 10 turns: + - Generate a concise summary of the older turns (everything before the last 10). + - Check whether the query can be answered from that summary. + 3. If still cannot answer, return can_answer=False (workflow falls back to RAG). + + When the summary path succeeds, the returned ContextDetectionResult has: + - can_answer_from_context=True + - answered_from_summary=True + - context_snippet set to the answer extracted from the summary, so that + Phase 2 (stream_context_response / generate_context_response) can use it + directly as the context for response generation. + + Args: + query: User query to classify + conversation_history: Full conversation history + + Returns: + Tuple of (ContextDetectionResult, cost_dict) + """ + total_turns = len(conversation_history) + + # Step 1: check the most recent 10 turns + result, cost_dict = await self.detect_context( + query=query, conversation_history=conversation_history + ) + + # If already answered or it's a greeting, return immediately + if result.is_greeting or result.can_answer_from_context: + return result, cost_dict + + # Step 2 & 3: if history exceeds 10 turns, try summary-based detection + if total_turns > 10: + logger.info( + f"History has {total_turns} turns (> 10) | " + f"Cannot answer from recent 10 | Attempting summary-based detection" + ) + older_history = conversation_history[:-10] + logger.info(f"Summarizing {len(older_history)} older turns") + + try: + summary, summary_cost = await self._generate_conversation_summary( + older_history + ) + cost_dict = self._merge_cost_dicts(cost_dict, summary_cost) + + if summary: + summary_result, analysis_cost = await self._analyze_from_summary( + query=query, summary=summary + ) + cost_dict = self._merge_cost_dicts(cost_dict, analysis_cost) + + if summary_result.can_answer_from_context and summary_result.answer: + logger.info( + f"DETECTION: Can answer from summary | " + f"Reasoning: {summary_result.reasoning}" + ) + # Surface the summary-derived answer as context_snippet so + # Phase 2 can generate a polished response from it. + return ContextDetectionResult( + is_greeting=False, + can_answer_from_context=True, + reasoning=summary_result.reasoning, + context_snippet=summary_result.answer, + answered_from_summary=True, + ), cost_dict + + logger.info( + "Cannot answer from summary either | Falling back to RAG" + ) + else: + logger.warning( + "Summary generation returned empty | Falling back to RAG" + ) + + except Exception as e: + logger.error(f"Summary-based detection failed: {e}", exc_info=True) + else: + logger.info( + f"History has {total_turns} turns (<= 10) | " + f"No summary needed | Falling back to RAG" + ) + + return result, cost_dict + + @staticmethod + def _yield_in_chunks(text: str, chunk_size: int = 5) -> list[str]: + """Split text into word-group chunks for simulated streaming.""" + words = text.split() + chunks = [] + for i in range(0, len(words), chunk_size): + group = words[i : i + chunk_size] + trailing = " " if i + chunk_size < len(words) else "" + chunks.append(" ".join(group) + trailing) + return chunks + async def stream_context_response( self, query: str, @@ -365,30 +469,39 @@ async def stream_context_response( """ Phase 2 (streaming): Stream a generated answer using DSPy native streaming. - Uses ContextResponseGenerationSignature with DSPy's streamify() so tokens - are yielded in real time — same mechanism as ResponseGeneratorAgent.stream_response(). + Creates a fresh streamify predictor per call (avoids stale StreamListener + issues that occur when the cached predictor is reused across calls). + + Fallback chain: + 1. DSPy streamify → yield StreamResponse tokens as they arrive. + 2. If no stream tokens received but final Prediction has an answer, + yield it in word-group chunks. + 3. If that is also empty, call generate_context_response() directly + and yield the result in word-group chunks. Args: query: The user query to answer context_snippet: Relevant context extracted during Phase 1 detection Yields: - Token strings as they arrive from the LLM + Token strings as they arrive from the LLM (or simulated chunks) """ logger.info(f"CONTEXT GENERATOR: Phase 2 streaming | Query: '{query[:100]}'") self.llm_manager.ensure_global_config() output_stream = None stream_started = False + prediction_answer: Optional[str] = None try: with self.llm_manager.use_task_local(): - if self._stream_predictor is None: - answer_listener = StreamListener(signature_field_name="answer") - self._stream_predictor = dspy.streamify( - dspy.Predict(ContextResponseGenerationSignature), - stream_listeners=[answer_listener], - ) - output_stream = self._stream_predictor( + # Always create a fresh StreamListener + streamified predictor so that + # the listener's internal state is clean for this call. + answer_listener = StreamListener(signature_field_name="answer") + stream_predictor: Any = dspy.streamify( + dspy.Predict(ContextResponseGenerationSignature), + stream_listeners=[answer_listener], + ) + output_stream = stream_predictor( context_snippet=context_snippet, user_query=query, ) @@ -402,11 +515,11 @@ async def stream_context_response( logger.info( "Context response streaming complete (final Prediction received)" ) + if not stream_started: + # Tokens didn't stream — extract answer from the Prediction + # directly as first fallback before leaving the LM context. + prediction_answer = getattr(chunk, "answer", "") or "" - if not stream_started: - logger.warning( - "Context streaming finished but no 'answer' tokens received." - ) except GeneratorExit: raise except Exception as e: @@ -421,6 +534,31 @@ async def stream_context_response( f"Error during context stream cleanup: {cleanup_error}" ) + if stream_started: + return + + # Fallback 1: answer was in the final Prediction but didn't stream as tokens + if prediction_answer: + logger.warning( + "Stream tokens not received — yielding answer from final Prediction in chunks." + ) + for text_chunk in self._yield_in_chunks(prediction_answer): + yield text_chunk + return + + # Fallback 2: Prediction had no answer either — call generate_context_response + logger.warning( + "No answer from streamify — falling back to generate_context_response." + ) + fallback_answer, _ = await self.generate_context_response( + query=query, context_snippet=context_snippet + ) + if fallback_answer: + for text_chunk in self._yield_in_chunks(fallback_answer): + yield text_chunk + else: + logger.error("All Phase 2 streaming fallbacks exhausted — empty response.") + async def generate_context_response( self, query: str, diff --git a/src/tool_classifier/workflows/context_workflow.py b/src/tool_classifier/workflows/context_workflow.py index 8d69675..0aa7fb2 100644 --- a/src/tool_classifier/workflows/context_workflow.py +++ b/src/tool_classifier/workflows/context_workflow.py @@ -1,6 +1,6 @@ """Context workflow executor - Layer 2: Conversation history and greetings.""" -from typing import Any, AsyncIterator, Dict, Optional +from typing import Any, AsyncIterator, Dict, Optional, cast import time import dspy from loguru import logger @@ -77,10 +77,19 @@ async def _detect( time_metric: Dict[str, float], costs_metric: Dict[str, Dict[str, Any]], ) -> Optional[ContextDetectionResult]: - """Phase 1: run context detection. Returns ContextDetectionResult or None on error.""" + """Phase 1: run context detection with summary fallback. + + Checks the last 10 conversation turns first. If the query cannot be + answered from those and the history exceeds 10 turns, falls back to a + summary-based check over the older turns. Returns None on error so the + caller falls through to RAG. + """ try: start = time.time() - result, cost = await self.context_analyzer.detect_context( + ( + result, + cost, + ) = await self.context_analyzer.detect_context_with_summary_fallback( query=message, conversation_history=history ) time_metric["context.detection"] = time.time() - start @@ -267,12 +276,29 @@ async def execute_async( language = detect_language(request.message) history = self._build_history(request) - detection_result = await self._detect( - request.message, history, time_metric, costs_metric - ) - if detection_result is None: - self._log_costs(costs_metric) - return None + # Check if analysis is pre-computed (e.g. from classifier classify step) + pre_computed = context.get("analysis_result") + if ( + pre_computed is not None + and hasattr(pre_computed, "is_greeting") + and hasattr(pre_computed, "can_answer_from_context") + ): + detection_result: ContextDetectionResult = cast( + ContextDetectionResult, pre_computed + ) + costs_metric.setdefault( + "context_detection", + {"total_cost": 0.0, "total_tokens": 0, "num_calls": 0}, + ) + else: + _detected = await self._detect( + request.message, history, time_metric, costs_metric + ) + if _detected is None: + self._log_costs(costs_metric) + context["costs_dict"] = costs_metric + return None + detection_result = _detected logger.info( f"[{request.chatId}] Detection: greeting={detection_result.is_greeting} " @@ -286,6 +312,7 @@ async def execute_async( greeting_type=detection_result.greeting_type, language=language ) self._log_costs(costs_metric) + context["costs_dict"] = costs_metric return OrchestrationResponse( chatId=request.chatId, llmServiceActive=True, @@ -298,6 +325,7 @@ async def execute_async( detection_result.can_answer_from_context and detection_result.context_snippet ): + context["costs_dict"] = costs_metric return await self._generate_response_async( request, detection_result.context_snippet, time_metric, costs_metric ) @@ -306,6 +334,7 @@ async def execute_async( f"[{request.chatId}] Cannot answer from context — falling back to RAG" ) self._log_costs(costs_metric) + context["costs_dict"] = costs_metric return None async def execute_streaming( diff --git a/tests/conftest.py b/tests/conftest.py index d1633b7..e26acfc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,3 +6,12 @@ # Add the project root to Python path so tests can import from src project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) + +# Add src directory to Python path for direct module imports +src_dir = project_root / "src" +sys.path.insert(0, str(src_dir)) + +# Add models directory (sibling to src) for backward compatibility +models_dir = project_root / "models" +if models_dir.exists(): + sys.path.insert(0, str(models_dir.parent)) diff --git a/tests/test_context_analyzer.py b/tests/test_context_analyzer.py new file mode 100644 index 0000000..094b8a4 --- /dev/null +++ b/tests/test_context_analyzer.py @@ -0,0 +1,979 @@ +"""Unit tests for context analyzer - greeting detection and context analysis.""" + +import pytest +from collections.abc import Generator +from unittest.mock import MagicMock, patch +import json +import dspy + +from src.tool_classifier.context_analyzer import ( + ContextAnalyzer, +) +from src.tool_classifier.greeting_constants import get_greeting_response + + +@pytest.fixture(autouse=True) +def mock_dspy_lm() -> Generator[MagicMock, None, None]: + """Mock DSPy LM to prevent 'No LM is loaded' errors.""" + mock_lm = MagicMock() + mock_lm.history = [] + with patch("dspy.settings") as mock_settings: + mock_settings.lm = mock_lm + # Configure DSPy with mock LM + dspy.configure(lm=mock_lm) + yield mock_lm + + +class TestContextAnalyzerInit: + """Test ContextAnalyzer initialization.""" + + def test_init_creates_analyzer(self) -> None: + """ContextAnalyzer should initialize with LLM manager.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + assert analyzer.llm_manager is llm_manager + assert analyzer._module is None + assert analyzer._summary_module is None + assert analyzer._summary_analysis_module is None + + +class TestConversationHistoryFormatting: + """Test conversation history formatting.""" + + def test_format_empty_history(self) -> None: + """Empty history should return empty JSON array.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + result = analyzer._format_conversation_history([]) + + assert result == "[]" + + def test_format_single_turn(self) -> None: + """Single conversation turn should be formatted correctly.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + history = [ + { + "authorRole": "user", + "message": "Hello", + "timestamp": "2024-01-01T12:00:00", + } + ] + + result = analyzer._format_conversation_history(history) + parsed = json.loads(result) + + assert len(parsed) == 1 + assert parsed[0]["role"] == "user" + assert parsed[0]["message"] == "Hello" + + def test_format_multiple_turns(self) -> None: + """Multiple conversation turns should be formatted correctly.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + history = [ + { + "authorRole": "user", + "message": "What is tax?", + "timestamp": "2024-01-01T12:00:00", + }, + { + "authorRole": "bot", + "message": "Tax is a mandatory financial charge.", + "timestamp": "2024-01-01T12:00:01", + }, + { + "authorRole": "user", + "message": "Thank you", + "timestamp": "2024-01-01T12:00:02", + }, + ] + + result = analyzer._format_conversation_history(history) + parsed = json.loads(result) + + assert len(parsed) == 3 + assert parsed[0]["role"] == "user" + assert parsed[1]["role"] == "bot" + assert parsed[2]["role"] == "user" + + def test_format_truncates_to_max_turns(self) -> None: + """History should be truncated to last 10 turns.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Create 15 turns + history = [ + { + "authorRole": "user" if i % 2 == 0 else "bot", + "message": f"Message {i}", + "timestamp": f"2024-01-01T12:00:{i:02d}", + } + for i in range(15) + ] + + result = analyzer._format_conversation_history(history, max_turns=10) + parsed = json.loads(result) + + assert len(parsed) == 10 + # Should have last 10 turns (indices 5-14) + assert parsed[0]["message"] == "Message 5" + assert parsed[-1]["message"] == "Message 14" + + +class TestGreetingDetection: + """Test greeting detection functionality.""" + + @pytest.mark.asyncio + async def test_detect_estonian_greeting(self) -> None: + """Should detect Estonian greeting 'Tere' and generate response.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Mock DSPy module response + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": "Tere! Kuidas ma saan sind aidata?", + "reasoning": "User said hello in Estonian", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, cost_dict = await analyzer.analyze_context( + query="Tere!", + conversation_history=[], + language="et", + ) + + assert result.is_greeting is True + assert result.can_answer_from_context is False + assert "Tere" in result.answer + assert cost_dict["total_cost"] == 0.001 + + @pytest.mark.asyncio + async def test_detect_english_greeting(self) -> None: + """Should detect English greeting 'Hello' and generate response.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Mock DSPy module response + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": "Hello! How can I help you?", + "reasoning": "User said hello in English", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, cost_dict = await analyzer.analyze_context( + query="Hello!", + conversation_history=[], + language="en", + ) + + assert result.is_greeting is True + assert "Hello" in result.answer or "hello" in result.answer.lower() + + @pytest.mark.asyncio + async def test_detect_goodbye(self) -> None: + """Should detect goodbye greeting.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": "Goodbye! Have a great day!", + "reasoning": "User said goodbye", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="Bye!", + conversation_history=[], + language="en", + ) + + assert result.is_greeting is True + + @pytest.mark.asyncio + async def test_detect_thanks(self) -> None: + """Should detect thank you greeting.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": "You're welcome! Feel free to ask if you have more questions.", + "reasoning": "User said thank you", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="Thank you!", + conversation_history=[], + language="en", + ) + + assert result.is_greeting is True + + +class TestContextBasedAnswering: + """Test context-based question answering.""" + + @pytest.mark.asyncio + async def test_answer_from_conversation_history(self) -> None: + """Should extract answer from conversation history when query references it.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + history = [ + { + "authorRole": "user", + "message": "What is the tax rate?", + "timestamp": "2024-01-01T12:00:00", + }, + { + "authorRole": "bot", + "message": "The tax rate is 20%.", + "timestamp": "2024-01-01T12:00:01", + }, + ] + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": True, + "answer": "I mentioned that the tax rate is 20%.", + "reasoning": "User is asking about previously mentioned tax rate", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.002, + "total_tokens": 100, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What was the rate you mentioned?", + conversation_history=history, + language="en", + ) + + assert result.is_greeting is False + assert result.can_answer_from_context is True + assert result.answer is not None + assert "20%" in result.answer + + @pytest.mark.asyncio + async def test_cannot_answer_from_context(self) -> None: + """Should return cannot answer when query doesn't reference history.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + history = [ + { + "authorRole": "user", + "message": "What is the weather?", + "timestamp": "2024-01-01T12:00:00", + }, + ] + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Query is about taxes, not previous weather discussion", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.002, + "total_tokens": 100, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What is the tax rate?", + conversation_history=history, + language="en", + ) + + assert result.is_greeting is False + assert result.can_answer_from_context is False + assert result.answer is None + + +class TestErrorHandling: + """Test error handling in context analyzer.""" + + @pytest.mark.asyncio + async def test_handles_llm_json_parse_error(self) -> None: + """Should handle invalid JSON response from LLM gracefully.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Mock DSPy module to return invalid JSON + mock_response = MagicMock() + mock_response.analysis_result = "Invalid JSON response" + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="Hello", + conversation_history=[], + language="en", + ) + + # Should fallback to safe default + assert result.is_greeting is False + assert result.can_answer_from_context is False + assert result.answer is None + assert "Failed to parse" in result.reasoning + + @pytest.mark.asyncio + async def test_handles_llm_exception(self) -> None: + """Should handle LLM call exceptions gracefully.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Mock DSPy module to raise exception + with patch.object( + dspy, + "ChainOfThought", + return_value=MagicMock(side_effect=Exception("LLM error")), + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.0, + "total_tokens": 0, + "num_calls": 0, + } + + result, _ = await analyzer.analyze_context( + query="Hello", + conversation_history=[], + language="en", + ) + + # Should fallback to safe default + assert result.is_greeting is False + assert result.can_answer_from_context is False + assert result.answer is None + assert "error" in result.reasoning.lower() + + +class TestFallbackGreeting: + """Test fallback greeting responses.""" + + def test_fallback_estonian_greeting(self) -> None: + """Should return Estonian fallback greeting.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + response = analyzer.get_fallback_greeting_response("et") + + assert "Tere" in response + + def test_fallback_english_greeting(self) -> None: + """Should return English fallback greeting.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + response = analyzer.get_fallback_greeting_response("en") + + assert "Hello" in response or "hello" in response + + def test_fallback_unknown_language_defaults_to_estonian(self) -> None: + """Should default to Estonian for unknown language codes.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + response = analyzer.get_fallback_greeting_response("xx") + + assert "Tere" in response or "tere" in response.lower() + + +class TestGreetingConstants: + """Test greeting constants and helper functions.""" + + def test_get_estonian_hello(self) -> None: + """Should return Estonian hello greeting.""" + response = get_greeting_response("hello", "et") + assert "Tere" in response + + def test_get_english_goodbye(self) -> None: + """Should return English goodbye greeting.""" + response = get_greeting_response("goodbye", "en") + assert "Goodbye" in response or "goodbye" in response + + def test_get_estonian_thanks(self) -> None: + """Should return Estonian thanks greeting.""" + response = get_greeting_response("thanks", "et") + assert "Palun" in response + + def test_unknown_greeting_type_defaults_to_hello(self) -> None: + """Should default to hello for unknown greeting types.""" + response = get_greeting_response("unknown", "en") + assert "Hello" in response or "hello" in response + + +def _make_history(num_turns: int) -> list[dict[str, str]]: + """Helper to create a conversation history with the specified number of turns.""" + return [ + { + "authorRole": "user" if i % 2 == 0 else "bot", + "message": f"Message {i}", + "timestamp": f"2024-01-01T12:00:{i:02d}", + } + for i in range(num_turns) + ] + + +class TestCostMerging: + """Test cost dictionary merging.""" + + def test_merge_cost_dicts(self) -> None: + """Should sum all numeric values from two cost dicts.""" + cost1 = { + "total_cost": 0.001, + "total_tokens": 50, + "total_prompt_tokens": 30, + "total_completion_tokens": 20, + "num_calls": 1, + } + cost2 = { + "total_cost": 0.002, + "total_tokens": 100, + "total_prompt_tokens": 60, + "total_completion_tokens": 40, + "num_calls": 1, + } + + merged = ContextAnalyzer._merge_cost_dicts(cost1, cost2) + + assert merged["total_cost"] == pytest.approx(0.003) + assert merged["total_tokens"] == 150 + assert merged["total_prompt_tokens"] == 90 + assert merged["total_completion_tokens"] == 60 + assert merged["num_calls"] == 2 + + def test_merge_cost_dicts_with_empty(self) -> None: + """Should handle merging with an empty cost dict.""" + cost1 = { + "total_cost": 0.001, + "total_tokens": 50, + "total_prompt_tokens": 30, + "total_completion_tokens": 20, + "num_calls": 1, + } + + merged = ContextAnalyzer._merge_cost_dicts(cost1, {}) + + assert merged["total_cost"] == 0.001 + assert merged["total_tokens"] == 50 + assert merged["num_calls"] == 1 + + +class TestConversationSummary: + """Test conversation summary generation.""" + + @pytest.mark.asyncio + async def test_generate_summary_from_older_turns(self) -> None: + """Should generate summary from older conversation turns.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + older_history = _make_history(6) + + mock_response = MagicMock() + mock_response.summary = "User discussed messages 0-5 about various topics." + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + summary, cost_dict = await analyzer._generate_conversation_summary( + older_history + ) + + assert summary == "User discussed messages 0-5 about various topics." + assert cost_dict["total_cost"] == 0.001 + + @pytest.mark.asyncio + async def test_generate_summary_handles_exception(self) -> None: + """Should return empty string when summary generation fails.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + with patch.object( + dspy, + "ChainOfThought", + return_value=MagicMock(side_effect=Exception("LLM error")), + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.0, + "total_tokens": 0, + "num_calls": 0, + } + + summary, _ = await analyzer._generate_conversation_summary( + _make_history(5) + ) + + assert summary == "" + + @pytest.mark.asyncio + async def test_analyze_from_summary_can_answer(self) -> None: + """Should answer from summary when information is available.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "can_answer_from_context": True, + "answer": "The tax rate discussed earlier was 20%.", + "reasoning": "Summary contains tax rate information", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.002, + "total_tokens": 100, + "num_calls": 1, + } + + result, cost_dict = await analyzer._analyze_from_summary( + query="What was the tax rate?", + summary="User asked about tax. Bot replied: tax rate is 20%.", + ) + + assert result.can_answer_from_context is True + assert result.answered_from_summary is True + assert result.answer is not None + assert "20%" in result.answer + + @pytest.mark.asyncio + async def test_analyze_from_summary_cannot_answer(self) -> None: + """Should return cannot answer when summary doesn't contain relevant info.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "can_answer_from_context": False, + "answer": None, + "reasoning": "Summary does not contain information about weather", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.002, + "total_tokens": 100, + "num_calls": 1, + } + + result, _ = await analyzer._analyze_from_summary( + query="What is the weather?", + summary="User discussed tax rates and filing.", + ) + + assert result.can_answer_from_context is False + assert result.answered_from_summary is False + assert result.answer is None + + @pytest.mark.asyncio + async def test_analyze_from_summary_handles_exception(self) -> None: + """Should return safe fallback when summary analysis fails.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + with patch.object( + dspy, + "ChainOfThought", + return_value=MagicMock(side_effect=Exception("LLM error")), + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.0, + "total_tokens": 0, + "num_calls": 0, + } + + result, _ = await analyzer._analyze_from_summary( + query="test", summary="test summary" + ) + + assert result.can_answer_from_context is False + assert result.answered_from_summary is False + assert result.answer is None + + +class TestSummaryFlow: + """Test the full analyze_context flow with summary logic.""" + + @pytest.mark.asyncio + async def test_short_history_skips_summary(self) -> None: + """With <= 10 turns, should use recent history only, no summary.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Cannot answer from recent history, but only 8 turns - should NOT trigger summary + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Cannot answer from context", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What is digital signature?", + conversation_history=_make_history(8), + language="en", + ) + + # Should not answer (no summary triggered for <= 10 turns) + assert result.can_answer_from_context is False + assert result.answered_from_summary is False + assert result.answer is None + + @pytest.mark.asyncio + async def test_long_history_answers_from_recent(self) -> None: + """With > 10 turns, if recent 10 can answer, should not trigger summary.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Can answer from recent history + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": True, + "answer": "The rate is 20%.", + "reasoning": "Found in recent history", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What was the rate?", + conversation_history=_make_history(15), + language="en", + ) + + assert result.can_answer_from_context is True + assert result.answered_from_summary is False + assert result.answer == "The rate is 20%." + + @pytest.mark.asyncio + async def test_long_history_answers_from_summary(self) -> None: + """With > 10 turns, if recent can't answer but summary can, should return summary answer.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Step 1: Recent history cannot answer + recent_response = MagicMock() + recent_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Not in recent history", + } + ) + + # Step 2: Summary generation + summary_response = MagicMock() + summary_response.summary = ( + "User asked about tax rates. Bot said the tax rate is 20%." + ) + + # Step 3: Summary analysis can answer + summary_analysis_response = MagicMock() + summary_analysis_response.analysis_result = json.dumps( + { + "can_answer_from_context": True, + "answer": "Based on our earlier discussion, the tax rate is 20%.", + "reasoning": "Found tax rate in conversation summary", + } + ) + + # Chain of Thought is called 3 times: recent analysis, summary gen, summary analysis + call_count = 0 + mock_modules = [ + MagicMock(return_value=recent_response), + MagicMock(return_value=summary_response), + MagicMock(return_value=summary_analysis_response), + ] + + def chain_of_thought_factory(*args: object, **kwargs: object) -> MagicMock: + nonlocal call_count + module = mock_modules[call_count] + call_count += 1 + return module + + with patch.object(dspy, "ChainOfThought", side_effect=chain_of_thought_factory): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, cost_dict = await analyzer.analyze_context( + query="What was the tax rate we discussed?", + conversation_history=_make_history(15), + language="en", + ) + + assert result.can_answer_from_context is True + assert result.answered_from_summary is True + assert result.answer is not None + assert "20%" in result.answer + # Costs should be merged from all 3 calls + assert cost_dict["num_calls"] == 3 + + @pytest.mark.asyncio + async def test_long_history_falls_to_rag(self) -> None: + """With > 10 turns, if neither recent nor summary can answer, should fall through.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + # Step 1: Recent history cannot answer + recent_response = MagicMock() + recent_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Not in recent history", + } + ) + + # Step 2: Summary generation + summary_response = MagicMock() + summary_response.summary = "User discussed weather and greetings." + + # Step 3: Summary analysis cannot answer + summary_analysis_response = MagicMock() + summary_analysis_response.analysis_result = json.dumps( + { + "can_answer_from_context": False, + "answer": None, + "reasoning": "Summary does not contain tax information", + } + ) + + call_count = 0 + mock_modules = [ + MagicMock(return_value=recent_response), + MagicMock(return_value=summary_response), + MagicMock(return_value=summary_analysis_response), + ] + + def chain_of_thought_factory(*args: object, **kwargs: object) -> MagicMock: + nonlocal call_count + module = mock_modules[call_count] + call_count += 1 + return module + + with patch.object(dspy, "ChainOfThought", side_effect=chain_of_thought_factory): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What is the tax rate?", + conversation_history=_make_history(15), + language="en", + ) + + # Should not be able to answer -> falls to RAG + assert result.can_answer_from_context is False + assert result.answered_from_summary is False + assert result.answer is None + + @pytest.mark.asyncio + async def test_answered_from_summary_flag_is_false_for_recent(self) -> None: + """The answered_from_summary flag should be False for recent history answers.""" + llm_manager = MagicMock() + analyzer = ContextAnalyzer(llm_manager) + + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": True, + "answer": "The answer from recent history.", + "reasoning": "Found in recent conversation", + } + ) + + with patch.object( + dspy, "ChainOfThought", return_value=MagicMock(return_value=mock_response) + ): + with patch( + "src.tool_classifier.context_analyzer.get_lm_usage_since" + ) as mock_cost: + mock_cost.return_value = { + "total_cost": 0.001, + "total_tokens": 50, + "num_calls": 1, + } + + result, _ = await analyzer.analyze_context( + query="What did you say?", + conversation_history=_make_history(5), + language="en", + ) + + assert result.answered_from_summary is False diff --git a/tests/test_context_workflow.py b/tests/test_context_workflow.py new file mode 100644 index 0000000..1362a72 --- /dev/null +++ b/tests/test_context_workflow.py @@ -0,0 +1,698 @@ +"""Unit tests for context workflow executor.""" + +import pytest +from collections.abc import AsyncGenerator, Generator +from unittest.mock import AsyncMock, MagicMock, patch +import dspy + +from src.tool_classifier.workflows.context_workflow import ContextWorkflowExecutor +from src.tool_classifier.context_analyzer import ContextDetectionResult +from models.request_models import ( + OrchestrationRequest, + OrchestrationResponse, + ConversationItem, +) + + +@pytest.fixture +def mock_dspy_lm() -> Generator[MagicMock, None, None]: + """Mock DSPy LM to prevent 'No LM is loaded' errors.""" + mock_lm = MagicMock() + mock_lm.history = [] + with patch("dspy.settings") as mock_settings: + mock_settings.lm = mock_lm + # Configure DSPy with mock LM + dspy.configure(lm=mock_lm) + yield mock_lm + + +@pytest.fixture +def mock_orchestration_service() -> MagicMock: + """Create mock orchestration service for streaming tests.""" + import json as _json + import time as _time + + service = MagicMock() + + def _format_sse_impl(chat_id: str, content: str) -> str: + payload = { + "chatId": chat_id, + "payload": {"content": content}, + "timestamp": int(_time.time() * 1000), + } + return f"data: {_json.dumps(payload)}\n\n" + + service.format_sse = _format_sse_impl + service.log_costs = MagicMock() + return service + + +@pytest.fixture +def llm_manager() -> MagicMock: + """Create mock LLM manager.""" + return MagicMock() + + +@pytest.fixture +def context_workflow( + llm_manager: MagicMock, + mock_orchestration_service: MagicMock, + mock_dspy_lm: MagicMock, +) -> ContextWorkflowExecutor: + """Create ContextWorkflowExecutor instance.""" + return ContextWorkflowExecutor( + llm_manager, orchestration_service=mock_orchestration_service + ) + + +@pytest.fixture +def sample_request() -> OrchestrationRequest: + """Create sample orchestration request.""" + return OrchestrationRequest( + chatId="test-chat-123", + message="Hello!", + authorId="test-user", + conversationHistory=[], + url="https://example.com", + environment="testing", + connection_id="test-connection", + ) + + +class TestContextWorkflowInit: + """Test context workflow initialization.""" + + def test_init_creates_workflow(self, llm_manager: MagicMock) -> None: + """ContextWorkflowExecutor should initialize with LLM manager.""" + workflow = ContextWorkflowExecutor(llm_manager) + + assert workflow.llm_manager is llm_manager + assert workflow.context_analyzer is not None + + +class TestExecuteAsyncGreeting: + """Test execute_async with greeting queries.""" + + @pytest.mark.asyncio + async def test_execute_async_greeting_estonian( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should handle Estonian greeting and return response.""" + sample_request.message = "Tere!" + + # Mock context analyzer + mock_analysis = ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Greeting detected", + context_snippet=None, + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + context_dict = {} + response = await context_workflow.execute_async( + sample_request, context_dict + ) + + assert response is not None + assert isinstance(response, OrchestrationResponse) + assert response.chatId == "test-chat-123" + assert "Tere" in response.content + assert response.llmServiceActive is True + assert response.questionOutOfLLMScope is False + assert response.inputGuardFailed is False + + # Check cost tracking + assert "costs_dict" in context_dict + assert "context_detection" in context_dict["costs_dict"] + + @pytest.mark.asyncio + async def test_execute_async_greeting_english( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should handle English greeting and return response.""" + sample_request.message = "Hello!" + + mock_analysis = ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="English greeting detected", + context_snippet=None, + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is not None + assert "Hello" in response.content or "hello" in response.content.lower() + + +class TestExecuteAsyncContextBased: + """Test execute_async with context-based queries.""" + + @pytest.mark.asyncio + async def test_execute_async_context_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should answer from conversation history when possible.""" + # Add conversation history + sample_request.conversationHistory = [ + ConversationItem( + authorRole="user", + message="What is the tax rate?", + timestamp="2024-01-01T12:00:00", + ), + ConversationItem( + authorRole="bot", + message="The tax rate is 20%.", + timestamp="2024-01-01T12:00:01", + ), + ] + sample_request.message = "What was the rate you mentioned?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Referring to previous conversation about tax rate", + context_snippet="The tax rate is 20%.", + ) + + with ( + patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.002, "total_tokens": 100, "num_calls": 1}, + ), + ), + patch.object( + context_workflow.context_analyzer, + "generate_context_response", + new_callable=AsyncMock, + return_value=( + "The tax rate is 20%.", + {"total_cost": 0.003, "num_calls": 1}, + ), + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is not None + assert "20%" in response.content + + @pytest.mark.asyncio + async def test_execute_async_cannot_answer_from_context( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when cannot answer from context (fallback to RAG).""" + sample_request.message = "What is digital signature?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Query requires knowledge base search", + context_snippet=None, + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is None + + @pytest.mark.asyncio + async def test_execute_async_answer_is_none( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when can_answer_from_context=True but context_snippet is absent.""" + mock_analysis = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=True, + context_snippet=None, # No snippet → cannot generate answer + reasoning="No relevant snippet found in history", + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is None + + +class TestExecuteAsyncErrorHandling: + """Test error handling in execute_async.""" + + @pytest.mark.asyncio + async def test_execute_async_handles_analyzer_exception( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when context analyzer raises exception.""" + with patch.object( + context_workflow.context_analyzer, + "detect_context", + side_effect=Exception("Analysis failed"), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is None + + +class TestExecuteStreaming: + """Test execute_streaming functionality.""" + + @pytest.mark.asyncio + async def test_execute_streaming_greeting( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should stream greeting response.""" + sample_request.message = "Hello!" + + mock_analysis = ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Greeting detected", + context_snippet=None, + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is not None + + # Collect streamed chunks + chunks = [chunk async for chunk in stream_gen] + + # Should have multiple chunks + END marker + assert len(chunks) > 1 + + # Last chunk should be END marker + last_chunk = chunks[-1] + assert "END" in last_chunk + + # All chunks should be valid SSE format + for chunk in chunks: + assert chunk.startswith("data: ") + assert chunk.endswith("\n\n") + + @pytest.mark.asyncio + async def test_execute_streaming_context_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should stream context-based answer.""" + sample_request.message = "What did you say earlier?" + sample_request.conversationHistory = [ + ConversationItem( + authorRole="bot", + message="The rate is 20%.", + timestamp="2024-01-01T12:00:00", + ), + ] + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Referring to previous message", + context_snippet="I mentioned that the rate is 20%.", + ) + + async def _fake_history_stream( + *args: object, **kwargs: object + ) -> AsyncGenerator[str, None]: + yield context_workflow.orchestration_service.format_sse( + sample_request.chatId, "I mentioned that the rate is 20%." + ) + yield context_workflow.orchestration_service.format_sse( + sample_request.chatId, "END" + ) + + with ( + patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.002, "total_tokens": 100, "num_calls": 1}, + ), + ), + patch.object( + context_workflow, + "_create_history_stream", + new_callable=AsyncMock, + return_value=_fake_history_stream(), + ), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is not None + + chunks = [chunk async for chunk in stream_gen] + + assert len(chunks) > 0 + # Verify END marker + assert "END" in chunks[-1] + + @pytest.mark.asyncio + async def test_execute_streaming_cannot_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when cannot answer (fallback to RAG).""" + sample_request.message = "What is digital signature?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=False, + reasoning="Requires knowledge base", + ) + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is None + + @pytest.mark.asyncio + async def test_execute_streaming_handles_exception( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return None when analyzer raises exception.""" + with patch.object( + context_workflow.context_analyzer, + "detect_context", + side_effect=Exception("Analysis failed"), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is None + + +class TestCostTracking: + """Test cost tracking functionality.""" + + @pytest.mark.asyncio + async def test_cost_tracking_in_context_dict( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should track costs in context dictionary.""" + mock_analysis = ContextDetectionResult( + is_greeting=True, + can_answer_from_context=False, + reasoning="Greeting", + ) + + cost_dict = { + "total_cost": 0.0015, + "total_tokens": 75, + "total_prompt_tokens": 50, + "total_completion_tokens": 25, + "num_calls": 1, + } + + with patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=(mock_analysis, cost_dict), + ): + context_dict = {} + await context_workflow.execute_async(sample_request, context_dict) + + assert "costs_dict" in context_dict + assert "context_detection" in context_dict["costs_dict"] + assert context_dict["costs_dict"]["context_detection"]["total_cost"] == 0.0015 + assert context_dict["costs_dict"]["context_detection"]["total_tokens"] == 75 + + +class TestLanguageDetection: + """Test language detection integration.""" + + @pytest.mark.asyncio + async def test_detects_estonian_language( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should detect Estonian language from query.""" + sample_request.message = "Tere! Kuidas läheb?" + + mock_analysis = ContextDetectionResult( + is_greeting=True, + can_answer_from_context=False, + reasoning="Estonian greeting", + ) + + with ( + patch.object( + context_workflow.context_analyzer, "detect_context" + ) as mock_detect, + patch( + "src.tool_classifier.greeting_constants.get_greeting_response" + ) as mock_greeting, + ): + mock_detect.return_value = ( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ) + mock_greeting.return_value = "Tere! Kuidas ma saan sind aidata?" + + await context_workflow.execute_async(sample_request, {}) + + # Verify Estonian language was used for greeting response + mock_greeting.assert_called_with(greeting_type="hello", language="et") + + @pytest.mark.asyncio + async def test_detects_english_language( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should detect English language from query.""" + sample_request.message = "Hello! How are you?" + + mock_analysis = ContextDetectionResult( + is_greeting=True, + can_answer_from_context=False, + reasoning="English greeting", + ) + + with ( + patch.object( + context_workflow.context_analyzer, "detect_context" + ) as mock_detect, + patch( + "src.tool_classifier.greeting_constants.get_greeting_response" + ) as mock_greeting, + ): + mock_detect.return_value = ( + mock_analysis, + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ) + mock_greeting.return_value = "Hello! How can I help you?" + + await context_workflow.execute_async(sample_request, {}) + + # Verify English language was used for greeting response + mock_greeting.assert_called_with(greeting_type="hello", language="en") + + +class TestExecuteAsyncSummaryBased: + """Test execute_async with summary-based answers.""" + + @pytest.mark.asyncio + async def test_execute_async_summary_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should return response when answer comes from conversation summary.""" + sample_request.message = "What was the tax rate we discussed earlier?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Found in conversation summary", + context_snippet="Based on our earlier discussion, the tax rate is 20%.", + answered_from_summary=True, + ) + + with ( + patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.005, "total_tokens": 200, "num_calls": 3}, + ), + ), + patch.object( + context_workflow.context_analyzer, + "generate_context_response", + new_callable=AsyncMock, + return_value=( + "Based on our earlier discussion, the tax rate is 20%.", + {"total_cost": 0.003, "num_calls": 1}, + ), + ), + ): + response = await context_workflow.execute_async(sample_request, {}) + + assert response is not None + assert isinstance(response, OrchestrationResponse) + assert "20%" in response.content + assert response.llmServiceActive is True + + @pytest.mark.asyncio + async def test_execute_streaming_summary_answer( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should stream summary-based answer correctly.""" + sample_request.message = "What was the tax rate we discussed earlier?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Found in conversation summary", + context_snippet="Based on our earlier discussion, the tax rate is 20%.", + answered_from_summary=True, + ) + + async def _fake_summary_stream( + *args: object, **kwargs: object + ) -> AsyncGenerator[str, None]: + yield context_workflow.orchestration_service.format_sse( + sample_request.chatId, "The tax rate is 20%." + ) + yield context_workflow.orchestration_service.format_sse( + sample_request.chatId, "END" + ) + + with ( + patch.object( + context_workflow.context_analyzer, + "detect_context", + return_value=( + mock_analysis, + {"total_cost": 0.005, "total_tokens": 200, "num_calls": 3}, + ), + ), + patch.object( + context_workflow, + "_create_history_stream", + new_callable=AsyncMock, + return_value=_fake_summary_stream(), + ), + ): + stream_gen = await context_workflow.execute_streaming(sample_request, {}) + + assert stream_gen is not None + + chunks = [chunk async for chunk in stream_gen] + + # Should have multiple chunks + END marker + assert len(chunks) > 1 + assert "END" in chunks[-1] + + @pytest.mark.asyncio + async def test_pre_computed_summary_analysis( + self, + context_workflow: ContextWorkflowExecutor, + sample_request: OrchestrationRequest, + ) -> None: + """Should use pre-computed summary analysis from classifier.""" + sample_request.message = "What was the tax rate?" + + mock_analysis = ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Found in summary", + context_snippet="The tax rate is 20%.", + answered_from_summary=True, + ) + + # Pre-computed analysis (from classifier) + context = {"analysis_result": mock_analysis} + + with patch.object( + context_workflow.context_analyzer, + "generate_context_response", + new_callable=AsyncMock, + return_value=( + "The tax rate is 20%.", + {"total_cost": 0.003, "num_calls": 1}, + ), + ): + response = await context_workflow.execute_async(sample_request, context) + + assert response is not None + assert "20%" in response.content diff --git a/tests/test_context_workflow_integration.py b/tests/test_context_workflow_integration.py new file mode 100644 index 0000000..bca2af2 --- /dev/null +++ b/tests/test_context_workflow_integration.py @@ -0,0 +1,851 @@ +"""Integration tests for context workflow. + +Tests the full classify -> route -> execute chain with real component wiring. +Only the LLM layer (dspy) and RAG orchestration service are mocked. + +These tests verify: +- ToolClassifier.classify() correctly routes greetings to CONTEXT workflow +- ToolClassifier.route_to_workflow() executes the context workflow end-to-end +- Fallback from CONTEXT to RAG when context cannot answer +- Streaming mode for context workflow responses +- Cost tracking propagation through the classify -> execute chain +- Error resilience (LLM failures, JSON parse errors) +""" + +import pytest +from collections.abc import AsyncGenerator, Generator +from contextlib import AbstractContextManager +from unittest.mock import AsyncMock, MagicMock, patch +import json +import dspy + +from src.tool_classifier.classifier import ToolClassifier +from src.tool_classifier.context_analyzer import ContextDetectionResult +from src.tool_classifier.models import ClassificationResult +from src.models.request_models import ( + OrchestrationRequest, + OrchestrationResponse, + ConversationItem, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def mock_dspy_lm() -> Generator[MagicMock, None, None]: + """Mock DSPy LM to prevent 'No LM is loaded' errors.""" + mock_lm = MagicMock() + mock_lm.history = [] + with patch("dspy.settings") as mock_settings: + mock_settings.lm = mock_lm + # Configure DSPy with mock LM + dspy.configure(lm=mock_lm) + yield mock_lm + + +@pytest.fixture +def mock_orchestration_service() -> MagicMock: + """Create mock orchestration service for RAG workflow fallback.""" + import json as _json + import time as _time + + service = MagicMock() + + # Non-streaming RAG fallback returns a valid response + async def mock_execute_pipeline(**kwargs: object) -> OrchestrationResponse: + return OrchestrationResponse( + chatId=kwargs["request"].chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content="RAG fallback answer.", + ) + + service._execute_orchestration_pipeline = AsyncMock( + side_effect=mock_execute_pipeline + ) + service._initialize_service_components = MagicMock(return_value={}) + service._log_costs = MagicMock() + service.log_costs = MagicMock() + + def _format_sse_impl(chat_id: str, content: str) -> str: + payload = { + "chatId": chat_id, + "payload": {"content": content}, + "timestamp": int(_time.time() * 1000), + } + return f"data: {_json.dumps(payload)}\n\n" + + service.format_sse = _format_sse_impl + + # Streaming RAG fallback + async def mock_stream_pipeline(**kwargs: object) -> AsyncGenerator[str, None]: + yield 'data: {"chatId":"test","payload":{"content":"RAG stream"}}\n\n' + yield 'data: {"chatId":"test","payload":{"content":"END"}}\n\n' + + service._stream_rag_pipeline = mock_stream_pipeline + + return service + + +@pytest.fixture +def llm_manager() -> MagicMock: + """Create mock LLM manager.""" + return MagicMock() + + +@pytest.fixture +def classifier( + llm_manager: MagicMock, mock_orchestration_service: MagicMock +) -> ToolClassifier: + """Create a real ToolClassifier with real workflow executors.""" + return ToolClassifier( + llm_manager=llm_manager, + orchestration_service=mock_orchestration_service, + ) + + +def _make_request( + message: str, + chat_id: str = "integration-test-chat", + history: list | None = None, +) -> OrchestrationRequest: + """Helper to build an OrchestrationRequest.""" + return OrchestrationRequest( + chatId=chat_id, + message=message, + authorId="test-user", + conversationHistory=history or [], + url="https://example.com", + environment="testing", + connection_id="test-conn", + ) + + +def _mock_dspy_greeting(answer_text: str) -> AbstractContextManager[MagicMock]: + """Return a patch context manager that makes dspy return a greeting analysis.""" + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": True, + "can_answer_from_context": False, + "answer": answer_text, + "reasoning": "Greeting detected", + } + ) + return patch( + "dspy.ChainOfThought", + return_value=MagicMock(return_value=mock_response), + ) + + +def _mock_dspy_context_answer( + answer_text: str, reasoning: str = "History reference" +) -> AbstractContextManager[MagicMock]: + """Return a patch that makes dspy return a context-based answer.""" + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": True, + "answer": answer_text, + "reasoning": reasoning, + } + ) + return patch( + "dspy.ChainOfThought", + return_value=MagicMock(return_value=mock_response), + ) + + +def _mock_dspy_no_match() -> AbstractContextManager[MagicMock]: + """Return a patch that makes dspy indicate neither greeting nor context match.""" + mock_response = MagicMock() + mock_response.analysis_result = json.dumps( + { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Requires knowledge base search", + } + ) + return patch( + "dspy.ChainOfThought", + return_value=MagicMock(return_value=mock_response), + ) + + +def _patch_cost_utils() -> AbstractContextManager[MagicMock]: + """Patch cost tracking to avoid dspy settings dependency. + + Patches at both possible module paths to handle Python's module identity + behaviour when src/ is on sys.path (module may be loaded as either + ``tool_classifier.context_analyzer`` or ``src.tool_classifier.context_analyzer``). + """ + cost_return = { + "total_cost": 0.001, + "total_tokens": 50, + "total_prompt_tokens": 30, + "total_completion_tokens": 20, + "num_calls": 1, + } + + import sys + + # Determine which module key is actually loaded at runtime + if "tool_classifier.context_analyzer" in sys.modules: + target = "tool_classifier.context_analyzer.get_lm_usage_since" + else: + target = "src.tool_classifier.context_analyzer.get_lm_usage_since" + + return patch(target, return_value=cost_return) + + +# --------------------------------------------------------------------------- +# Integration: classify -> route -> execute (non-streaming) +# --------------------------------------------------------------------------- + + +class TestClassifyAndRouteGreeting: + """Test full classify -> route chain for greeting queries.""" + + @pytest.mark.asyncio + async def test_greeting_classify_returns_context_workflow( + self, classifier: ToolClassifier + ) -> None: + """classify() should return CONTEXT workflow for greeting queries. + + With the hybrid-search classifier, classify() uses Qdrant to detect + service queries. When no service matches (or embedding fails in tests), + it falls back to CONTEXT. The analysis_result is produced later inside + the context workflow executor during route_to_workflow. + """ + with ( + _mock_dspy_greeting("Tere! Kuidas ma saan sind aidata?"), + _patch_cost_utils(), + ): + result = await classifier.classify( + query="Tere!", + conversation_history=[], + language="et", + ) + + # Hybrid classifier routes non-service queries to CONTEXT + assert result.workflow.value == "context" + # analysis_result is now populated during route_to_workflow, not classify + assert result.metadata is not None + + @pytest.mark.asyncio + async def test_greeting_end_to_end_non_streaming( + self, classifier: ToolClassifier + ) -> None: + """Full chain: classify greeting -> route to context workflow -> get response.""" + with _mock_dspy_greeting("Hello! How can I help you?"), _patch_cost_utils(): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + request = _make_request("Hello!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Greeting detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert response.chatId == "integration-test-chat" + assert "Hello" in response.content + assert response.llmServiceActive is True + assert response.questionOutOfLLMScope is False + + @pytest.mark.asyncio + async def test_estonian_greeting_end_to_end( + self, classifier: ToolClassifier + ) -> None: + """Full chain for Estonian greeting.""" + with ( + _mock_dspy_greeting("Tere! Kuidas ma saan sind aidata?"), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="Tere!", + conversation_history=[], + language="et", + ) + + request = _make_request("Tere!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Estonian greeting detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "Tere" in response.content + + @pytest.mark.asyncio + async def test_goodbye_end_to_end(self, classifier: ToolClassifier) -> None: + """Full chain for goodbye greeting.""" + with _mock_dspy_greeting("Goodbye! Have a great day!"), _patch_cost_utils(): + classification = await classifier.classify( + query="Goodbye!", + conversation_history=[], + language="en", + ) + + request = _make_request("Goodbye!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="goodbye", + can_answer_from_context=False, + reasoning="Goodbye detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "Goodbye" in response.content + + @pytest.mark.asyncio + async def test_thanks_end_to_end(self, classifier: ToolClassifier) -> None: + """Full chain for thanks greeting.""" + with ( + _mock_dspy_greeting("You're welcome! Feel free to ask more."), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="Thank you!", + conversation_history=[], + language="en", + ) + + request = _make_request("Thank you!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="thanks", + can_answer_from_context=False, + reasoning="Thanks detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "welcome" in response.content.lower() + + +class TestClassifyAndRouteContextAnswer: + """Test full classify -> route chain for context-based answers.""" + + @pytest.mark.asyncio + async def test_context_answer_end_to_end(self, classifier: ToolClassifier) -> None: + """Full chain: classify history query -> route to context -> get answer.""" + history = [ + ConversationItem( + authorRole="user", + message="What is the tax rate?", + timestamp="2024-01-01T12:00:00", + ), + ConversationItem( + authorRole="bot", + message="The tax rate is 20%.", + timestamp="2024-01-01T12:00:01", + ), + ] + + with ( + _mock_dspy_context_answer("I mentioned the tax rate is 20%."), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="What was the rate?", + conversation_history=history, + language="en", + ) + + request = _make_request("What was the rate?", history=history) + with ( + patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=False, + greeting_type="hello", + can_answer_from_context=True, + reasoning="Tax rate referenced in history", + context_snippet="The tax rate is 20%.", + ), + {"total_cost": 0.002, "total_tokens": 100, "num_calls": 1}, + ), + ), + patch.object( + classifier.context_workflow.context_analyzer, + "generate_context_response", + new_callable=AsyncMock, + return_value=( + "I mentioned the tax rate is 20%.", + {"total_cost": 0.003, "num_calls": 1}, + ), + ), + ): + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert classification.workflow.value == "context" + assert isinstance(response, OrchestrationResponse) + assert "20%" in response.content + + @pytest.mark.asyncio + async def test_context_answer_with_long_history( + self, classifier: ToolClassifier + ) -> None: + """Should pass last 10 turns to the analyzer even with longer history.""" + history = [ + ConversationItem( + authorRole="user" if i % 2 == 0 else "bot", + message=f"Message {i}", + timestamp=f"2024-01-01T12:00:{i:02d}", + ) + for i in range(15) + ] + + with ( + _mock_dspy_context_answer("Based on our conversation, here's the answer."), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="What did we discuss?", + conversation_history=history, + language="en", + ) + + request = _make_request("What did we discuss?", history=history) + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert classification.workflow.value == "context" + assert isinstance(response, OrchestrationResponse) + assert response.content is not None + + +# --------------------------------------------------------------------------- +# Integration: fallback from CONTEXT to RAG +# --------------------------------------------------------------------------- + + +class TestContextToRAGFallback: + """Test that context workflow falls back to RAG when it cannot answer.""" + + @pytest.mark.asyncio + async def test_classify_defaults_to_rag_when_no_context_match( + self, classifier: ToolClassifier, mock_orchestration_service: MagicMock + ) -> None: + """When context analyzer can't answer, the full route chain ends at RAG. + + With the hybrid-search classifier, classify() returns CONTEXT for + non-service queries. The RAG fallback is triggered inside + route_to_workflow when the context workflow returns None. + """ + with _mock_dspy_no_match(), _patch_cost_utils(): + classification = await classifier.classify( + query="What is a digital signature?", + conversation_history=[], + language="en", + ) + + # Classifier routes non-service queries to CONTEXT first + assert classification.workflow.value == "context" + + # Full route: context can't answer → falls back to RAG + request = _make_request("What is a digital signature?") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "RAG" in response.content + + @pytest.mark.asyncio + async def test_fallback_to_rag_end_to_end( + self, classifier: ToolClassifier, mock_orchestration_service: MagicMock + ) -> None: + """Full chain: context can't answer -> falls back to RAG -> gets RAG response.""" + with _mock_dspy_no_match(), _patch_cost_utils(): + classification = await classifier.classify( + query="What is a digital signature?", + conversation_history=[], + language="en", + ) + + # Hybrid classifier routes to CONTEXT first; RAG is via fallback + assert classification.workflow.value == "context" + + request = _make_request("What is a digital signature?") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + # RAG mock returns "RAG fallback answer." + assert "RAG" in response.content + + @pytest.mark.asyncio + async def test_context_workflow_returns_none_triggers_rag_fallback( + self, classifier: ToolClassifier, mock_orchestration_service: MagicMock + ) -> None: + """When context workflow returns None during routing, RAG fallback is used.""" + # Force classification to CONTEXT but with an analysis that will produce None + no_answer_analysis = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=False, + answer=None, + reasoning="Cannot answer", + ) + + # Use the WorkflowType from the same module path the classifier uses + from tool_classifier.enums import WorkflowType as _WorkflowType + + forced_classification = ClassificationResult( + workflow=_WorkflowType.CONTEXT, + confidence=0.95, + metadata={"analysis_result": no_answer_analysis}, + reasoning="Forced for test", + ) + + request = _make_request("Something that context can't answer") + response = await classifier.route_to_workflow( + classification=forced_classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + # Should have fallen through to RAG + assert "RAG" in response.content + + +# --------------------------------------------------------------------------- +# Integration: streaming mode +# --------------------------------------------------------------------------- + + +class TestStreamingIntegration: + """Test the full classify -> route -> stream chain.""" + + @pytest.mark.asyncio + async def test_streaming_greeting_end_to_end( + self, classifier: ToolClassifier + ) -> None: + """Full chain: classify greeting -> route streaming -> collect SSE chunks.""" + with _mock_dspy_greeting("Hello! How can I help you?"), _patch_cost_utils(): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + request = _make_request("Hello!") + with patch.object( + classifier.context_workflow.context_analyzer, + "detect_context", + new_callable=AsyncMock, + return_value=( + ContextDetectionResult( + is_greeting=True, + greeting_type="hello", + can_answer_from_context=False, + reasoning="Greeting detected", + ), + {"total_cost": 0.001, "total_tokens": 50, "num_calls": 1}, + ), + ): + stream = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=True, + ) + + # Collect chunks inside the mock context so the dspy patch is active + # when the async generator body executes (lazy evaluation). + chunks = [chunk async for chunk in stream] + + # Should have content chunks + END marker + assert len(chunks) >= 2 + for chunk in chunks: + assert chunk.startswith("data: ") + assert chunk.endswith("\n\n") + + # Last chunk should contain END + last_payload = json.loads(chunks[-1][6:-2]) + assert last_payload["payload"]["content"] == "END" + + # Reconstruct content from non-END chunks + content_parts = [] + for chunk in chunks[:-1]: + payload = json.loads(chunk[6:-2]) + content_parts.append(payload["payload"]["content"]) + full_content = "".join(content_parts) + assert "Hello" in full_content + + @pytest.mark.asyncio + async def test_streaming_context_answer_end_to_end( + self, classifier: ToolClassifier + ) -> None: + """Full chain: classify history query -> route streaming -> collect answer.""" + history = [ + ConversationItem( + authorRole="bot", + message="The deadline is March 31st.", + timestamp="2024-01-01T12:00:00", + ), + ] + + with ( + _mock_dspy_context_answer("The deadline is March 31st."), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="When is the deadline?", + conversation_history=history, + language="en", + ) + + request = _make_request("When is the deadline?", history=history) + stream = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=True, + ) + + chunks = [chunk async for chunk in stream] + + assert len(chunks) >= 2 + last_payload = json.loads(chunks[-1][6:-2]) + assert last_payload["payload"]["content"] == "END" + + @pytest.mark.asyncio + async def test_streaming_fallback_to_rag( + self, classifier: ToolClassifier, mock_orchestration_service: MagicMock + ) -> None: + """Streaming: context can't answer -> falls back to RAG streaming.""" + # Force classification to CONTEXT with no answer + no_answer_analysis = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=False, + answer=None, + reasoning="Cannot answer", + ) + + from tool_classifier.enums import WorkflowType as _WorkflowType + + forced_classification = ClassificationResult( + workflow=_WorkflowType.CONTEXT, + confidence=0.95, + metadata={"analysis_result": no_answer_analysis}, + reasoning="Forced for test", + ) + + request = _make_request("Something needing RAG") + stream = await classifier.route_to_workflow( + classification=forced_classification, + request=request, + is_streaming=True, + ) + + chunks = [chunk async for chunk in stream] + + # Should have received RAG streaming output + assert len(chunks) >= 1 + + +# --------------------------------------------------------------------------- +# Integration: cost tracking across the chain +# --------------------------------------------------------------------------- + + +class TestCostTrackingIntegration: + """Test that cost data flows through the full classify -> execute chain.""" + + @pytest.mark.asyncio + async def test_costs_propagated_through_classification( + self, classifier: ToolClassifier + ) -> None: + """Cost dict from context analysis should be tracked during workflow execution. + + With the hybrid-search classifier, costs are tracked inside the context + workflow executor (execute_async/execute_streaming), not in classify(). + The cost dict is stored in the workflow's internal context dictionary. + """ + with _mock_dspy_greeting("Hello!"), _patch_cost_utils(): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + # Verify classify succeeded and routes to CONTEXT + assert classification.workflow.value == "context" + + # Execute the workflow to trigger cost tracking + request = _make_request("Hello!") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + # Verify workflow ran successfully (costs tracked internally) + assert isinstance(response, OrchestrationResponse) + assert response.chatId == "integration-test-chat" + + +# --------------------------------------------------------------------------- +# Integration: error resilience +# --------------------------------------------------------------------------- + + +class TestErrorResilience: + """Test that errors in context analysis gracefully fall back to RAG.""" + + @pytest.mark.asyncio + async def test_llm_exception_falls_back_to_rag( + self, classifier: ToolClassifier + ) -> None: + """If context analyzer LLM call raises, the route chain falls back to RAG. + + With the hybrid-search classifier, classify() returns CONTEXT for + non-service queries. When the context workflow LLM call raises, the + context workflow returns None and route_to_workflow falls back to RAG. + """ + with ( + patch( + "dspy.ChainOfThought", + return_value=MagicMock(side_effect=Exception("LLM unavailable")), + ), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + # classify() returns CONTEXT (non-service query) + assert classification.workflow.value == "context" + + # Full route: context LLM fails → falls back to RAG gracefully + request = _make_request("Hello!") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "RAG" in response.content + + @pytest.mark.asyncio + async def test_json_parse_error_falls_back_to_rag( + self, classifier: ToolClassifier + ) -> None: + """If LLM returns invalid JSON, the route chain falls back to RAG. + + JSON parse failure causes context analysis to return is_greeting=False, + answer=None. The context workflow then returns None and the fallback + chain routes to RAG. + """ + mock_response = MagicMock() + mock_response.analysis_result = "not valid json at all" + + with ( + patch( + "dspy.ChainOfThought", + return_value=MagicMock(return_value=mock_response), + ), + _patch_cost_utils(), + ): + classification = await classifier.classify( + query="Hello!", + conversation_history=[], + language="en", + ) + + # classify() returns CONTEXT (non-service query) + assert classification.workflow.value == "context" + + # Full route: JSON parse fails → context returns None → RAG fallback + request = _make_request("Hello!") + response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + assert isinstance(response, OrchestrationResponse) + assert "RAG" in response.content