diff --git a/pkg/embedding/gemini.go b/pkg/embedding/gemini.go index f09fc39b..fd8e6cb7 100644 --- a/pkg/embedding/gemini.go +++ b/pkg/embedding/gemini.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" ) @@ -20,7 +21,12 @@ type GeminiProvider struct { type GeminiOption func(*GeminiProvider) func WithGeminiModel(model string) GeminiOption { - return func(p *GeminiProvider) { p.model = model } + return func(p *GeminiProvider) { + p.model = normalizeGeminiModel(model) + if dim := geminiModelDimension(p.model); dim > 0 { + p.dimension = dim + } + } } func NewGeminiProvider(apiKey string, opts ...GeminiOption) (*GeminiProvider, error) { @@ -71,6 +77,7 @@ func (p *GeminiProvider) EmbedBatch(ctx context.Context, texts []string) ([][]fl func (p *GeminiProvider) embedSingle(ctx context.Context, text string) ([]float32, error) { payload := map[string]any{ + "model": p.model, "content": map[string]any{ "parts": []any{ map[string]string{"text": text}, @@ -108,3 +115,25 @@ func (p *GeminiProvider) embedSingle(ctx context.Context, text string) ([]float3 return result.Embedding.Values, nil } + +func normalizeGeminiModel(model string) string { + model = strings.TrimSpace(model) + if model == "" { + return "models/text-embedding-004" + } + if !strings.Contains(model, "/") { + return "models/" + model + } + return model +} + +func geminiModelDimension(model string) int { + switch normalizeGeminiModel(model) { + case "models/gemini-embedding-001": + return 3072 + case "models/text-embedding-004": + return 768 + default: + return 0 + } +} diff --git a/pkg/embedding/manager_test.go b/pkg/embedding/manager_test.go index a726082e..c3bffd6d 100644 --- a/pkg/embedding/manager_test.go +++ b/pkg/embedding/manager_test.go @@ -224,6 +224,48 @@ func TestGeminiProvider(t *testing.T) { } } +func TestGeminiProviderGeminiEmbedding001(t *testing.T) { + p, err := NewGeminiProvider("test-key", WithGeminiModel("gemini-embedding-001")) + if err != nil { + t.Fatalf("failed to create provider: %v", err) + } + + if p.model != "models/gemini-embedding-001" { + t.Fatalf("expected normalized model path, got %q", p.model) + } + if p.Dimension() != 3072 { + t.Fatalf("expected dimension 3072, got %d", p.Dimension()) + } +} + +func TestGeminiProviderNormalizesModelPath(t *testing.T) { + p, err := NewGeminiProvider("test-key", WithGeminiModel("text-embedding-004")) + if err != nil { + t.Fatalf("failed to create provider: %v", err) + } + + if p.model != "models/text-embedding-004" { + t.Fatalf("expected normalized model path, got %q", p.model) + } + if p.Dimension() != 768 { + t.Fatalf("expected dimension 768, got %d", p.Dimension()) + } +} + +func TestGeminiProviderUnknownModelKeepsDefaultDimension(t *testing.T) { + p, err := NewGeminiProvider("test-key", WithGeminiModel("custom-gemini-embed")) + if err != nil { + t.Fatalf("failed to create provider: %v", err) + } + + if p.model != "models/custom-gemini-embed" { + t.Fatalf("expected normalized model path, got %q", p.model) + } + if p.Dimension() != 768 { + t.Fatalf("expected default dimension 768 for unknown Gemini models, got %d", p.Dimension()) + } +} + func TestOllamaProvider(t *testing.T) { p := NewOllamaProvider( WithOllamaBaseURL("http://localhost:11434"),