diff --git a/internal/config/state/service_test.go b/internal/config/state/service_test.go index e0a7b9bb..407fe5a6 100644 --- a/internal/config/state/service_test.go +++ b/internal/config/state/service_test.go @@ -184,6 +184,52 @@ func TestSelectionServiceListModelsSnapshotRejectsUnsupportedDriver(t *testing.T } } +func TestSelectionServiceRefreshModelsUsesRefreshPathAndRepairsCurrentModel(t *testing.T) { + t.Parallel() + + manager := newSelectionTestManager(t, testDefaultConfig()) + if err := manager.Update(context.Background(), func(cfg *configpkg.Config) error { + cfg.CurrentModel = "removed-model" + return nil + }); err != nil { + t.Fatalf("seed current model: %v", err) + } + + service := NewService(manager, newDriverSupporterStub(), catalogMethodsStub{ + listModels: []providertypes.ModelDescriptor{ + {ID: OpenAIDefaultModel, Name: "GPT Default"}, + {ID: "gpt-5.4-mini", Name: "GPT Mini"}, + }, + }) + + models, err := service.RefreshModels(context.Background()) + if err != nil { + t.Fatalf("RefreshModels() error = %v", err) + } + if len(models) != 2 { + t.Fatalf("expected refreshed models, got %+v", models) + } + + cfg := manager.Get() + if cfg.CurrentModel != OpenAIDefaultModel { + t.Fatalf("expected current model repaired to provider default, got %q", cfg.CurrentModel) + } +} + +func TestSelectionServiceRefreshModelsReturnsNoModelsAvailable(t *testing.T) { + t.Parallel() + + manager := newSelectionTestManager(t, testDefaultConfig()) + service := NewService(manager, newDriverSupporterStub(), catalogMethodsStub{ + listModels: nil, + }) + + _, err := service.RefreshModels(context.Background()) + if !errors.Is(err, ErrNoModelsAvailable) { + t.Fatalf("expected ErrNoModelsAvailable, got %v", err) + } +} + func TestSelectionServiceSelectProviderAndSetCurrentModel(t *testing.T) { manager := newSelectionTestManager(t, testDefaultConfig()) diff --git a/internal/provider/catalog/service_test.go b/internal/provider/catalog/service_test.go index fec5c6d3..54e4fb44 100644 --- a/internal/provider/catalog/service_test.go +++ b/internal/provider/catalog/service_test.go @@ -397,6 +397,50 @@ func TestListProviderModelsRefreshesWhenCatalogSnapshotIsEmpty(t *testing.T) { } } +func TestRefreshProviderModelsReturnsConfiguredModelsWhenDiscoveryDisabled(t *testing.T) { + t.Parallel() + + service := NewService("", newRegistry(t, openaicompat.DriverName, nil), newMemoryStore()) + providerCfg := customGatewayProvider() + providerCfg.ModelSource = config.ModelSourceManual + providerCfg.Models = []providertypes.ModelDescriptor{{ID: "manual-model", Name: "Manual Model"}} + + models, err := service.RefreshProviderModels(context.Background(), mustCatalogInput(t, providerCfg)) + if err != nil { + t.Fatalf("RefreshProviderModels() error = %v", err) + } + if len(models) != 1 || models[0].ID != "manual-model" { + t.Fatalf("expected configured models without discovery, got %+v", models) + } +} + +func TestRefreshProviderModelsDiscoversAndPersists(t *testing.T) { + t.Setenv(testAPIKeyEnv, "test-key") + + registry := newRegistry(t, openaicompat.DriverName, func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + return []providertypes.ModelDescriptor{{ID: "fresh-model", Name: "Fresh Model"}}, nil + }) + store := newMemoryStore() + service := NewService("", registry, store) + + input := customGatewayProviderSource() + models, err := service.RefreshProviderModels(context.Background(), input) + if err != nil { + t.Fatalf("RefreshProviderModels() error = %v", err) + } + if len(models) != 1 || models[0].ID != "fresh-model" { + t.Fatalf("expected discovered models, got %+v", models) + } + + cached, err := store.Load(context.Background(), input.Identity) + if err != nil { + t.Fatalf("load cached catalog: %v", err) + } + if len(cached.Models) != 1 || cached.Models[0].ID != "fresh-model" { + t.Fatalf("expected refreshed models to be persisted, got %+v", cached.Models) + } +} + func TestDiscoverAndPersistFailurePaths(t *testing.T) { t.Run("unsupported driver", func(t *testing.T) { service := NewService("", provider.NewRegistry(), newMemoryStore())