From 55eb2c1ad4962bd8e82402b72194003c643315ac Mon Sep 17 00:00:00 2001 From: precious112 Date: Wed, 11 Mar 2026 05:02:44 +0100 Subject: [PATCH] fixed BYOK config --- .../agent/src/argus_agent/llm/registry.py | 33 +++++++++++++++++++ .../agent/src/argus_agent/queue/worker.py | 6 ++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/packages/agent/src/argus_agent/llm/registry.py b/packages/agent/src/argus_agent/llm/registry.py index bc41c26..1f44033 100644 --- a/packages/agent/src/argus_agent/llm/registry.py +++ b/packages/agent/src/argus_agent/llm/registry.py @@ -37,6 +37,39 @@ def get_provider() -> LLMProvider: return _providers[provider_name](settings.llm) +async def get_provider_for_tenant(tenant_id: str) -> LLMProvider: + """Get an LLM provider using the tenant's BYOK keys if configured, + otherwise fall back to the platform default.""" + from argus_agent.api.llm_keys import get_tenant_llm_key + + tenant_config = await get_tenant_llm_key(tenant_id) + if tenant_config and tenant_config.get("api_key"): + provider_name = tenant_config.get("provider", "openai") + if provider_name not in _providers: + _discover_providers() + if provider_name not in _providers: + raise ValueError( + f"Unknown LLM provider: {provider_name}. Available: {list(_providers.keys())}" + ) + + settings = get_settings() + # Build a temporary LLMConfig with tenant's keys + from argus_agent.config import LLMConfig + + tenant_llm = LLMConfig( + provider=provider_name, + api_key=tenant_config["api_key"], + model=tenant_config.get("model") or settings.llm.model, + base_url=tenant_config.get("base_url") or settings.llm.base_url, + temperature=settings.llm.temperature, + max_tokens=settings.llm.max_tokens, + ) + return _providers[provider_name](tenant_llm) + + # No BYOK config — use platform default + return get_provider() + + def _discover_providers() -> None: """Auto-discover available providers based on installed packages.""" try: diff --git a/packages/agent/src/argus_agent/queue/worker.py b/packages/agent/src/argus_agent/queue/worker.py index 93bf187..034dd9b 100644 --- a/packages/agent/src/argus_agent/queue/worker.py +++ b/packages/agent/src/argus_agent/queue/worker.py @@ -176,10 +176,10 @@ async def on_event(event_type: str, data: dict[str, Any]) -> None: msg = json.dumps({"event_type": event_type, "data": data}) await redis_pub.publish(f"{STREAM_KEY_PREFIX}{task_id}", msg) - # 4. Get LLM provider - from argus_agent.llm.registry import get_provider + # 4. Get LLM provider (tenant BYOK keys take priority) + from argus_agent.llm.registry import get_provider_for_tenant - provider = get_provider() + provider = await get_provider_for_tenant(payload.tenant_id) # 5. Build AgentLoop from argus_agent.agent.loop import AgentLoop