Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions backend/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import asyncio
import httpx
from typing import Any, Dict, List

from fastapi import APIRouter, HTTPException, Query, status
Expand Down Expand Up @@ -48,9 +49,27 @@
}


async def fetch_ollama_models() -> List[str]:
base_url = cred_store.get_ollama_base_url()
try:
async with httpx.AsyncClient(timeout=2.0) as client:
resp = await client.get(f"{base_url.rstrip('/')}/api/tags")
resp.raise_for_status()
data = resp.json()
return [m["name"] for m in data.get("models", [])]
except Exception:
return []


async def get_all_models_dict() -> Dict[str, List[str]]:
models = MODELS.copy()
models["ollama"] = await fetch_ollama_models()
return models


@router.get("")
async def list_models() -> Dict[str, List[str]]:
return MODELS
return await get_all_models_dict()


@router.get("/details")
Expand All @@ -63,7 +82,8 @@ async def models_details() -> Dict[str, List[Dict[str, Any]]]:
{
"claude": [ {id, context_window, max_output_tokens, source}, ... ],
"openai": [ ... ],
"gemini": [ ... ]
"gemini": [ ... ],
"ollama": [ ... ]
}
"""
async def _one(kind: str, model: str) -> Dict[str, Any]:
Expand All @@ -76,40 +96,31 @@ async def _one(kind: str, model: str) -> Dict[str, Any]:
"source": info.source if info else None,
}

all_models = await get_all_models_dict()
tasks: Dict[str, List[asyncio.Task[Dict[str, Any]]]] = {}
for kind, ids in MODELS.items():
for kind, ids in all_models.items():
tasks[kind] = [asyncio.create_task(_one(kind, m)) for m in ids]

out: Dict[str, List[Dict[str, Any]]] = {}
for kind, task_list in tasks.items():
out[kind] = await asyncio.gather(*task_list)
if task_list:
out[kind] = await asyncio.gather(*task_list)
else:
out[kind] = []
return out


@router.get("/info")
async def model_info_endpoint(
kind: AgentKind = Query(..., description="Provider (claude, openai, gemini)"),
model: str = Query(..., description="Model id — must be in MODELS[kind]"),
kind: AgentKind = Query(..., description="Provider (claude, openai, gemini, ollama)"),
model: str = Query(..., description="Model id"),
) -> Dict[str, Any]:
"""Return context window + max output for a specific model.

