diff --git a/src/llm_orchestration_service_api.py b/src/llm_orchestration_service_api.py index 110c299..12cb0ed 100644 --- a/src/llm_orchestration_service_api.py +++ b/src/llm_orchestration_service_api.py @@ -281,6 +281,12 @@ async def orchestrate_llm_request( # Process the request response = await orchestration_service.process_orchestration_request(request) + buttons_present = bool(response.buttons) + buttons_count = len(response.buttons) if response.buttons else 0 + logger.info( + f"[orchestrate] buttons in response for chatId {request.chatId}: " + f"present={buttons_present}, count={buttons_count}" + ) logger.info(f"Successfully processed request for chatId: {request.chatId}") return response @@ -364,6 +370,10 @@ async def test_orchestrate_llm_request( # If response is already TestOrchestrationResponse (when environment is testing), return it directly if isinstance(response, TestOrchestrationResponse): + buttons_count = len(response.buttons) if response.buttons else 0 + logger.info( + f"[test_orchestrate] buttons present in response: {buttons_count}" + ) logger.info( f"Successfully processed test request for environment: {request.environment}" ) @@ -375,9 +385,9 @@ async def test_orchestrate_llm_request( questionOutOfLLMScope=response.questionOutOfLLMScope, inputGuardFailed=response.inputGuardFailed, content=response.content, + buttons=response.buttons, chunks=None, # OrchestrationResponse doesn't have chunks ) - logger.info( f"Successfully processed test request for environment: {request.environment}" ) diff --git a/src/models/request_models.py b/src/models/request_models.py index 689c68c..c6c58eb 100644 --- a/src/models/request_models.py +++ b/src/models/request_models.py @@ -138,6 +138,16 @@ class DocumentReference(BaseModel): relevance_score: float = Field(..., description="Relevance score (0-1)") +class ChoiceButton(BaseModel): + """A single MCQ choice button returned in an orchestration response.""" + + title: str = Field(..., description="Button label shown to the user") + payload: str = Field( + ..., + description="Routing string sent when the button is clicked (e.g. '#service, /POST/...')", + ) + + class OrchestrationResponse(BaseModel): """Model for LLM orchestration response.""" @@ -150,6 +160,10 @@ class OrchestrationResponse(BaseModel): ..., description="Whether input guard validation failed" ) content: str = Field(..., description="Response content with citations") + buttons: Optional[List[ChoiceButton]] = Field( + default=None, + description="Optional list of choice buttons for MCQ step responses", + ) # New models for embedding and context generation @@ -261,6 +275,10 @@ class TestOrchestrationResponse(BaseModel): ..., description="Whether input guard validation failed" ) content: str = Field(..., description="Response content with citations") + buttons: Optional[List[ChoiceButton]] = Field( + default=None, + description="Optional list of choice buttons for MCQ step responses", + ) chunks: Optional[List[ChunkInfo]] = Field( default=None, description="Retrieved chunks with rank and content" ) diff --git a/src/utils/input_sanitizer.py b/src/utils/input_sanitizer.py index 3627038..b0bd146 100644 --- a/src/utils/input_sanitizer.py +++ b/src/utils/input_sanitizer.py @@ -57,6 +57,8 @@ def strip_html_tags(text: str) -> str: if not text: return text + text = html.unescape(text) + # First pass: Remove dangerous tags and their content for tag in InputSanitizer.DANGEROUS_TAGS: # Remove opening tag, content, and closing tag @@ -74,9 +76,6 @@ def strip_html_tags(text: str) -> str: # Third pass: Remove all remaining HTML tags text = re.sub(r"<[^>]+>", "", text) - # Unescape HTML entities (e.g., < -> <) - text = html.unescape(text) - return text @staticmethod diff --git a/tests/test_input_sanitizer.py b/tests/test_input_sanitizer.py new file mode 100644 index 0000000..ad129f5 --- /dev/null +++ b/tests/test_input_sanitizer.py @@ -0,0 +1,125 @@ +"""Unit tests for InputSanitizer — focused on #service prefix safety. + +Validates that strip_html_tags() and sanitize_message() leave the +#service, /POST/... routing prefix characters (#, comma, /) untouched, +so that prefix detection logic in downstream handlers can always match. +""" + +import pytest + +from src.utils.input_sanitizer import InputSanitizer + + +class TestSanitizeMessageServicePrefix: + """Primary passthrough: #service, /METHOD/... payloads must survive sanitization unchanged.""" + + def test_exact_service_prefix_passthrough(self) -> None: + """The canonical #service prefix must survive sanitization bit-for-bit identical.""" + msg = "#service, /POST/services/active/foo" + assert InputSanitizer.sanitize_message(msg) == msg + + @pytest.mark.parametrize( + "msg", + [ + "#service, /POST/services/active/foo", + "#service, /GET/services/list", + "#service, /DELETE/services/active/foo", + "#service, /PUT/services/active/foo", + "#service, /PATCH/services/active/foo", + "#service, /POST/services/active/foo?status=true", + "#service, /POST/services/active/foo?a=1&b=2", + "#service, /POST/services/active/foo#anchor", + ], + ) + def test_service_prefix_variants_passthrough(self, msg: str) -> None: + """All #service, /METHOD/... variants must pass through unmodified.""" + assert InputSanitizer.sanitize_message(msg) == msg + + +class TestSanitizeMessageHtmlStripping: + """Confirms HTML IS stripped while #service prefix characters survive. + + These tests prove the sanitizer is active (not a no-op) and that it + surgically removes only HTML constructs, leaving #, comma, and / intact. + """ + + def test_bold_tags_stripped_prefix_survives(self) -> None: + result = InputSanitizer.sanitize_message( + "#service, /POST/services/active/foo" + ) + assert result == "#service, /POST/services/active/foo" + + def test_script_tag_content_stripped_path_survives(self) -> None: + """Dangerous foo" + ) + assert result == "#service, /POST/foo" + + def test_entity_encoded_script_tag_stripped_path_survives(self) -> None: + """Entity-encoded