diff --git a/mcp-servers/litert-mcp/README.md b/mcp-servers/litert-mcp/README.md index 96987a527..bd256c3b4 100644 --- a/mcp-servers/litert-mcp/README.md +++ b/mcp-servers/litert-mcp/README.md @@ -50,6 +50,8 @@ Runs inference using the configured LiteRT-LM model. } ``` +**Note**: The current CLI wrapper only supports text-only inference. For multimodal capabilities (image/audio), use the LiteRT-LM C++ or Python API directly. + ## Setup for Development This server uses a manual JSON-RPC implementation to avoid external dependencies in the base environment. Just run: diff --git a/mcp-servers/litert-mcp/server.py b/mcp-servers/litert-mcp/server.py index 0e0b2bbb8..9561d0e9d 100644 --- a/mcp-servers/litert-mcp/server.py +++ b/mcp-servers/litert-mcp/server.py @@ -147,10 +147,26 @@ async def _handle_tools_call(self, request_id, params): async def _run_inference(self, args: Dict[str, Any]) -> Dict[str, Any]: prompt = args.get("prompt") + + # Validate prompt parameter + if not prompt or not isinstance(prompt, str) or not prompt.strip(): + return { + "status": "error", + "message": "Invalid or empty prompt. Please provide a non-empty text prompt." + } + model_path = args.get("model_path") or self.default_model_path image_path = args.get("image_path") audio_path = args.get("audio_path") backend = args.get("backend", "cpu") + + # Validate backend parameter + valid_backends = ["cpu", "gpu", "npu"] + if backend not in valid_backends: + return { + "status": "error", + "message": f"Invalid backend '{backend}'. Must be one of {valid_backends}." + } # Validate Prompt if not prompt: @@ -251,7 +267,7 @@ async def main(): ) writer = asyncio.StreamWriter(w_transport, w_protocol, None, asyncio.get_event_loop()) except Exception as e: - LOGGER.warning(f"Could not connect write pipe to stdout: {e}. Falling back to print.") + LOGGER.warning(f"Could not connect write pipe to stdout: {e}. Falling back to sys.stdout.write().") writer = None else: # Windows fallback: @@ -281,7 +297,8 @@ async def main(): writer = None print(response_str, flush=True) else: - print(response_str, flush=True) + sys.stdout.write(response_str) + sys.stdout.flush() except json.JSONDecodeError: LOGGER.error(f"Invalid JSON received: {line}")