Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 153 additions & 15 deletions src/tool_classifier/context_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def __init__(self, llm_manager: Any) -> None: # noqa: ANN401
# Phase 1 & 2 modules for two-phase detection+generation flow
self._detection_module: Optional[dspy.Module] = None
self._response_generation_module: Optional[dspy.Module] = None
self._stream_predictor: Optional[Any] = None
logger.info("Context analyzer initialized")

def _format_conversation_history(
Expand Down Expand Up @@ -357,6 +356,111 @@ async def detect_context(
)
return result, cost_dict

async def detect_context_with_summary_fallback(
self,
query: str,
conversation_history: List[Dict[str, Any]],
) -> tuple[ContextDetectionResult, Dict[str, Any]]:
"""
Phase 1 with summary fallback: detect if query can be answered from history.

Implements a 3-step flow:
1. Check the last 10 turns via detect_context().
2. If cannot answer AND total history > 10 turns:
- Generate a concise summary of the older turns (everything before the last 10).
- Check whether the query can be answered from that summary.
3. If still cannot answer, return can_answer=False (workflow falls back to RAG).

When the summary path succeeds, the returned ContextDetectionResult has:
- can_answer_from_context=True
- answered_from_summary=True
- context_snippet set to the answer extracted from the summary, so that
Phase 2 (stream_context_response / generate_context_response) can use it
directly as the context for response generation.

Args:
query: User query to classify
conversation_history: Full conversation history

Returns:
Tuple of (ContextDetectionResult, cost_dict)
"""
total_turns = len(conversation_history)

# Step 1: check the most recent 10 turns
result, cost_dict = await self.detect_context(
query=query, conversation_history=conversation_history
)

# If already answered or it's a greeting, return immediately
if result.is_greeting or result.can_answer_from_context:
return result, cost_dict

# Step 2 & 3: if history exceeds 10 turns, try summary-based detection
if total_turns > 10:
logger.info(
f"History has {total_turns} turns (> 10) | "
f"Cannot answer from recent 10 | Attempting summary-based detection"
)
older_history = conversation_history[:-10]
logger.info(f"Summarizing {len(older_history)} older turns")

try:
summary, summary_cost = await self._generate_conversation_summary(
older_history
)
cost_dict = self._merge_cost_dicts(cost_dict, summary_cost)

if summary:
summary_result, analysis_cost = await self._analyze_from_summary(
query=query, summary=summary
)
cost_dict = self._merge_cost_dicts(cost_dict, analysis_cost)

if summary_result.can_answer_from_context and summary_result.answer:
logger.info(
f"DETECTION: Can answer from summary | "
f"Reasoning: {summary_result.reasoning}"
)
# Surface the summary-derived answer as context_snippet so
# Phase 2 can generate a polished response from it.
return ContextDetectionResult(
is_greeting=False,
can_answer_from_context=True,
reasoning=summary_result.reasoning,
context_snippet=summary_result.answer,
answered_from_summary=True,
), cost_dict

logger.info(
"Cannot answer from summary either | Falling back to RAG"
)
else:
logger.warning(
"Summary generation returned empty | Falling back to RAG"
)

except Exception as e:
logger.error(f"Summary-based detection failed: {e}", exc_info=True)
else:
logger.info(
f"History has {total_turns} turns (<= 10) | "
f"No summary needed | Falling back to RAG"
)

return result, cost_dict

@staticmethod
def _yield_in_chunks(text: str, chunk_size: int = 5) -> list[str]:
"""Split text into word-group chunks for simulated streaming."""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size):
group = words[i : i + chunk_size]
trailing = " " if i + chunk_size < len(words) else ""
chunks.append(" ".join(group) + trailing)
return chunks