Response shape:
{
"kind": "gemini",
"model": "gemini-2.5-pro",
"context_window": 2000000,
"max_output_tokens": 64000,
"source": "api" // or "static", or null if unknown
}
"""
if model not in MODELS.get(kind, []):
all_models = await get_all_models_dict()
if model not in all_models.get(kind, []):
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f"Unknown model '{model}' for provider '{kind}'.",
)
# For providers that require a key to live-fetch (Gemini), pass the
# saved credential through. Static-table lookups don't need it but
# it's harmless to pass an empty string.
api_key = cred_store.get_key(kind) or ""
info = await get_model_info(kind, model, api_key)
if info is None:
Expand Down
7 changes: 7 additions & 0 deletions backend/app/llm/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ async def get_model_info(
async with _cache_lock:
_cache[key] = live
return live

if agent_kind == "ollama":
return await _fetch_ollama(model)

return _STATIC.get(key)

Expand Down Expand Up @@ -130,3 +133,7 @@ async def _fetch_gemini(model: str, api_key: str) -> Optional[ModelInfo]:
exc,
)
return None

async def _fetch_ollama(model: str) -> Optional[ModelInfo]:
return ModelInfo(context_window=256000, max_output_tokens=8192, source="api")

107 changes: 107 additions & 0 deletions backend/app/llm/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,4 +445,111 @@ def build_provider(agent_kind: str, api_key: str):
return OpenAIProvider(api_key)
if agent_kind == "gemini":
return GeminiProvider(api_key)
if agent_kind == "ollama":
return OllamaProvider(api_key)
raise ValueError(f"Unknown agent_kind: {agent_kind}")

class OllamaProvider:
kind = "ollama"

def __init__(self, api_key: str) -> None:
from openai import AsyncOpenAI
from app.store.credentials import get_ollama_base_url
base_url = get_ollama_base_url().rstrip('/') + '/v1'
self._client = AsyncOpenAI(api_key=api_key if api_key else "ollama", base_url=base_url)

async def close(self) -> None:
try:
await self._client.close()
except Exception:
pass

async def stream_turn(
self,
*,
system: str,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
model: str,
emit: EmitFn,
) -> Dict[str, Any]:
msg_id = f"msg-{uuid.uuid4()}"
openai_msgs: List[Dict[str, Any]] = [
{"role": "system", "content": system}
]
openai_msgs.extend(_translate_to_openai(messages))

text = ""
tool_calls: Dict[int, Dict[str, str]] = {}
input_tokens = 0
output_tokens = 0

stream = await self._client.chat.completions.create(
model=model,
messages=openai_msgs,
tools=tools or None,
stream=True,
stream_options={"include_usage": True},
)

async for chunk in stream:
usage = getattr(chunk, "usage", None)
if usage is not None:
input_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
output_tokens = int(getattr(usage, "completion_tokens", 0) or 0)

if not chunk.choices:
continue
delta = chunk.choices[0].delta

content = getattr(delta, "content", None)
if content:
text += content
await emit({
"type": "assistant.delta",
"text": content,
"message_id": msg_id,
})

tc_deltas = getattr(delta, "tool_calls", None)
if tc_deltas:
for tc_delta in tc_deltas:
idx = tc_delta.index
tc = tool_calls.setdefault(
idx, {"id": "", "name": "", "arguments": ""}
)
if tc_delta.id:
tc["id"] = tc_delta.id
fn = getattr(tc_delta, "function", None)
if fn is not None:
if getattr(fn, "name", None):
tc["name"] = fn.name
if getattr(fn, "arguments", None):
tc["arguments"] += fn.arguments

if text:
await emit({
"type": "assistant.complete",
"text": text,
"message_id": msg_id,
})
if input_tokens or output_tokens:
await emit({
"type": "usage",
"message_id": msg_id,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
})

tool_uses: List[Dict[str, Any]] = []
for tc in tool_calls.values():
try:
args = json.loads(tc["arguments"] or "{}")
except json.JSONDecodeError:
args = {}
# Ollama might not give us a tool id for OpenAIs, let's inject one if empty
t_id = tc["id"] if tc["id"] else f"call_{uuid.uuid4()}"
tool_uses.append(
{"id": t_id, "name": tc["name"], "input": args}
)
return {"text": text, "tool_uses": tool_uses}
18 changes: 17 additions & 1 deletion backend/app/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# modified by agent: add ollama settings routes
import logging
from contextlib import asynccontextmanager
from pathlib import Path

from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles

# Make app.* loggers surface at INFO so WS lifecycle events are visible.
Expand All @@ -15,6 +16,7 @@
from app.api.sessions import router as sessions_router
from app.db import init_db
from app.ws.session_ws import router as ws_router
from app.store.credentials import get_ollama_base_url, set_ollama_base_url


@asynccontextmanager
Expand All @@ -37,6 +39,20 @@ async def hello() -> dict:
return {"message": "Hello from the Forge backend"}


@app.get("/api/settings/ollama")
async def get_ollama_settings() -> dict:
return {"base_url": get_ollama_base_url()}


@app.post("/api/settings/ollama")
async def post_ollama_settings(body: dict) -> dict:
base_url = body.get("base_url")
if not isinstance(base_url, str) or not base_url.strip() or not (base_url.startswith("http://") or base_url.startswith("https://")):
return JSONResponse(status_code=422, content={"error": "Invalid URL. Must start with http:// or https://"})
set_ollama_base_url(base_url.strip())
return {"ok": True}


# Serve the built frontend bundle from app/static/ when present. The directory
# is created by scripts/build-wheel.sh before packaging; in local dev it won't
# exist and we skip mounting — use `npm run dev` for the frontend instead.
Expand Down
2 changes: 1 addition & 1 deletion backend/app/schemas/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, Field

AgentKind = Literal["claude", "openai", "gemini"]
AgentKind = Literal["claude", "openai", "gemini", "ollama"]


class CredentialStatus(BaseModel):
Expand Down
12 changes: 11 additions & 1 deletion backend/app/store/credentials.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# modified by agent: add ollama base url config functions
from datetime import datetime, timezone
from typing import Optional

from app.db import get_conn
from app.schemas.credentials import AgentKind, CredentialStatus

AGENT_KINDS: tuple[AgentKind, ...] = ("claude", "openai", "gemini")
AGENT_KINDS: tuple[AgentKind, ...] = ("claude", "openai", "gemini", "ollama")


def _now_iso() -> str:
Expand Down Expand Up @@ -76,3 +77,12 @@ def delete_key(agent_kind: AgentKind) -> bool:
conn.rollback()
raise
return cur.rowcount > 0


def get_ollama_base_url() -> str:
url = get_key("ollama_url") # type: ignore
return url if url else "http://localhost:11434"


def set_ollama_base_url(url: str) -> None:
upsert_key("ollama_url", url) # type: ignore
Loading