From 7bb490d13a66ef4ce64f3b13f61ab6ae07a7830f Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Wed, 15 Apr 2026 16:42:34 +0545 Subject: [PATCH 1/2] feat: OpenAI support --- backend/core/consts.py | 5 + .../services/factories/ai_provider_factory.py | 2 + .../services/providers/ai/openai_provider.py | 141 +++++ backend/core/tests/test_openai_provider.py | 510 ++++++++++++++++++ frontend/components/icons/OpenAIIcon.vue | 6 + frontend/composables/useAIProviderIcon.ts | 3 + 6 files changed, 667 insertions(+) create mode 100644 backend/core/services/providers/ai/openai_provider.py create mode 100644 backend/core/tests/test_openai_provider.py create mode 100644 frontend/components/icons/OpenAIIcon.vue diff --git a/backend/core/consts.py b/backend/core/consts.py index 9a2c1d5..fbd20dc 100644 --- a/backend/core/consts.py +++ b/backend/core/consts.py @@ -13,6 +13,11 @@ 'label': 'Google Gemini', 'base_url': 'https://generativelanguage.googleapis.com/v1beta' }, + { + 'id': 'openai', + 'label': 'OpenAI', + 'base_url': 'https://api.openai.com/v1' + }, { 'id': 'custom', 'label': 'Custom Provider', diff --git a/backend/core/services/factories/ai_provider_factory.py b/backend/core/services/factories/ai_provider_factory.py index 0113edf..4555b5e 100644 --- a/backend/core/services/factories/ai_provider_factory.py +++ b/backend/core/services/factories/ai_provider_factory.py @@ -3,11 +3,13 @@ from ..contracts.ai_provider_contract import AIProviderContract from ..providers.ai.custom_provider import CustomProvider from ..providers.ai.gemini_provider import GeminiProvider +from ..providers.ai.openai_provider import OpenAIProvider class AIProviderFactory: PROVIDER_CLASSES = { 'gemini': GeminiProvider, + 'openai': OpenAIProvider, 'custom': CustomProvider, } diff --git a/backend/core/services/providers/ai/openai_provider.py b/backend/core/services/providers/ai/openai_provider.py new file mode 100644 index 0000000..7a94825 --- /dev/null +++ b/backend/core/services/providers/ai/openai_provider.py @@ -0,0 +1,141 @@ +import json +from typing import Optional, Dict, Any +from pydantic import BaseModel +import openai +from ...contracts.ai_provider_contract import AIProviderContract +from core.agent_response_schema import SupportAgentResponse + +EXCLUDED_PREFIXES = ("whisper", "tts", "dall-e", "davinci", "babbage", "text-embedding-ada") +EXCLUDED_SUFFIXES = ("-instruct",) + + +class OpenAIProvider(AIProviderContract): + def __init__(self, api_key: str, config: Optional[Dict[str, Any]] = None) -> None: + super().__init__(api_key, config) + base_url = self.config.get("base_url") if self.config else None + try: + kwargs: Dict[str, Any] = {"api_key": api_key} + if base_url: + kwargs["base_url"] = base_url + self.client = openai.OpenAI(**kwargs) + except Exception as e: + raise ValueError(f"Failed to initialize OpenAI client: {e}") + + def validate_connection(self) -> tuple[bool, list[Dict[str, Any]]]: + try: + models = self.get_models() + return True, models + except Exception: + return False, [] + + def get_models(self) -> list[Dict[str, Any]]: + try: + raw_models = self.client.models.list().data + except Exception as e: + raise ValueError(f"Failed to retrieve models from OpenAI API: {e}") + + result = [] + for model in raw_models: + model_id = model.id + if any(model_id.startswith(p) for p in EXCLUDED_PREFIXES): + continue + if any(model_id.endswith(s) for s in EXCLUDED_SUFFIXES): + continue + result.append({ + "id": model_id, + "name": model_id, + "object": model.object, + "created": model.created, + "owned_by": model.owned_by, + }) + return result + + def generate_with_conversation( + self, + model: str, + messages: list[dict], + tools: list[dict] | None, + response_schema: type[BaseModel], + ) -> tuple: + kwargs: Dict[str, Any] = { + "model": model, + "messages": messages, + } + + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = "auto" + else: + kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": response_schema.__name__, + "schema": response_schema.model_json_schema(), + "strict": True, + }, + } + + try: + response = self.client.chat.completions.create(**kwargs) + except openai.AuthenticationError as e: + raise ValueError(f"Invalid OpenAI API key: {e}") + except openai.RateLimitError as e: + raise ValueError(f"OpenAI rate limit exceeded: {e}") + except openai.APIError as e: + raise ValueError(f"OpenAI API error: {e}") + + usage_metadata = self._extract_usage(response) + choice = response.choices[0] + + if choice.finish_reason == "tool_calls": + raw_tool_calls = [] + for tc in (choice.message.tool_calls or []): + raw_tool_calls.append({ + "id": tc.id, + "name": tc.function.name, + "args": json.loads(tc.function.arguments), + }) + return choice.message.content or "", raw_tool_calls, usage_metadata + + try: + parsed = response_schema.model_validate_json(choice.message.content) + except Exception as e: + raise ValueError(f"Failed to parse OpenAI response as {response_schema.__name__}: {e}") + + return parsed, [], usage_metadata + + def _extract_usage(self, response) -> dict: + usage: dict = {} + try: + meta = getattr(response, "usage", None) + if meta is not None: + for key in ("prompt_tokens", "completion_tokens", "total_tokens"): + val = getattr(meta, key, None) + if val is not None: + usage[key] = val + details = getattr(meta, "prompt_tokens_details", None) + if details is not None: + cached = getattr(details, "cached_tokens", None) + if cached is not None: + usage["cached_tokens"] = cached + except Exception: + pass + return usage + + def embed(self, model: str, texts: list[str]) -> list[list[float]]: + try: + response = self.client.embeddings.create(input=texts, model=model) + except Exception as e: + raise ValueError(f"OpenAI embedding error: {e}") + sorted_data = sorted(response.data, key=lambda item: item.index) + return [item.embedding for item in sorted_data] + + def generate_text(self, model: str, contents: str, **kwargs) -> SupportAgentResponse: + messages = [{"role": "user", "content": contents}] + result, _, _ = self.generate_with_conversation( + model=model, + messages=messages, + tools=None, + response_schema=SupportAgentResponse, + ) + return result diff --git a/backend/core/tests/test_openai_provider.py b/backend/core/tests/test_openai_provider.py new file mode 100644 index 0000000..47dd16d --- /dev/null +++ b/backend/core/tests/test_openai_provider.py @@ -0,0 +1,510 @@ +""" +Tests for OpenAIProvider — unit tests and property-based tests. + +Property tests: + - Property 2: Tool call / structured output mutual exclusion (Validates: Requirements 6.2, 7.2) + - Property 3: Model list filter stability (Validates: Requirements 5.2, 5.3, 5.6) + - Property 4: Embedding order preservation (Validates: Requirements 9.2) + - Property 5: validate_connection never raises (Validates: Requirements 4.2, 4.3) +""" +import json +import pytest +from unittest.mock import MagicMock, patch, PropertyMock +from types import SimpleNamespace + +from hypothesis import given, settings, HealthCheck +from hypothesis import strategies as st + +from core.services.providers.ai.openai_provider import OpenAIProvider, EXCLUDED_PREFIXES, EXCLUDED_SUFFIXES +from core.agent_response_schema import SupportAgentResponse + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_provider(api_key="sk-test", config=None): + """Return an OpenAIProvider with a mocked openai.OpenAI client.""" + with patch("core.services.providers.ai.openai_provider.openai.OpenAI"): + provider = OpenAIProvider(api_key=api_key, config=config) + return provider + + +def _make_model(model_id, object_="model", created=0, owned_by="openai"): + m = MagicMock() + m.id = model_id + m.object = object_ + m.created = created + m.owned_by = owned_by + return m + + +def _make_stop_response(content: str): + """Build a minimal ChatCompletion-like object with finish_reason='stop'.""" + choice = MagicMock() + choice.finish_reason = "stop" + choice.message.content = content + choice.message.tool_calls = None + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 20 + usage.total_tokens = 30 + usage.prompt_tokens_details = None + + response = MagicMock() + response.choices = [choice] + response.usage = usage + return response + + +def _make_tool_call_response(tool_name="search", args=None): + """Build a minimal ChatCompletion-like object with finish_reason='tool_calls'.""" + tc = MagicMock() + tc.id = "call_abc123" + tc.function.name = tool_name + tc.function.arguments = json.dumps(args or {"query": "test"}) + + choice = MagicMock() + choice.finish_reason = "tool_calls" + choice.message.content = "" + choice.message.tool_calls = [tc] + + usage = MagicMock() + usage.prompt_tokens = 5 + usage.completion_tokens = 10 + usage.total_tokens = 15 + usage.prompt_tokens_details = None + + response = MagicMock() + response.choices = [choice] + response.usage = usage + return response + + +# --------------------------------------------------------------------------- +# Unit tests — __init__ +# --------------------------------------------------------------------------- + +class TestOpenAIProviderInit: + def test_init_success(self): + with patch("core.services.providers.ai.openai_provider.openai.OpenAI") as mock_cls: + provider = OpenAIProvider(api_key="sk-test") + mock_cls.assert_called_once_with(api_key="sk-test") + assert provider.api_key == "sk-test" + + def test_init_with_base_url(self): + with patch("core.services.providers.ai.openai_provider.openai.OpenAI") as mock_cls: + provider = OpenAIProvider(api_key="sk-test", config={"base_url": "https://proxy.example.com"}) + mock_cls.assert_called_once_with(api_key="sk-test", base_url="https://proxy.example.com") + + def test_init_with_none_config(self): + with patch("core.services.providers.ai.openai_provider.openai.OpenAI") as mock_cls: + provider = OpenAIProvider(api_key="sk-test", config=None) + mock_cls.assert_called_once_with(api_key="sk-test") + + def test_init_raises_value_error_on_sdk_failure(self): + with patch("core.services.providers.ai.openai_provider.openai.OpenAI", side_effect=Exception("bad key")): + with pytest.raises(ValueError, match="Failed to initialize OpenAI client"): + OpenAIProvider(api_key="bad") + + +# --------------------------------------------------------------------------- +# Unit tests — get_models +# --------------------------------------------------------------------------- + +class TestGetModels: + def test_filters_excluded_prefixes(self): + provider = _make_provider() + raw = [ + _make_model("gpt-4o"), + _make_model("whisper-1"), + _make_model("tts-1"), + _make_model("dall-e-3"), + _make_model("davinci-002"), + _make_model("babbage-002"), + _make_model("text-embedding-ada-002"), + ] + provider.client.models.list.return_value.data = raw + result = provider.get_models() + ids = [m["id"] for m in result] + assert "gpt-4o" in ids + assert "whisper-1" not in ids + assert "tts-1" not in ids + assert "dall-e-3" not in ids + assert "davinci-002" not in ids + assert "babbage-002" not in ids + assert "text-embedding-ada-002" not in ids + + def test_filters_instruct_suffix(self): + provider = _make_provider() + provider.client.models.list.return_value.data = [ + _make_model("gpt-3.5-turbo-instruct"), + _make_model("gpt-4o"), + ] + result = provider.get_models() + ids = [m["id"] for m in result] + assert "gpt-3.5-turbo-instruct" not in ids + assert "gpt-4o" in ids + + def test_returns_id_and_name_keys(self): + provider = _make_provider() + provider.client.models.list.return_value.data = [_make_model("gpt-4o")] + result = provider.get_models() + assert len(result) == 1 + assert "id" in result[0] + assert "name" in result[0] + + def test_raises_value_error_on_api_failure(self): + provider = _make_provider() + provider.client.models.list.side_effect = Exception("network error") + with pytest.raises(ValueError, match="Failed to retrieve models"): + provider.get_models() + + def test_returns_empty_list_when_all_filtered(self): + provider = _make_provider() + provider.client.models.list.return_value.data = [ + _make_model("whisper-1"), + _make_model("tts-1"), + ] + result = provider.get_models() + assert result == [] + + +# --------------------------------------------------------------------------- +# Unit tests — validate_connection +# --------------------------------------------------------------------------- + +class TestValidateConnection: + def test_returns_true_and_models_on_success(self): + provider = _make_provider() + provider.client.models.list.return_value.data = [_make_model("gpt-4o")] + ok, models = provider.validate_connection() + assert ok is True + assert len(models) == 1 + + def test_returns_false_empty_on_exception(self): + provider = _make_provider() + provider.client.models.list.side_effect = Exception("auth error") + ok, models = provider.validate_connection() + assert ok is False + assert models == [] + + def test_never_raises(self): + provider = _make_provider() + provider.client.models.list.side_effect = RuntimeError("unexpected") + # Must not raise + result = provider.validate_connection() + assert isinstance(result, tuple) + + +# --------------------------------------------------------------------------- +# Unit tests — generate_with_conversation +# --------------------------------------------------------------------------- + +class TestGenerateWithConversation: + def _valid_response_json(self): + return json.dumps({ + "answer": "Reset your password via Settings.", + "status": "ANSWERED", + "escalation": False, + "reason_for_escalation": "", + "sentiment_score": 70, + "escalation_score": 5, + "criticality_score": 10, + }) + + def test_no_tools_returns_parsed_schema(self): + provider = _make_provider() + provider.client.chat.completions.create.return_value = _make_stop_response(self._valid_response_json()) + messages = [{"role": "user", "content": "How do I reset my password?"}] + result, tool_calls, usage = provider.generate_with_conversation( + model="gpt-4o", messages=messages, tools=None, response_schema=SupportAgentResponse + ) + assert isinstance(result, SupportAgentResponse) + assert tool_calls == [] + assert "prompt_tokens" in usage + + def test_no_tools_uses_response_format(self): + provider = _make_provider() + provider.client.chat.completions.create.return_value = _make_stop_response(self._valid_response_json()) + messages = [{"role": "user", "content": "Hello"}] + provider.generate_with_conversation( + model="gpt-4o", messages=messages, tools=None, response_schema=SupportAgentResponse + ) + call_kwargs = provider.client.chat.completions.create.call_args[1] + assert "response_format" in call_kwargs + assert "tools" not in call_kwargs + + def test_with_tools_returns_tool_calls(self): + provider = _make_provider() + provider.client.chat.completions.create.return_value = _make_tool_call_response("search_kb", {"query": "reset"}) + tools = [{"type": "function", "function": {"name": "search_kb", "description": "Search", "parameters": {}}}] + messages = [{"role": "user", "content": "Help"}] + text, tool_calls, usage = provider.generate_with_conversation( + model="gpt-4o", messages=messages, tools=tools, response_schema=SupportAgentResponse + ) + assert len(tool_calls) == 1 + assert tool_calls[0]["name"] == "search_kb" + assert "id" in tool_calls[0] + assert "args" in tool_calls[0] + + def test_with_tools_uses_tool_choice_auto(self): + provider = _make_provider() + provider.client.chat.completions.create.return_value = _make_tool_call_response() + tools = [{"type": "function", "function": {"name": "fn", "description": "", "parameters": {}}}] + provider.generate_with_conversation( + model="gpt-4o", messages=[{"role": "user", "content": "x"}], tools=tools, response_schema=SupportAgentResponse + ) + call_kwargs = provider.client.chat.completions.create.call_args[1] + assert call_kwargs.get("tool_choice") == "auto" + assert "response_format" not in call_kwargs + + def test_raises_value_error_on_authentication_error(self): + import openai as openai_module + provider = _make_provider() + provider.client.chat.completions.create.side_effect = openai_module.AuthenticationError( + message="Invalid key", response=MagicMock(), body={} + ) + with pytest.raises(ValueError, match="Invalid OpenAI API key"): + provider.generate_with_conversation( + model="gpt-4o", messages=[{"role": "user", "content": "x"}], + tools=None, response_schema=SupportAgentResponse + ) + + def test_raises_value_error_on_rate_limit_error(self): + import openai as openai_module + provider = _make_provider() + provider.client.chat.completions.create.side_effect = openai_module.RateLimitError( + message="Rate limit", response=MagicMock(), body={} + ) + with pytest.raises(ValueError, match="rate limit"): + provider.generate_with_conversation( + model="gpt-4o", messages=[{"role": "user", "content": "x"}], + tools=None, response_schema=SupportAgentResponse + ) + + def test_raises_value_error_on_api_error(self): + import openai as openai_module + provider = _make_provider() + provider.client.chat.completions.create.side_effect = openai_module.APIStatusError( + message="Server error", response=MagicMock(status_code=500), body={} + ) + with pytest.raises(ValueError, match="OpenAI API error"): + provider.generate_with_conversation( + model="gpt-4o", messages=[{"role": "user", "content": "x"}], + tools=None, response_schema=SupportAgentResponse + ) + + def test_raises_value_error_on_bad_json(self): + provider = _make_provider() + provider.client.chat.completions.create.return_value = _make_stop_response("not valid json {{{") + with pytest.raises(ValueError, match="Failed to parse OpenAI response"): + provider.generate_with_conversation( + model="gpt-4o", messages=[{"role": "user", "content": "x"}], + tools=None, response_schema=SupportAgentResponse + ) + + +# --------------------------------------------------------------------------- +# Unit tests — embed +# --------------------------------------------------------------------------- + +class TestEmbed: + def test_returns_vectors_in_input_order(self): + provider = _make_provider() + # Return shuffled: index 1 first, then 0 + e0 = MagicMock(); e0.index = 0; e0.embedding = [0.1, 0.2] + e1 = MagicMock(); e1.index = 1; e1.embedding = [0.3, 0.4] + provider.client.embeddings.create.return_value.data = [e1, e0] # shuffled + result = provider.embed(model="text-embedding-3-small", texts=["a", "b"]) + assert result[0] == [0.1, 0.2] + assert result[1] == [0.3, 0.4] + + def test_raises_value_error_on_failure(self): + provider = _make_provider() + provider.client.embeddings.create.side_effect = Exception("network error") + with pytest.raises(ValueError, match="OpenAI embedding error"): + provider.embed(model="text-embedding-3-small", texts=["hello"]) + + +# --------------------------------------------------------------------------- +# Property-based tests +# --------------------------------------------------------------------------- + +# --- Property 5: validate_connection never raises --- +# Validates: Requirements 4.2, 4.3 + +@given(api_key=st.text()) +@settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow]) +def test_property_5_validate_connection_never_raises(api_key): + """ + **Validates: Requirements 4.2, 4.3** + + Property 5: validate_connection never raises. + For any api_key string, validate_connection always returns tuple[bool, list] + and never propagates an exception. + """ + with patch("core.services.providers.ai.openai_provider.openai.OpenAI"): + provider = OpenAIProvider(api_key=api_key if api_key else "sk-x") + + # Make the client raise an arbitrary exception + provider.client.models.list.side_effect = Exception("simulated failure") + + result = provider.validate_connection() + assert isinstance(result, tuple) + assert len(result) == 2 + ok, models = result + assert isinstance(ok, bool) + assert isinstance(models, list) + assert ok is False + assert models == [] + + +# --- Property 3: Model list filter stability --- +# Validates: Requirements 5.2, 5.3, 5.6 + +def _model_id_strategy(): + """Generate model IDs including edge cases.""" + normal = st.text(alphabet=st.characters(whitelist_categories=("Ll", "Lu", "Nd"), whitelist_characters="-_."), min_size=1, max_size=30) + prefixed = st.sampled_from(list(EXCLUDED_PREFIXES)).map(lambda p: p + "-extra") + suffixed = st.just("gpt-4o-instruct") + clean = st.just("gpt-4o") + return st.one_of(normal, prefixed, suffixed, clean) + + +@given(model_ids=st.lists(_model_id_strategy(), min_size=0, max_size=20)) +@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow]) +def test_property_3_model_list_filter_stability(model_ids): + """ + **Validates: Requirements 5.2, 5.3, 5.6** + + Property 3: Model list filter stability. + - Filter excludes models with excluded prefixes or suffixes. + - Filter is idempotent. + - Every returned dict contains 'id' and 'name'. + """ + provider = _make_provider() + raw = [_make_model(mid) for mid in model_ids] + provider.client.models.list.return_value.data = raw + + result = provider.get_models() + + # Every returned model must have id and name + for m in result: + assert "id" in m + assert "name" in m + + # No excluded prefix + for m in result: + for prefix in EXCLUDED_PREFIXES: + assert not m["id"].startswith(prefix), f"{m['id']} starts with excluded prefix {prefix}" + + # No excluded suffix + for m in result: + for suffix in EXCLUDED_SUFFIXES: + assert not m["id"].endswith(suffix), f"{m['id']} ends with excluded suffix {suffix}" + + # Idempotency: applying filter again yields same result + result_ids = {m["id"] for m in result} + for mid in result_ids: + assert not any(mid.startswith(p) for p in EXCLUDED_PREFIXES) + assert not any(mid.endswith(s) for s in EXCLUDED_SUFFIXES) + + +# --- Property 4: Embedding order preservation --- +# Validates: Requirements 9.2 + +@given(texts=st.lists(st.text(min_size=1), min_size=1, max_size=10)) +@settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow]) +def test_property_4_embedding_order_preservation(texts): + """ + **Validates: Requirements 9.2** + + Property 4: Embedding order preservation. + embed() returns len(result) == len(texts) and result[i] corresponds to texts[i]. + """ + provider = _make_provider() + + # Build shuffled embeddings (reverse order to test sorting) + embeddings = [] + for i, _ in enumerate(texts): + e = MagicMock() + e.index = i + e.embedding = [float(i), float(i) * 0.1] + embeddings.append(e) + + # Shuffle: reverse order + shuffled = list(reversed(embeddings)) + provider.client.embeddings.create.return_value.data = shuffled + + result = provider.embed(model="text-embedding-3-small", texts=texts) + + assert len(result) == len(texts) + for i, vec in enumerate(result): + assert vec == [float(i), float(i) * 0.1], f"result[{i}] does not match expected vector for texts[{i}]" + + +# --- Property 2: Tool call / structured output mutual exclusion --- +# Validates: Requirements 6.2, 7.2 + +def _valid_response_json(): + return json.dumps({ + "answer": "Here is the answer.", + "status": "ANSWERED", + "escalation": False, + "reason_for_escalation": "", + "sentiment_score": 50, + "escalation_score": 10, + "criticality_score": 5, + }) + + +@given( + has_tools=st.booleans(), + finish_reason=st.sampled_from(["stop", "tool_calls"]), +) +@settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow]) +def test_property_2_tool_call_structured_output_mutual_exclusion(has_tools, finish_reason): + """ + **Validates: Requirements 6.2, 7.2** + + Property 2: Tool call / structured output mutual exclusion. + - If tools non-empty and finish_reason='tool_calls': tool_calls non-empty, first element is str. + - If tools empty/None and finish_reason='stop': tool_calls=[], first element is response_schema instance. + """ + provider = _make_provider() + + tools = [{"type": "function", "function": {"name": "fn", "description": "", "parameters": {}}}] if has_tools else None + + if finish_reason == "tool_calls": + provider.client.chat.completions.create.return_value = _make_tool_call_response("fn", {"x": 1}) + else: + provider.client.chat.completions.create.return_value = _make_stop_response(_valid_response_json()) + + messages = [{"role": "user", "content": "test"}] + + result_tuple = provider.generate_with_conversation( + model="gpt-4o", + messages=messages, + tools=tools, + response_schema=SupportAgentResponse, + ) + + first, tool_calls, usage = result_tuple + assert isinstance(tool_calls, list) + assert isinstance(usage, dict) + + if finish_reason == "tool_calls": + # Tool calls branch: tool_calls non-empty, first is str + assert len(tool_calls) > 0 + assert isinstance(first, str) + for tc in tool_calls: + assert "id" in tc + assert "name" in tc + assert "args" in tc + else: + # Stop branch (no tools): tool_calls empty, first is parsed schema + if not has_tools: + assert tool_calls == [] + assert isinstance(first, SupportAgentResponse) diff --git a/frontend/components/icons/OpenAIIcon.vue b/frontend/components/icons/OpenAIIcon.vue new file mode 100644 index 0000000..c1ba5ce --- /dev/null +++ b/frontend/components/icons/OpenAIIcon.vue @@ -0,0 +1,6 @@ + diff --git a/frontend/composables/useAIProviderIcon.ts b/frontend/composables/useAIProviderIcon.ts index f3bf9d9..96f0f47 100644 --- a/frontend/composables/useAIProviderIcon.ts +++ b/frontend/composables/useAIProviderIcon.ts @@ -1,12 +1,15 @@ import { computed, defineComponent, h } from 'vue' import { Sparkles } from 'lucide-vue-next' import GeminiIcon from '~/components/icons/GeminiIcon.vue' +import OpenAIIcon from '~/components/icons/OpenAIIcon.vue' export function useAIProviderIcon(provider: string) { return computed(() => { switch (provider?.toLowerCase()) { case 'gemini': return GeminiIcon + case 'openai': + return OpenAIIcon case 'custom': default: return defineComponent({ From 7f4ea3066a2b58603bbf42afb5c9c0a3c9538635 Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Fri, 17 Apr 2026 10:53:14 +0545 Subject: [PATCH 2/2] feat: OpenAI support & filter models --- backend/core/consts.py | 10 +- .../services/providers/ai/openai_provider.py | 130 +++-- backend/core/tests/test_openai_provider.py | 510 ------------------ frontend/components/App/ConfigureAIModels.vue | 36 +- 4 files changed, 124 insertions(+), 562 deletions(-) delete mode 100644 backend/core/tests/test_openai_provider.py diff --git a/backend/core/consts.py b/backend/core/consts.py index fbd20dc..f702bd6 100644 --- a/backend/core/consts.py +++ b/backend/core/consts.py @@ -18,11 +18,11 @@ 'label': 'OpenAI', 'base_url': 'https://api.openai.com/v1' }, - { - 'id': 'custom', - 'label': 'Custom Provider', - 'base_url': '' - } + # { + # 'id': 'custom', + # 'label': 'Custom Provider', + # 'base_url': '' + # } ] SUPPORTED_INTEGRATIONS = [ diff --git a/backend/core/services/providers/ai/openai_provider.py b/backend/core/services/providers/ai/openai_provider.py index 7a94825..f8e0383 100644 --- a/backend/core/services/providers/ai/openai_provider.py +++ b/backend/core/services/providers/ai/openai_provider.py @@ -5,10 +5,6 @@ from ...contracts.ai_provider_contract import AIProviderContract from core.agent_response_schema import SupportAgentResponse -EXCLUDED_PREFIXES = ("whisper", "tts", "dall-e", "davinci", "babbage", "text-embedding-ada") -EXCLUDED_SUFFIXES = ("-instruct",) - - class OpenAIProvider(AIProviderContract): def __init__(self, api_key: str, config: Optional[Dict[str, Any]] = None) -> None: super().__init__(api_key, config) @@ -37,10 +33,6 @@ def get_models(self) -> list[Dict[str, Any]]: result = [] for model in raw_models: model_id = model.id - if any(model_id.startswith(p) for p in EXCLUDED_PREFIXES): - continue - if any(model_id.endswith(s) for s in EXCLUDED_SUFFIXES): - continue result.append({ "id": model_id, "name": model_id, @@ -57,26 +49,71 @@ def generate_with_conversation( tools: list[dict] | None, response_schema: type[BaseModel], ) -> tuple: - kwargs: Dict[str, Any] = { - "model": model, - "messages": messages, - } - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = "auto" - else: - kwargs["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": response_schema.__name__, - "schema": response_schema.model_json_schema(), - "strict": True, - }, - } + has_tool_history = any(m.get("role") == "tool" for m in messages) try: - response = self.client.chat.completions.create(**kwargs) + if tools: + response = self.client.chat.completions.create( + model=model, + messages=messages, + tools=tools, + tool_choice="auto", + ) + usage_metadata = self._extract_usage(response) + choice = response.choices[0] + + if choice.finish_reason == "tool_calls": + raw_tool_calls = [ + { + "id": tc.id, + "name": tc.function.name, + "args": json.loads(tc.function.arguments), + } + for tc in (choice.message.tool_calls or []) + ] + messages.append({ + "role": "assistant", + "content": choice.message.content or "", + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in (choice.message.tool_calls or []) + ], + }) + return choice.message.content or "", raw_tool_calls, usage_metadata + + return self._parse_structured_response(choice, response_schema, self._extract_usage(response)) + + elif has_tool_history: + response = self.client.chat.completions.create( + model=model, + messages=messages, + response_format={"type": "json_object"}, + ) + usage_metadata = self._extract_usage(response) + choice = response.choices[0] + return self._parse_structured_response(choice, response_schema, usage_metadata) + + else: + completion = self.client.beta.chat.completions.parse( + model=model, + messages=messages, + response_format=response_schema, + ) + usage_metadata = self._extract_usage(completion) + message = completion.choices[0].message + + if message.refusal: + raise ValueError(f"OpenAI refused the request: {message.refusal}") + + return message.parsed, [], usage_metadata + except openai.AuthenticationError as e: raise ValueError(f"Invalid OpenAI API key: {e}") except openai.RateLimitError as e: @@ -84,24 +121,37 @@ def generate_with_conversation( except openai.APIError as e: raise ValueError(f"OpenAI API error: {e}") - usage_metadata = self._extract_usage(response) - choice = response.choices[0] - - if choice.finish_reason == "tool_calls": - raw_tool_calls = [] - for tc in (choice.message.tool_calls or []): - raw_tool_calls.append({ - "id": tc.id, - "name": tc.function.name, - "args": json.loads(tc.function.arguments), - }) - return choice.message.content or "", raw_tool_calls, usage_metadata + def _parse_structured_response(self, choice, response_schema, usage_metadata): + raw = choice.message.content or "" + stripped = raw.strip() + + if stripped.startswith("```"): + stripped = stripped.split("\n", 1)[-1] + if stripped.rstrip().endswith("```"): + stripped = stripped.rstrip()[:-3].rstrip() + + brace_idx = stripped.find("{") + if brace_idx > 0: + stripped = stripped[brace_idx:] + last_brace = stripped.rfind("}") + if last_brace != -1 and last_brace < len(stripped) - 1: + stripped = stripped[:last_brace + 1] + + if not stripped.startswith("{"): + stripped = json.dumps({ + "answer": raw.strip(), + "status": "ANSWERED", + "escalation": False, + "reason_for_escalation": "", + "sentiment_score": 50, + "escalation_score": 0, + "criticality_score": 0, + }) try: - parsed = response_schema.model_validate_json(choice.message.content) + parsed = response_schema.model_validate_json(stripped) except Exception as e: raise ValueError(f"Failed to parse OpenAI response as {response_schema.__name__}: {e}") - return parsed, [], usage_metadata def _extract_usage(self, response) -> dict: diff --git a/backend/core/tests/test_openai_provider.py b/backend/core/tests/test_openai_provider.py deleted file mode 100644 index 47dd16d..0000000 --- a/backend/core/tests/test_openai_provider.py +++ /dev/null @@ -1,510 +0,0 @@ -""" -Tests for OpenAIProvider — unit tests and property-based tests. - -Property tests: - - Property 2: Tool call / structured output mutual exclusion (Validates: Requirements 6.2, 7.2) - - Property 3: Model list filter stability (Validates: Requirements 5.2, 5.3, 5.6) - - Property 4: Embedding order preservation (Validates: Requirements 9.2) - - Property 5: validate_connection never raises (Validates: Requirements 4.2, 4.3) -""" -import json -import pytest -from unittest.mock import MagicMock, patch, PropertyMock -from types import SimpleNamespace - -from hypothesis import given, settings, HealthCheck -from hypothesis import strategies as st - -from core.services.providers.ai.openai_provider import OpenAIProvider, EXCLUDED_PREFIXES, EXCLUDED_SUFFIXES -from core.agent_response_schema import SupportAgentResponse - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _make_provider(api_key="sk-test", config=None): - """Return an OpenAIProvider with a mocked openai.OpenAI client.""" - with patch("core.services.providers.ai.openai_provider.openai.OpenAI"): - provider = OpenAIProvider(api_key=api_key, config=config) - return provider - - -def _make_model(model_id, object_="model", created=0, owned_by="openai"): - m = MagicMock() - m.id = model_id - m.object = object_ - m.created = created - m.owned_by = owned_by - return m - - -def _make_stop_response(content: str): - """Build a minimal ChatCompletion-like object with finish_reason='stop'.""" - choice = MagicMock() - choice.finish_reason = "stop" - choice.message.content = content - choice.message.tool_calls = None - - usage = MagicMock() - usage.prompt_tokens = 10 - usage.completion_tokens = 20 - usage.total_tokens = 30 - usage.prompt_tokens_details = None - - response = MagicMock() - response.choices = [choice] - response.usage = usage - return response - - -def _make_tool_call_response(tool_name="search", args=None): - """Build a minimal ChatCompletion-like object with finish_reason='tool_calls'.""" - tc = MagicMock() - tc.id = "call_abc123" - tc.function.name = tool_name - tc.function.arguments = json.dumps(args or {"query": "test"}) - - choice = MagicMock() - choice.finish_reason = "tool_calls" - choice.message.content = "" - choice.message.tool_calls = [tc] - - usage = MagicMock() - usage.prompt_tokens = 5 - usage.completion_tokens = 10 - usage.total_tokens = 15 - usage.prompt_tokens_details = None - - response = MagicMock() - response.choices = [choice] - response.usage = usage - return response - - -# --------------------------------------------------------------------------- -# Unit tests — __init__ -# --------------------------------------------------------------------------- - -class TestOpenAIProviderInit: - def test_init_success(self): - with patch("core.services.providers.ai.openai_provider.openai.OpenAI") as mock_cls: - provider = OpenAIProvider(api_key="sk-test") - mock_cls.assert_called_once_with(api_key="sk-test") - assert provider.api_key == "sk-test" - - def test_init_with_base_url(self): - with patch("core.services.providers.ai.openai_provider.openai.OpenAI") as mock_cls: - provider = OpenAIProvider(api_key="sk-test", config={"base_url": "https://proxy.example.com"}) - mock_cls.assert_called_once_with(api_key="sk-test", base_url="https://proxy.example.com") - - def test_init_with_none_config(self): - with patch("core.services.providers.ai.openai_provider.openai.OpenAI") as mock_cls: - provider = OpenAIProvider(api_key="sk-test", config=None) - mock_cls.assert_called_once_with(api_key="sk-test") - - def test_init_raises_value_error_on_sdk_failure(self): - with patch("core.services.providers.ai.openai_provider.openai.OpenAI", side_effect=Exception("bad key")): - with pytest.raises(ValueError, match="Failed to initialize OpenAI client"): - OpenAIProvider(api_key="bad") - - -# --------------------------------------------------------------------------- -# Unit tests — get_models -# --------------------------------------------------------------------------- - -class TestGetModels: - def test_filters_excluded_prefixes(self): - provider = _make_provider() - raw = [ - _make_model("gpt-4o"), - _make_model("whisper-1"), - _make_model("tts-1"), - _make_model("dall-e-3"), - _make_model("davinci-002"), - _make_model("babbage-002"), - _make_model("text-embedding-ada-002"), - ] - provider.client.models.list.return_value.data = raw - result = provider.get_models() - ids = [m["id"] for m in result] - assert "gpt-4o" in ids - assert "whisper-1" not in ids - assert "tts-1" not in ids - assert "dall-e-3" not in ids - assert "davinci-002" not in ids - assert "babbage-002" not in ids - assert "text-embedding-ada-002" not in ids - - def test_filters_instruct_suffix(self): - provider = _make_provider() - provider.client.models.list.return_value.data = [ - _make_model("gpt-3.5-turbo-instruct"), - _make_model("gpt-4o"), - ] - result = provider.get_models() - ids = [m["id"] for m in result] - assert "gpt-3.5-turbo-instruct" not in ids - assert "gpt-4o" in ids - - def test_returns_id_and_name_keys(self): - provider = _make_provider() - provider.client.models.list.return_value.data = [_make_model("gpt-4o")] - result = provider.get_models() - assert len(result) == 1 - assert "id" in result[0] - assert "name" in result[0] - - def test_raises_value_error_on_api_failure(self): - provider = _make_provider() - provider.client.models.list.side_effect = Exception("network error") - with pytest.raises(ValueError, match="Failed to retrieve models"): - provider.get_models() - - def test_returns_empty_list_when_all_filtered(self): - provider = _make_provider() - provider.client.models.list.return_value.data = [ - _make_model("whisper-1"), - _make_model("tts-1"), - ] - result = provider.get_models() - assert result == [] - - -# --------------------------------------------------------------------------- -# Unit tests — validate_connection -# --------------------------------------------------------------------------- - -class TestValidateConnection: - def test_returns_true_and_models_on_success(self): - provider = _make_provider() - provider.client.models.list.return_value.data = [_make_model("gpt-4o")] - ok, models = provider.validate_connection() - assert ok is True - assert len(models) == 1 - - def test_returns_false_empty_on_exception(self): - provider = _make_provider() - provider.client.models.list.side_effect = Exception("auth error") - ok, models = provider.validate_connection() - assert ok is False - assert models == [] - - def test_never_raises(self): - provider = _make_provider() - provider.client.models.list.side_effect = RuntimeError("unexpected") - # Must not raise - result = provider.validate_connection() - assert isinstance(result, tuple) - - -# --------------------------------------------------------------------------- -# Unit tests — generate_with_conversation -# --------------------------------------------------------------------------- - -class TestGenerateWithConversation: - def _valid_response_json(self): - return json.dumps({ - "answer": "Reset your password via Settings.", - "status": "ANSWERED", - "escalation": False, - "reason_for_escalation": "", - "sentiment_score": 70, - "escalation_score": 5, - "criticality_score": 10, - }) - - def test_no_tools_returns_parsed_schema(self): - provider = _make_provider() - provider.client.chat.completions.create.return_value = _make_stop_response(self._valid_response_json()) - messages = [{"role": "user", "content": "How do I reset my password?"}] - result, tool_calls, usage = provider.generate_with_conversation( - model="gpt-4o", messages=messages, tools=None, response_schema=SupportAgentResponse - ) - assert isinstance(result, SupportAgentResponse) - assert tool_calls == [] - assert "prompt_tokens" in usage - - def test_no_tools_uses_response_format(self): - provider = _make_provider() - provider.client.chat.completions.create.return_value = _make_stop_response(self._valid_response_json()) - messages = [{"role": "user", "content": "Hello"}] - provider.generate_with_conversation( - model="gpt-4o", messages=messages, tools=None, response_schema=SupportAgentResponse - ) - call_kwargs = provider.client.chat.completions.create.call_args[1] - assert "response_format" in call_kwargs - assert "tools" not in call_kwargs - - def test_with_tools_returns_tool_calls(self): - provider = _make_provider() - provider.client.chat.completions.create.return_value = _make_tool_call_response("search_kb", {"query": "reset"}) - tools = [{"type": "function", "function": {"name": "search_kb", "description": "Search", "parameters": {}}}] - messages = [{"role": "user", "content": "Help"}] - text, tool_calls, usage = provider.generate_with_conversation( - model="gpt-4o", messages=messages, tools=tools, response_schema=SupportAgentResponse - ) - assert len(tool_calls) == 1 - assert tool_calls[0]["name"] == "search_kb" - assert "id" in tool_calls[0] - assert "args" in tool_calls[0] - - def test_with_tools_uses_tool_choice_auto(self): - provider = _make_provider() - provider.client.chat.completions.create.return_value = _make_tool_call_response() - tools = [{"type": "function", "function": {"name": "fn", "description": "", "parameters": {}}}] - provider.generate_with_conversation( - model="gpt-4o", messages=[{"role": "user", "content": "x"}], tools=tools, response_schema=SupportAgentResponse - ) - call_kwargs = provider.client.chat.completions.create.call_args[1] - assert call_kwargs.get("tool_choice") == "auto" - assert "response_format" not in call_kwargs - - def test_raises_value_error_on_authentication_error(self): - import openai as openai_module - provider = _make_provider() - provider.client.chat.completions.create.side_effect = openai_module.AuthenticationError( - message="Invalid key", response=MagicMock(), body={} - ) - with pytest.raises(ValueError, match="Invalid OpenAI API key"): - provider.generate_with_conversation( - model="gpt-4o", messages=[{"role": "user", "content": "x"}], - tools=None, response_schema=SupportAgentResponse - ) - - def test_raises_value_error_on_rate_limit_error(self): - import openai as openai_module - provider = _make_provider() - provider.client.chat.completions.create.side_effect = openai_module.RateLimitError( - message="Rate limit", response=MagicMock(), body={} - ) - with pytest.raises(ValueError, match="rate limit"): - provider.generate_with_conversation( - model="gpt-4o", messages=[{"role": "user", "content": "x"}], - tools=None, response_schema=SupportAgentResponse - ) - - def test_raises_value_error_on_api_error(self): - import openai as openai_module - provider = _make_provider() - provider.client.chat.completions.create.side_effect = openai_module.APIStatusError( - message="Server error", response=MagicMock(status_code=500), body={} - ) - with pytest.raises(ValueError, match="OpenAI API error"): - provider.generate_with_conversation( - model="gpt-4o", messages=[{"role": "user", "content": "x"}], - tools=None, response_schema=SupportAgentResponse - ) - - def test_raises_value_error_on_bad_json(self): - provider = _make_provider() - provider.client.chat.completions.create.return_value = _make_stop_response("not valid json {{{") - with pytest.raises(ValueError, match="Failed to parse OpenAI response"): - provider.generate_with_conversation( - model="gpt-4o", messages=[{"role": "user", "content": "x"}], - tools=None, response_schema=SupportAgentResponse - ) - - -# --------------------------------------------------------------------------- -# Unit tests — embed -# --------------------------------------------------------------------------- - -class TestEmbed: - def test_returns_vectors_in_input_order(self): - provider = _make_provider() - # Return shuffled: index 1 first, then 0 - e0 = MagicMock(); e0.index = 0; e0.embedding = [0.1, 0.2] - e1 = MagicMock(); e1.index = 1; e1.embedding = [0.3, 0.4] - provider.client.embeddings.create.return_value.data = [e1, e0] # shuffled - result = provider.embed(model="text-embedding-3-small", texts=["a", "b"]) - assert result[0] == [0.1, 0.2] - assert result[1] == [0.3, 0.4] - - def test_raises_value_error_on_failure(self): - provider = _make_provider() - provider.client.embeddings.create.side_effect = Exception("network error") - with pytest.raises(ValueError, match="OpenAI embedding error"): - provider.embed(model="text-embedding-3-small", texts=["hello"]) - - -# --------------------------------------------------------------------------- -# Property-based tests -# --------------------------------------------------------------------------- - -# --- Property 5: validate_connection never raises --- -# Validates: Requirements 4.2, 4.3 - -@given(api_key=st.text()) -@settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow]) -def test_property_5_validate_connection_never_raises(api_key): - """ - **Validates: Requirements 4.2, 4.3** - - Property 5: validate_connection never raises. - For any api_key string, validate_connection always returns tuple[bool, list] - and never propagates an exception. - """ - with patch("core.services.providers.ai.openai_provider.openai.OpenAI"): - provider = OpenAIProvider(api_key=api_key if api_key else "sk-x") - - # Make the client raise an arbitrary exception - provider.client.models.list.side_effect = Exception("simulated failure") - - result = provider.validate_connection() - assert isinstance(result, tuple) - assert len(result) == 2 - ok, models = result - assert isinstance(ok, bool) - assert isinstance(models, list) - assert ok is False - assert models == [] - - -# --- Property 3: Model list filter stability --- -# Validates: Requirements 5.2, 5.3, 5.6 - -def _model_id_strategy(): - """Generate model IDs including edge cases.""" - normal = st.text(alphabet=st.characters(whitelist_categories=("Ll", "Lu", "Nd"), whitelist_characters="-_."), min_size=1, max_size=30) - prefixed = st.sampled_from(list(EXCLUDED_PREFIXES)).map(lambda p: p + "-extra") - suffixed = st.just("gpt-4o-instruct") - clean = st.just("gpt-4o") - return st.one_of(normal, prefixed, suffixed, clean) - - -@given(model_ids=st.lists(_model_id_strategy(), min_size=0, max_size=20)) -@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow]) -def test_property_3_model_list_filter_stability(model_ids): - """ - **Validates: Requirements 5.2, 5.3, 5.6** - - Property 3: Model list filter stability. - - Filter excludes models with excluded prefixes or suffixes. - - Filter is idempotent. - - Every returned dict contains 'id' and 'name'. - """ - provider = _make_provider() - raw = [_make_model(mid) for mid in model_ids] - provider.client.models.list.return_value.data = raw - - result = provider.get_models() - - # Every returned model must have id and name - for m in result: - assert "id" in m - assert "name" in m - - # No excluded prefix - for m in result: - for prefix in EXCLUDED_PREFIXES: - assert not m["id"].startswith(prefix), f"{m['id']} starts with excluded prefix {prefix}" - - # No excluded suffix - for m in result: - for suffix in EXCLUDED_SUFFIXES: - assert not m["id"].endswith(suffix), f"{m['id']} ends with excluded suffix {suffix}" - - # Idempotency: applying filter again yields same result - result_ids = {m["id"] for m in result} - for mid in result_ids: - assert not any(mid.startswith(p) for p in EXCLUDED_PREFIXES) - assert not any(mid.endswith(s) for s in EXCLUDED_SUFFIXES) - - -# --- Property 4: Embedding order preservation --- -# Validates: Requirements 9.2 - -@given(texts=st.lists(st.text(min_size=1), min_size=1, max_size=10)) -@settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow]) -def test_property_4_embedding_order_preservation(texts): - """ - **Validates: Requirements 9.2** - - Property 4: Embedding order preservation. - embed() returns len(result) == len(texts) and result[i] corresponds to texts[i]. - """ - provider = _make_provider() - - # Build shuffled embeddings (reverse order to test sorting) - embeddings = [] - for i, _ in enumerate(texts): - e = MagicMock() - e.index = i - e.embedding = [float(i), float(i) * 0.1] - embeddings.append(e) - - # Shuffle: reverse order - shuffled = list(reversed(embeddings)) - provider.client.embeddings.create.return_value.data = shuffled - - result = provider.embed(model="text-embedding-3-small", texts=texts) - - assert len(result) == len(texts) - for i, vec in enumerate(result): - assert vec == [float(i), float(i) * 0.1], f"result[{i}] does not match expected vector for texts[{i}]" - - -# --- Property 2: Tool call / structured output mutual exclusion --- -# Validates: Requirements 6.2, 7.2 - -def _valid_response_json(): - return json.dumps({ - "answer": "Here is the answer.", - "status": "ANSWERED", - "escalation": False, - "reason_for_escalation": "", - "sentiment_score": 50, - "escalation_score": 10, - "criticality_score": 5, - }) - - -@given( - has_tools=st.booleans(), - finish_reason=st.sampled_from(["stop", "tool_calls"]), -) -@settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow]) -def test_property_2_tool_call_structured_output_mutual_exclusion(has_tools, finish_reason): - """ - **Validates: Requirements 6.2, 7.2** - - Property 2: Tool call / structured output mutual exclusion. - - If tools non-empty and finish_reason='tool_calls': tool_calls non-empty, first element is str. - - If tools empty/None and finish_reason='stop': tool_calls=[], first element is response_schema instance. - """ - provider = _make_provider() - - tools = [{"type": "function", "function": {"name": "fn", "description": "", "parameters": {}}}] if has_tools else None - - if finish_reason == "tool_calls": - provider.client.chat.completions.create.return_value = _make_tool_call_response("fn", {"x": 1}) - else: - provider.client.chat.completions.create.return_value = _make_stop_response(_valid_response_json()) - - messages = [{"role": "user", "content": "test"}] - - result_tuple = provider.generate_with_conversation( - model="gpt-4o", - messages=messages, - tools=tools, - response_schema=SupportAgentResponse, - ) - - first, tool_calls, usage = result_tuple - assert isinstance(tool_calls, list) - assert isinstance(usage, dict) - - if finish_reason == "tool_calls": - # Tool calls branch: tool_calls non-empty, first is str - assert len(tool_calls) > 0 - assert isinstance(first, str) - for tc in tool_calls: - assert "id" in tc - assert "name" in tc - assert "args" in tc - else: - # Stop branch (no tools): tool_calls empty, first is parsed schema - if not has_tools: - assert tool_calls == [] - assert isinstance(first, SupportAgentResponse) diff --git a/frontend/components/App/ConfigureAIModels.vue b/frontend/components/App/ConfigureAIModels.vue index 34e7b78..62d99f2 100644 --- a/frontend/components/App/ConfigureAIModels.vue +++ b/frontend/components/App/ConfigureAIModels.vue @@ -158,6 +158,8 @@ const configuredAIProviderOptions = computed(() => })), ) +const EXCLUDED_MODELS = ['text-embedding-3-small'] + const getProviderModels = (providerId: number) => { const providerWithModels = AIProviderModelsStore.providerModels.find( pm => pm.ai_provider.id === providerId, @@ -167,13 +169,33 @@ const getProviderModels = (providerId: number) => { return [] } - return providerWithModels.ai_provider_models.models_data.map((model) => { - const modelName = model.name || model.displayName || model.id || Object.values(model)[0] || '' - return { - label: modelName, - value: modelName, - } - }) + const capability = props.config.capability.toLowerCase() + + return providerWithModels.ai_provider_models.models_data + .filter((model) => { + const modelName = (model.name || model.displayName || model.id || Object.values(model)[0] || '').toLowerCase() + + if (EXCLUDED_MODELS.some(excluded => modelName.includes(excluded.toLowerCase()))) { + return false + } + + if (capability === 'embedding') { + return modelName.includes('embedding') + } + + if (capability === 'text') { + return !modelName.includes('embedding') + } + + return true + }) + .map((model) => { + const modelName = model.name || model.displayName || model.id || Object.values(model)[0] || '' + return { + label: modelName, + value: modelName, + } + }) } const configureModel = form.handleSubmit(async (values) => {