async def stream_context_response(
self,
query: str,
Expand All @@ -365,30 +469,39 @@ async def stream_context_response(
"""
Phase 2 (streaming): Stream a generated answer using DSPy native streaming.

Uses ContextResponseGenerationSignature with DSPy's streamify() so tokens
are yielded in real time — same mechanism as ResponseGeneratorAgent.stream_response().
Creates a fresh streamify predictor per call (avoids stale StreamListener
issues that occur when the cached predictor is reused across calls).

Fallback chain:
1. DSPy streamify → yield StreamResponse tokens as they arrive.
2. If no stream tokens received but final Prediction has an answer,
yield it in word-group chunks.
3. If that is also empty, call generate_context_response() directly
and yield the result in word-group chunks.

Args:
query: The user query to answer
context_snippet: Relevant context extracted during Phase 1 detection

Yields:
Token strings as they arrive from the LLM
Token strings as they arrive from the LLM (or simulated chunks)
"""
logger.info(f"CONTEXT GENERATOR: Phase 2 streaming | Query: '{query[:100]}'")

self.llm_manager.ensure_global_config()
output_stream = None
stream_started = False
prediction_answer: Optional[str] = None
try:
with self.llm_manager.use_task_local():
if self._stream_predictor is None:
answer_listener = StreamListener(signature_field_name="answer")
self._stream_predictor = dspy.streamify(
dspy.Predict(ContextResponseGenerationSignature),
stream_listeners=[answer_listener],
)
output_stream = self._stream_predictor(
# Always create a fresh StreamListener + streamified predictor so that
# the listener's internal state is clean for this call.
answer_listener = StreamListener(signature_field_name="answer")
stream_predictor: Any = dspy.streamify(
dspy.Predict(ContextResponseGenerationSignature),
stream_listeners=[answer_listener],
)
output_stream = stream_predictor(
context_snippet=context_snippet,
user_query=query,
)
Expand All @@ -402,11 +515,11 @@ async def stream_context_response(
logger.info(
"Context response streaming complete (final Prediction received)"
)
if not stream_started:
# Tokens didn't stream — extract answer from the Prediction
# directly as first fallback before leaving the LM context.
prediction_answer = getattr(chunk, "answer", "") or ""

if not stream_started:
logger.warning(
"Context streaming finished but no 'answer' tokens received."
)
except GeneratorExit:
raise
except Exception as e:
Expand All @@ -421,6 +534,31 @@ async def stream_context_response(
f"Error during context stream cleanup: {cleanup_error}"
)

if stream_started:
return

# Fallback 1: answer was in the final Prediction but didn't stream as tokens
if prediction_answer:
logger.warning(
"Stream tokens not received — yielding answer from final Prediction in chunks."
)
for text_chunk in self._yield_in_chunks(prediction_answer):
yield text_chunk
return

# Fallback 2: Prediction had no answer either — call generate_context_response
logger.warning(
"No answer from streamify — falling back to generate_context_response."
)
fallback_answer, _ = await self.generate_context_response(
query=query, context_snippet=context_snippet
)
if fallback_answer:
for text_chunk in self._yield_in_chunks(fallback_answer):
yield text_chunk
else:
logger.error("All Phase 2 streaming fallbacks exhausted — empty response.")

async def generate_context_response(
self,
query: str,
Expand Down
47 changes: 38 additions & 9 deletions src/tool_classifier/workflows/context_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Context workflow executor - Layer 2: Conversation history and greetings."""

from typing import Any, AsyncIterator, Dict, Optional
from typing import Any, AsyncIterator, Dict, Optional, cast
import time
import dspy
from loguru import logger
Expand Down Expand Up @@ -77,10 +77,19 @@ async def _detect(
time_metric: Dict[str, float],
costs_metric: Dict[str, Dict[str, Any]],
) -> Optional[ContextDetectionResult]:
"""Phase 1: run context detection. Returns ContextDetectionResult or None on error."""
"""Phase 1: run context detection with summary fallback.

Checks the last 10 conversation turns first. If the query cannot be
answered from those and the history exceeds 10 turns, falls back to a
summary-based check over the older turns. Returns None on error so the
caller falls through to RAG.
"""
try:
start = time.time()
result, cost = await self.context_analyzer.detect_context(
(
result,
cost,
) = await self.context_analyzer.detect_context_with_summary_fallback(
query=message, conversation_history=history
)
time_metric["context.detection"] = time.time() - start
Expand Down Expand Up @@ -267,12 +276,29 @@ async def execute_async(
language = detect_language(request.message)
history = self._build_history(request)

detection_result = await self._detect(
request.message, history, time_metric, costs_metric
)
if detection_result is None:
self._log_costs(costs_metric)
return None
# Check if analysis is pre-computed (e.g. from classifier classify step)
pre_computed = context.get("analysis_result")
if (
pre_computed is not None
and hasattr(pre_computed, "is_greeting")
and hasattr(pre_computed, "can_answer_from_context")
):
detection_result: ContextDetectionResult = cast(
ContextDetectionResult, pre_computed
)
costs_metric.setdefault(
"context_detection",
{"total_cost": 0.0, "total_tokens": 0, "num_calls": 0},
)
else:
_detected = await self._detect(
request.message, history, time_metric, costs_metric
)
if _detected is None:
self._log_costs(costs_metric)
context["costs_dict"] = costs_metric
return None
detection_result = _detected

logger.info(
f"[{request.chatId}] Detection: greeting={detection_result.is_greeting} "
Expand All @@ -286,6 +312,7 @@ async def execute_async(
greeting_type=detection_result.greeting_type, language=language
)
self._log_costs(costs_metric)
context["costs_dict"] = costs_metric
return OrchestrationResponse(
chatId=request.chatId,
llmServiceActive=True,
Expand All @@ -298,6 +325,7 @@ async def execute_async(
detection_result.can_answer_from_context
and detection_result.context_snippet
):
context["costs_dict"] = costs_metric
return await self._generate_response_async(
request, detection_result.context_snippet, time_metric, costs_metric
)
Expand All @@ -306,6 +334,7 @@ async def execute_async(
f"[{request.chatId}] Cannot answer from context — falling back to RAG"
)
self._log_costs(costs_metric)
context["costs_dict"] = costs_metric
return None

async def execute_streaming(
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,12 @@
# Add the project root to Python path so tests can import from src
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

# Add src directory to Python path for direct module imports
src_dir = project_root / "src"
sys.path.insert(0, str(src_dir))

# Add models directory (sibling to src) for backward compatibility
models_dir = project_root / "models"
if models_dir.exists():
sys.path.insert(0, str(models_dir.parent))
Loading
Loading