diff --git a/tests/unit/test_config_service.py b/tests/unit/test_config_service.py index dd2b75c5..886a0d44 100644 --- a/tests/unit/test_config_service.py +++ b/tests/unit/test_config_service.py @@ -280,6 +280,116 @@ async def test_provider_schema_uses_astrbot_provider_config_classification(self, assert schema["provider_options_by_type"]["embedding"][0]["provider_type_label"] == "Embedding" assert schema["provider_options_by_type"]["rerank"][0]["provider_type_label"] == "Reranker" + @pytest.mark.asyncio + async def test_provider_schema_replaces_stale_config_model_with_live_models(self, tmp_path, monkeypatch): + plugin_config = PluginConfig.create_default() + plugin_config.data_dir = str(tmp_path / "self_learning_data") + + context = Mock() + context.get_all_providers = Mock(return_value=[]) + context.get_all_embedding_providers = Mock(return_value=[]) + context.provider_manager = SimpleNamespace( + provider_insts=[], + embedding_provider_insts=[], + rerank_provider_insts=[], + inst_map={}, + provider_sources_config=[ + { + "id": "openai-source", + "provider_type": "chat_completion", + "api_base": "https://models.example.test/v1/chat/completions", + "key": ["sk-test"], + }, + ], + providers_config=[ + { + "id": "chat-config", + "provider_source_id": "openai-source", + "model": "deleted-model", + }, + ], + ) + + service_factory = Mock() + service_factory.context = context + factory_manager = Mock() + factory_manager.get_service_factory = Mock(return_value=service_factory) + + container = Mock() + container.plugin_config = plugin_config + container.factory_manager = factory_manager + + calls = [] + + async def fake_fetch(self, models_url, api_key="", custom_headers=None): + calls.append((models_url, api_key, custom_headers)) + return ["live-model-b", "live-model-a"] + + monkeypatch.setattr(ConfigService, "_fetch_models_from_endpoint", fake_fetch) + + schema = await ConfigService(container).get_config_schema() + option = schema["provider_options_by_type"]["chat_completion"][0] + groups = {group["key"]: group for group in schema["groups"]} + model_fields = {field["key"]: field for field in groups["Model_Configuration"]["fields"]} + + assert calls == [ + ("https://models.example.test/v1/models", "sk-test", None), + ] + assert option["value"] == "chat-config" + assert option["model_source"] == "live" + assert option["available_models"] == ["live-model-a", "live-model-b"] + assert option["configured_model_available"] is False + assert "deleted-model" not in option["label"] + assert "live-model-a" in option["label"] + assert model_fields["filter_provider_id"]["options"][0] == option + + @pytest.mark.asyncio + async def test_provider_schema_falls_back_to_config_model_when_live_models_unavailable(self, tmp_path, monkeypatch): + plugin_config = PluginConfig.create_default() + plugin_config.data_dir = str(tmp_path / "self_learning_data") + + context = Mock() + context.get_all_providers = Mock(return_value=[]) + context.get_all_embedding_providers = Mock(return_value=[]) + context.provider_manager = SimpleNamespace( + provider_insts=[], + embedding_provider_insts=[], + rerank_provider_insts=[], + inst_map={}, + provider_sources_config=[], + providers_config=[ + { + "id": "chat-config", + "provider_type": "chat_completion", + "api_base": "https://models.example.test/v1", + "key": ["sk-test"], + "model": "configured-model", + }, + ], + ) + + service_factory = Mock() + service_factory.context = context + factory_manager = Mock() + factory_manager.get_service_factory = Mock(return_value=service_factory) + + container = Mock() + container.plugin_config = plugin_config + container.factory_manager = factory_manager + + async def fake_fetch(self, models_url, api_key="", custom_headers=None): + return [] + + monkeypatch.setattr(ConfigService, "_fetch_models_from_endpoint", fake_fetch) + + schema = await ConfigService(container).get_config_schema() + option = schema["provider_options_by_type"]["chat_completion"][0] + + assert option["value"] == "chat-config" + assert option["model_source"] == "configured" + assert "configured-model" in option["label"] + assert "available_models" not in option + def test_provider_option_builders_share_metadata_shape(self): provider_meta = SimpleNamespace( id="embed-live", @@ -305,6 +415,38 @@ def test_provider_option_builders_share_metadata_shape(self): assert config_option["provider_type"] == "embedding" assert config_option["provider_type_label"] == "Embedding" + def test_generic_provider_field_uses_combined_prefetched_options(self, tmp_path): + service = ConfigService(build_container(tmp_path)) + provider_options_by_type = { + "chat_completion": [ + ConfigService._build_provider_option("chat-a", "gpt-test", "chat_completion"), + ], + "embedding": [ + ConfigService._build_provider_option("embed-a", "embed-test", "embedding"), + ], + "rerank": [ + ConfigService._build_provider_option("rerank-a", "rerank-test", "rerank"), + ], + } + + field = service._build_field_spec( + "custom_provider_id", + { + "description": "自定义 Provider", + "type": "string", + }, + {}, + provider_options_by_type, + ) + + assert field["widget"] == "provider" + assert field["provider_type"] == "" + assert {option["value"] for option in field["options"]} == { + "chat-a", + "embed-a", + "rerank-a", + } + @pytest.mark.asyncio async def test_config_schema_covers_all_plugin_config_fields(self, tmp_path): container = build_container(tmp_path) diff --git a/webui/services/config_service.py b/webui/services/config_service.py index 3db1b1fb..db8b1d8b 100644 --- a/webui/services/config_service.py +++ b/webui/services/config_service.py @@ -3,12 +3,15 @@ """ from __future__ import annotations +import asyncio +import inspect import json import os from collections.abc import MutableMapping from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +import aiohttp from astrbot.api import logger from pydantic import ValidationError @@ -294,6 +297,9 @@ def _load_schema_definition() -> Dict[str, Any]: "rerank": "Reranker", } +_PROVIDER_MODELS_TIMEOUT_SECONDS = 3.0 +_PROVIDER_OPTION_MODEL_PREVIEW_LIMIT = 3 + _RESTART_REQUIRED_KEYS = { "data_dir", "db_type", @@ -711,30 +717,99 @@ def _provider_type_value(provider: Any, default_type: str = "") -> str: def _provider_type_label(provider_type: str) -> str: return _PROVIDER_TYPE_LABELS.get(provider_type, provider_type) + @staticmethod + def _model_id_from_item(item: Any) -> Optional[str]: + if isinstance(item, bytes): + try: + item = item.decode("utf-8") + except UnicodeDecodeError: + return None + if isinstance(item, str): + model_id = item.strip() + return model_id or None + if isinstance(item, dict): + for key in ("id", "name", "model"): + value = item.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + value = getattr(item, "id", None) or getattr(item, "name", None) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + @staticmethod + def _normalize_model_list(value: Any) -> List[str]: + if value is None: + return [] + if isinstance(value, dict): + value = value.get("data") or value.get("models") or [] + elif isinstance(value, (str, bytes)): + value = [value] + + models: List[str] = [] + seen = set() + for item in ConfigService._as_provider_list(value): + model_id = ConfigService._model_id_from_item(item) + if not model_id or model_id in seen: + continue + seen.add(model_id) + models.append(model_id) + return sorted(models) + + @staticmethod + def _model_preview(models: List[str]) -> str: + preview = ", ".join(models[:_PROVIDER_OPTION_MODEL_PREVIEW_LIMIT]) + remaining = len(models) - _PROVIDER_OPTION_MODEL_PREVIEW_LIMIT + if remaining > 0: + preview = f"{preview} 等 {len(models)} 个模型" + return preview + + @staticmethod + def _option_model_label(model_name: Any, available_models: List[str]) -> str: + configured_model = str(model_name or "").strip() + if available_models: + if configured_model and configured_model in available_models: + return configured_model + return ConfigService._model_preview(available_models) + return configured_model + @staticmethod def _build_provider_option( provider_id: Any, model_name: Any = None, provider_type: str = "", - ) -> Optional[Dict[str, str]]: + available_models: Any = None, + ) -> Optional[Dict[str, Any]]: if not provider_id: return None + live_models = ConfigService._normalize_model_list(available_models) + model_label = ConfigService._option_model_label(model_name, live_models) + configured_model = str(model_name or "").strip() label_parts = [str(provider_id)] - if model_name and str(model_name) not in str(provider_id): - label_parts.append(str(model_name)) + if model_label and model_label not in str(provider_id): + label_parts.append(model_label) if provider_type: label_parts.append(provider_type) - return { + option: Dict[str, Any] = { "value": str(provider_id), "label": " / ".join(label_parts), "provider_type": provider_type, "provider_type_label": ConfigService._provider_type_label(provider_type), + "model_source": "live" if live_models else "configured", } + if live_models: + option["available_models"] = live_models + option["model_count"] = len(live_models) + if configured_model: + option["configured_model_available"] = configured_model in live_models + return option @staticmethod - def _provider_option(provider: Any, default_type: str = "") -> Optional[Dict[str, str]]: + def _provider_identity(provider: Any, default_type: str = "") -> Tuple[Any, Any, str]: try: meta = provider.meta() except Exception: @@ -742,15 +817,27 @@ def _provider_option(provider: Any, default_type: str = "") -> Optional[Dict[str provider_id = getattr(meta, "id", None) or getattr(provider, "id", None) provider_type = ConfigService._provider_type_value(provider, default_type) - model_name = getattr(meta, "model", None) or getattr(provider, "model", None) + model_name = ( + getattr(meta, "model", None) + or getattr(provider, "model", None) + or getattr(provider, "model_name", None) + ) + return provider_id, model_name, provider_type + + @staticmethod + def _provider_option(provider: Any, default_type: str = "") -> Optional[Dict[str, Any]]: + provider_id, model_name, provider_type = ConfigService._provider_identity( + provider, + default_type, + ) return ConfigService._build_provider_option(provider_id, model_name, provider_type) @staticmethod - def _provider_option_from_config( + def _provider_identity_from_config( provider_config: Any, provider_source_types: Dict[str, str], default_type: str = "", - ) -> Optional[Dict[str, str]]: + ) -> Optional[Tuple[Any, Any, str]]: if not isinstance(provider_config, dict): return None @@ -769,19 +856,44 @@ def _provider_option_from_config( or provider_config.get("embedding_model") or provider_config.get("rerank_model") ) + return provider_id, model_name, provider_type + + @staticmethod + def _provider_option_from_config( + provider_config: Any, + provider_source_types: Dict[str, str], + default_type: str = "", + ) -> Optional[Dict[str, Any]]: + identity = ConfigService._provider_identity_from_config( + provider_config, + provider_source_types, + default_type, + ) + if not identity: + return None + + provider_id, model_name, provider_type = identity return ConfigService._build_provider_option(provider_id, model_name, provider_type) @staticmethod - def _dedupe_options(options: List[Dict[str, str]]) -> List[Dict[str, str]]: + def _dedupe_options(options: List[Dict[str, Any]]) -> List[Dict[str, Any]]: seen = set() - result: List[Dict[str, str]] = [] + result: List[Dict[str, Any]] = [] + positions: Dict[Tuple[Any, Any], int] = {} for option in options: value = option.get("value") provider_type = option.get("provider_type", "") key = (value, provider_type) - if not value or key in seen: + if not value: + continue + if key in seen: + existing_index = positions[key] + existing = result[existing_index] + if option.get("model_source") == "live" and existing.get("model_source") != "live": + result[existing_index] = option continue seen.add(key) + positions[key] = len(result) result.append(option) return result @@ -820,7 +932,276 @@ def _provider_source_types(provider_manager: Any) -> Dict[str, str]: ) return source_types - def _provider_options(self, expected_type: Optional[str] = None) -> List[Dict[str, str]]: + @staticmethod + def _merged_provider_config(provider_manager: Any, provider_config: Any) -> Any: + if not isinstance(provider_config, dict): + return provider_config + + merge_getter = getattr(provider_manager, "get_merged_provider_config", None) + if callable(merge_getter): + try: + merged = merge_getter(provider_config) + if isinstance(merged, dict): + return merged + except Exception: + logger.debug("合并 Provider 配置失败", exc_info=True) + + provider_source_id = provider_config.get("provider_source_id") + if not provider_source_id: + return provider_config + + for provider_source in ConfigService._as_provider_list( + getattr(provider_manager, "provider_sources_config", None) + ): + if not isinstance(provider_source, dict): + continue + if provider_source.get("id") != provider_source_id: + continue + merged = {**provider_source, **provider_config} + if provider_config.get("id"): + merged["id"] = provider_config["id"] + return merged + return provider_config + + @staticmethod + def _resolve_api_key_value(value: Any) -> str: + if isinstance(value, (list, tuple)): + for item in value: + resolved = ConfigService._resolve_api_key_value(item) + if resolved: + return resolved + return "" + if not isinstance(value, str): + return "" + + key = value.strip() + if not key: + return "" + if key.startswith("$"): + env_key = key[1:] + if env_key.startswith("{") and env_key.endswith("}"): + env_key = env_key[1:-1] + return os.getenv(env_key, "").strip() if env_key else "" + return key + + @staticmethod + def _provider_api_key(provider_config: Dict[str, Any], provider_type: str) -> str: + if provider_type == "embedding": + key_fields = ("embedding_api_key", "api_key", "key") + elif provider_type == "rerank": + key_fields = ( + "rerank_api_key", + "nvidia_rerank_api_key", + "api_key", + "key", + ) + else: + key_fields = ("key", "api_key") + + for field in key_fields: + resolved = ConfigService._resolve_api_key_value(provider_config.get(field)) + if resolved: + return resolved + return "" + + @staticmethod + def _provider_api_base(provider_config: Dict[str, Any], provider_type: str) -> str: + if provider_type == "embedding": + base_fields = ("embedding_api_base", "api_base") + elif provider_type == "rerank": + base_fields = ( + "rerank_api_base", + "nvidia_rerank_api_base", + "api_base", + ) + else: + base_fields = ("api_base",) + + for field in base_fields: + value = provider_config.get(field) + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + @staticmethod + def _models_url_from_api_base(api_base: str) -> str: + base = str(api_base or "").strip().rstrip("/") + if not base: + return "" + + lower_base = base.lower() + endpoint_suffixes = ( + "/chat/completions", + "/embeddings", + "/rerank", + "/rerank/completions", + ) + for suffix in endpoint_suffixes: + if lower_base.endswith(suffix): + base = base[: -len(suffix)] + lower_base = base.lower() + break + + if lower_base.endswith("/models"): + return base + if lower_base.endswith("/v1") or lower_base.endswith("/v4"): + return f"{base}/models" + return f"{base}/v1/models" + + @staticmethod + def _models_payload_to_list(payload: Any) -> List[str]: + if isinstance(payload, dict): + return ConfigService._normalize_model_list( + payload.get("data") or payload.get("models") or [] + ) + return ConfigService._normalize_model_list(payload) + + async def _fetch_models_from_endpoint( + self, + models_url: str, + api_key: str = "", + custom_headers: Optional[Dict[str, Any]] = None, + ) -> List[str]: + if not models_url: + return [] + + headers: Dict[str, str] = {"Accept": "application/json"} + if isinstance(custom_headers, dict): + for key, value in custom_headers.items(): + if key and value is not None: + headers[str(key)] = str(value) + if api_key and not any(key.lower() == "authorization" for key in headers): + headers["Authorization"] = f"Bearer {api_key}" + + timeout = aiohttp.ClientTimeout(total=_PROVIDER_MODELS_TIMEOUT_SECONDS) + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(models_url, headers=headers) as response: + response.raise_for_status() + payload = await response.json(content_type=None) + except Exception as exc: + logger.debug(f"实时获取模型列表失败: {models_url}: {exc}") + return [] + + return self._models_payload_to_list(payload) + + async def _models_from_provider_config( + self, + provider_config: Any, + provider_type: str, + model_cache: Dict[Tuple[str, str, str], List[str]], + ) -> List[str]: + if not isinstance(provider_config, dict): + return [] + + provider_id = str(provider_config.get("id") or provider_config.get("provider_id") or "") + cache_key = ("config", provider_type, provider_id) + if cache_key in model_cache: + return model_cache[cache_key] + + api_base = self._provider_api_base(provider_config, provider_type) + models_url = self._models_url_from_api_base(api_base) + api_key = self._provider_api_key(provider_config, provider_type) + custom_headers = provider_config.get("custom_headers") + models = await self._fetch_models_from_endpoint( + models_url, + api_key, + custom_headers if isinstance(custom_headers, dict) else None, + ) + model_cache[cache_key] = models + return models + + async def _models_from_provider_instance( + self, + provider: Any, + provider_id: Any, + provider_type: str, + model_cache: Dict[Tuple[str, str, str], List[str]], + ) -> List[str]: + cache_key = ("instance", provider_type, str(provider_id)) + if cache_key in model_cache: + return model_cache[cache_key] + + models: List[str] = [] + get_models = getattr(provider, "get_models", None) + if callable(get_models): + try: + result = get_models() + if inspect.isawaitable(result): + result = await asyncio.wait_for( + result, + timeout=_PROVIDER_MODELS_TIMEOUT_SECONDS, + ) + models = self._normalize_model_list(result) + except Exception: + logger.debug(f"Provider {provider_id} 实时获取模型列表失败", exc_info=True) + + if not models: + provider_config = getattr(provider, "provider_config", None) + models = await self._models_from_provider_config( + provider_config, + provider_type, + model_cache, + ) + + model_cache[cache_key] = models + return models + + async def _provider_option_async( + self, + provider: Any, + default_type: str, + model_cache: Dict[Tuple[str, str, str], List[str]], + ) -> Optional[Dict[str, Any]]: + provider_id, model_name, provider_type = self._provider_identity( + provider, + default_type, + ) + if not provider_id: + return None + + models = await self._models_from_provider_instance( + provider, + provider_id, + provider_type, + model_cache, + ) + return self._build_provider_option( + provider_id, + model_name, + provider_type, + models, + ) + + async def _provider_option_from_config_async( + self, + provider_config: Any, + provider_source_types: Dict[str, str], + default_type: str, + model_cache: Dict[Tuple[str, str, str], List[str]], + ) -> Optional[Dict[str, Any]]: + identity = self._provider_identity_from_config( + provider_config, + provider_source_types, + default_type, + ) + if not identity: + return None + + provider_id, model_name, provider_type = identity + models = await self._models_from_provider_config( + provider_config, + provider_type, + model_cache, + ) + return self._build_provider_option( + provider_id, + model_name, + provider_type, + models, + ) + + def _provider_options(self, expected_type: Optional[str] = None) -> List[Dict[str, Any]]: factory_manager = getattr(self.container, "factory_manager", None) if not factory_manager or not hasattr(factory_manager, "get_service_factory"): return [] @@ -830,7 +1211,7 @@ def _provider_options(self, expected_type: Optional[str] = None) -> List[Dict[st if not context: return [] - options: List[Dict[str, str]] = [] + options: List[Dict[str, Any]] = [] expected = self._normalize_provider_type(expected_type) provider_manager = getattr(context, "provider_manager", None) @@ -890,6 +1271,120 @@ def _provider_options(self, expected_type: Optional[str] = None) -> List[Dict[st logger.warning(f"获取 Provider 列表失败: {e}") return [] + async def _provider_options_async( + self, + expected_type: Optional[str] = None, + model_cache: Optional[Dict[Tuple[str, str, str], List[str]]] = None, + ) -> List[Dict[str, Any]]: + factory_manager = getattr(self.container, "factory_manager", None) + if not factory_manager or not hasattr(factory_manager, "get_service_factory"): + return [] + + try: + context = self._get_provider_context() + if not context: + return [] + + options: List[Dict[str, Any]] = [] + expected = self._normalize_provider_type(expected_type) + provider_manager = getattr(context, "provider_manager", None) + cache = model_cache if model_cache is not None else {} + + if expected in {"", "chat_completion", "llm", "chat"} and callable(getattr(context, "get_all_providers", None)): + for provider in self._as_provider_list(context.get_all_providers()): + option = await self._provider_option_async( + provider, + "chat_completion", + cache, + ) + if option: + options.append(option) + if expected in {"", "chat_completion"} and provider_manager and hasattr(provider_manager, "provider_insts"): + for provider in self._as_provider_list(provider_manager.provider_insts): + option = await self._provider_option_async( + provider, + "chat_completion", + cache, + ) + if option: + options.append(option) + + if expected in {"", "embedding"} and callable(getattr(context, "get_all_embedding_providers", None)): + for provider in self._as_provider_list(context.get_all_embedding_providers()): + option = await self._provider_option_async( + provider, + "embedding", + cache, + ) + if option: + options.append(option) + if expected in {"", "embedding"} and provider_manager and hasattr(provider_manager, "embedding_provider_insts"): + for provider in self._as_provider_list(provider_manager.embedding_provider_insts): + option = await self._provider_option_async( + provider, + "embedding", + cache, + ) + if option: + options.append(option) + + if expected in {"", "rerank", "reranker"}: + rerank_providers = [] + rerank_getter = getattr(context, "get_all_rerank_providers", None) + if callable(rerank_getter): + rerank_providers = self._as_provider_list(rerank_getter()) + if not rerank_providers and provider_manager and hasattr(provider_manager, "rerank_provider_insts"): + rerank_providers = provider_manager.rerank_provider_insts + for provider in self._as_provider_list(rerank_providers): + option = await self._provider_option_async( + provider, + "rerank", + cache, + ) + if option: + options.append(option) + + if expected == "" and provider_manager and hasattr(provider_manager, "inst_map"): + for provider in provider_manager.inst_map.values(): + option = await self._provider_option_async( + provider, + "", + cache, + ) + if option: + options.append(option) + + if provider_manager and hasattr(provider_manager, "providers_config"): + provider_source_types = self._provider_source_types(provider_manager) + for provider_config in self._as_provider_list(provider_manager.providers_config): + merged_config = self._merged_provider_config( + provider_manager, + provider_config, + ) + identity = self._provider_identity_from_config( + merged_config, + provider_source_types, + ) + if not identity: + continue + option_type = identity[2] + if expected and option_type != expected: + continue + option = await self._provider_option_from_config_async( + merged_config, + provider_source_types, + "", + cache, + ) + if not option: + continue + options.append(option) + + return self._dedupe_options(options) + except Exception as e: + logger.warning(f"获取 Provider 列表失败: {e}") + return [] + @staticmethod def _provider_expected_type_for_field(key: str) -> Optional[str]: if key in {"filter_provider_id", "refine_provider_id", "reinforce_provider_id"}: @@ -900,11 +1395,29 @@ def _provider_expected_type_for_field(key: str) -> Optional[str]: return "rerank" return None + @staticmethod + def _provider_options_for_field( + provider_type: str, + provider_options_by_type: Optional[Dict[str, List[Dict[str, Any]]]], + ) -> List[Dict[str, Any]]: + if provider_options_by_type is None: + return [] + if provider_type: + return provider_options_by_type.get(provider_type, []) + return ConfigService._dedupe_options( + [ + option + for options in provider_options_by_type.values() + for option in options + ] + ) + def _build_field_spec( self, key: str, raw_spec: Dict[str, Any], current_config: Dict[str, Any], + provider_options_by_type: Optional[Dict[str, List[Dict[str, Any]]]] = None, ) -> Dict[str, Any]: field_type = raw_spec.get("type", "string") widget = "text" @@ -967,11 +1480,21 @@ def _build_field_spec( field_spec["provider_type"], field_spec["provider_type"] or "Provider", ) - field_spec["options"] = self._provider_options(field_spec["provider_type"]) + if provider_options_by_type is not None: + field_spec["options"] = self._provider_options_for_field( + field_spec["provider_type"], + provider_options_by_type, + ) + else: + field_spec["options"] = self._provider_options(field_spec["provider_type"]) return field_spec - def _build_group_schema(self, schema_definition: Dict[str, Any]) -> List[Dict[str, Any]]: + def _build_group_schema( + self, + schema_definition: Dict[str, Any], + provider_options_by_type: Optional[Dict[str, List[Dict[str, Any]]]] = None, + ) -> List[Dict[str, Any]]: current_config = self.plugin_config.to_dict() if self.plugin_config else {} groups: List[Dict[str, Any]] = [] @@ -983,7 +1506,12 @@ def _build_group_schema(self, schema_definition: Dict[str, Any]) -> List[Dict[st continue fields = [ - self._build_field_spec(field_key, field_spec, current_config) + self._build_field_spec( + field_key, + field_spec, + current_config, + provider_options_by_type, + ) for field_key, field_spec in items.items() ] @@ -1051,17 +1579,38 @@ async def get_config_schema(self) -> Dict[str, Any]: self._sync_config_sources() merged_schema = self._merged_schema_definition() + model_cache: Dict[Tuple[str, str, str], List[str]] = {} + provider_options_by_type = { + "chat_completion": await self._provider_options_async( + "chat_completion", + model_cache, + ), + "embedding": await self._provider_options_async( + "embedding", + model_cache, + ), + "rerank": await self._provider_options_async( + "rerank", + model_cache, + ), + } + provider_options = self._dedupe_options( + [ + option + for options in provider_options_by_type.values() + for option in options + ] + ) return { "config": self.plugin_config.to_dict(), - "groups": self._build_group_schema(merged_schema), + "groups": self._build_group_schema( + merged_schema, + provider_options_by_type, + ), "warnings": get_config_cost_warnings(self.plugin_config), - "provider_options": self._provider_options(), - "provider_options_by_type": { - "chat_completion": self._provider_options("chat_completion"), - "embedding": self._provider_options("embedding"), - "rerank": self._provider_options("rerank"), - }, + "provider_options": provider_options, + "provider_options_by_type": provider_options_by_type, } async def update_config(self, new_config: Dict[str, Any]) -> Tuple[bool, str, Dict[str, Any]]: