From 2e97393a00367b40ec91e6b6c9b87b0ec3ef50c3 Mon Sep 17 00:00:00 2001 From: Evgenii Danilin Date: Fri, 13 Mar 2026 23:11:23 +0700 Subject: [PATCH 1/3] vector_size --- internal/config/config.go | 7 +++++++ internal/indexer/indexer.go | 2 +- internal/mcp/server.go | 2 +- internal/store/qdrant.go | 8 +++++--- internal/store/qdrant_test.go | 30 +++++++++++++++--------------- 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 6428d1c..8d848ec 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -44,6 +44,7 @@ type StorageConfig struct { Type string `yaml:"type"` URL string `yaml:"url"` CollectionPrefix string `yaml:"collection_prefix"` + VectorSize int `yaml:"vector_size"` } type PromptsConfig struct { @@ -177,6 +178,9 @@ func merge(home, project *Config) *Config { if project.Storage.CollectionPrefix != "" { cfg.Storage.CollectionPrefix = project.Storage.CollectionPrefix } + if project.Storage.VectorSize != 0 { + cfg.Storage.VectorSize = project.Storage.VectorSize + } // Indexer: override non-zero fields if project.Indexer.MaxFileSize != 0 { @@ -223,6 +227,9 @@ func setDefaults(cfg *Config) { if cfg.Indexer.Workers <= 0 { cfg.Indexer.Workers = 2 } + if cfg.Storage.VectorSize <= 0 { + cfg.Storage.VectorSize = 3072 + } if cfg.Prompts.ProjectStructureAnalysis == "" { cfg.Prompts.ProjectStructureAnalysis = prompts.DefaultProjectStructureAnalysis } diff --git a/internal/indexer/indexer.go b/internal/indexer/indexer.go index 59563e5..1e68d1d 100644 --- a/internal/indexer/indexer.go +++ b/internal/indexer/indexer.go @@ -359,7 +359,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { } // Initialize store - db := store.NewQdrantStore(cfg.Storage.URL, cfg.Storage.CollectionPrefix, cfg.Project.Name, logger) + db := store.NewQdrantStore(cfg.Storage.URL, cfg.Storage.CollectionPrefix, cfg.Project.Name, cfg.Storage.VectorSize, logger) // Force mode: delete existing data and start fresh if force { diff --git a/internal/mcp/server.go b/internal/mcp/server.go index b78109e..a70b742 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -191,7 +191,7 @@ func RunServer(configPath string, logger *slog.Logger) error { return fmt.Errorf("creating embedding provider: %w", err) } - db := store.NewQdrantStore(cfg.Storage.URL, cfg.Storage.CollectionPrefix, cfg.Project.Name, logger) + db := store.NewQdrantStore(cfg.Storage.URL, cfg.Storage.CollectionPrefix, cfg.Project.Name, cfg.Storage.VectorSize, logger) srv := NewServer(db, embedder, rootPath, logger) diff --git a/internal/store/qdrant.go b/internal/store/qdrant.go index 2880b89..2ba799f 100644 --- a/internal/store/qdrant.go +++ b/internal/store/qdrant.go @@ -14,15 +14,17 @@ import ( type QdrantStore struct { baseURL string collection string + vectorSize int client *http.Client logger *slog.Logger } // NewQdrantStore creates a new Qdrant store client. -func NewQdrantStore(url, collectionPrefix, projectName string, logger *slog.Logger) *QdrantStore { +func NewQdrantStore(url, collectionPrefix, projectName string, vectorSize int, logger *slog.Logger) *QdrantStore { return &QdrantStore{ baseURL: url, collection: collectionPrefix + projectName, + vectorSize: vectorSize, client: &http.Client{Timeout: 30 * time.Second}, logger: logger, } @@ -46,10 +48,10 @@ func (q *QdrantStore) EnsureCollection() error { q.logger.Debug("EnsureCollection: creating", "collection", q.collection) - // Create collection with vector size 3072 and cosine distance + // Create collection with configured vector size and cosine distance body := map[string]any{ "vectors": map[string]any{ - "size": 3072, + "size": q.vectorSize, "distance": "Cosine", }, } diff --git a/internal/store/qdrant_test.go b/internal/store/qdrant_test.go index a257a83..fec657e 100644 --- a/internal/store/qdrant_test.go +++ b/internal/store/qdrant_test.go @@ -48,7 +48,7 @@ func TestEnsureCollection_AlreadyExists(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) if err := s.EnsureCollection(); err != nil { t.Fatalf("EnsureCollection failed: %v", err) } @@ -83,7 +83,7 @@ func TestEnsureCollection_Creates(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) if err := s.EnsureCollection(); err != nil { t.Fatalf("EnsureCollection failed: %v", err) } @@ -105,7 +105,7 @@ func TestUpsertPoint(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) err := s.UpsertPoint(&Point{ @@ -153,7 +153,7 @@ func TestUpsertPoint_UsesProvidedID(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) customID := "custom-uuid-12345" err := s.UpsertPoint(&Point{ ID: customID, @@ -185,7 +185,7 @@ func TestUpsertPoints(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) err := s.UpsertPoints([]*Point{ @@ -231,7 +231,7 @@ func TestUpsertPoints(t *testing.T) { } func TestUpsertPoints_Empty(t *testing.T) { - s := NewQdrantStore("http://localhost:6333", "vedcode_", "test", noopLogger) + s := NewQdrantStore("http://localhost:6333", "vedcode_", "test", 3072, noopLogger) err := s.UpsertPoints([]*Point{}) if err != nil { t.Fatalf("UpsertPoints with empty slice should not error: %v", err) @@ -289,7 +289,7 @@ func TestGetAllFilePoints(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) points, err := s.GetAllFilePoints() if err != nil { t.Fatalf("GetAllFilePoints failed: %v", err) @@ -347,7 +347,7 @@ func TestGetAllDirPoints(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) points, err := s.GetAllDirPoints() if err != nil { t.Fatalf("GetAllDirPoints failed: %v", err) @@ -386,7 +386,7 @@ func TestGetPointByFilePath_Found(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) point, err := s.GetPointByFilePath("src/payment.go") if err != nil { t.Fatalf("GetPointByFilePath failed: %v", err) @@ -406,7 +406,7 @@ func TestGetPointByFilePath_NotFound(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) point, err := s.GetPointByFilePath("nonexistent.go") if err != nil { t.Fatalf("GetPointByFilePath failed: %v", err) @@ -429,7 +429,7 @@ func TestDeletePoints(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) err := s.DeletePoints([]string{"uuid-1", "uuid-2"}) if err != nil { t.Fatalf("DeletePoints failed: %v", err) @@ -476,7 +476,7 @@ func TestSearch(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) results, err := s.Search([]float32{0.1, 0.2, 0.3}, 3) if err != nil { t.Fatalf("Search failed: %v", err) @@ -512,7 +512,7 @@ func TestSearch_DefaultLimit(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) _, err := s.Search([]float32{0.1}, 0) if err != nil { t.Fatalf("Search failed: %v", err) @@ -520,7 +520,7 @@ func TestSearch_DefaultLimit(t *testing.T) { } func TestCollectionName(t *testing.T) { - s := NewQdrantStore("http://localhost:6333", "vedcode_", "my-app", noopLogger) + s := NewQdrantStore("http://localhost:6333", "vedcode_", "my-app", 3072, noopLogger) if s.collection != "vedcode_my-app" { t.Errorf("expected collection name vedcode_my-app, got %s", s.collection) } @@ -533,7 +533,7 @@ func TestQdrantError(t *testing.T) { })) defer srv.Close() - s := NewQdrantStore(srv.URL, "vedcode_", "test", noopLogger) + s := NewQdrantStore(srv.URL, "vedcode_", "test", 3072, noopLogger) // EnsureCollection should fail on creation err := s.EnsureCollection() From a437f3c40060958bedf64be1c543973bd137b045 Mon Sep 17 00:00:00 2001 From: Evgenii Danilin Date: Sun, 15 Mar 2026 21:15:59 +0700 Subject: [PATCH 2/3] logger + generic-http + trace --- cmd/vedcode/main.go | 12 +- internal/config/config.go | 26 +++ internal/config/config_test.go | 238 +++++++++++++++------ internal/indexer/indexer.go | 79 ++++--- internal/providers/gemini.go | 57 +---- internal/providers/gemini_test.go | 11 - internal/providers/generic_http.go | 265 +++++++++++++++++++++++ internal/providers/generic_http_test.go | 266 ++++++++++++++++++++++++ internal/providers/provider.go | 14 +- internal/providers/providers_test.go | 8 + internal/providers/retry.go | 58 ++++++ internal/trace/trace.go | 129 ++++++++++-- 12 files changed, 957 insertions(+), 206 deletions(-) create mode 100644 internal/providers/generic_http.go create mode 100644 internal/providers/generic_http_test.go create mode 100644 internal/providers/providers_test.go create mode 100644 internal/providers/retry.go diff --git a/cmd/vedcode/main.go b/cmd/vedcode/main.go index 5fdfd24..32823f7 100644 --- a/cmd/vedcode/main.go +++ b/cmd/vedcode/main.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "log" "os" "path/filepath" "strings" @@ -13,8 +12,6 @@ import ( ) func main() { - log.SetFlags(0) - if len(os.Args) < 2 { printUsage() os.Exit(1) @@ -47,9 +44,11 @@ func main() { traceLogPath = filepath.Join(".vedcode", "mcp-trace.log") } - logger, closer, err := trace.NewLogger(traceEnabled, traceLogPath) + console := command == "indexer" + logger, closer, err := trace.NewLogger(traceEnabled, traceLogPath, console) if err != nil { - log.Fatalf("Error: %v", err) + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) } if closer != nil { defer closer.Close() @@ -67,7 +66,8 @@ func main() { } if err != nil { - log.Fatalf("Error: %v", err) + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) } } diff --git a/internal/config/config.go b/internal/config/config.go index 8d848ec..795e2ae 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -270,5 +270,31 @@ func validate(cfg *Config) error { if cfg.Storage.CollectionPrefix == "" { return fmt.Errorf("config validation: storage.collection_prefix is required") } + // URL is required only for HTTP-based providers + if requiresURL(cfg.LLM.Provider) && cfg.LLM.URL == "" { + return fmt.Errorf("config validation: llm.url is required for provider %q", cfg.LLM.Provider) + } + if requiresURL(cfg.Embedding.Provider) && cfg.Embedding.URL == "" { + return fmt.Errorf("config validation: embedding.url is required for provider %q", cfg.Embedding.Provider) + } + + // API key is required for SDK-based providers + if cfg.LLM.Provider == "gemini" && cfg.LLM.APIKey == "" { + return fmt.Errorf("config validation: llm.api_key is required for gemini provider") + } + if cfg.Embedding.Provider == "gemini" && cfg.Embedding.APIKey == "" { + return fmt.Errorf("config validation: embedding.api_key is required for gemini provider") + } + return nil } + +// requiresURL returns true for HTTP-based providers that need a base URL. +func requiresURL(provider string) bool { + switch provider { + case "generic-http": + return true + default: + return false + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b0797db..15d3304 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -91,12 +91,12 @@ func TestLoad_ValidConfig(t *testing.T) { func TestLoad_DefaultMaxFileSize(t *testing.T) { yml := ` llm: - provider: "gemini" - api_key: "key" + provider: "generic-http" + url: "http://localhost:11434/v1" model: "model" embedding: - provider: "gemini" - api_key: "key" + provider: "generic-http" + url: "http://localhost:11434/v1" model: "emb" storage: type: "qdrant" @@ -122,11 +122,13 @@ func TestLoad_EnvVarSubstitution(t *testing.T) { yml := ` llm: - provider: "gemini" + provider: "generic-http" + url: "http://localhost:11434/v1" api_key: "${TEST_VEDCODE_API_KEY}" model: "model" embedding: - provider: "gemini" + provider: "generic-http" + url: "http://localhost:11434/v1" api_key: "${TEST_VEDCODE_API_KEY}" model: "emb" storage: @@ -153,11 +155,13 @@ func TestLoad_EnvVarNotSet(t *testing.T) { yml := ` llm: - provider: "gemini" + provider: "generic-http" + url: "http://localhost:11434/v1" api_key: "${NONEXISTENT_VAR_VEDCODE}" model: "model" embedding: - provider: "gemini" + provider: "generic-http" + url: "http://localhost:11434/v1" api_key: "key" model: "emb" storage: @@ -187,10 +191,11 @@ func TestLoad_ValidationErrors(t *testing.T) { name: "missing llm.provider", yml: ` llm: - api_key: "k" + url: "http://x" model: "m" embedding: - provider: "g" + provider: "generic-http" + url: "http://x" model: "e" storage: type: "q" @@ -203,10 +208,11 @@ storage: name: "missing embedding.provider", yml: ` llm: - provider: "g" - api_key: "k" + provider: "generic-http" + url: "http://x" model: "m" embedding: + url: "http://x" model: "e" storage: type: "q" @@ -219,11 +225,12 @@ storage: name: "missing embedding.model", yml: ` llm: - provider: "g" - api_key: "k" + provider: "generic-http" + url: "http://x" model: "m" embedding: - provider: "g" + provider: "generic-http" + url: "http://x" storage: type: "q" url: "http://x" @@ -235,11 +242,12 @@ storage: name: "missing storage.url", yml: ` llm: - provider: "g" - api_key: "k" + provider: "generic-http" + url: "http://x" model: "m" embedding: - provider: "g" + provider: "generic-http" + url: "http://x" model: "e" storage: type: "q" @@ -247,6 +255,74 @@ storage: `, wantErr: "storage.url is required", }, + { + name: "missing llm.url for generic-http", + yml: ` +llm: + provider: "generic-http" + model: "m" +embedding: + provider: "generic-http" + url: "http://x" + model: "e" +storage: + type: "q" + url: "http://x" + collection_prefix: "p" +`, + wantErr: `llm.url is required for provider "generic-http"`, + }, + { + name: "missing embedding.url for generic-http", + yml: ` +llm: + provider: "generic-http" + url: "http://x" + model: "m" +embedding: + provider: "generic-http" + model: "e" +storage: + type: "q" + url: "http://x" + collection_prefix: "p" +`, + wantErr: `embedding.url is required for provider "generic-http"`, + }, + { + name: "missing llm.api_key for gemini", + yml: ` +llm: + provider: "gemini" + model: "gemini-2.5-flash" +embedding: + provider: "gemini" + api_key: "key" + model: "gemini-embedding-001" +storage: + type: "q" + url: "http://x" + collection_prefix: "p" +`, + wantErr: "llm.api_key is required for gemini provider", + }, + { + name: "missing embedding.api_key for gemini", + yml: ` +llm: + provider: "gemini" + api_key: "key" + model: "gemini-2.5-flash" +embedding: + provider: "gemini" + model: "gemini-embedding-001" +storage: + type: "q" + url: "http://x" + collection_prefix: "p" +`, + wantErr: "embedding.api_key is required for gemini provider", + }, } for _, tt := range tests { @@ -356,7 +432,7 @@ indexer: workers: 8 llm: - model: "gemini-2.0-pro" + model: "gpt-4o" ` projectPath := writeTestConfig(t, projectYml) @@ -369,8 +445,8 @@ llm: if cfg.Indexer.Workers != 8 { t.Errorf("workers = %d, want %d", cfg.Indexer.Workers, 8) } - if cfg.LLM.Model != "gemini-2.0-pro" { - t.Errorf("model = %q, want %q", cfg.LLM.Model, "gemini-2.0-pro") + if cfg.LLM.Model != "gpt-4o" { + t.Errorf("model = %q, want %q", cfg.LLM.Model, "gpt-4o") } // Inherited from home @@ -415,49 +491,13 @@ indexer: } } -func TestLoad_DifferentLLMAndEmbeddingProviders(t *testing.T) { - yml := ` -llm: - provider: "gemini" - api_key: "key" - model: "gemini-2.5-flash" -embedding: - provider: "ollama" - url: "http://localhost:11434" - model: "nomic-embed-text" -storage: - type: "qdrant" - url: "http://localhost:6333" - collection_prefix: "v_" -` - path := writeTestConfig(t, yml) - - cfg, err := loadWithPaths("", path) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if cfg.LLM.Provider != "gemini" { - t.Errorf("llm.provider = %q, want %q", cfg.LLM.Provider, "gemini") - } - if cfg.Embedding.Provider != "ollama" { - t.Errorf("embedding.provider = %q, want %q", cfg.Embedding.Provider, "ollama") - } - if cfg.Embedding.URL != "http://localhost:11434" { - t.Errorf("embedding.url = %q, want %q", cfg.Embedding.URL, "http://localhost:11434") - } - if cfg.Embedding.Model != "nomic-embed-text" { - t.Errorf("embedding.model = %q, want %q", cfg.Embedding.Model, "nomic-embed-text") - } -} - func TestLoad_EmbeddingMerge_ProjectOverridesHome(t *testing.T) { homePath := writeTestConfig(t, homeConfig) projectYml := ` embedding: - provider: "ollama" - url: "http://localhost:11434" - model: "project-model" + url: "http://localhost:11434/v1" + model: "nomic-embed-text" ` projectPath := writeTestConfig(t, projectYml) @@ -465,11 +505,11 @@ embedding: if err != nil { t.Fatalf("unexpected error: %v", err) } - if cfg.Embedding.Provider != "ollama" { - t.Errorf("embedding.provider = %q, want %q", cfg.Embedding.Provider, "ollama") + if cfg.Embedding.URL != "http://localhost:11434/v1" { + t.Errorf("embedding.url = %q, want %q", cfg.Embedding.URL, "http://localhost:11434/v1") } - if cfg.Embedding.Model != "project-model" { - t.Errorf("embedding.model = %q, want %q", cfg.Embedding.Model, "project-model") + if cfg.Embedding.Model != "nomic-embed-text" { + t.Errorf("embedding.model = %q, want %q", cfg.Embedding.Model, "nomic-embed-text") } } @@ -527,6 +567,78 @@ prompts: } } +func TestLoad_VectorSizeDefault(t *testing.T) { + path := writeTestConfig(t, validConfig) + + cfg, err := loadWithPaths("", path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Storage.VectorSize != 3072 { + t.Errorf("default vector_size = %d, want 3072", cfg.Storage.VectorSize) + } +} + +func TestLoad_VectorSizeCustom(t *testing.T) { + yml := ` +llm: + provider: "generic-http" + url: "http://localhost:11434/v1" + model: "m" +embedding: + provider: "generic-http" + url: "http://localhost:11434/v1" + model: "emb" +storage: + type: "qdrant" + url: "http://localhost:6333" + collection_prefix: "v_" + vector_size: 1536 +` + path := writeTestConfig(t, yml) + + cfg, err := loadWithPaths("", path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Storage.VectorSize != 1536 { + t.Errorf("vector_size = %d, want 1536", cfg.Storage.VectorSize) + } +} + +func TestLoad_GenericHTTPProviderWithURL(t *testing.T) { + yml := ` +llm: + provider: "generic-http" + url: "http://localhost:11434/v1" + model: "llama3.1" +embedding: + provider: "generic-http" + url: "http://localhost:11434/v1" + model: "nomic-embed-text" +storage: + type: "qdrant" + url: "http://localhost:6333" + collection_prefix: "v_" + vector_size: 768 +` + path := writeTestConfig(t, yml) + + cfg, err := loadWithPaths("", path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.LLM.Provider != "generic-http" { + t.Errorf("llm.provider = %q, want generic-http", cfg.LLM.Provider) + } + if cfg.LLM.URL != "http://localhost:11434/v1" { + t.Errorf("llm.url = %q, want http://localhost:11434/v1", cfg.LLM.URL) + } + if cfg.Storage.VectorSize != 768 { + t.Errorf("vector_size = %d, want 768", cfg.Storage.VectorSize) + } +} + func TestLoad_PromptsMerge_ProjectOverridesHome(t *testing.T) { homeYml := homeConfig + ` prompts: diff --git a/internal/indexer/indexer.go b/internal/indexer/indexer.go index 1e68d1d..b42edba 100644 --- a/internal/indexer/indexer.go +++ b/internal/indexer/indexer.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "log" "log/slog" "os" "path/filepath" @@ -235,7 +234,7 @@ func (t *dirTracker) tryAnalyzeDir(dirPath string) { defer func() { <-t.sem }() n := t.progress.Add(1) - log.Printf("[%d/%d] Analyzing dir %s", n, t.totalItems, dirPath) + t.logger.Info(fmt.Sprintf("[%d/%d] Analyzing dir %s", n, t.totalItems, dirPath)) t.logger.Debug("dir indexing started", "dir", dirPath, "index", n, "total", t.totalItems, "hash", newHash) dirStart := time.Now() @@ -248,7 +247,7 @@ func (t *dirTracker) tryAnalyzeDir(dirPath string) { response, err := t.llm.GenerateJSON(dirPrompt, dirAnalysisSchema) if err != nil { - log.Printf("Error analyzing dir %s: %v", dirPath, err) + t.logger.Error(fmt.Sprintf("Error analyzing dir %s: %v", dirPath, err)) t.errors.Add(1) t.notifyParent(dirPath) return @@ -256,7 +255,7 @@ func (t *dirTracker) tryAnalyzeDir(dirPath string) { analysis, err := parseDirAnalysis(response) if err != nil { - log.Printf("Error parsing dir analysis for %s: %v", dirPath, err) + t.logger.Error(fmt.Sprintf("Error parsing dir analysis for %s: %v", dirPath, err)) t.errors.Add(1) t.notifyParent(dirPath) return @@ -264,7 +263,7 @@ func (t *dirTracker) tryAnalyzeDir(dirPath string) { embedding, err := t.embedder.EmbedContent(analysis.Summary) if err != nil { - log.Printf("Error embedding dir %s: %v", dirPath, err) + t.logger.Error(fmt.Sprintf("Error embedding dir %s: %v", dirPath, err)) t.errors.Add(1) t.notifyParent(dirPath) return @@ -283,7 +282,7 @@ func (t *dirTracker) tryAnalyzeDir(dirPath string) { } if err := t.db.UpsertPoint(point); err != nil { - log.Printf("Error saving dir %s: %v", dirPath, err) + t.logger.Error(fmt.Sprintf("Error saving dir %s: %v", dirPath, err)) t.errors.Add(1) t.notifyParent(dirPath) return @@ -363,19 +362,18 @@ func Run(configPath string, force bool, logger *slog.Logger) error { // Force mode: delete existing data and start fresh if force { - log.Println("Force mode: cleaning up existing data...") - logger.Debug("force mode: cleaning up") + logger.Info("Force mode: cleaning up existing data...") overviewPath := filepath.Join(rootPath, ".vedcode", "project_overview.md") if err := os.Remove(overviewPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("removing project overview: %w", err) } - log.Println("Deleted .vedcode/project_overview.md") + logger.Info("Deleted .vedcode/project_overview.md") if err := db.DeleteCollection(); err != nil { - log.Printf("Warning: could not delete collection: %v", err) + logger.Warn(fmt.Sprintf("could not delete collection: %v", err)) } else { - log.Println("Deleted Qdrant collection") + logger.Info("Deleted Qdrant collection") } } @@ -383,12 +381,12 @@ func Run(configPath string, force bool, logger *slog.Logger) error { return fmt.Errorf("ensuring collection: %w", err) } - log.Println("=== VedCode Indexer ===") - log.Printf("Project: %s", cfg.Project.Name) - log.Printf("Root: %s", rootPath) + logger.Info("=== VedCode Indexer ===") + logger.Info(fmt.Sprintf("Project: %s", cfg.Project.Name)) + logger.Info(fmt.Sprintf("Root: %s", rootPath)) // --- Stage 1: Project structure analysis & cleanup --- - log.Println("\n--- Stage 1: Project structure analysis & cleanup ---") + logger.Info("\n--- Stage 1: Project structure analysis & cleanup ---") walkResult, err := walker.Walk(walker.Options{ RootPath: rootPath, @@ -398,7 +396,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { if err != nil { return fmt.Errorf("walking project: %w", err) } - log.Printf("Found %d files", len(walkResult.Files)) + logger.Info(fmt.Sprintf("Found %d files", len(walkResult.Files))) logger.Debug("walker completed", "files_found", len(walkResult.Files), "root_path", rootPath, @@ -426,12 +424,12 @@ func Run(configPath string, force bool, logger *slog.Logger) error { deletedCount := 0 if len(deleteIDs) > 0 { if err := db.DeletePoints(deleteIDs); err != nil { - log.Printf("Warning: error deleting stale points: %v", err) + logger.Warn(fmt.Sprintf("error deleting stale points: %v", err)) } else { deletedCount = len(deleteIDs) } } - log.Printf("Deleted %d stale file records from Qdrant", deletedCount) + logger.Info(fmt.Sprintf("Deleted %d stale file records from Qdrant", deletedCount)) logger.Debug("stale file cleanup", "deleted", deletedCount, "total_existing", len(existingPoints)) // Clean up deleted directories from Qdrant @@ -451,12 +449,12 @@ func Run(configPath string, force bool, logger *slog.Logger) error { deletedDirCount := 0 if len(deleteDirIDs) > 0 { if err := db.DeletePoints(deleteDirIDs); err != nil { - log.Printf("Warning: error deleting stale dir points: %v", err) + logger.Warn(fmt.Sprintf("error deleting stale dir points: %v", err)) } else { deletedDirCount = len(deleteDirIDs) } } - log.Printf("Deleted %d stale directory records from Qdrant", deletedDirCount) + logger.Info(fmt.Sprintf("Deleted %d stale directory records from Qdrant", deletedDirCount)) logger.Debug("stale dir cleanup", "deleted", deletedDirCount, "total_existing", len(existingDirPoints)) // Analyze project structure via LLM @@ -464,7 +462,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { "CONTENT": walkResult.Tree, }) - log.Println("Analyzing project structure...") + logger.Info("Analyzing project structure...") logger.Debug("analyzing project structure", "prompt_length", len(structurePrompt)) projectOverview, err := llm.GenerateContent(structurePrompt) @@ -481,7 +479,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { if err := os.WriteFile(overviewPath, []byte(projectOverview), 0o644); err != nil { return fmt.Errorf("saving project overview: %w", err) } - log.Printf("Project overview saved to %s", overviewPath) + logger.Info(fmt.Sprintf("Project overview saved to %s", overviewPath)) // Build existing points map for hash comparison (keyed by file_path) existingByPath := make(map[string]*store.Point, len(existingPoints)) @@ -490,8 +488,8 @@ func Run(configPath string, force bool, logger *slog.Logger) error { } // --- Stage 2: File & directory indexing (interleaved) --- - log.Println("\n--- Stage 2: File & directory indexing ---") - log.Printf("Using %d worker(s)", cfg.Indexer.Workers) + logger.Info("\n--- Stage 2: File & directory indexing ---") + logger.Info(fmt.Sprintf("Using %d worker(s)", cfg.Indexer.Workers)) var indexedCount atomic.Int64 var errorCount atomic.Int64 @@ -512,7 +510,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { sem, &wg, &progress, totalItems, logger, ) - log.Printf("Found %d items to analyze (%d files, %d dirs)", totalItems, len(walkResult.Files), totalDirs) + logger.Info(fmt.Sprintf("Found %d items to analyze (%d files, %d dirs)", totalItems, len(walkResult.Files), totalDirs)) for _, relPath := range walkResult.Files { absPath := filepath.Join(rootPath, relPath) @@ -521,7 +519,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { content, err := os.ReadFile(absPath) if err != nil { n := progress.Add(1) - log.Printf("[%d/%d] Error reading %s: %v", n, totalItems, relPath, err) + logger.Error(fmt.Sprintf("[%d/%d] Error reading %s: %v", n, totalItems, relPath, err)) errorCount.Add(1) tracker.fileFailed(relPath) continue @@ -545,7 +543,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { defer func() { <-sem }() n := progress.Add(1) - log.Printf("[%d/%d] Indexing %s", n, totalItems, relPath) + logger.Info(fmt.Sprintf("[%d/%d] Indexing %s", n, totalItems, relPath)) logger.Debug("file indexing started", "file", relPath, "index", n, @@ -563,7 +561,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { response, err := llm.GenerateJSON(filePrompt, fileAnalysisSchema) if err != nil { - log.Printf("[%d/%d] Error analyzing %s: %v", n, totalItems, relPath, err) + logger.Error(fmt.Sprintf("[%d/%d] Error analyzing %s: %v", n, totalItems, relPath, err)) errorCount.Add(1) tracker.fileFailed(relPath) return @@ -571,7 +569,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { analysis, err := parseAnalysis(response) if err != nil { - log.Printf("[%d/%d] Error parsing analysis for %s: %v", n, totalItems, relPath, err) + logger.Error(fmt.Sprintf("[%d/%d] Error parsing analysis for %s: %v", n, totalItems, relPath, err)) errorCount.Add(1) tracker.fileFailed(relPath) return @@ -587,7 +585,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { // Get embedding for the summary embedding, err := embedder.EmbedContent(analysis.Summary) if err != nil { - log.Printf("[%d/%d] Error embedding %s: %v", n, totalItems, relPath, err) + logger.Error(fmt.Sprintf("[%d/%d] Error embedding %s: %v", n, totalItems, relPath, err)) errorCount.Add(1) tracker.fileFailed(relPath) return @@ -608,7 +606,7 @@ func Run(configPath string, force bool, logger *slog.Logger) error { } if err := db.UpsertPoint(point); err != nil { - log.Printf("[%d/%d] Error saving %s: %v", n, totalItems, relPath, err) + logger.Error(fmt.Sprintf("[%d/%d] Error saving %s: %v", n, totalItems, relPath, err)) errorCount.Add(1) tracker.fileFailed(relPath) return @@ -628,15 +626,15 @@ func Run(configPath string, force bool, logger *slog.Logger) error { dirIndexed, dirSkipped, dirErrors := tracker.results() // --- Summary --- - log.Println("\n=== Indexing complete ===") - log.Printf("Total files: %d", len(walkResult.Files)) - log.Printf("Indexed: %d", indexedCount.Load()) - log.Printf("Skipped: %d (unchanged)", skippedCount) - log.Printf("Deleted: %d (removed from project)", deletedCount) - log.Printf("Errors: %d", errorCount.Load()) - log.Printf("Dirs indexed: %d", dirIndexed) - log.Printf("Dirs skipped: %d (unchanged)", dirSkipped) - log.Printf("Dirs errors: %d", dirErrors) + logger.Info("\n=== Indexing complete ===") + logger.Info(fmt.Sprintf("Total files: %d", len(walkResult.Files))) + logger.Info(fmt.Sprintf("Indexed: %d", indexedCount.Load())) + logger.Info(fmt.Sprintf("Skipped: %d (unchanged)", skippedCount)) + logger.Info(fmt.Sprintf("Deleted: %d (removed from project)", deletedCount)) + logger.Info(fmt.Sprintf("Errors: %d", errorCount.Load())) + logger.Info(fmt.Sprintf("Dirs indexed: %d", dirIndexed)) + logger.Info(fmt.Sprintf("Dirs skipped: %d (unchanged)", dirSkipped)) + logger.Info(fmt.Sprintf("Dirs errors: %d", dirErrors)) logger.Debug("indexing complete", "total_files", len(walkResult.Files), @@ -759,4 +757,3 @@ func buildSubdirsSummariesText(childDirs []string, dirSummary map[string]string) sort.Strings(lines) return strings.Join(lines, "\n") } - diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go index e2bc667..355ade6 100644 --- a/internal/providers/gemini.go +++ b/internal/providers/gemini.go @@ -5,18 +5,11 @@ import ( "encoding/json" "fmt" "log/slog" - "strings" "time" "google.golang.org/genai" ) -const ( - defaultTimeout = 120 * time.Second - maxRetries = 3 - baseRetryDelay = time.Second -) - // modelsAPI abstracts the genai Models API for testability. type modelsAPI interface { GenerateContent(ctx context.Context, model string, contents []*genai.Content, config *genai.GenerateContentConfig) (*genai.GenerateContentResponse, error) @@ -72,7 +65,7 @@ func (g *GeminiProvider) GenerateContent(prompt string) (string, error) { } var resp *genai.GenerateContentResponse - err := g.retryOnRateLimit(func(ctx context.Context) error { + err := retryOnRateLimit(g.logger, func(ctx context.Context) error { var apiErr error resp, apiErr = g.models.GenerateContent(ctx, g.model, contents, nil) return apiErr @@ -127,7 +120,7 @@ func (g *GeminiProvider) GenerateJSON(prompt string, schema string) (string, err } var resp *genai.GenerateContentResponse - err := g.retryOnRateLimit(func(ctx context.Context) error { + err := retryOnRateLimit(g.logger, func(ctx context.Context) error { var apiErr error resp, apiErr = g.models.GenerateContent(ctx, g.model, contents, config) return apiErr @@ -168,7 +161,7 @@ func (g *GeminiProvider) EmbedContent(text string) ([]float32, error) { } var resp *genai.EmbedContentResponse - err := g.retryOnRateLimit(func(ctx context.Context) error { + err := retryOnRateLimit(g.logger, func(ctx context.Context) error { var apiErr error resp, apiErr = g.models.EmbedContent(ctx, g.embeddingModel, contents, nil) return apiErr @@ -194,47 +187,3 @@ func (g *GeminiProvider) EmbedContent(text string) ([]float32, error) { return resp.Embeddings[0].Values, nil } - -// retryOnRateLimit retries the given function with exponential backoff on rate limit errors. -func (g *GeminiProvider) retryOnRateLimit(fn func(ctx context.Context) error) error { - var lastErr error - for attempt := range maxRetries { - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) - lastErr = fn(ctx) - cancel() - - if lastErr == nil { - return nil - } - - if !isRetryableError(lastErr) { - return lastErr - } - - if attempt < maxRetries-1 { - delay := baseRetryDelay * time.Duration(1<= slog.LevelInfo +} + +func (h *consoleHandler) Handle(_ context.Context, r slog.Record) error { + h.mu.Lock() + defer h.mu.Unlock() + _, err := fmt.Fprintln(h.w, r.Message) + return err +} + +func (h *consoleHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h } +func (h *consoleHandler) WithGroup(_ string) slog.Handler { return h } + +// multiHandler dispatches log records to multiple slog.Handlers. +type multiHandler struct { + handlers []slog.Handler +} + +func (m *multiHandler) Enabled(ctx context.Context, level slog.Level) bool { + for _, h := range m.handlers { + if h.Enabled(ctx, level) { + return true + } } + return false +} - var w io.Writer +func (m *multiHandler) Handle(ctx context.Context, r slog.Record) error { + for _, h := range m.handlers { + if h.Enabled(ctx, r.Level) { + if err := h.Handle(ctx, r); err != nil { + return err + } + } + } + return nil +} + +func (m *multiHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + handlers := make([]slog.Handler, len(m.handlers)) + for i, h := range m.handlers { + handlers[i] = h.WithAttrs(attrs) + } + return &multiHandler{handlers: handlers} +} + +func (m *multiHandler) WithGroup(name string) slog.Handler { + handlers := make([]slog.Handler, len(m.handlers)) + for i, h := range m.handlers { + handlers[i] = h.WithGroup(name) + } + return &multiHandler{handlers: handlers} +} + +// NewLogger creates a configured slog.Logger. +// +// console: if true, always writes plain-text messages to stderr at Info+ level. +// enabled: if true, writes JSON trace logs at Debug+ level. +// path: trace log file path (empty = stderr for trace). +// +// When both console and trace target stderr, console handler is skipped +// to avoid duplicate output. +func NewLogger(enabled bool, path string, console bool) (*slog.Logger, io.Closer, error) { + var handlers []slog.Handler var closer io.Closer - if path != "" { - f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) - if err != nil { - return nil, nil, fmt.Errorf("opening trace log file: %w", err) - } - w = f - closer = f - } else { - w = os.Stderr + traceToStderr := enabled && path == "" + + // Console handler: plain text to stderr at Info level. + // Skip if trace already goes to stderr (avoid duplicate output). + if console && !traceToStderr { + handlers = append(handlers, &consoleHandler{w: os.Stderr}) } - handler := slog.NewJSONHandler(w, &slog.HandlerOptions{ - Level: slog.LevelDebug, - AddSource: true, - }) + // Trace handler: JSON to file or stderr at Debug level. + if enabled { + var w io.Writer + if path != "" { + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return nil, nil, fmt.Errorf("opening trace log file: %w", err) + } + w = f + closer = f + } else { + w = os.Stderr + } - return slog.New(handler), closer, nil + handler := slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + AddSource: true, + }) + handlers = append(handlers, handler) + } + + switch len(handlers) { + case 0: + return slog.New(slog.NewTextHandler(io.Discard, nil)), nil, nil + case 1: + return slog.New(handlers[0]), closer, nil + default: + return slog.New(&multiHandler{handlers: handlers}), closer, nil + } } From 6e71a4a2d8f73461e439b162194753f00fcfd94f Mon Sep 17 00:00:00 2001 From: Evgenii Danilin Date: Sun, 15 Mar 2026 21:28:59 +0700 Subject: [PATCH 3/3] vector_size --- internal/config/config.go | 20 ++++++++------------ internal/config/config_test.go | 16 ++++++++-------- internal/indexer/indexer.go | 12 +++++++++++- internal/mcp/server.go | 12 +++++++++++- internal/mcp/server_test.go | 7 +++++++ internal/providers/gemini.go | 9 +++++++++ internal/providers/generic_http.go | 9 +++++++++ internal/providers/provider.go | 3 +++ 8 files changed, 66 insertions(+), 22 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 795e2ae..94a22e6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -34,17 +34,17 @@ type IndexerConfig struct { } type ProviderConfig struct { - Provider string `yaml:"provider"` - APIKey string `yaml:"api_key"` - URL string `yaml:"url"` - Model string `yaml:"model"` + Provider string `yaml:"provider"` + APIKey string `yaml:"api_key"` + URL string `yaml:"url"` + Model string `yaml:"model"` + VectorSize int `yaml:"vector_size"` } type StorageConfig struct { Type string `yaml:"type"` URL string `yaml:"url"` CollectionPrefix string `yaml:"collection_prefix"` - VectorSize int `yaml:"vector_size"` } type PromptsConfig struct { @@ -167,6 +167,9 @@ func merge(home, project *Config) *Config { if project.Embedding.Model != "" { cfg.Embedding.Model = project.Embedding.Model } + if project.Embedding.VectorSize != 0 { + cfg.Embedding.VectorSize = project.Embedding.VectorSize + } // Storage: override non-zero fields if project.Storage.Type != "" { @@ -178,10 +181,6 @@ func merge(home, project *Config) *Config { if project.Storage.CollectionPrefix != "" { cfg.Storage.CollectionPrefix = project.Storage.CollectionPrefix } - if project.Storage.VectorSize != 0 { - cfg.Storage.VectorSize = project.Storage.VectorSize - } - // Indexer: override non-zero fields if project.Indexer.MaxFileSize != 0 { cfg.Indexer.MaxFileSize = project.Indexer.MaxFileSize @@ -227,9 +226,6 @@ func setDefaults(cfg *Config) { if cfg.Indexer.Workers <= 0 { cfg.Indexer.Workers = 2 } - if cfg.Storage.VectorSize <= 0 { - cfg.Storage.VectorSize = 3072 - } if cfg.Prompts.ProjectStructureAnalysis == "" { cfg.Prompts.ProjectStructureAnalysis = prompts.DefaultProjectStructureAnalysis } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 15d3304..adb32c8 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -574,8 +574,8 @@ func TestLoad_VectorSizeDefault(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if cfg.Storage.VectorSize != 3072 { - t.Errorf("default vector_size = %d, want 3072", cfg.Storage.VectorSize) + if cfg.Embedding.VectorSize != 0 { + t.Errorf("default vector_size = %d, want 0 (auto-detect)", cfg.Embedding.VectorSize) } } @@ -589,11 +589,11 @@ embedding: provider: "generic-http" url: "http://localhost:11434/v1" model: "emb" + vector_size: 1536 storage: type: "qdrant" url: "http://localhost:6333" collection_prefix: "v_" - vector_size: 1536 ` path := writeTestConfig(t, yml) @@ -601,8 +601,8 @@ storage: if err != nil { t.Fatalf("unexpected error: %v", err) } - if cfg.Storage.VectorSize != 1536 { - t.Errorf("vector_size = %d, want 1536", cfg.Storage.VectorSize) + if cfg.Embedding.VectorSize != 1536 { + t.Errorf("vector_size = %d, want 1536", cfg.Embedding.VectorSize) } } @@ -616,11 +616,11 @@ embedding: provider: "generic-http" url: "http://localhost:11434/v1" model: "nomic-embed-text" + vector_size: 768 storage: type: "qdrant" url: "http://localhost:6333" collection_prefix: "v_" - vector_size: 768 ` path := writeTestConfig(t, yml) @@ -634,8 +634,8 @@ storage: if cfg.LLM.URL != "http://localhost:11434/v1" { t.Errorf("llm.url = %q, want http://localhost:11434/v1", cfg.LLM.URL) } - if cfg.Storage.VectorSize != 768 { - t.Errorf("vector_size = %d, want 768", cfg.Storage.VectorSize) + if cfg.Embedding.VectorSize != 768 { + t.Errorf("vector_size = %d, want 768", cfg.Embedding.VectorSize) } } diff --git a/internal/indexer/indexer.go b/internal/indexer/indexer.go index b42edba..16b792e 100644 --- a/internal/indexer/indexer.go +++ b/internal/indexer/indexer.go @@ -357,8 +357,18 @@ func Run(configPath string, force bool, logger *slog.Logger) error { return fmt.Errorf("creating embedding provider: %w", err) } + // Determine vector size: use config value or auto-detect from provider + vectorSize := cfg.Embedding.VectorSize + if vectorSize <= 0 { + vectorSize, err = embedder.DetectVectorSize() + if err != nil { + return fmt.Errorf("detecting vector size: %w", err) + } + logger.Info("Auto-detected vector size", "vector_size", vectorSize) + } + // Initialize store - db := store.NewQdrantStore(cfg.Storage.URL, cfg.Storage.CollectionPrefix, cfg.Project.Name, cfg.Storage.VectorSize, logger) + db := store.NewQdrantStore(cfg.Storage.URL, cfg.Storage.CollectionPrefix, cfg.Project.Name, vectorSize, logger) // Force mode: delete existing data and start fresh if force { diff --git a/internal/mcp/server.go b/internal/mcp/server.go index a70b742..d1abdc6 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -191,7 +191,17 @@ func RunServer(configPath string, logger *slog.Logger) error { return fmt.Errorf("creating embedding provider: %w", err) } - db := store.NewQdrantStore(cfg.Storage.URL, cfg.Storage.CollectionPrefix, cfg.Project.Name, cfg.Storage.VectorSize, logger) + // Determine vector size: use config value or auto-detect from provider + vectorSize := cfg.Embedding.VectorSize + if vectorSize <= 0 { + vectorSize, err = embedder.DetectVectorSize() + if err != nil { + return fmt.Errorf("detecting vector size: %w", err) + } + logger.Info("Auto-detected vector size", "vector_size", vectorSize) + } + + db := store.NewQdrantStore(cfg.Storage.URL, cfg.Storage.CollectionPrefix, cfg.Project.Name, vectorSize, logger) srv := NewServer(db, embedder, rootPath, logger) diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 98a6f40..4a43de9 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -51,6 +51,13 @@ func (m *mockProvider) EmbedContent(text string) ([]float32, error) { return m.vector, m.err } +func (m *mockProvider) DetectVectorSize() (int, error) { + if m.err != nil { + return 0, m.err + } + return len(m.vector), nil +} + func TestSearchCode_Success(t *testing.T) { results := []*store.SearchResult{ {FilePath: "src/payment.go", Summary: "Payment processing", Score: 0.95}, diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go index 355ade6..b740b94 100644 --- a/internal/providers/gemini.go +++ b/internal/providers/gemini.go @@ -187,3 +187,12 @@ func (g *GeminiProvider) EmbedContent(text string) ([]float32, error) { return resp.Embeddings[0].Values, nil } + +// DetectVectorSize returns the embedding dimensionality by generating a test embedding. +func (g *GeminiProvider) DetectVectorSize() (int, error) { + vec, err := g.EmbedContent("test") + if err != nil { + return 0, fmt.Errorf("detect vector size: %w", err) + } + return len(vec), nil +} diff --git a/internal/providers/generic_http.go b/internal/providers/generic_http.go index a8ec1f1..36987b9 100644 --- a/internal/providers/generic_http.go +++ b/internal/providers/generic_http.go @@ -216,6 +216,15 @@ func (p *GenericHTTPProvider) EmbedContent(text string) ([]float32, error) { return resp.Data[0].Embedding, nil } +// DetectVectorSize returns the embedding dimensionality by generating a test embedding. +func (p *GenericHTTPProvider) DetectVectorSize() (int, error) { + vec, err := p.EmbedContent("test") + if err != nil { + return 0, fmt.Errorf("detect vector size: %w", err) + } + return len(vec), nil +} + // doChat sends a chat completions request and decodes the response. func (p *GenericHTTPProvider) doChat(ctx context.Context, reqBody chatRequest, result *chatResponse) error { return p.doRequest(ctx, "/chat/completions", reqBody, result) diff --git a/internal/providers/provider.go b/internal/providers/provider.go index c7d4a35..4ff10de 100644 --- a/internal/providers/provider.go +++ b/internal/providers/provider.go @@ -16,6 +16,9 @@ type TextGenerator interface { // EmbeddingProvider generates vector embeddings from text. type EmbeddingProvider interface { EmbedContent(text string) ([]float32, error) + // DetectVectorSize returns the embedding vector dimensionality + // by generating a test embedding. + DetectVectorSize() (int, error) } // NewTextGenerator creates a TextGenerator based on provider config.