diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b6ab0ca --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,62 @@ +"""Shared pytest fixtures for plugins-adapter unit tests.""" + +# Standard +import sys +from unittest.mock import AsyncMock, MagicMock, Mock + +# Third-Party +import pytest + + +@pytest.fixture +def mock_envoy_modules(): + """Mock envoy protobuf modules to avoid proto build dependencies.""" + mock_ep = MagicMock() + mock_ep_grpc = MagicMock() + mock_core = MagicMock() + mock_http_status = MagicMock() + + sys.modules["envoy"] = MagicMock() + sys.modules["envoy.service"] = MagicMock() + sys.modules["envoy.service.ext_proc"] = MagicMock() + sys.modules["envoy.service.ext_proc.v3"] = MagicMock() + sys.modules["envoy.service.ext_proc.v3.external_processor_pb2"] = mock_ep + sys.modules["envoy.service.ext_proc.v3.external_processor_pb2_grpc"] = mock_ep_grpc + sys.modules["envoy.config"] = MagicMock() + sys.modules["envoy.config.core"] = MagicMock() + sys.modules["envoy.config.core.v3"] = MagicMock() + sys.modules["envoy.config.core.v3.base_pb2"] = mock_core + sys.modules["envoy.type"] = MagicMock() + sys.modules["envoy.type.v3"] = MagicMock() + sys.modules["envoy.type.v3.http_status_pb2"] = mock_http_status + + yield { + "ep": mock_ep, + "ep_grpc": mock_ep_grpc, + "core": mock_core, + "http_status": mock_http_status, + } + + for key in list(sys.modules.keys()): + if key.startswith("envoy"): + del sys.modules[key] + if "src.server" in sys.modules: + del sys.modules["src.server"] + + +@pytest.fixture +def mock_manager(): + """Create a mock PluginManager with async invoke_hook.""" + mock = Mock() + mock.invoke_hook = AsyncMock() + return mock + + +@pytest.fixture +def sample_tool_result_body(): + """Sample MCP tool result response body.""" + return { + "jsonrpc": "2.0", + "id": "test-123", + "result": {"content": [{"type": "text", "text": "Tool execution result"}]}, + } diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000..150aa8c --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,142 @@ +"""Unit tests for server helper functions. + +Covers: set_result_in_body, get_modified_response, + create_mcp_immediate_error_response +""" + +# Standard +import json + +# First-Party +from cpex.framework import PluginViolation + + +def test_set_result_in_body(mock_envoy_modules): + """set_result_in_body mutates body['params']['arguments'] in place.""" + import src.server + + body = {"params": {"arguments": {"old_key": "old_value"}}} + new_args = {"new_key": "new_value", "count": 42} + + src.server.set_result_in_body(body, new_args) + + assert body["params"]["arguments"] == new_args + + +def test_set_result_in_body_overwrites_existing(mock_envoy_modules): + """set_result_in_body replaces all previous arguments.""" + import src.server + + body = {"params": {"arguments": {"a": 1, "b": 2, "c": 3}}} + src.server.set_result_in_body(body, {"x": 99}) + + assert body["params"]["arguments"] == {"x": 99} + assert "a" not in body["params"]["arguments"] + + +def test_get_modified_response_returns_body_response(mock_envoy_modules): + """get_modified_response encodes the body dict as JSON in a BodyResponse.""" + import src.server + + body = { + "jsonrpc": "2.0", + "id": "1", + "result": {"content": [{"type": "text", "text": "hello"}]}, + } + response = src.server.get_modified_response(body) + + assert response is not None + + +def test_create_mcp_immediate_error_response_default_code(mock_envoy_modules): + """No violation → error code defaults to -32000 (generic server error).""" + import src.server + + body = {"jsonrpc": "2.0", "id": "test-001"} + + captured = [] + original_dumps = json.dumps + + def spy(obj, **kwargs): + if isinstance(obj, dict) and "error" in obj: + captured.append(obj) + return original_dumps(obj, **kwargs) + + json.dumps = spy + try: + response = src.server.create_mcp_immediate_error_response(body, "Something went wrong") + finally: + json.dumps = original_dumps + + assert response is not None + assert len(captured) == 1 + err = captured[0] + assert err["error"]["code"] == -32000 + assert err["error"]["message"] == "Something went wrong" + assert err["jsonrpc"] == "2.0" + assert err["id"] == "test-001" + + +def test_create_mcp_immediate_error_response_with_violation_reason(mock_envoy_modules): + """Violation reason/description override the fallback message.""" + import src.server + + body = {"jsonrpc": "2.0", "id": "test-002"} + violation = PluginViolation( + reason="Content policy violated", + description="Detected restricted content in response", + code="POLICY_VIOLATION", + ) + + captured = [] + original_dumps = json.dumps + + def spy(obj, **kwargs): + if isinstance(obj, dict) and "error" in obj: + captured.append(obj) + return original_dumps(obj, **kwargs) + + json.dumps = spy + try: + response = src.server.create_mcp_immediate_error_response(body, "fallback msg", violation=violation) + finally: + json.dumps = original_dumps + + assert response is not None + assert len(captured) == 1 + err = captured[0] + assert "Content policy violated" in err["error"]["message"] + assert "Detected restricted content" in err["error"]["message"] + # mcp_error_code not set → still uses default -32000 + assert err["error"]["code"] == -32000 + + +def test_create_mcp_immediate_error_response_with_mcp_error_code(mock_envoy_modules): + """Violation mcp_error_code overrides the default -32000 code.""" + import src.server + + body = {"jsonrpc": "2.0", "id": "test-003"} + violation = PluginViolation( + reason="Invalid params", + description="Tool args failed validation", + code="INVALID_ARGS", + mcp_error_code=-32602, + ) + + captured = [] + original_dumps = json.dumps + + def spy(obj, **kwargs): + if isinstance(obj, dict) and "error" in obj: + captured.append(obj) + return original_dumps(obj, **kwargs) + + json.dumps = spy + try: + response = src.server.create_mcp_immediate_error_response(body, "fallback", violation=violation) + finally: + json.dumps = original_dumps + + assert response is not None + err = captured[0] + assert err["error"]["code"] == -32602 diff --git a/tests/test_prompt_pre_fetch.py b/tests/test_prompt_pre_fetch.py new file mode 100644 index 0000000..bba37ac --- /dev/null +++ b/tests/test_prompt_pre_fetch.py @@ -0,0 +1,123 @@ +"""Unit tests for getPromptPreFetchResponse. + +Tests the prompt pre-fetch path: validation, modification, and blocking. +""" + +# Standard +import json +from unittest.mock import Mock + +# Third-Party +import pytest + +# First-Party +from cpex.framework import PluginViolation, PromptPrehookPayload + + +@pytest.fixture +def prompt_body(): + """Sample MCP prompts/get request body.""" + return { + "jsonrpc": "2.0", + "id": "test-456", + "method": "prompts/get", + "params": { + "name": "test_prompt", + "arguments": {"arg0": "some value"}, + }, + } + + +def _make_result(continue_processing=True, modified_payload=None, violation=None): + result = Mock() + result.continue_processing = continue_processing + result.modified_payload = modified_payload + result.violation = violation + return result + + +@pytest.mark.asyncio +async def test_getPromptPreFetchResponse_continue_no_modification(mock_envoy_modules, mock_manager, prompt_body): + """Plugin allows the prompt fetch with no changes.""" + import src.server + + mock_manager.invoke_hook.return_value = (_make_result(), None) + src.server.manager = mock_manager + + response = await src.server.getPromptPreFetchResponse(prompt_body) + + assert mock_manager.invoke_hook.called + call_args = mock_manager.invoke_hook.call_args[0] + payload = call_args[1] + assert isinstance(payload, PromptPrehookPayload) + assert payload.prompt_id == "test_prompt" + assert response is not None + + +@pytest.mark.asyncio +async def test_getPromptPreFetchResponse_continue_with_modified_args(mock_envoy_modules, mock_manager, prompt_body): + """Plugin modifies prompt arguments — modified args are forwarded.""" + import src.server + + modified_args = {"arg0": "rewritten value"} + modified_payload = Mock() + modified_payload.args = {"tool_args": modified_args} + + mock_manager.invoke_hook.return_value = (_make_result(modified_payload=modified_payload), None) + src.server.manager = mock_manager + + captured_bodies = [] + original_dumps = json.dumps + + def spy_dumps(obj, **kwargs): + if isinstance(obj, dict) and "params" in obj: + captured_bodies.append(obj) + return original_dumps(obj, **kwargs) + + json.dumps = spy_dumps + try: + response = await src.server.getPromptPreFetchResponse(prompt_body) + finally: + json.dumps = original_dumps + + assert mock_manager.invoke_hook.called + assert response is not None + assert len(captured_bodies) > 0 + assert captured_bodies[0]["params"]["arguments"] == modified_args + + +@pytest.mark.asyncio +async def test_getPromptPreFetchResponse_blocked(mock_envoy_modules, mock_manager, prompt_body): + """Plugin blocks the prompt fetch — response is an MCP error.""" + import src.server + + violation = PluginViolation( + reason="Prompt not permitted", + description="This prompt template is restricted", + code="PROMPT_BLOCKED", + ) + mock_manager.invoke_hook.return_value = (_make_result(continue_processing=False, violation=violation), None) + src.server.manager = mock_manager + + captured_bodies = [] + original_dumps = json.dumps + + def spy_dumps(obj, **kwargs): + if isinstance(obj, dict) and "error" in obj: + captured_bodies.append(obj) + return original_dumps(obj, **kwargs) + + json.dumps = spy_dumps + try: + response = await src.server.getPromptPreFetchResponse(prompt_body) + finally: + json.dumps = original_dumps + + assert mock_manager.invoke_hook.called + assert response is not None + assert len(captured_bodies) > 0 + error_body = captured_bodies[0] + assert "error" in error_body + assert "Prompt not permitted" in error_body["error"]["message"] + assert error_body["id"] == "test-456" + assert error_body["jsonrpc"] == "2.0" diff --git a/tests/test_server.py b/tests/test_tool_post_invoke.py similarity index 63% rename from tests/test_server.py rename to tests/test_tool_post_invoke.py index 79b6931..7ab51c2 100644 --- a/tests/test_server.py +++ b/tests/test_tool_post_invoke.py @@ -1,12 +1,13 @@ -"""Unit tests for ext-proc server functions +"""Unit tests for getToolPostInvokeResponse and process_response_body_buffer. These tests use dynamic import and mocking to avoid proto dependencies. +Shared fixtures (mock_envoy_modules, mock_manager, sample_tool_result_body) +come from conftest.py. """ # Standard import json -import sys -from unittest.mock import AsyncMock, MagicMock, Mock +from unittest.mock import MagicMock # Third-Party import pytest @@ -19,63 +20,6 @@ ) -@pytest.fixture -def mock_envoy_modules(): - """Mock envoy protobuf modules to avoid proto dependencies.""" - # Create mock modules - mock_ep = MagicMock() - mock_ep_grpc = MagicMock() - mock_core = MagicMock() - mock_http_status = MagicMock() - - # Add to sys.modules before importing server - sys.modules["envoy"] = MagicMock() - sys.modules["envoy.service"] = MagicMock() - sys.modules["envoy.service.ext_proc"] = MagicMock() - sys.modules["envoy.service.ext_proc.v3"] = MagicMock() - sys.modules["envoy.service.ext_proc.v3.external_processor_pb2"] = mock_ep - sys.modules["envoy.service.ext_proc.v3.external_processor_pb2_grpc"] = mock_ep_grpc - sys.modules["envoy.config"] = MagicMock() - sys.modules["envoy.config.core"] = MagicMock() - sys.modules["envoy.config.core.v3"] = MagicMock() - sys.modules["envoy.config.core.v3.base_pb2"] = mock_core - sys.modules["envoy.type"] = MagicMock() - sys.modules["envoy.type.v3"] = MagicMock() - sys.modules["envoy.type.v3.http_status_pb2"] = mock_http_status - - yield { - "ep": mock_ep, - "ep_grpc": mock_ep_grpc, - "core": mock_core, - "http_status": mock_http_status, - } - - # Cleanup - for key in list(sys.modules.keys()): - if key.startswith("envoy"): - del sys.modules[key] - if "src.server" in sys.modules: - del sys.modules["src.server"] - - -@pytest.fixture -def mock_manager(): - """Create a mock PluginManager.""" - mock = Mock() - mock.invoke_hook = AsyncMock() - return mock - - -@pytest.fixture -def sample_tool_result_body(): - """Create a sample tool result body.""" - return { - "jsonrpc": "2.0", - "id": "test-123", - "result": {"content": [{"type": "text", "text": "Tool execution result"}]}, - } - - def setup_response_mocks(mock_envoy_modules): """Setup common response mocks.""" mock_envoy_modules["ep"].ProcessingResponse.return_value = MagicMock() @@ -105,67 +49,42 @@ def verify_payload_content(payload, expected_result, expected_text): @pytest.mark.asyncio async def test_getToolPostInvokeResponse_continue_processing(mock_envoy_modules, mock_manager, sample_tool_result_body): - """Test getToolPostInvokeResponse when plugin allows processing to continue.""" - # Setup mock response objects + """Plugin allows processing to continue — hook is called with correct payload.""" mock_response = MagicMock() mock_response.HasField.return_value = True mock_response.response_body.response.HasField.return_value = False mock_envoy_modules["ep"].ProcessingResponse.return_value = mock_response - # Import server after mocking import src.server - # Setup mock to return continue_processing=True - mock_result = ToolPostInvokeResult( - continue_processing=True, - modified_payload=None, - ) + mock_result = ToolPostInvokeResult(continue_processing=True, modified_payload=None) mock_manager.invoke_hook.return_value = (mock_result, None) - - # Inject mock manager src.server.manager = mock_manager - # Call the function _ = await src.server.getToolPostInvokeResponse(sample_tool_result_body) - # Verify the hook was called assert mock_manager.invoke_hook.called - call_args = mock_manager.invoke_hook.call_args[0] - payload = call_args[1] + payload = mock_manager.invoke_hook.call_args[0][1] assert isinstance(payload, ToolPostInvokePayload) assert payload.result == sample_tool_result_body["result"] - # assert payload.name == "replaceme" # Replace this after better naming @pytest.mark.asyncio async def test_getToolPostInvokeResponse_blocked(mock_envoy_modules, mock_manager, sample_tool_result_body): - """Test getToolPostInvokeResponse when plugin blocks the response. - - This test verifies that when continue_processing=False, the function - uses immediate_response (not response_body) and includes violation details. - """ - # Setup mocks for immediate_response path + """Plugin blocks the response — immediate_response is used with violation details.""" setup_response_mocks(mock_envoy_modules) - # Import server after mocking import src.server - # Setup mock to return continue_processing=False with violation violation = PluginViolation( reason="Sensitive content detected", description="Tool response contains forbidden content", code="CONTENT_VIOLATION", ) - mock_result = ToolPostInvokeResult( - continue_processing=False, - violation=violation, - ) + mock_result = ToolPostInvokeResult(continue_processing=False, violation=violation) mock_manager.invoke_hook.return_value = (mock_result, None) - - # Inject mock manager src.server.manager = mock_manager - # Capture json.dumps calls to verify error body content original_dumps = json.dumps captured_bodies = [] @@ -176,90 +95,62 @@ def spy_dumps(obj, **kwargs): json.dumps = spy_dumps try: - # Call the function response = await src.server.getToolPostInvokeResponse(sample_tool_result_body) finally: json.dumps = original_dumps - # Verify the hook was called with correct payload assert mock_manager.invoke_hook.called - call_args = mock_manager.invoke_hook.call_args[0] - payload = call_args[1] + payload = mock_manager.invoke_hook.call_args[0][1] assert isinstance(payload, ToolPostInvokePayload) assert payload.result == sample_tool_result_body["result"] - - # Verify response was created (error path taken) assert response is not None - - # Verify error body was created with violation details assert len(captured_bodies) > 0 error_body = captured_bodies[0] - assert "error" in error_body assert error_body["error"]["code"] == -32000 - # Verify violation message is included assert "Sensitive content detected" in error_body["error"]["message"] assert "Tool response contains forbidden content" in error_body["error"]["message"] @pytest.mark.asyncio async def test_getToolPostInvokeResponse_modified_payload(mock_envoy_modules, mock_manager, sample_tool_result_body): - """Test getToolPostInvokeResponse when plugin modifies the payload.""" - # Import server after mocking + """Plugin modifies the payload — modified result is serialised into the response.""" import src.server - # Setup mock to return modified payload modified_result = {"content": [{"type": "text", "text": "Modified tool result"}]} modified_payload = ToolPostInvokePayload(name="test_tool", result=modified_result) - mock_result = ToolPostInvokeResult( - continue_processing=True, - modified_payload=modified_payload, - ) + mock_result = ToolPostInvokeResult(continue_processing=True, modified_payload=modified_payload) mock_manager.invoke_hook.return_value = (mock_result, None) - - # Inject mock manager src.server.manager = mock_manager - # Spy on json.dumps to capture what body is being serialized original_dumps = json.dumps captured_body = None def spy_dumps(obj, **kwargs): nonlocal captured_body - # Capture the body dict that's being serialized if isinstance(obj, dict) and "result" in obj and "jsonrpc" in obj: captured_body = obj return original_dumps(obj, **kwargs) json.dumps = spy_dumps try: - # Call the function response = await src.server.getToolPostInvokeResponse(sample_tool_result_body) finally: json.dumps = original_dumps - # Verify the hook was called assert mock_manager.invoke_hook.called - - # Verify response was created assert response is not None - - # Verify the body was modified with the new result assert captured_body is not None, "json.dumps should have been called with the modified body" assert captured_body["result"] == modified_result assert captured_body["result"]["content"][0]["text"] == "Modified tool result" - # Verify original metadata (jsonrpc, id) is preserved assert captured_body["jsonrpc"] == sample_tool_result_body["jsonrpc"] assert captured_body["id"] == sample_tool_result_body["id"] @pytest.mark.asyncio async def test_getToolPostInvokeResponse_multiple_content_items(mock_envoy_modules, mock_manager): - """Test getToolPostInvokeResponse with multiple content items.""" - # Setup mock response - mock_response = MagicMock() - mock_envoy_modules["ep"].ProcessingResponse.return_value = mock_response + """Payload passed to hook carries all content items intact.""" + mock_envoy_modules["ep"].ProcessingResponse.return_value = MagicMock() - # Import server after mocking import src.server body = { @@ -273,19 +164,13 @@ async def test_getToolPostInvokeResponse_multiple_content_items(mock_envoy_modul ] }, } - mock_result = ToolPostInvokeResult(continue_processing=True) mock_manager.invoke_hook.return_value = (mock_result, None) - - # Inject mock manager src.server.manager = mock_manager - # Call the function _ = await src.server.getToolPostInvokeResponse(body) - # Verify the payload passed to the hook contains all content - call_args = mock_manager.invoke_hook.call_args[0] - payload = call_args[1] + payload = mock_manager.invoke_hook.call_args[0][1] assert len(payload.result["content"]) == 3 assert payload.result["content"][0]["text"] == "First item" assert payload.result["content"][1]["text"] == "Second item" @@ -293,13 +178,13 @@ async def test_getToolPostInvokeResponse_multiple_content_items(mock_envoy_modul # ============================================================================ -# Response Body Processing Tests +# Response Body Buffer Processing Tests # ============================================================================ @pytest.mark.asyncio async def test_process_response_body_buffer_with_tool_result(mock_envoy_modules, mock_manager): - """Test process_response_body_buffer with a tool result.""" + """Plain JSON-RPC tool result triggers the post-invoke hook.""" setup_response_mocks(mock_envoy_modules) import src.server @@ -317,13 +202,12 @@ async def test_process_response_body_buffer_with_tool_result(mock_envoy_modules, assert mock_manager.invoke_hook.called payload = mock_manager.invoke_hook.call_args[0][1] verify_payload_content(payload, tool_result["result"], "Result") - # Verify ProcessingResponse was returned assert response is not None @pytest.mark.asyncio async def test_process_response_body_buffer_with_sse_format(mock_envoy_modules, mock_manager): - """Test process_response_body_buffer with SSE formatted content.""" + """SSE-wrapped tool result is parsed and triggers the post-invoke hook.""" setup_response_mocks(mock_envoy_modules) import src.server @@ -342,16 +226,12 @@ async def test_process_response_body_buffer_with_sse_format(mock_envoy_modules, assert mock_manager.invoke_hook.called payload = mock_manager.invoke_hook.call_args[0][1] verify_payload_content(payload, tool_result["result"], "SSE data") - # Verify ProcessingResponse was returned assert response is not None @pytest.mark.asyncio async def test_process_response_body_buffer_multiple_chunks_scenario(mock_envoy_modules, mock_manager): - """Test buffering: content in chunks, then empty end_of_stream chunk. - - Simulates: chunk1 (content) + chunk2 (content) + chunk3 (empty, end_of_stream). - """ + """Multi-chunk buffer is assembled and processed as one unit.""" setup_response_mocks(mock_envoy_modules) import src.server @@ -365,39 +245,35 @@ async def test_process_response_body_buffer_multiple_chunks_scenario(mock_envoy_ } body_bytes = json.dumps(tool_result).encode("utf-8") - # Simulate buffering: chunk1 + chunk2 + empty chunk buffer = bytearray() - buffer.extend(body_bytes[:25]) # Chunk 1 - buffer.extend(body_bytes[25:]) # Chunk 2 - buffer.extend(b"") # Chunk 3 (empty, triggers processing) + buffer.extend(body_bytes[:25]) + buffer.extend(body_bytes[25:]) + buffer.extend(b"") response = await src.server.process_response_body_buffer(buffer) assert mock_manager.invoke_hook.called payload = mock_manager.invoke_hook.call_args[0][1] verify_payload_content(payload, tool_result["result"], "Multi chunk data") - # Verify ProcessingResponse was returned assert response is not None @pytest.mark.asyncio async def test_process_response_body_buffer_empty(mock_envoy_modules, mock_manager): - """Test process_response_body_buffer with empty buffer.""" + """Empty buffer returns a response without invoking the hook.""" setup_response_mocks(mock_envoy_modules) import src.server src.server.manager = mock_manager response = await src.server.process_response_body_buffer(bytearray()) - # Verify hook is NOT called for empty buffer assert not mock_manager.invoke_hook.called, "Tool post-invoke hook should not be called for empty buffer" - # Verify response is returned (function doesn't crash on empty buffer) assert response is not None @pytest.mark.asyncio async def test_process_response_body_buffer_non_tool_result(mock_envoy_modules, mock_manager): - """Test process_response_body_buffer with non-tool result.""" + """Error responses pass through without invoking the hook.""" setup_response_mocks(mock_envoy_modules) import src.server @@ -411,7 +287,5 @@ async def test_process_response_body_buffer_non_tool_result(mock_envoy_modules, buffer = bytearray(json.dumps(error_response).encode("utf-8")) response = await src.server.process_response_body_buffer(buffer) - # Verify hook is NOT called for error responses assert not mock_manager.invoke_hook.called, "Tool post-invoke hook should not be called for error responses" - # Verify response is returned (function handles error responses gracefully) assert response is not None diff --git a/tests/test_tool_pre_invoke.py b/tests/test_tool_pre_invoke.py new file mode 100644 index 0000000..4f6d309 --- /dev/null +++ b/tests/test_tool_pre_invoke.py @@ -0,0 +1,143 @@ +"""Unit tests for getToolPreInvokeResponse. + +Tests the tool pre-invoke path: argument validation, modification, and blocking. +""" + +# Standard +import json +from unittest.mock import Mock + +# Third-Party +import pytest + +# First-Party +from cpex.framework import PluginViolation, ToolPreInvokePayload + + +@pytest.fixture +def tool_call_body(): + """Sample MCP tools/call request body.""" + return { + "jsonrpc": "2.0", + "id": "test-123", + "method": "tools/call", + "params": { + "name": "test_tool", + "arguments": {"param": "value"}, + }, + } + + +def _make_result(continue_processing=True, modified_payload=None, violation=None): + """Build a mock hook result.""" + result = Mock() + result.continue_processing = continue_processing + result.modified_payload = modified_payload + result.violation = violation + return result + + +@pytest.mark.asyncio +async def test_getToolPreInvokeResponse_continue_no_modification(mock_envoy_modules, mock_manager, tool_call_body): + """Plugin allows the tool call through with no argument changes.""" + import src.server + + mock_manager.invoke_hook.return_value = (_make_result(), None) + src.server.manager = mock_manager + + response = await src.server.getToolPreInvokeResponse(tool_call_body) + + assert mock_manager.invoke_hook.called + call_args = mock_manager.invoke_hook.call_args[0] + payload = call_args[1] + assert isinstance(payload, ToolPreInvokePayload) + assert payload.name == "test_tool" + assert response is not None + + +@pytest.mark.asyncio +async def test_getToolPreInvokeResponse_continue_with_modified_args(mock_envoy_modules, mock_manager, tool_call_body): + """Plugin modifies tool arguments — modified args are forwarded in the response.""" + import src.server + + modified_args = {"param": "sanitized_value", "injected": False} + modified_payload = Mock() + modified_payload.args = {"tool_args": modified_args} + + mock_manager.invoke_hook.return_value = (_make_result(modified_payload=modified_payload), None) + src.server.manager = mock_manager + + captured_bodies = [] + original_dumps = json.dumps + + def spy_dumps(obj, **kwargs): + if isinstance(obj, dict) and "params" in obj: + captured_bodies.append(obj) + return original_dumps(obj, **kwargs) + + json.dumps = spy_dumps + try: + response = await src.server.getToolPreInvokeResponse(tool_call_body) + finally: + json.dumps = original_dumps + + assert mock_manager.invoke_hook.called + assert response is not None + assert len(captured_bodies) > 0 + assert captured_bodies[0]["params"]["arguments"] == modified_args + + +@pytest.mark.asyncio +async def test_getToolPreInvokeResponse_blocked_with_violation(mock_envoy_modules, mock_manager, tool_call_body): + """Plugin blocks the tool call — response is an MCP error with violation details.""" + import src.server + + violation = PluginViolation( + reason="Forbidden argument detected", + description="The tool arguments contain disallowed content", + code="ARGS_VIOLATION", + ) + mock_manager.invoke_hook.return_value = (_make_result(continue_processing=False, violation=violation), None) + src.server.manager = mock_manager + + captured_bodies = [] + original_dumps = json.dumps + + def spy_dumps(obj, **kwargs): + if isinstance(obj, dict) and "error" in obj: + captured_bodies.append(obj) + return original_dumps(obj, **kwargs) + + json.dumps = spy_dumps + try: + response = await src.server.getToolPreInvokeResponse(tool_call_body) + finally: + json.dumps = original_dumps + + assert mock_manager.invoke_hook.called + assert response is not None + assert len(captured_bodies) > 0 + error_body = captured_bodies[0] + assert "error" in error_body + assert "Forbidden argument detected" in error_body["error"]["message"] + assert "disallowed content" in error_body["error"]["message"] + + +@pytest.mark.asyncio +async def test_getToolPreInvokeResponse_payload_carries_tool_name(mock_envoy_modules, mock_manager): + """The ToolPreInvokePayload passed to the hook reflects the tool name from the request.""" + import src.server + + body = { + "jsonrpc": "2.0", + "id": "x", + "method": "tools/call", + "params": {"name": "my_special_tool", "arguments": {}}, + } + mock_manager.invoke_hook.return_value = (_make_result(), None) + src.server.manager = mock_manager + + await src.server.getToolPreInvokeResponse(body) + + payload = mock_manager.invoke_hook.call_args[0][1] + assert payload.name == "my_special_tool"