diff --git a/internal/store/store.go b/internal/store/store.go index 445a2ab..3d53775 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -2720,7 +2720,7 @@ func (s *Store) Search(query string, opts SearchOptions) ([]SearchResult, error) sqlQ := ` SELECT o.id, ifnull(o.sync_id, '') as sync_id, o.session_id, o.type, o.title, o.content, o.tool_name, o.project, o.scope, o.topic_key, o.revision_count, o.duplicate_count, o.last_seen_at, o.created_at, o.updated_at, o.deleted_at, - fts.rank + bm25(observations_fts, 5.0, 1.0, 0.0, 0.0, 0.0, 3.0) as rank -- weights: title=5, content=1, tool_name=0, type=0, project=0, topic_key=3 FROM observations_fts fts JOIN observations o ON o.id = fts.rowid WHERE observations_fts MATCH ? AND o.deleted_at IS NULL @@ -2742,7 +2742,7 @@ func (s *Store) Search(query string, opts SearchOptions) ([]SearchResult, error) args = append(args, normalizeScope(opts.Scope)) } - sqlQ += " ORDER BY fts.rank LIMIT ?" + sqlQ += " ORDER BY rank LIMIT ?" args = append(args, limit) rows, err := s.queryItHook(s.db, sqlQ, args...) diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 9ddeed4..5f601a4 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -7338,146 +7338,158 @@ func TestAddObservation_DecayNotAppliedToExistingRows(t *testing.T) { } } -// ─── C.2 [RED] — ListDeferred / GetDeferred ────────────────────────────────── - -// seedDeferredRow is a test helper that inserts a row into sync_apply_deferred. -// Uses a different name from the existing insertDeferredRow in sync_apply_test.go -// which has parameter order (syncID, entity, payload, retryCount, applyStatus). -func seedDeferredRow(t *testing.T, s *Store, syncID, entity, payload string, retryCount int, applyStatus string) { - t.Helper() - if _, err := s.db.Exec(` - INSERT INTO sync_apply_deferred - (sync_id, entity, payload, apply_status, retry_count, first_seen_at) - VALUES (?, ?, ?, ?, ?, datetime('now')) - `, syncID, entity, payload, applyStatus, retryCount); err != nil { - t.Fatalf("seedDeferredRow %q: %v", syncID, err) - } -} - -// TestListDeferred_HappyPath verifies pagination and status filter. -func TestListDeferred_HappyPath(t *testing.T) { +func TestBM25RankingOrdersTitleAboveTopicKeyAboveContent(t *testing.T) { s := newTestStore(t) - validPayload := `{"relation_type":"conflicts_with","source_id":"obs-aaa","target_id":"obs-bbb"}` - seedDeferredRow(t, s, "def-001", "relation", validPayload, 0, "deferred") - seedDeferredRow(t, s, "def-002", "relation", validPayload, 1, "deferred") - seedDeferredRow(t, s, "def-003", "relation", validPayload, 5, "dead") + if err := s.CreateSession("s1", "engram", "/tmp/engram"); err != nil { + t.Fatalf("create session: %v", err) + } - // List all. - all, err := s.ListDeferred(ListDeferredOptions{Limit: 50}) + // Observation with "kubernetes" in title only + _, err := s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "decision", + Title: "kubernetes deployment strategy", + Content: "we decided on rolling updates for the cluster", + Project: "engram", + Scope: "project", + }) if err != nil { - t.Fatalf("ListDeferred all: %v", err) - } - if len(all) != 3 { - t.Errorf("expected 3 rows; got %d", len(all)) + t.Fatalf("add title-match obs: %v", err) } - // List only deferred status. - deferred, err := s.ListDeferred(ListDeferredOptions{Status: "deferred", Limit: 50}) + // Observation with "kubernetes" in topic_key only + _, err = s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "decision", + Title: "container orchestration setup", + Content: "we decided on orchestration for the cluster", + Project: "engram", + Scope: "project", + TopicKey: "infra/kubernetes-config", + }) if err != nil { - t.Fatalf("ListDeferred deferred: %v", err) - } - if len(deferred) != 2 { - t.Errorf("expected 2 deferred rows; got %d", len(deferred)) + t.Fatalf("add topic-match obs: %v", err) } - // Pagination: limit=1. - page, err := s.ListDeferred(ListDeferredOptions{Limit: 1}) + // Observation with "kubernetes" in content only + _, err = s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "decision", + Title: "infrastructure notes", + Content: "we migrated everything to kubernetes last quarter", + Project: "engram", + Scope: "project", + }) if err != nil { - t.Fatalf("ListDeferred limit=1: %v", err) - } - if len(page) != 1 { - t.Errorf("expected 1 row with limit=1; got %d", len(page)) + t.Fatalf("add content-match obs: %v", err) } -} -// TestListDeferred_DecodedPayload verifies that DeferredRow.Payload is decoded -// and PayloadValid=true for well-formed JSON. -func TestListDeferred_DecodedPayload(t *testing.T) { - s := newTestStore(t) + results, err := s.Search("kubernetes", SearchOptions{Project: "engram", Limit: 10}) + if err != nil { + t.Fatalf("search: %v", err) + } - validPayload := `{"relation_type":"conflicts_with","source_id":"obs-src","target_id":"obs-tgt","extra":42}` - seedDeferredRow(t, s, "def-valid", "relation", validPayload, 0, "deferred") + if len(results) < 3 { + t.Fatalf("expected at least 3 results, got %d", len(results)) + } - rows, err := s.ListDeferred(ListDeferredOptions{Limit: 10}) - if err != nil { - t.Fatalf("ListDeferred: %v", err) + // With BM25 weights (title=5, topic_key=3, content=1): + // title-match should rank highest, then topic_key-match, then content-match + if !strings.Contains(results[0].Title, "kubernetes") { + t.Fatalf("expected first result to have 'kubernetes' in title (title-match), got %q", results[0].Title) } - if len(rows) != 1 { - t.Fatalf("expected 1 row; got %d", len(rows)) + if results[0].TopicKey != nil && strings.Contains(*results[0].TopicKey, "kubernetes") { + // First result should be the title-match, not the topic-match + } else if !strings.Contains(results[0].Title, "kubernetes") { + t.Fatalf("first result should be the title-match observation") } - row := rows[0] - if !row.PayloadValid { - t.Errorf("expected PayloadValid=true for well-formed JSON; got false. PayloadRaw=%q", row.PayloadRaw) + + // The topic_key match should come before the content-only match + topicIdx := -1 + contentIdx := -1 + for i, r := range results { + if r.TopicKey != nil && strings.Contains(*r.TopicKey, "kubernetes") { + topicIdx = i + } + if r.Title == "infrastructure notes" { + contentIdx = i + } } - if row.Payload == nil { - t.Fatal("expected decoded Payload map; got nil") + if topicIdx == -1 || contentIdx == -1 { + t.Fatalf("could not find topic-match (idx=%d) or content-match (idx=%d) in results", topicIdx, contentIdx) } - if row.Payload["relation_type"] != "conflicts_with" { - t.Errorf("decoded Payload[relation_type]: want conflicts_with; got %v", row.Payload["relation_type"]) + if topicIdx > contentIdx { + t.Fatalf("expected topic_key match (idx=%d) to rank above content match (idx=%d) with BM25 weights", topicIdx, contentIdx) } } -// TestListDeferred_MalformedPayload verifies that a malformed JSON payload sets -// PayloadValid=false and preserves PayloadRaw. -func TestListDeferred_MalformedPayload(t *testing.T) { +func TestBM25DoesNotAffectTopicKeyDirectRoute(t *testing.T) { s := newTestStore(t) - seedDeferredRow(t, s, "def-bad", "relation", "not valid json", 5, "dead") + if err := s.CreateSession("s1", "engram", "/tmp/engram"); err != nil { + t.Fatalf("create session: %v", err) + } - rows, err := s.ListDeferred(ListDeferredOptions{Limit: 10}) + _, err := s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "architecture", + Title: "JWT auth middleware", + Content: "Use JWT for all API endpoints", + Project: "engram", + Scope: "project", + TopicKey: "auth/jwt", + }) if err != nil { - t.Fatalf("ListDeferred malformed: %v", err) + t.Fatalf("add obs: %v", err) } - if len(rows) != 1 { - t.Fatalf("expected 1 row; got %d", len(rows)) + + // Query containing "/" triggers the direct topic_key route + results, err := s.Search("auth/jwt", SearchOptions{Project: "engram", Limit: 10}) + if err != nil { + t.Fatalf("search: %v", err) } - row := rows[0] - if row.PayloadValid { - t.Errorf("expected PayloadValid=false for malformed JSON; got true") + + if len(results) == 0 { + t.Fatalf("expected at least 1 result from direct topic_key route") } - if row.PayloadRaw != "not valid json" { - t.Errorf("expected PayloadRaw preserved; got %q", row.PayloadRaw) + + if results[0].Rank != -1000 { + t.Fatalf("expected Rank=-1000 for direct topic_key match, got %f", results[0].Rank) } } -// TestGetDeferred_HappyPath verifies GetDeferred returns the correct row. -func TestGetDeferred_HappyPath(t *testing.T) { +func TestBM25SearchReturnsNoSQLErrors(t *testing.T) { s := newTestStore(t) - validPayload := `{"relation_type":"related","source_id":"obs-xyz","target_id":"obs-abc"}` - seedDeferredRow(t, s, "def-xyz", "relation", validPayload, 2, "deferred") + if err := s.CreateSession("s1", "engram", "/tmp/engram"); err != nil { + t.Fatalf("create session: %v", err) + } - row, err := s.GetDeferred("def-xyz") + // Add at least one observation so the FTS table isn't empty + _, err := s.AddObservation(AddObservationParams{ + SessionID: "s1", + Type: "note", + Title: "test observation", + Content: "some content for testing", + Project: "engram", + Scope: "project", + }) if err != nil { - t.Fatalf("GetDeferred: %v", err) - } - if row.SyncID != "def-xyz" { - t.Errorf("expected SyncID=def-xyz; got %q", row.SyncID) - } - if row.ApplyStatus != "deferred" { - t.Errorf("expected ApplyStatus=deferred; got %q", row.ApplyStatus) - } - if row.RetryCount != 2 { - t.Errorf("expected RetryCount=2; got %d", row.RetryCount) + t.Fatalf("add obs: %v", err) } - if !row.PayloadValid { - t.Errorf("expected PayloadValid=true for valid JSON; got false") - } - if row.Payload["relation_type"] != "related" { - t.Errorf("decoded Payload[relation_type]: want related; got %v", row.Payload["relation_type"]) - } -} - -// TestGetDeferred_NotFound verifies GetDeferred returns an error wrapping "not found". -func TestGetDeferred_NotFound(t *testing.T) { - s := newTestStore(t) - _, err := s.GetDeferred("def-missing") - if err == nil { - t.Fatal("expected error for missing sync_id; got nil") + // Test various edge-case queries that should not produce SQL errors + edgeCases := []string{ + "normal query", + "special chars: @#$%^&*()", + strings.Repeat("a", 500), // very long query } - if !strings.Contains(err.Error(), "not found") { - t.Errorf("expected error to contain 'not found'; got %q", err.Error()) + + for _, q := range edgeCases { + _, err := s.Search(q, SearchOptions{Project: "engram", Limit: 10}) + if err != nil { + t.Fatalf("search with query %q returned error: %v", q[:min(len(q), 50)], err) + } } }