Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Shared pytest fixtures for plugins-adapter unit tests."""

# Standard
import sys
from unittest.mock import AsyncMock, MagicMock, Mock

# Third-Party
import pytest


@pytest.fixture
def mock_envoy_modules():
"""Mock envoy protobuf modules to avoid proto build dependencies."""
mock_ep = MagicMock()
mock_ep_grpc = MagicMock()
mock_core = MagicMock()
mock_http_status = MagicMock()

sys.modules["envoy"] = MagicMock()
sys.modules["envoy.service"] = MagicMock()
sys.modules["envoy.service.ext_proc"] = MagicMock()
sys.modules["envoy.service.ext_proc.v3"] = MagicMock()
sys.modules["envoy.service.ext_proc.v3.external_processor_pb2"] = mock_ep
sys.modules["envoy.service.ext_proc.v3.external_processor_pb2_grpc"] = mock_ep_grpc
sys.modules["envoy.config"] = MagicMock()
sys.modules["envoy.config.core"] = MagicMock()
sys.modules["envoy.config.core.v3"] = MagicMock()
sys.modules["envoy.config.core.v3.base_pb2"] = mock_core
sys.modules["envoy.type"] = MagicMock()
sys.modules["envoy.type.v3"] = MagicMock()
sys.modules["envoy.type.v3.http_status_pb2"] = mock_http_status

yield {
"ep": mock_ep,
"ep_grpc": mock_ep_grpc,
"core": mock_core,
"http_status": mock_http_status,
}

for key in list(sys.modules.keys()):
if key.startswith("envoy"):
del sys.modules[key]
if "src.server" in sys.modules:
del sys.modules["src.server"]


@pytest.fixture
def mock_manager():
"""Create a mock PluginManager with async invoke_hook."""
mock = Mock()
mock.invoke_hook = AsyncMock()
return mock


@pytest.fixture
def sample_tool_result_body():
"""Sample MCP tool result response body."""
return {
"jsonrpc": "2.0",
"id": "test-123",
"result": {"content": [{"type": "text", "text": "Tool execution result"}]},
}
142 changes: 142 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Unit tests for server helper functions.

Covers: set_result_in_body, get_modified_response,
create_mcp_immediate_error_response
"""

# Standard
import json

# First-Party
from cpex.framework import PluginViolation


def test_set_result_in_body(mock_envoy_modules):
"""set_result_in_body mutates body['params']['arguments'] in place."""
import src.server

body = {"params": {"arguments": {"old_key": "old_value"}}}
new_args = {"new_key": "new_value", "count": 42}

src.server.set_result_in_body(body, new_args)

assert body["params"]["arguments"] == new_args


def test_set_result_in_body_overwrites_existing(mock_envoy_modules):
"""set_result_in_body replaces all previous arguments."""
import src.server

body = {"params": {"arguments": {"a": 1, "b": 2, "c": 3}}}
src.server.set_result_in_body(body, {"x": 99})

assert body["params"]["arguments"] == {"x": 99}
assert "a" not in body["params"]["arguments"]


def test_get_modified_response_returns_body_response(mock_envoy_modules):
"""get_modified_response encodes the body dict as JSON in a BodyResponse."""
import src.server

body = {
"jsonrpc": "2.0",
"id": "1",
"result": {"content": [{"type": "text", "text": "hello"}]},
}
response = src.server.get_modified_response(body)

assert response is not None


def test_create_mcp_immediate_error_response_default_code(mock_envoy_modules):
"""No violation → error code defaults to -32000 (generic server error)."""
import src.server

body = {"jsonrpc": "2.0", "id": "test-001"}

captured = []
original_dumps = json.dumps

def spy(obj, **kwargs):
if isinstance(obj, dict) and "error" in obj:
captured.append(obj)
return original_dumps(obj, **kwargs)

json.dumps = spy
try:
response = src.server.create_mcp_immediate_error_response(body, "Something went wrong")
finally:
json.dumps = original_dumps

assert response is not None
assert len(captured) == 1
err = captured[0]
assert err["error"]["code"] == -32000
assert err["error"]["message"] == "Something went wrong"
assert err["jsonrpc"] == "2.0"
assert err["id"] == "test-001"


def test_create_mcp_immediate_error_response_with_violation_reason(mock_envoy_modules):
"""Violation reason/description override the fallback message."""
import src.server

body = {"jsonrpc": "2.0", "id": "test-002"}
violation = PluginViolation(
reason="Content policy violated",
description="Detected restricted content in response",
code="POLICY_VIOLATION",
)

captured = []
original_dumps = json.dumps

def spy(obj, **kwargs):
if isinstance(obj, dict) and "error" in obj:
captured.append(obj)
return original_dumps(obj, **kwargs)

json.dumps = spy
try:
response = src.server.create_mcp_immediate_error_response(body, "fallback msg", violation=violation)
finally:
json.dumps = original_dumps

assert response is not None
assert len(captured) == 1
err = captured[0]
assert "Content policy violated" in err["error"]["message"]
assert "Detected restricted content" in err["error"]["message"]
# mcp_error_code not set → still uses default -32000
assert err["error"]["code"] == -32000


def test_create_mcp_immediate_error_response_with_mcp_error_code(mock_envoy_modules):
"""Violation mcp_error_code overrides the default -32000 code."""
import src.server

body = {"jsonrpc": "2.0", "id": "test-003"}
violation = PluginViolation(
reason="Invalid params",
description="Tool args failed validation",
code="INVALID_ARGS",
mcp_error_code=-32602,
)

captured = []
original_dumps = json.dumps

def spy(obj, **kwargs):
if isinstance(obj, dict) and "error" in obj:
captured.append(obj)
return original_dumps(obj, **kwargs)

json.dumps = spy
try:
response = src.server.create_mcp_immediate_error_response(body, "fallback", violation=violation)
finally:
json.dumps = original_dumps

assert response is not None
err = captured[0]
assert err["error"]["code"] == -32602
123 changes: 123 additions & 0 deletions tests/test_prompt_pre_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Unit tests for getPromptPreFetchResponse.

Tests the prompt pre-fetch path: validation, modification, and blocking.
"""

# Standard
import json
from unittest.mock import Mock

# Third-Party
import pytest

# First-Party
from cpex.framework import PluginViolation, PromptPrehookPayload


@pytest.fixture
def prompt_body():
"""Sample MCP prompts/get request body."""
return {
"jsonrpc": "2.0",
"id": "test-456",
"method": "prompts/get",
"params": {
"name": "test_prompt",
"arguments": {"arg0": "some value"},
},
}


def _make_result(continue_processing=True, modified_payload=None, violation=None):
result = Mock()
result.continue_processing = continue_processing
result.modified_payload = modified_payload
result.violation = violation
return result


@pytest.mark.asyncio
async def test_getPromptPreFetchResponse_continue_no_modification(mock_envoy_modules, mock_manager, prompt_body):
"""Plugin allows the prompt fetch with no changes."""
import src.server

mock_manager.invoke_hook.return_value = (_make_result(), None)
src.server.manager = mock_manager

response = await src.server.getPromptPreFetchResponse(prompt_body)

assert mock_manager.invoke_hook.called
call_args = mock_manager.invoke_hook.call_args[0]
payload = call_args[1]
assert isinstance(payload, PromptPrehookPayload)
assert payload.prompt_id == "test_prompt"
assert response is not None


@pytest.mark.asyncio
async def test_getPromptPreFetchResponse_continue_with_modified_args(mock_envoy_modules, mock_manager, prompt_body):
"""Plugin modifies prompt arguments — modified args are forwarded."""
import src.server

modified_args = {"arg0": "rewritten value"}
modified_payload = Mock()
modified_payload.args = {"tool_args": modified_args}

mock_manager.invoke_hook.return_value = (_make_result(modified_payload=modified_payload), None)
src.server.manager = mock_manager

captured_bodies = []
original_dumps = json.dumps

def spy_dumps(obj, **kwargs):
if isinstance(obj, dict) and "params" in obj:
captured_bodies.append(obj)
return original_dumps(obj, **kwargs)

json.dumps = spy_dumps
try:
response = await src.server.getPromptPreFetchResponse(prompt_body)
finally:
json.dumps = original_dumps

assert mock_manager.invoke_hook.called
assert response is not None
assert len(captured_bodies) > 0
assert captured_bodies[0]["params"]["arguments"] == modified_args


@pytest.mark.asyncio
async def test_getPromptPreFetchResponse_blocked(mock_envoy_modules, mock_manager, prompt_body):
"""Plugin blocks the prompt fetch — response is an MCP error."""
import src.server

violation = PluginViolation(
reason="Prompt not permitted",
description="This prompt template is restricted",
code="PROMPT_BLOCKED",
)
mock_manager.invoke_hook.return_value = (_make_result(continue_processing=False, violation=violation), None)
src.server.manager = mock_manager

captured_bodies = []
original_dumps = json.dumps

def spy_dumps(obj, **kwargs):
if isinstance(obj, dict) and "error" in obj:
captured_bodies.append(obj)
return original_dumps(obj, **kwargs)

json.dumps = spy_dumps
try:
response = await src.server.getPromptPreFetchResponse(prompt_body)
finally:
json.dumps = original_dumps

assert mock_manager.invoke_hook.called
assert response is not None
assert len(captured_bodies) > 0
error_body = captured_bodies[0]
assert "error" in error_body
assert "Prompt not permitted" in error_body["error"]["message"]
assert error_body["id"] == "test-456"
assert error_body["jsonrpc"] == "2.0"
Loading
Loading