Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion cmd/engram/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"syscall"

"github.com/Gentleman-Programming/engram/internal/embedding"
"github.com/Gentleman-Programming/engram/internal/mcp"
"github.com/Gentleman-Programming/engram/internal/project"
"github.com/Gentleman-Programming/engram/internal/server"
Expand Down Expand Up @@ -162,6 +163,8 @@ func main() {
cmdProjects(cfg)
case "setup":
cmdSetup()
case "backfill-embeddings":
cmdBackfillEmbeddings(cfg)
case "version", "--version", "-v":
fmt.Printf("engram %s\n", version)
case "help", "--help", "-h":
Expand Down Expand Up @@ -195,6 +198,8 @@ func cmdServe(cfg store.Config) {
}
defer s.Close()

configureEmbeddings(s, "", "", "")

srv := newHTTPServer(s, port)

// Graceful shutdown on SIGINT/SIGTERM.
Expand All @@ -212,9 +217,12 @@ func cmdServe(cfg store.Config) {
}

func cmdMCP(cfg store.Config) {
// Parse --tools and --project flags
// Parse --tools, --project, and --embedding-* flags
toolsFilter := ""
projectOverride := ""
embProvider := ""
embModel := ""
embURL := ""
for i := 2; i < len(os.Args); i++ {
if strings.HasPrefix(os.Args[i], "--tools=") {
toolsFilter = strings.TrimPrefix(os.Args[i], "--tools=")
Expand All @@ -226,6 +234,21 @@ func cmdMCP(cfg store.Config) {
} else if os.Args[i] == "--project" && i+1 < len(os.Args) {
projectOverride = os.Args[i+1]
i++
} else if strings.HasPrefix(os.Args[i], "--embedding-provider=") {
embProvider = strings.TrimPrefix(os.Args[i], "--embedding-provider=")
} else if os.Args[i] == "--embedding-provider" && i+1 < len(os.Args) {
embProvider = os.Args[i+1]
i++
} else if strings.HasPrefix(os.Args[i], "--embedding-model=") {
embModel = strings.TrimPrefix(os.Args[i], "--embedding-model=")
} else if os.Args[i] == "--embedding-model" && i+1 < len(os.Args) {
embModel = os.Args[i+1]
i++
} else if strings.HasPrefix(os.Args[i], "--embedding-url=") {
embURL = strings.TrimPrefix(os.Args[i], "--embedding-url=")
} else if os.Args[i] == "--embedding-url" && i+1 < len(os.Args) {
embURL = os.Args[i+1]
i++
}
}

Expand All @@ -248,6 +271,8 @@ func cmdMCP(cfg store.Config) {
}
defer s.Close()

configureEmbeddings(s, embProvider, embModel, embURL)

mcpCfg := mcp.MCPConfig{
DefaultProject: detectedProject,
}
Expand All @@ -260,6 +285,42 @@ func cmdMCP(cfg store.Config) {
}
}

// configureEmbeddings sets up an embedding provider on the store.
// CLI flags take precedence over environment variables.
func configureEmbeddings(s *store.Store, provider, model, url string) {
// Environment variable fallbacks
if provider == "" {
provider = os.Getenv("ENGRAM_EMBEDDING_PROVIDER")
}
if model == "" {
model = os.Getenv("ENGRAM_EMBEDDING_MODEL")
}
if url == "" {
url = os.Getenv("ENGRAM_EMBEDDING_URL")
}

if provider == "" || provider == "none" {
return
}

embCfg := embedding.Config{
Provider: provider,
Model: model,
URL: url,
APIKey: os.Getenv("ENGRAM_EMBEDDING_API_KEY"),
}

emb, err := embedding.NewProvider(embCfg)
if err != nil {
log.Printf("[engram] embedding provider setup failed: %v", err)
return
}
if emb != nil {
s.SetEmbeddingProvider(emb)
log.Printf("[engram] embedding provider: %s (model: %s)", provider, emb.ModelName())
}
}

func cmdTUI(cfg store.Config) {
s, err := storeNew(cfg)
if err != nil {
Expand Down Expand Up @@ -726,6 +787,53 @@ func cmdSync(cfg store.Config) {
fmt.Printf(" git add .engram/ && git commit -m \"sync engram memories\"\n")
}

func cmdBackfillEmbeddings(cfg store.Config) {
batchSize := 50
embProvider := ""
embModel := ""
embURL := ""

for i := 2; i < len(os.Args); i++ {
if strings.HasPrefix(os.Args[i], "--batch-size=") {
if n, err := strconv.Atoi(strings.TrimPrefix(os.Args[i], "--batch-size=")); err == nil {
batchSize = n
}
} else if strings.HasPrefix(os.Args[i], "--embedding-provider=") {
embProvider = strings.TrimPrefix(os.Args[i], "--embedding-provider=")
} else if strings.HasPrefix(os.Args[i], "--embedding-model=") {
embModel = strings.TrimPrefix(os.Args[i], "--embedding-model=")
} else if strings.HasPrefix(os.Args[i], "--embedding-url=") {
embURL = strings.TrimPrefix(os.Args[i], "--embedding-url=")
}
}

s, err := storeNew(cfg)
if err != nil {
fatal(err)
}
defer s.Close()

configureEmbeddings(s, embProvider, embModel, embURL)

if s.EmbeddingProvider() == nil {
fmt.Fprintln(os.Stderr, "error: no embedding provider configured")
fmt.Fprintln(os.Stderr, " set --embedding-provider=ollama or ENGRAM_EMBEDDING_PROVIDER=ollama")
exitFunc(1)
return
}

fmt.Fprintf(os.Stderr, "Backfilling embeddings (batch size: %d, provider: %s)...\n", batchSize, s.EmbeddingProvider().ModelName())

if err := s.BackfillEmbeddings(batchSize, func(done, total int) {
fmt.Fprintf(os.Stderr, "\r %d / %d observations embedded", done, total)
}); err != nil {
fmt.Fprintln(os.Stderr)
fatal(err)
}

fmt.Fprintln(os.Stderr, "\nDone.")
}

func cmdProjects(cfg store.Config) {
// Route: engram projects list | engram projects consolidate [--all] [--dry-run]
subCmd := "list"
Expand Down
Loading