From 3b329f8902a320f7545b19937543c621472526fb Mon Sep 17 00:00:00 2001 From: Alessandro Costantino Baltag Date: Wed, 13 May 2026 20:44:57 +0200 Subject: [PATCH] Add Ollama provider support --- .env.example | 7 + README.md | 5 +- finbot/agents/chat.py | 104 ++++++--- finbot/config.py | 1 + finbot/core/llm/client.py | 13 +- finbot/core/llm/ollama_client.py | 205 ++++++++++++++++-- pyproject.toml | 3 +- .../unit/agents/test_chat_provider_routing.py | 75 +++++++ tests/unit/llm/test_llm_client.py | 67 +++--- tests/unit/llm/test_ollama_client.py | 133 +++++++++++- uv.lock | 15 ++ 11 files changed, 529 insertions(+), 99 deletions(-) create mode 100644 tests/unit/agents/test_chat_provider_routing.py diff --git a/.env.example b/.env.example index 31526784..1cee7cc9 100644 --- a/.env.example +++ b/.env.example @@ -16,7 +16,14 @@ PORT=8000 SECRET_KEY=super_long_default_key_change_this_in_production # ── LLM ────────────────────────────────────────────────────────────── +LLM_PROVIDER=openai +LLM_DEFAULT_MODEL=gpt-5-nano OPENAI_API_KEY=your_openai_api_key_here +# For Ollama, set LLM_PROVIDER=ollama. FinBot uses OLLAMA_MODEL only for Ollama. +# OLLAMA_MODEL=gemma4:e2b +# OLLAMA_BASE_URL=http://localhost:11434 +# If the app runs in Docker and Ollama runs on your host, use: +# OLLAMA_BASE_URL=http://host.docker.internal:11434 # ── Email ──────────────────────────────────────────────────────────── # Use "console" for dev (prints to stdout), "resend" for production diff --git a/README.md b/README.md index 9ddd5c6b..b16be369 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ uv run python run.py Platform runs at [http://localhost:8000](http://localhost:8000) -> An LLM API key (OpenAI or Ollama) is needed for AI agent challenges. +> An LLM backend is needed for AI agent challenges: an OpenAI API key or a reachable Ollama server. > Redis is needed for event-driven challenge detection. > Without them, you can still explore the UI and codebase. @@ -143,6 +143,7 @@ Key environment variables (see `[.env.example](.env.example)` for the full templ | `DATABASE_TYPE` | `sqlite` | `sqlite` or `postgresql` | | `OPENAI_API_KEY` | - | Required for AI agent challenges | | `LLM_PROVIDER` | `openai` | `openai` or `ollama` | +| `OLLAMA_MODEL` | `gemma4:e2b` | What model should FinBot use | | `REDIS_URL` | `redis://localhost:6379` | Event bus for CTF processing | | `SECRET_KEY` | dev default | **Change in production** | | `EMAIL_PROVIDER` | `console` | `console` (dev) or `resend` (prod) | @@ -211,4 +212,4 @@ OWASP FinBot CTF is part of the [OWASP GenAI Security Project](https://genai.owa - **[Abigail Dede Okley](https://www.linkedin.com/in/abigailokley)** -- Chief Cat Herder (project manager, keeping all the cats aligned and on track) - **[Carolina Steadham](https://www.linkedin.com/in/carolinacsteadham)** -- Guardian of Quality Realms (ensuring every feature meets its highest destiny, safeguarding workstream integrity) -And all the amazing [contributors](https://github.com/GenAI-Security-Project/finbot-ctf/graphs/contributors) who make this project possible. \ No newline at end of file +And all the amazing [contributors](https://github.com/GenAI-Security-Project/finbot-ctf/graphs/contributors) who make this project possible. diff --git a/finbot/agents/chat.py b/finbot/agents/chat.py index 424a64e1..718fe1fe 100644 --- a/finbot/agents/chat.py +++ b/finbot/agents/chat.py @@ -16,13 +16,12 @@ from datetime import UTC, datetime from typing import Any -from openai import AsyncOpenAI - from finbot.config import settings from finbot.core.auth.session import SessionContext from finbot.core.data.database import db_session -from finbot.core.data.models import CTFEvent +from finbot.core.data.models import CTFEvent, LLMRequest from finbot.core.data.repositories import ChatMessageRepository, VendorRepository +from finbot.core.llm.client import get_llm_client from finbot.core.messaging import event_bus from finbot.guardrails.schemas import HookKind from finbot.guardrails.service import GuardrailHookService @@ -66,11 +65,17 @@ def __init__( self.max_history = max_history self.agent_name = agent_name self._workflow_id = self._resolve_workflow_id() - self._client = AsyncOpenAI( - api_key=settings.OPENAI_API_KEY, - timeout=settings.CHAT_STREAM_TIMEOUT, - ) - self._model = settings.LLM_DEFAULT_MODEL + self._provider = settings.LLM_PROVIDER.strip().lower() + self._llm_client = get_llm_client() + self._client = None + if self._provider == "openai": + from openai import AsyncOpenAI # pylint: disable=import-outside-toplevel + + self._client = AsyncOpenAI( + api_key=settings.OPENAI_API_KEY, + timeout=settings.CHAT_STREAM_TIMEOUT, + ) + self._model = self._llm_client.default_model self._mcp_provider: MCPToolProvider | None = None self._mcp_connected = False self._tool_callables = self._build_native_callables() @@ -367,24 +372,60 @@ async def stream_response( user_message=user_message, ) - stream = await self._client.responses.create(**stream_params) - pending_tool_calls: list[dict] = [] + append_tool_call_items = True - async for event in stream: - if event.type == "response.output_text.delta": - full_response += event.delta - yield f"data: {json.dumps({'type': 'token', 'content': event.delta})}\n\n" - - elif event.type == "response.output_item.done": - if event.item.type == "function_call": - pending_tool_calls.append( - { - "name": event.item.name, - "call_id": event.item.call_id, - "arguments": json.loads(event.item.arguments), - } + try: + if self._provider == "openai": + if self._client is None: + raise RuntimeError("OpenAI chat client is not initialized") + + stream = await self._client.responses.create(**stream_params) + + async for event in stream: + if event.type == "response.output_text.delta": + full_response += event.delta + yield f"data: {json.dumps({'type': 'token', 'content': event.delta})}\n\n" + + elif event.type == "response.output_item.done": + if event.item.type == "function_call": + pending_tool_calls.append( + { + "name": event.item.name, + "call_id": event.item.call_id, + "arguments": json.loads(event.item.arguments), + } + ) + else: + append_tool_call_items = False + response = await self._llm_client.chat( + request=LLMRequest( + messages=input_messages, + model=self._model, + temperature=settings.LLM_DEFAULT_TEMPERATURE, + tools=tools, ) + ) + if response.messages: + input_messages = response.messages + if not response.success: + raise RuntimeError(response.content or "LLM provider unavailable") + + content = response.content or "" + if content: + full_response += content + yield f"data: {json.dumps({'type': 'token', 'content': content})}\n\n" + pending_tool_calls = response.tool_calls or [] + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Chat model call failed: %s", e) + error_msg = ( + "The configured AI provider is unavailable. " + "Check LLM_PROVIDER, OLLAMA_MODEL or LLM_DEFAULT_MODEL, " + "and provider credentials or URL." + ) + yield f"data: {json.dumps({'type': 'error', 'content': error_msg})}\n\n" + yield f"data: {json.dumps({'type': 'done'})}\n\n" + return await self._guardrail_service.invoke( HookKind.after_model, @@ -427,14 +468,15 @@ async def _keepalive_emitter() -> None: summary=f"Chat tool call: {tc['name']}", ) - input_messages.append( - { - "type": "function_call", - "name": tc["name"], - "call_id": tc["call_id"], - "arguments": json.dumps(tc["arguments"]), - } - ) + if append_tool_call_items: + input_messages.append( + { + "type": "function_call", + "name": tc["name"], + "call_id": tc["call_id"], + "arguments": json.dumps(tc["arguments"]), + } + ) tool_start = datetime.now(UTC) result = await self._execute_tool(tc["name"], tc["arguments"]) tool_duration_ms = int( diff --git a/finbot/config.py b/finbot/config.py index df362f5c..854b461f 100644 --- a/finbot/config.py +++ b/finbot/config.py @@ -111,6 +111,7 @@ class Settings(BaseSettings): # Ollama OLLAMA_BASE_URL: str = "http://localhost:11434" + OLLAMA_MODEL: str = "gemma4:e2b" # Development Config RELOAD: bool = True diff --git a/finbot/core/llm/client.py b/finbot/core/llm/client.py index 461eaa80..b43c775e 100644 --- a/finbot/core/llm/client.py +++ b/finbot/core/llm/client.py @@ -14,8 +14,12 @@ class LLMClient: """LLM Client with configurable provider and model""" def __init__(self): - self.provider = settings.LLM_PROVIDER - self.default_model = settings.LLM_DEFAULT_MODEL + self.provider = settings.LLM_PROVIDER.strip().lower() + self.default_model = ( + settings.OLLAMA_MODEL + if self.provider == "ollama" + else settings.LLM_DEFAULT_MODEL + ) self.default_temperature = settings.LLM_DEFAULT_TEMPERATURE self.client = self._get_client() @@ -26,6 +30,11 @@ def _get_client(self): from finbot.core.llm.openai_client import OpenAIClient return OpenAIClient() + elif self.provider == "ollama": + # pylint: disable=import-outside-toplevel + from finbot.core.llm.ollama_client import OllamaClient + + return OllamaClient() elif self.provider == "mock": # pylint: disable=import-outside-toplevel from finbot.core.llm.mock_client import MockLLMClient diff --git a/finbot/core/llm/ollama_client.py b/finbot/core/llm/ollama_client.py index f4d453a1..9c36e5dc 100644 --- a/finbot/core/llm/ollama_client.py +++ b/finbot/core/llm/ollama_client.py @@ -1,5 +1,7 @@ """Ollama Client with configurable model""" +import asyncio +import json import logging from typing import Any @@ -15,16 +17,187 @@ class OllamaClient: """Ollama Client with configurable model""" def __init__(self): - self.default_model = settings.LLM_DEFAULT_MODEL + raw_provider = getattr(settings, "LLM_PROVIDER", "openai") + self.provider = ( + raw_provider.strip().lower() if isinstance(raw_provider, str) else "openai" + ) + self.default_model = ( + settings.OLLAMA_MODEL if self.provider == "ollama" else settings.LLM_DEFAULT_MODEL + ) self.default_temperature = settings.LLM_DEFAULT_TEMPERATURE self.host = getattr(settings, "OLLAMA_BASE_URL", "http://localhost:11434") + self._default_model_checked = False + self._default_model_lock = asyncio.Lock() - self._client = AsyncClient( host=self.host, timeout=settings.LLM_TIMEOUT, ) + @staticmethod + def _extract_model_name(model: Any) -> str | None: + if isinstance(model, dict): + return model.get("model") or model.get("name") + return getattr(model, "model", None) or getattr(model, "name", None) + + @classmethod + def _model_is_available( + cls, requested_model: str, available_models: list[Any] + ) -> bool: + requested = requested_model.strip() + if not requested: + return False + + requested_with_latest = f"{requested}:latest" if ":" not in requested else requested + for model in available_models: + model_name = cls._extract_model_name(model) + if model_name in {requested, requested_with_latest}: + return True + return False + + async def _ensure_default_model_available(self) -> None: + if self.provider != "ollama" or self._default_model_checked: + return + + async with self._default_model_lock: + if self._default_model_checked: + return + + models_response = await self._client.list() + available_models = list(getattr(models_response, "models", []) or []) + if not self._model_is_available(self.default_model, available_models): + logger.info( + "Ollama model %s is not installed; pulling it now", + self.default_model, + ) + await self._client.pull(self.default_model) + + self._default_model_checked = True + + @staticmethod + def _coerce_tool_arguments(arguments: Any) -> dict[str, Any]: + """Convert provider-neutral tool arguments into Ollama's mapping shape.""" + if isinstance(arguments, dict): + return arguments + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + return parsed if isinstance(parsed, dict) else {} + except json.JSONDecodeError: + return {} + return {} + + @classmethod + def _normalize_messages(cls, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert internal/OpenAI-style tool history into Ollama chat messages.""" + normalized: list[dict[str, Any]] = [] + tool_call_names: dict[str, str] = {} + + for message in messages: + msg = dict(message) + msg_type = msg.get("type") + + if msg_type == "function_call": + name = msg.get("name") + call_id = msg.get("call_id") + if call_id and name: + tool_call_names[call_id] = name + if name: + normalized.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": name, + "arguments": cls._coerce_tool_arguments( + msg.get("arguments") + ), + } + } + ], + } + ) + continue + + if msg_type == "function_call_output": + call_id = msg.get("call_id") + tool_name = tool_call_names.get(call_id, call_id or "tool") + normalized.append( + { + "role": "tool", + "content": str(msg.get("output") or ""), + "tool_name": tool_name, + } + ) + continue + + tool_calls = msg.get("tool_calls") + if isinstance(tool_calls, list): + ollama_tool_calls: list[dict[str, Any]] = [] + for tool_call in tool_calls: + function = tool_call.get("function", {}) + name = tool_call.get("name") or function.get("name") + call_id = tool_call.get("call_id") + if call_id and name: + tool_call_names[call_id] = name + if name: + ollama_tool_calls.append( + { + "function": { + "name": name, + "arguments": cls._coerce_tool_arguments( + tool_call.get( + "arguments", function.get("arguments") + ) + ), + } + } + ) + + normalized_message = { + "role": msg.get("role", "assistant"), + "content": str(msg.get("content") or ""), + } + if ollama_tool_calls: + normalized_message["tool_calls"] = ollama_tool_calls + normalized.append(normalized_message) + continue + + normalized.append( + { + key: value + for key, value in msg.items() + if key in {"role", "content", "images", "tool_name"} + } + ) + + return normalized + + @staticmethod + def _normalize_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + """Convert OpenAI Responses-style tool definitions into Ollama's format.""" + if not tools: + return None + + normalized: list[dict[str, Any]] = [] + for tool in tools: + if tool.get("type") == "function" and "function" not in tool: + normalized.append( + { + "type": "function", + "function": { + "name": tool.get("name"), + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + }, + } + ) + else: + normalized.append(tool) + return normalized + @retry(max_retries=3, backoff_seconds=0.5) async def chat( self, @@ -35,11 +208,15 @@ async def chat( """ try: model = request.model or self.default_model - temperature = self.default_temperature if request.temperature is None else request.temperature - + temperature = ( + self.default_temperature if request.temperature is None else request.temperature + ) + if model == self.default_model: + await self._ensure_default_model_available() + # Create a shallow copy to avoid mutating request.messages. # Prevents history leakage when the same LLMRequest object is reused. - messages: list[dict[str,Any]] = list(request.messages) if request.messages else [] + messages = self._normalize_messages(list(request.messages or [])) options = { "temperature": temperature, @@ -52,13 +229,12 @@ async def chat( "options": options, } - if request.output_json_schema: chat_params["format"] = request.output_json_schema.get("schema") - - if request.tools: - chat_params["tools"] = request.tools + tools = self._normalize_tools(request.tools) + if tools: + chat_params["tools"] = tools response = await self._client.chat(**chat_params) @@ -79,8 +255,7 @@ async def chat( # Normalize content to str content = message.content if isinstance(message.content, str) else "" - - tool_calls: list[dict[str,Any]] = [] + tool_calls: list[dict[str, Any]] = [] raw_tool_calls = getattr(message, "tool_calls", []) if isinstance(raw_tool_calls, list) and raw_tool_calls: for idx, tc in enumerate(raw_tool_calls): @@ -99,7 +274,7 @@ async def chat( ) # tool_calls normalized to plain dicts — JSON-serializable - history_entry: dict[str,Any] = { + history_entry: dict[str, Any] = { "role": "assistant", "content": content, } @@ -114,8 +289,6 @@ async def chat( "eval_count": getattr(response, "eval_count", None), } - - return LLMResponse( content=content, provider="ollama", @@ -125,6 +298,6 @@ async def chat( tool_calls=tool_calls, ) - except Exception as e: + except Exception as e: logger.error("Ollama chat failed: %s", e) - raise + raise diff --git a/pyproject.toml b/pyproject.toml index d0a01c6b..2a16e7ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "sqlalchemy>=2.0.44", "uvicorn[standard]>=0.37.0", "gunicorn>=21.2.0", + "ollama>=0.6.2", ] [dependency-groups] @@ -69,4 +70,4 @@ asyncio_mode = "strict" [tool.setuptools.packages.find] where = ["."] -include = ["finbot*"] \ No newline at end of file +include = ["finbot*"] diff --git a/tests/unit/agents/test_chat_provider_routing.py b/tests/unit/agents/test_chat_provider_routing.py new file mode 100644 index 00000000..8811c8ff --- /dev/null +++ b/tests/unit/agents/test_chat_provider_routing.py @@ -0,0 +1,75 @@ +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from finbot.agents import chat as chat_module +from finbot.agents.chat import ChatAssistantBase +from finbot.core.auth.session import SessionContext +from finbot.core.data.models import LLMResponse + + +class DummyChatAssistant(ChatAssistantBase): + def _resolve_workflow_id(self) -> str: + return "wf_test_chat" + + async def _connect_mcp(self) -> None: + self._mcp_connected = True + + def _get_system_prompt(self) -> str: + return "You are a test assistant." + + def _get_native_tool_definitions(self) -> list[dict]: + return [] + + def _build_native_callables(self) -> dict: + return {} + + def _load_history(self) -> list[dict]: + return [] + + def _save_message(self, role: str, content: str, workflow_id: str | None = None): + return None + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_chat_assistant_uses_llm_router_for_non_openai_provider(monkeypatch): + fake_llm_client = SimpleNamespace( + default_model="gemma4:e2b", + chat=AsyncMock( + return_value=LLMResponse( + content="local response", + provider="ollama", + success=True, + messages=[{"role": "assistant", "content": "local response"}], + tool_calls=[], + ) + ), + ) + + monkeypatch.setattr(chat_module.settings, "LLM_PROVIDER", "ollama") + monkeypatch.setattr(chat_module.settings, "OLLAMA_MODEL", "gemma4:e2b") + monkeypatch.setattr(chat_module, "get_llm_client", lambda: fake_llm_client) + monkeypatch.setattr(chat_module.event_bus, "emit_agent_event", AsyncMock()) + + session_context = SessionContext( + session_id="session_test", + user_id="user_test", + is_temporary=True, + namespace="test", + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=1), + current_vendor_id=1, + ) + assistant = DummyChatAssistant(session_context=session_context) + assistant._guardrail_service.invoke = AsyncMock() + + chunks = [chunk async for chunk in assistant.stream_response("hello")] + + assert assistant._client is None + fake_llm_client.chat.assert_awaited_once() + assert fake_llm_client.chat.await_args.kwargs["request"].model == "gemma4:e2b" + assert any('"content": "local response"' in chunk for chunk in chunks) + assert chunks[-1] == 'data: {"type": "done"}\n\n' diff --git a/tests/unit/llm/test_llm_client.py b/tests/unit/llm/test_llm_client.py index c9794779..09a8f376 100644 --- a/tests/unit/llm/test_llm_client.py +++ b/tests/unit/llm/test_llm_client.py @@ -23,7 +23,7 @@ # LLM-PROV-008: Module-Level Singleton Documents Import-Time Risk # LLM-PROV-009: Bad Provider Raises ValueError At Instantiation # LLM-PROV-010: Ollama Provider Initialization -# LLM-PROV-011: Ollama Provider Not Registered Raises ValueError +# LLM-PROV-011: Ollama Provider Registered In Router # LLM-PROV-012: No Warning When Provider Matches # LLM-PROV-013: LLMClient Does Not Mutate Request Before Delegation # LLM-PROV-014: Error Response Is Well-Formed @@ -117,20 +117,22 @@ def test_ollama_client_default_configuration(): AsyncClient with the correct host and timeout?" Test Steps: - 1. Configure settings.LLM_DEFAULT_MODEL, settings.LLM_DEFAULT_TEMPERATURE, + 1. Configure settings.OLLAMA_MODEL, settings.LLM_DEFAULT_TEMPERATURE, and settings.LLM_TIMEOUT. 2. Instantiate OllamaClient. 3. Inspect initialized attributes. Expected Results: - 1. default_model is set to settings.LLM_DEFAULT_MODEL + 1. default_model is set to settings.OLLAMA_MODEL when LLM_PROVIDER is ollama 2. default_temperature is set to settings.LLM_DEFAULT_TEMPERATURE 3. host is set to settings.OLLAMA_BASE_URL 4. AsyncClient is initialized with the configured host 5. AsyncClient uses settings.LLM_TIMEOUT as timeout """ with patch("finbot.core.llm.ollama_client.settings") as mock_settings: - mock_settings.LLM_DEFAULT_MODEL = "llama3.2" + mock_settings.LLM_PROVIDER = "ollama" + mock_settings.OLLAMA_MODEL = "gemma4:e2b" + mock_settings.LLM_DEFAULT_MODEL = "gpt-5-nano" mock_settings.LLM_DEFAULT_TEMPERATURE = 0.7 mock_settings.LLM_TIMEOUT = 60 mock_settings.OLLAMA_BASE_URL = "http://localhost:11434" @@ -139,7 +141,7 @@ def test_ollama_client_default_configuration(): client = OllamaClient() # default_model must come from settings — wrong value means requests use the wrong model - assert client.default_model == "llama3.2" + assert client.default_model == "gemma4:e2b" # default_temperature must come from settings — controls response randomness for every request assert client.default_temperature == pytest.approx(0.7) # host must be set from OLLAMA_BASE_URL so the client knows which server to connect to @@ -471,20 +473,22 @@ def test_ollama_provider_initialization(): AsyncClient with the correct host and timeout?" Test Steps: - 1. Configure settings.LLM_DEFAULT_MODEL, settings.LLM_DEFAULT_TEMPERATURE, + 1. Configure settings.OLLAMA_MODEL, settings.LLM_DEFAULT_TEMPERATURE, and settings.LLM_TIMEOUT 2. Instantiate OllamaClient 3. Inspect initialized attributes Expected Results: - 1. default_model is set to settings.LLM_DEFAULT_MODEL + 1. default_model is set to settings.OLLAMA_MODEL when LLM_PROVIDER is ollama 2. default_temperature is set to settings.LLM_DEFAULT_TEMPERATURE 3. host is set to settings.OLLAMA_BASE_URL or defaults to "http://localhost:11434" 4. AsyncClient is initialized with the configured host 5. AsyncClient uses settings.LLM_TIMEOUT as timeout """ with patch("finbot.core.llm.ollama_client.settings") as mock_settings: - mock_settings.LLM_DEFAULT_MODEL = "llama3.2" + mock_settings.LLM_PROVIDER = "ollama" + mock_settings.OLLAMA_MODEL = "gemma4:e2b" + mock_settings.LLM_DEFAULT_MODEL = "gpt-5-nano" mock_settings.LLM_DEFAULT_TEMPERATURE = 0.7 mock_settings.LLM_TIMEOUT = 60 mock_settings.OLLAMA_BASE_URL = "http://localhost:11434" @@ -493,7 +497,7 @@ def test_ollama_provider_initialization(): client = OllamaClient() # default_model must come from settings — wrong value means requests use the wrong model - assert client.default_model == "llama3.2" + assert client.default_model == "gemma4:e2b" # default_temperature must come from settings — controls response randomness for every request assert client.default_temperature == pytest.approx(0.7) # host must be set from OLLAMA_BASE_URL so the client knows which server to connect to @@ -506,43 +510,24 @@ def test_ollama_provider_initialization(): # ============================================================================ -# LLM-PROV-011: Ollama Provider Not Registered Raises ValueError +# LLM-PROV-011: Ollama Provider Registered In Router # ============================================================================ @pytest.mark.unit -def test_ollama_provider_raises_value_error(patched_settings): - """LLM-PROV-011: Ollama Provider Not Registered Raises ValueError +def test_ollama_provider_initializes_from_router(patched_settings): + """LLM-PROV-011: Ollama Provider Registered In Router - Verify that LLMClient raises ValueError when LLM_PROVIDER = "ollama". - - OllamaClient exists as a standalone class but is not registered in - LLMClient._get_client(). Setting LLM_PROVIDER = "ollama" falls through - to the raise ValueError at the end of the method. - - Regression note (client.py): - def _get_client(self): - if self.provider == "openai": - return OpenAIClient() - elif self.provider == "mock": - return MockLLMClient() - raise ValueError(f"Unsupported LLM provider: {self.provider}") - "ollama" is not handled, so it always raises ValueError at startup. - - Test Steps: - 1. Set LLM_PROVIDER = "ollama" via patched_settings - 2. Attempt to create LLMClient instance - 3. Expect ValueError containing "ollama" - - Expected Results: - 1. ValueError raised during initialization - 2. Error message contains "ollama" - 3. OllamaClient is not reachable via the standard LLMClient path + Verify that LLMClient routes LLM_PROVIDER = "ollama" to OllamaClient. """ - patched_settings.LLM_PROVIDER = "ollama" - patched_settings.LLM_DEFAULT_MODEL = "llama3.2" + patched_settings.LLM_PROVIDER = "OLLAMA" + patched_settings.OLLAMA_MODEL = "gemma4:e2b" + patched_settings.LLM_DEFAULT_MODEL = "gpt-5-nano" patched_settings.LLM_DEFAULT_TEMPERATURE = 0.7 - with pytest.raises(ValueError, match="ollama"): - LLMClient() + with patch("finbot.core.llm.ollama_client.OllamaClient") as mock_ollama_client: + client = LLMClient() + assert client.provider == "ollama" + assert client.default_model == "gemma4:e2b" + mock_ollama_client.assert_called_once() # ============================================================================ @@ -772,4 +757,4 @@ def test_google_sheets_integration_verification(): print("✓ Google Sheets integration verified successfully for LLM client tests") except Exception as e: - pytest.fail(f"Google Sheets verification failed: {e}") \ No newline at end of file + pytest.fail(f"Google Sheets verification failed: {e}") diff --git a/tests/unit/llm/test_ollama_client.py b/tests/unit/llm/test_ollama_client.py index a8d51550..ad879827 100644 --- a/tests/unit/llm/test_ollama_client.py +++ b/tests/unit/llm/test_ollama_client.py @@ -63,6 +63,127 @@ import gspread load_dotenv() + + +@pytest.fixture(autouse=True) +def _disable_ollama_auto_pull_for_unit_tests(monkeypatch): + monkeypatch.setattr("finbot.core.llm.ollama_client.settings.LLM_PROVIDER", "openai") + + +def test_normalizes_openai_style_tool_definitions_for_ollama(): + tools = [ + { + "type": "function", + "name": "get_vendor_details", + "strict": True, + "description": "Get vendor details", + "parameters": { + "type": "object", + "properties": {"vendor_id": {"type": "integer"}}, + "required": ["vendor_id"], + "additionalProperties": False, + }, + } + ] + + normalized = OllamaClient._normalize_tools(tools) + + assert normalized == [ + { + "type": "function", + "function": { + "name": "get_vendor_details", + "description": "Get vendor details", + "parameters": tools[0]["parameters"], + }, + } + ] + + +def test_normalizes_tool_call_history_for_ollama_followup(): + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_vendor_details", + "call_id": "ollama_call_0", + "arguments": {"vendor_id": 1}, + } + ], + }, + { + "type": "function_call_output", + "call_id": "ollama_call_0", + "output": '{"company_name": "Acme"}', + }, + ] + + normalized = OllamaClient._normalize_messages(messages) + + assert normalized[0] == { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "get_vendor_details", + "arguments": {"vendor_id": 1}, + } + } + ], + } + assert normalized[1] == { + "role": "tool", + "content": '{"company_name": "Acme"}', + "tool_name": "get_vendor_details", + } + + +@pytest.mark.parametrize( + ("available_models", "should_pull"), + [ + ([], True), + ([MagicMock(model="gemma4:e2b")], False), + ], +) +@pytest.mark.asyncio +@pytest.mark.unit +async def test_ensures_default_ollama_model_available(available_models, should_pull): + fake_message = AsyncMock() + fake_message.content = "reply" + fake_message.tool_calls = None + + fake_response = AsyncMock() + fake_response.message = fake_message + + with patch("finbot.core.llm.ollama_client.settings") as mock_settings: + mock_settings.LLM_PROVIDER = "ollama" + mock_settings.OLLAMA_MODEL = "gemma4:e2b" + mock_settings.LLM_DEFAULT_MODEL = "gpt-5-nano" + mock_settings.LLM_DEFAULT_TEMPERATURE = 0.7 + mock_settings.LLM_MAX_TOKENS = 5000 + mock_settings.LLM_TIMEOUT = 60 + mock_settings.OLLAMA_BASE_URL = "http://localhost:11434" + + with patch("finbot.core.llm.ollama_client.AsyncClient") as mock_client: + instance = mock_client.return_value + instance.list = AsyncMock(return_value=MagicMock(models=available_models)) + instance.pull = AsyncMock(return_value=MagicMock(status="success")) + instance.chat = AsyncMock(return_value=fake_response) + + client = OllamaClient() + await client.chat(LLMRequest(messages=[{"role": "user", "content": "hi"}])) + + instance.list.assert_awaited_once() + if should_pull: + instance.pull.assert_awaited_once_with("gemma4:e2b") + else: + instance.pull.assert_not_awaited() + assert instance.chat.await_args.kwargs["model"] == "gemma4:e2b" + + # ============================================================================ # LLM-CONF-001: Default Configuration Loading # ============================================================================ @@ -75,7 +196,7 @@ async def test_default_configuration_loading(): Test Steps: 1. Create OllamaClient instance without custom parameters - 2. Verify default_model is loaded from settings.LLM_DEFAULT_MODEL + 2. Verify default_model is loaded from settings.OLLAMA_MODEL when provider is ollama 3. Verify default_temperature is loaded from settings.LLM_DEFAULT_TEMPERATURE 4. Verify host is set to OLLAMA_BASE_URL or defaults to "https://localhost:11434" 5. Verify AsyncClient is initialized with correct host and timeout @@ -88,7 +209,9 @@ async def test_default_configuration_loading(): 5. AsyncClient configured with proper connection parameters """ with patch("finbot.core.llm.ollama_client.settings") as mock_settings: - mock_settings.LLM_DEFAULT_MODEL = "llama3.2" + mock_settings.LLM_PROVIDER = "ollama" + mock_settings.OLLAMA_MODEL = "gemma4:e2b" + mock_settings.LLM_DEFAULT_MODEL = "gpt-5-nano" mock_settings.LLM_DEFAULT_TEMPERATURE = 0.7 mock_settings.LLM_TIMEOUT = 60 mock_settings.OLLAMA_BASE_URL = "https://custom-ollama:11434" @@ -97,7 +220,7 @@ async def test_default_configuration_loading(): client = OllamaClient() # The client must store the model name from settings — used as default for every request - assert client.default_model == "llama3.2" + assert client.default_model == "gemma4:e2b" # Temperature controls response randomness (0=deterministic, 1=creative); must match settings assert client.default_temperature == pytest.approx(0.7) # The Ollama server URL must be read from OLLAMA_BASE_URL so the client knows where to connect @@ -1472,7 +1595,7 @@ async def test_request_messages_none(): 1. Create an LLMRequest with messages=None. 2. Mock Ollama response with a valid message. 3. Call OllamaClient.chat(). - + Expected Behavior: 1. The client does not crash. @@ -1645,5 +1768,3 @@ def test_google_sheets_integration_verification(): except Exception as e: pytest.fail(f"Google Sheets verification failed: {e}") - - diff --git a/uv.lock b/uv.lock index 3d981b60..4c1a17cb 100644 --- a/uv.lock +++ b/uv.lock @@ -520,6 +520,7 @@ dependencies = [ { name = "gunicorn" }, { name = "httpx" }, { name = "jinja2" }, + { name = "ollama" }, { name = "openai" }, { name = "playwright" }, { name = "psycopg2-binary" }, @@ -562,6 +563,7 @@ requires-dist = [ { name = "gunicorn", specifier = ">=21.2.0" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "jinja2", specifier = ">=3.1.6" }, + { name = "ollama", specifier = ">=0.6.2" }, { name = "openai", specifier = ">=2.6.1" }, { name = "playwright", specifier = ">=1.49.0" }, { name = "psycopg2-binary", specifier = ">=2.9.11" }, @@ -1165,6 +1167,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" }, ] +[[package]] +name = "ollama" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/72/5f12423b6b39ca8430fbe56f77fcf4ef60f63067c7c4a2e30e200ed9ec16/ollama-0.6.2.tar.gz", hash = "sha256:936d55daa684f474364c098611c933626f8d6c7d67065c5b7ae0c477b508b07f", size = 53145, upload-time = "2026-04-29T21:21:15.018Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/ab/d6722beeb2d10f7a3b9ff49375708904fde18f82b5609a0bc4aeb5996a4d/ollama-0.6.2-py3-none-any.whl", hash = "sha256:3ad7daab28e5a973445c36a73882a3ef698c2ebb00e21e308652741577509f7d", size = 15115, upload-time = "2026-04-29T21:21:13.794Z" }, +] + [[package]] name = "openai" version = "2.17.0"