From 741379ac41cafea11e28132804f0fb0ee4d54c6e Mon Sep 17 00:00:00 2001 From: Viperdk2020 <69140170+Viperdk2020@users.noreply.github.com> Date: Mon, 1 Sep 2025 05:37:13 +0200 Subject: [PATCH] feat(memory): implement recall scoring --- codex-rs/Cargo.lock | 1 + codex-rs/memory/Cargo.toml | 1 + codex-rs/memory/src/recall.rs | 92 ++++++++++++++++++++- codex-rs/memory/tests/recall.rs | 140 +++++++++++++++++++++++++++++++- 4 files changed, 228 insertions(+), 6 deletions(-) diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 152990aa479..445f520f1ee 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1424,6 +1424,7 @@ name = "codex-memory" version = "0.0.0" dependencies = [ "anyhow", + "chrono", "rusqlite", "serde", "serde_json", diff --git a/codex-rs/memory/Cargo.toml b/codex-rs/memory/Cargo.toml index 8d5c26a4bd4..f5bb0440bc9 100644 --- a/codex-rs/memory/Cargo.toml +++ b/codex-rs/memory/Cargo.toml @@ -7,6 +7,7 @@ edition = { workspace = true } anyhow = "1" serde = { version = "1", features = ["derive"] } serde_json = "1" +chrono = { version = "0.4", default-features = false, features = ["clock"] } [features] default = [] diff --git a/codex-rs/memory/src/recall.rs b/codex-rs/memory/src/recall.rs index fe74d550e45..f86a9925c4d 100644 --- a/codex-rs/memory/src/recall.rs +++ b/codex-rs/memory/src/recall.rs @@ -1,4 +1,9 @@ +use crate::store::MemoryStore; use crate::types::MemoryItem; +use crate::types::Status; +use chrono::DateTime; +use chrono::Utc; +use std::collections::BTreeSet; pub struct RecallContext { pub repo_root: Option, @@ -13,9 +18,88 @@ pub struct RecallContext { } pub fn recall( - _store: &dyn crate::store::MemoryStore, - _prompt: &str, - _ctx: &RecallContext, + store: &dyn MemoryStore, + prompt: &str, + ctx: &RecallContext, ) -> anyhow::Result> { - todo!() + let now = DateTime::parse_from_rfc3339(&ctx.now_rfc3339)?.with_timezone(&Utc); + let tokens = tokenize(prompt); + let mut scored: Vec<(f32, usize, MemoryItem)> = store + .list(None, Some(Status::Active))? + .into_iter() + .map(|item| { + let mut score = overlap_score(&tokens, &tokenize(&item.content)); + if let Some(f) = &ctx.current_file + && item.relevance_hints.files.iter().any(|h| f.ends_with(h)) + { + score += 0.4; + } + if let Some(c) = &ctx.crate_name + && item.relevance_hints.crates.iter().any(|h| h == c) + { + score += 0.3; + } + if let Some(l) = &ctx.language + && item + .relevance_hints + .languages + .iter() + .any(|h| h.eq_ignore_ascii_case(l)) + { + score += 0.2; + } + if let Some(cmd) = &ctx.command + && item.relevance_hints.commands.iter().any(|h| h == cmd) + { + score += 0.1; + } + let freq = 1.0 + item.counters.used_count as f32 * 0.1; + score *= freq; + if let Some(last) = &item.counters.last_used_at + && let Ok(dt) = DateTime::parse_from_rfc3339(last) + { + let age_days = (now - dt.with_timezone(&Utc)).num_days(); + let decay = 0.5f32.powf(age_days as f32 / 7.0); + score *= decay; + } + let token_len = item.content.split_whitespace().count(); + (score, token_len, item) + }) + .collect(); + scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + let mut out = Vec::new(); + let mut used_tokens = 0usize; + for (_, tokens, mut item) in scored { + if out.len() >= ctx.item_cap { + break; + } + if used_tokens + tokens > ctx.token_cap { + break; + } + used_tokens += tokens; + item.counters.used_count += 1; + item.counters.last_used_at = Some(ctx.now_rfc3339.clone()); + store.update(&item)?; + out.push(item); + } + Ok(out) +} + +fn tokenize(s: &str) -> BTreeSet { + let mut set = BTreeSet::new(); + for w in s.split(|c: char| !c.is_alphanumeric()) { + if w.is_empty() { + continue; + } + set.insert(w.to_ascii_lowercase()); + } + set +} + +fn overlap_score(a: &BTreeSet, b: &BTreeSet) -> f32 { + if a.is_empty() || b.is_empty() { + return 0.0; + } + let inter = a.intersection(b).count() as f32; + inter / a.len() as f32 } diff --git a/codex-rs/memory/tests/recall.rs b/codex-rs/memory/tests/recall.rs index 27a5b49e5e8..36f8bc1832d 100644 --- a/codex-rs/memory/tests/recall.rs +++ b/codex-rs/memory/tests/recall.rs @@ -1,4 +1,140 @@ +use codex_memory::recall::RecallContext; +use codex_memory::recall::recall; +use codex_memory::store::MemoryStore; +use codex_memory::types::Counters; +use codex_memory::types::Kind; +use codex_memory::types::MemoryItem; +use codex_memory::types::RelevanceHints; +use codex_memory::types::Scope; +use codex_memory::types::Status; +use std::collections::HashMap; +use std::sync::Mutex; + +#[derive(Default)] +struct TestStore { + items: Mutex>, +} + +impl TestStore { + fn new(items: Vec) -> Self { + let map = items.into_iter().map(|i| (i.id.clone(), i)).collect(); + Self { + items: Mutex::new(map), + } + } +} + +impl MemoryStore for TestStore { + fn add(&self, item: MemoryItem) -> anyhow::Result<()> { + self.items.lock().unwrap().insert(item.id.clone(), item); + Ok(()) + } + + fn update(&self, item: &MemoryItem) -> anyhow::Result<()> { + self.items + .lock() + .unwrap() + .insert(item.id.clone(), item.clone()); + Ok(()) + } + + fn delete(&self, _id: &str) -> anyhow::Result<()> { + Ok(()) + } + + fn get(&self, id: &str) -> anyhow::Result> { + Ok(self.items.lock().unwrap().get(id).cloned()) + } + + fn list( + &self, + _scope: Option, + status: Option, + ) -> anyhow::Result> { + let items = self.items.lock().unwrap(); + Ok(items + .values() + .filter(|i| match status.as_ref() { + Some(s) => i.status == *s, + None => true, + }) + .cloned() + .collect()) + } + + fn archive(&self, _id: &str, _archived: bool) -> anyhow::Result<()> { + Ok(()) + } + + fn export(&self, _out: &mut dyn std::io::Write) -> anyhow::Result<()> { + Ok(()) + } + + fn import(&self, _input: &mut dyn std::io::Read) -> anyhow::Result { + Ok(0) + } + + fn stats(&self) -> anyhow::Result { + Ok(serde_json::json!({})) + } +} + +fn item(id: &str, content: &str, lang: &str) -> MemoryItem { + MemoryItem { + id: id.to_string(), + created_at: "2024-01-01T00:00:00Z".into(), + updated_at: "2024-01-01T00:00:00Z".into(), + schema_version: 1, + source: "test".into(), + scope: Scope::Global, + status: Status::Active, + kind: Kind::Fact, + content: content.into(), + tags: vec![], + relevance_hints: RelevanceHints { + files: vec![], + crates: vec![], + languages: vec![lang.into()], + commands: vec![], + }, + counters: Counters { + seen_count: 0, + used_count: 0, + last_used_at: None, + }, + expiry: None, + } +} + #[test] -fn placeholder() { - // placeholder test +fn ranks_and_updates_counters() { + let a = item("1", "use cargo build for rust", "rust"); + let b = item("2", "cargo test runs tests", "rust"); + let c = item("3", "npm install packages", "javascript"); + let store = TestStore::new(vec![a.clone(), b.clone(), c.clone()]); + let now = "2024-01-10T00:00:00Z".to_string(); + let ctx = RecallContext { + repo_root: None, + dir: None, + current_file: None, + crate_name: None, + language: Some("rust".into()), + command: None, + now_rfc3339: now.clone(), + item_cap: 2, + token_cap: 50, + }; + let out = recall(&store, "cargo build rust", &ctx).unwrap(); + assert_eq!(out.len(), 2); + assert_eq!(out[0].id, "1"); + assert_eq!(out[1].id, "2"); + let a_upd = store.get("1").unwrap().unwrap(); + assert_eq!(a_upd.counters.used_count, 1); + assert_eq!(a_upd.counters.last_used_at.as_deref(), Some(now.as_str())); + let b_upd = store.get("2").unwrap().unwrap(); + assert_eq!(b_upd.counters.used_count, 1); + assert_eq!(b_upd.counters.last_used_at.as_deref(), Some(now.as_str())); + let c_upd = store.get("3").unwrap().unwrap(); + assert_eq!(c_upd.counters.used_count, 0); + assert_eq!(c_upd.counters.last_used_at, None); }