diff --git a/src/llm_orchestration_service_api.py b/src/llm_orchestration_service_api.py
index 110c299..12cb0ed 100644
--- a/src/llm_orchestration_service_api.py
+++ b/src/llm_orchestration_service_api.py
@@ -281,6 +281,12 @@ async def orchestrate_llm_request(
# Process the request
response = await orchestration_service.process_orchestration_request(request)
+ buttons_present = bool(response.buttons)
+ buttons_count = len(response.buttons) if response.buttons else 0
+ logger.info(
+ f"[orchestrate] buttons in response for chatId {request.chatId}: "
+ f"present={buttons_present}, count={buttons_count}"
+ )
logger.info(f"Successfully processed request for chatId: {request.chatId}")
return response
@@ -364,6 +370,10 @@ async def test_orchestrate_llm_request(
# If response is already TestOrchestrationResponse (when environment is testing), return it directly
if isinstance(response, TestOrchestrationResponse):
+ buttons_count = len(response.buttons) if response.buttons else 0
+ logger.info(
+ f"[test_orchestrate] buttons present in response: {buttons_count}"
+ )
logger.info(
f"Successfully processed test request for environment: {request.environment}"
)
@@ -375,9 +385,9 @@ async def test_orchestrate_llm_request(
questionOutOfLLMScope=response.questionOutOfLLMScope,
inputGuardFailed=response.inputGuardFailed,
content=response.content,
+ buttons=response.buttons,
chunks=None, # OrchestrationResponse doesn't have chunks
)
-
logger.info(
f"Successfully processed test request for environment: {request.environment}"
)
diff --git a/src/models/request_models.py b/src/models/request_models.py
index 689c68c..c6c58eb 100644
--- a/src/models/request_models.py
+++ b/src/models/request_models.py
@@ -138,6 +138,16 @@ class DocumentReference(BaseModel):
relevance_score: float = Field(..., description="Relevance score (0-1)")
+class ChoiceButton(BaseModel):
+ """A single MCQ choice button returned in an orchestration response."""
+
+ title: str = Field(..., description="Button label shown to the user")
+ payload: str = Field(
+ ...,
+ description="Routing string sent when the button is clicked (e.g. '#service, /POST/...')",
+ )
+
+
class OrchestrationResponse(BaseModel):
"""Model for LLM orchestration response."""
@@ -150,6 +160,10 @@ class OrchestrationResponse(BaseModel):
..., description="Whether input guard validation failed"
)
content: str = Field(..., description="Response content with citations")
+ buttons: Optional[List[ChoiceButton]] = Field(
+ default=None,
+ description="Optional list of choice buttons for MCQ step responses",
+ )
# New models for embedding and context generation
@@ -261,6 +275,10 @@ class TestOrchestrationResponse(BaseModel):
..., description="Whether input guard validation failed"
)
content: str = Field(..., description="Response content with citations")
+ buttons: Optional[List[ChoiceButton]] = Field(
+ default=None,
+ description="Optional list of choice buttons for MCQ step responses",
+ )
chunks: Optional[List[ChunkInfo]] = Field(
default=None, description="Retrieved chunks with rank and content"
)
diff --git a/src/utils/input_sanitizer.py b/src/utils/input_sanitizer.py
index 3627038..b0bd146 100644
--- a/src/utils/input_sanitizer.py
+++ b/src/utils/input_sanitizer.py
@@ -57,6 +57,8 @@ def strip_html_tags(text: str) -> str:
if not text:
return text
+ text = html.unescape(text)
+
# First pass: Remove dangerous tags and their content
for tag in InputSanitizer.DANGEROUS_TAGS:
# Remove opening tag, content, and closing tag
@@ -74,9 +76,6 @@ def strip_html_tags(text: str) -> str:
# Third pass: Remove all remaining HTML tags
text = re.sub(r"<[^>]+>", "", text)
- # Unescape HTML entities (e.g., < -> <)
- text = html.unescape(text)
-
return text
@staticmethod
diff --git a/tests/test_input_sanitizer.py b/tests/test_input_sanitizer.py
new file mode 100644
index 0000000..ad129f5
--- /dev/null
+++ b/tests/test_input_sanitizer.py
@@ -0,0 +1,125 @@
+"""Unit tests for InputSanitizer — focused on #service prefix safety.
+
+Validates that strip_html_tags() and sanitize_message() leave the
+#service, /POST/... routing prefix characters (#, comma, /) untouched,
+so that prefix detection logic in downstream handlers can always match.
+"""
+
+import pytest
+
+from src.utils.input_sanitizer import InputSanitizer
+
+
+class TestSanitizeMessageServicePrefix:
+ """Primary passthrough: #service, /METHOD/... payloads must survive sanitization unchanged."""
+
+ def test_exact_service_prefix_passthrough(self) -> None:
+ """The canonical #service prefix must survive sanitization bit-for-bit identical."""
+ msg = "#service, /POST/services/active/foo"
+ assert InputSanitizer.sanitize_message(msg) == msg
+
+ @pytest.mark.parametrize(
+ "msg",
+ [
+ "#service, /POST/services/active/foo",
+ "#service, /GET/services/list",
+ "#service, /DELETE/services/active/foo",
+ "#service, /PUT/services/active/foo",
+ "#service, /PATCH/services/active/foo",
+ "#service, /POST/services/active/foo?status=true",
+ "#service, /POST/services/active/foo?a=1&b=2",
+ "#service, /POST/services/active/foo#anchor",
+ ],
+ )
+ def test_service_prefix_variants_passthrough(self, msg: str) -> None:
+ """All #service, /METHOD/... variants must pass through unmodified."""
+ assert InputSanitizer.sanitize_message(msg) == msg
+
+
+class TestSanitizeMessageHtmlStripping:
+ """Confirms HTML IS stripped while #service prefix characters survive.
+
+ These tests prove the sanitizer is active (not a no-op) and that it
+ surgically removes only HTML constructs, leaving #, comma, and / intact.
+ """
+
+ def test_bold_tags_stripped_prefix_survives(self) -> None:
+ result = InputSanitizer.sanitize_message(
+ "#service, /POST/services/active/foo"
+ )
+ assert result == "#service, /POST/services/active/foo"
+
+ def test_script_tag_content_stripped_path_survives(self) -> None:
+ """Dangerous foo"
+ )
+ assert result == "#service, /POST/foo"
+
+ def test_entity_encoded_script_tag_stripped_path_survives(self) -> None:
+ """Entity-encoded