From 4469e6b18a93f8f8eca7f5669be4c66caccd63bb Mon Sep 17 00:00:00 2001 From: Manuel Retamozo Date: Sun, 26 Apr 2026 19:34:45 +0200 Subject: [PATCH] feat(store): replace FTS5 default ranking with weighted BM25 scoring Replace fts.rank with bm25(observations_fts, 5.0, 1.0, 0.0, 0.0, 0.0, 3.0) to apply weighted BM25 scoring in FTS5 search queries. Weights: title=5, content=1, tool_name=0, type=0, project=0, topic_key=3. Direct topic_key route (Rank=-1000) remains unchanged. Closes #241 --- internal/store/store.go | 4 +- internal/store/store_test.go | 156 +++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 2 deletions(-) diff --git a/internal/store/store.go b/internal/store/store.go index d1d9f8e..2af70f5 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -2621,7 +2621,7 @@ func (s *Store) Search(query string, opts SearchOptions) ([]SearchResult, error) sqlQ := ` SELECT o.id, ifnull(o.sync_id, '') as sync_id, o.session_id, o.type, o.title, o.content, o.tool_name, o.project, o.scope, o.topic_key, o.revision_count, o.duplicate_count, o.last_seen_at, o.created_at, o.updated_at, o.deleted_at, - fts.rank + bm25(observations_fts, 5.0, 1.0, 0.0, 0.0, 0.0, 3.0) as rank -- weights: title=5, content=1, tool_name=0, type=0, project=0, topic_key=3 FROM observations_fts fts JOIN observations o ON o.id = fts.rowid WHERE observations_fts MATCH ? AND o.deleted_at IS NULL @@ -2643,7 +2643,7 @@ func (s *Store) Search(query string, opts SearchOptions) ([]SearchResult, error) args = append(args, normalizeScope(opts.Scope)) } - sqlQ += " ORDER BY fts.rank LIMIT ?" + sqlQ += " ORDER BY rank LIMIT ?" args = append(args, limit) rows, err := s.queryItHook(s.db, sqlQ, args...) diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 81e85c2..22558ea 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -7004,3 +7004,159 @@ func TestAddObservation_DecayNotAppliedToExistingRows(t *testing.T) { t.Errorf("revision must not overwrite review_after: was %q, now %q", ra1, ra2) } } + +func TestBM25RankingOrdersTitleAboveTopicKeyAboveContent(t *testing.T) { + s := newTestStore(t) + + if err := s.CreateSession("s1", "engram", "/tmp/engram"); err != nil { + t.Fatalf("create session: %v", err) + } + + // Observation with "kubernetes" in title only + _, err := s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "decision", + Title: "kubernetes deployment strategy", + Content: "we decided on rolling updates for the cluster", + Project: "engram", + Scope: "project", + }) + if err != nil { + t.Fatalf("add title-match obs: %v", err) + } + + // Observation with "kubernetes" in topic_key only + _, err = s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "decision", + Title: "container orchestration setup", + Content: "we decided on orchestration for the cluster", + Project: "engram", + Scope: "project", + TopicKey: "infra/kubernetes-config", + }) + if err != nil { + t.Fatalf("add topic-match obs: %v", err) + } + + // Observation with "kubernetes" in content only + _, err = s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "decision", + Title: "infrastructure notes", + Content: "we migrated everything to kubernetes last quarter", + Project: "engram", + Scope: "project", + }) + if err != nil { + t.Fatalf("add content-match obs: %v", err) + } + + results, err := s.Search("kubernetes", SearchOptions{Project: "engram", Limit: 10}) + if err != nil { + t.Fatalf("search: %v", err) + } + + if len(results) < 3 { + t.Fatalf("expected at least 3 results, got %d", len(results)) + } + + // With BM25 weights (title=5, topic_key=3, content=1): + // title-match should rank highest, then topic_key-match, then content-match + if !strings.Contains(results[0].Title, "kubernetes") { + t.Fatalf("expected first result to have 'kubernetes' in title (title-match), got %q", results[0].Title) + } + if results[0].TopicKey != nil && strings.Contains(*results[0].TopicKey, "kubernetes") { + // First result should be the title-match, not the topic-match + } else if !strings.Contains(results[0].Title, "kubernetes") { + t.Fatalf("first result should be the title-match observation") + } + + // The topic_key match should come before the content-only match + topicIdx := -1 + contentIdx := -1 + for i, r := range results { + if r.TopicKey != nil && strings.Contains(*r.TopicKey, "kubernetes") { + topicIdx = i + } + if r.Title == "infrastructure notes" { + contentIdx = i + } + } + if topicIdx == -1 || contentIdx == -1 { + t.Fatalf("could not find topic-match (idx=%d) or content-match (idx=%d) in results", topicIdx, contentIdx) + } + if topicIdx > contentIdx { + t.Fatalf("expected topic_key match (idx=%d) to rank above content match (idx=%d) with BM25 weights", topicIdx, contentIdx) + } +} + +func TestBM25DoesNotAffectTopicKeyDirectRoute(t *testing.T) { + s := newTestStore(t) + + if err := s.CreateSession("s1", "engram", "/tmp/engram"); err != nil { + t.Fatalf("create session: %v", err) + } + + _, err := s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "architecture", + Title: "JWT auth middleware", + Content: "Use JWT for all API endpoints", + Project: "engram", + Scope: "project", + TopicKey: "auth/jwt", + }) + if err != nil { + t.Fatalf("add obs: %v", err) + } + + // Query containing "/" triggers the direct topic_key route + results, err := s.Search("auth/jwt", SearchOptions{Project: "engram", Limit: 10}) + if err != nil { + t.Fatalf("search: %v", err) + } + + if len(results) == 0 { + t.Fatalf("expected at least 1 result from direct topic_key route") + } + + if results[0].Rank != -1000 { + t.Fatalf("expected Rank=-1000 for direct topic_key match, got %f", results[0].Rank) + } +} + +func TestBM25SearchReturnsNoSQLErrors(t *testing.T) { + s := newTestStore(t) + + if err := s.CreateSession("s1", "engram", "/tmp/engram"); err != nil { + t.Fatalf("create session: %v", err) + } + + // Add at least one observation so the FTS table isn't empty + _, err := s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "note", + Title: "test observation", + Content: "some content for testing", + Project: "engram", + Scope: "project", + }) + if err != nil { + t.Fatalf("add obs: %v", err) + } + + // Test various edge-case queries that should not produce SQL errors + edgeCases := []string{ + "normal query", + "special chars: @#$%^&*()", + strings.Repeat("a", 500), // very long query + } + + for _, q := range edgeCases { + _, err := s.Search(q, SearchOptions{Project: "engram", Limit: 10}) + if err != nil { + t.Fatalf("search with query %q returned error: %v", q[:min(len(q), 50)], err) + } + } +}