diff --git a/src/aish/config.py b/src/aish/config.py index 6847c57..93b4876 100644 --- a/src/aish/config.py +++ b/src/aish/config.py @@ -12,6 +12,8 @@ import yaml from pydantic import BaseModel, ConfigDict, Field, field_validator +from aish.memory.config import MemoryConfig + class ToolArgPreviewSettingsDict(TypedDict): enabled: bool @@ -234,6 +236,11 @@ class ConfigModel(BaseModel): description="Whether the current configuration uses a free API key", ) + memory: MemoryConfig = Field( + default_factory=MemoryConfig, + description="Long-term memory configuration", + ) + @field_validator("tool_arg_preview", mode="before") @classmethod def normalize_tool_arg_preview(cls, v: Any) -> dict[str, ToolArgPreviewSettings]: @@ -379,6 +386,10 @@ def _load_config(self) -> ConfigModel: if "prompt_theme" not in config_data: config_data["prompt_theme"] = "compact" need_save = True + # Add memory section if missing (new field migration) + if "memory" not in config_data: + config_data["memory"] = {"enabled": True} + need_save = True if need_save: self._save_config_data(config_data) diff --git a/src/aish/llm.py b/src/aish/llm.py index 06e7059..496063f 100644 --- a/src/aish/llm.py +++ b/src/aish/llm.py @@ -325,6 +325,7 @@ def __init__( env_manager=None, interruption_manager=None, history_manager=None, + memory_manager=None, ): # noqa: F821 self.config = config self.model = config.model @@ -389,6 +390,13 @@ def __init__( self.system_diagnose_agent.name: self.system_diagnose_agent, self.skill_tool.name: self.skill_tool, } + + # Register memory tool if memory manager is provided + if memory_manager is not None: + from aish.tools.memory_tool import MemoryTool + + self.memory_tool = MemoryTool(memory_manager=memory_manager) + self.tools[self.memory_tool.name] = self.memory_tool else: # Use the provided tool set self.tools = tools_override @@ -1142,7 +1150,9 @@ def _get_messages_with_system( messages = context_manager.as_messages() if system_message: if messages and messages[0]["role"] == "system": - messages[0]["content"] = system_message + # Merge: keep knowledge context, append system prompt + existing = messages[0]["content"] + messages[0]["content"] = f"{existing}\n\n{system_message}" else: messages.insert(0, {"role": "system", "content": system_message}) reminder = self._build_skills_reminder_message() diff --git a/src/aish/memory/__init__.py b/src/aish/memory/__init__.py new file mode 100644 index 0000000..1399e29 --- /dev/null +++ b/src/aish/memory/__init__.py @@ -0,0 +1,4 @@ +from aish.memory.manager import MemoryManager +from aish.memory.models import MemoryCategory, MemoryEntry + +__all__ = ["MemoryManager", "MemoryCategory", "MemoryEntry"] diff --git a/src/aish/memory/config.py b/src/aish/memory/config.py new file mode 100644 index 0000000..f662196 --- /dev/null +++ b/src/aish/memory/config.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import os +from pathlib import Path +from pydantic import BaseModel, Field + + +def _default_data_dir() -> str: + """Resolve default memory directory, following the same pattern as skills. + + Uses AISH_CONFIG_DIR if set, otherwise ~/.config/aish/memory/ + """ + config_dir = os.environ.get("AISH_CONFIG_DIR") + if config_dir: + return str(Path(config_dir) / "memory") + return str(Path.home() / ".config" / "aish" / "memory") + + +class MemoryConfig(BaseModel): + """Configuration for long-term memory system.""" + + enabled: bool = Field(default=True, description="Enable long-term memory") + data_dir: str = Field( + default_factory=_default_data_dir, + description="Directory for memory files and database", + ) + recall_limit: int = Field( + default=5, gt=0, description="Max memories returned per recall" + ) + recall_token_budget: int = Field( + default=512, gt=0, description="Max tokens injected per recall" + ) + daily_retention_days: int = Field( + default=30, gt=0, description="Days to keep daily notes before auto-cleanup" + ) + auto_recall: bool = Field( + default=True, description="Automatically inject relevant memories before AI turns" + ) diff --git a/src/aish/memory/manager.py b/src/aish/memory/manager.py new file mode 100644 index 0000000..cefc52f --- /dev/null +++ b/src/aish/memory/manager.py @@ -0,0 +1,428 @@ +from __future__ import annotations + +import datetime as dt +import math +import re +import sqlite3 +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Optional + +from aish.memory.config import MemoryConfig +from aish.memory.models import MemoryCategory, MemoryEntry + +# Markdown link pattern: [title](path.md) +_MD_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+\.md)\)") + + +class MemoryManager: + """Long-term memory backed by Markdown files + FTS5, all I/O via thread pool. + + Storage layout (mirrors skills pattern: ~/.config/aish/memory/): + MEMORY.md - permanent knowledge (user-editable) + YYYY-MM-DD.md - daily notes (auto-created, auto-pruned) + memory.db - SQLite + FTS5 index + """ + + def __init__(self, config: MemoryConfig): + self.config = config + self.memory_dir = Path(config.data_dir).expanduser().resolve() + self.memory_dir.mkdir(parents=True, exist_ok=True) + self.memory_md = self.memory_dir / "MEMORY.md" + self.db_path = self.memory_dir / "memory.db" + self._conn = self._init_db() + self._pool = ThreadPoolExecutor(max_workers=2, thread_name_prefix="memory") + self._today: Optional[str] = None + self._ensure_memory_md() + self._ensure_daily_note() + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def today(self) -> str: + if self._today is None: + self._today = dt.date.today().isoformat() + return self._today + + @property + def memory_dir_path(self) -> str: + """Human-readable path for system prompt.""" + return str(self.memory_dir) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def close(self) -> None: + self._pool.shutdown(wait=False) + if self._conn: + self._conn.close() + + # ------------------------------------------------------------------ + # Async wrappers (run SQLite I/O in thread pool, never block UI) + # ------------------------------------------------------------------ + + def store_async( + self, + content: str, + category: MemoryCategory, + source: str = "explicit", + tags: str = "", + importance: float = 0.5, + ) -> None: + """Fire-and-forget store. Does NOT block the calling thread.""" + self._pool.submit( + self.store, content, category, source, tags, importance + ) + + def recall_async( + self, query: str, limit: int = 5, callback=None + ) -> None: + """Fire-and-forget recall. Calls *callback(results)* when done.""" + def _work(): + results = self.recall(query, limit) + if callback: + callback(results) + self._pool.submit(_work) + + # ------------------------------------------------------------------ + # Synchronous core (called from thread pool or directly) + # ------------------------------------------------------------------ + + # Categories worth persisting to MEMORY.md (permanent knowledge) + _PERMANENT_CATEGORIES = frozenset({ + MemoryCategory.PREFERENCE, + MemoryCategory.ENVIRONMENT, + MemoryCategory.SOLUTION, + }) + + def store( + self, + content: str, + category: MemoryCategory, + source: str = "explicit", + tags: str = "", + importance: float = 0.5, + ) -> int: + """Store a memory entry. Returns the row ID.""" + cursor = self._conn.execute( + """INSERT INTO memory_meta (source, category, content, tags, importance) + VALUES (?, ?, ?, ?, ?)""", + (source, category.value, content, tags, importance), + ) + row_id: int = cursor.lastrowid # type: ignore[assignment] + self._conn.commit() + + # Durable categories go to MEMORY.md; ephemeral goes to daily note + if category in self._PERMANENT_CATEGORIES: + self._append_to_memory_md(category, content) + else: + date_str = source.split(":", 1)[1] if source.startswith("daily:") else self.today + self._append_to_daily_note(date_str, category, content) + + return row_id + + def recall(self, query: str, limit: int = 5) -> list[MemoryEntry]: + """Search memories using FTS5, ranked by relevance with recency decay.""" + fts_expr = self._build_fts_query(query) + if not fts_expr: + return [] + try: + cursor = self._conn.execute( + """ + SELECT m.id, m.source, m.category, m.content, m.importance, + m.tags, m.created_at, m.last_accessed_at, m.access_count + FROM memory_fts f + JOIN memory_meta m ON m.id = f.rowid + WHERE memory_fts MATCH ? + ORDER BY rank + LIMIT ? + """, + (fts_expr, limit), + ) + except sqlite3.OperationalError: + return [] + + results = [] + for row in cursor.fetchall(): + entry = MemoryEntry( + id=row[0], + source=row[1], + category=MemoryCategory(row[2]), + content=row[3], + importance=row[4], + tags=row[5], + created_at=row[6], + last_accessed_at=row[7], + access_count=row[8], + ) + if entry.created_at: + try: + created = dt.datetime.fromisoformat(str(entry.created_at)) + if created.tzinfo is None: + created = created.replace(tzinfo=dt.timezone.utc) + days_old = (dt.datetime.now(dt.timezone.utc) - created).days + decay = math.exp(-days_old / 30.0) + entry.importance *= decay + except (ValueError, TypeError): + pass + results.append(entry) + + if results: + self._conn.executemany( + """UPDATE memory_meta + SET access_count = access_count + 1, + last_accessed_at = CURRENT_TIMESTAMP + WHERE id = ?""", + [(e.id,) for e in results], + ) + self._conn.commit() + + return results + + def delete(self, entry_id: int) -> None: + self._conn.execute("DELETE FROM memory_meta WHERE id = ?", (entry_id,)) + self._conn.commit() + + def list_recent(self, limit: int = 10) -> list[MemoryEntry]: + cursor = self._conn.execute( + """ + SELECT id, source, category, content, importance, tags, + created_at, last_accessed_at, access_count + FROM memory_meta + ORDER BY created_at DESC + LIMIT ? + """, + (limit,), + ) + return [ + MemoryEntry( + id=row[0], source=row[1], category=MemoryCategory(row[2]), + content=row[3], importance=row[4], tags=row[5], + created_at=row[6], last_accessed_at=row[7], access_count=row[8], + ) + for row in cursor.fetchall() + ] + + # ------------------------------------------------------------------ + # Session context (loaded once at startup) + # ------------------------------------------------------------------ + + def get_session_context(self) -> str: + """Load MEMORY.md + linked files + today's daily note as context.""" + parts: list[str] = [] + + if self.memory_md.exists(): + text = self.memory_md.read_text().strip() + if text: + parts.append(f"[Long-term Memory]\n{text}") + # Follow [title](path.md) links and append referenced files + for title, rel_path in _MD_LINK_RE.findall(text): + resolved = (self.memory_dir / rel_path).resolve() + # Security: stay inside memory_dir + try: + resolved.relative_to(self.memory_dir.resolve()) + except ValueError: + continue + if resolved.is_file(): + linked_text = resolved.read_text().strip() + if linked_text: + parts.append( + f"[{title}]\n{linked_text}" + ) + + daily_path = self.memory_dir / f"{self.today}.md" + if daily_path.exists(): + text = daily_path.read_text().strip() + if text: + parts.append(f"[Today's Memory]\n{text}") + + return "\n\n".join(parts) + + def get_system_prompt_section(self) -> str: + """Memory instructions injected into the AI system prompt. + + Modeled after OpenClaw's buildMemorySection() — tells the AI to + actively manage memory instead of relying on Python-side extraction. + """ + return ( + "## Memory System\n" + "You have persistent long-term memory.\n" + "1. BEFORE answering about prior work, decisions, dates, people, " + "or preferences: use the `memory_search` tool.\n" + "2. When you learn an important fact (user preference, environment " + "detail, solution, pattern): use the `memory` tool's `store` action " + "to save it.\n" + "3. Memory files are in {dir} — MEMORY.md for permanent knowledge, " + "daily YYYY-MM-DD.md notes for session context.\n" + "4. If memory/YYYY-MM-DD.md already exists, APPEND only.\n" + "5. For long or detailed knowledge, create a separate .md file in " + "the memory directory and add a markdown link in MEMORY.md, e.g. " + "`- [Title](knowledge/topic.md)`. Linked files are auto-loaded.\n" + ).format(dir=self.memory_dir_path) + + # ------------------------------------------------------------------ + # File helpers (synchronous, called from thread pool) + # ------------------------------------------------------------------ + + def _ensure_memory_md(self) -> None: + if not self.memory_md.exists(): + self.memory_md.write_text( + "# Long-term Memory\n" + "\n" + "Permanent knowledge about the user, projects, and preferences.\n" + "The AI reads and writes this file through the memory system.\n" + "\n" + "## Preferences\n\n" + "## Environment\n\n" + "## Solutions\n\n" + "## Patterns\n\n" + ) + + def _append_to_memory_md( + self, category: MemoryCategory, content: str + ) -> None: + """Append a fact to MEMORY.md under the matching section header. + + Deduplicates by checking if the content already exists in the file. + """ + self._ensure_memory_md() + text = self.memory_md.read_text() + + # Map category to MEMORY.md section header + section_map = { + MemoryCategory.PREFERENCE: "## Preferences", + MemoryCategory.ENVIRONMENT: "## Environment", + MemoryCategory.SOLUTION: "## Solutions", + MemoryCategory.PATTERN: "## Patterns", + } + header = section_map.get(category) + if not header: + return + + # Skip duplicate + if content in text: + return + + # Find the section and append + lines = text.split("\n") + insert_idx = len(lines) + for i, line in enumerate(lines): + if line.strip() == header: + insert_idx = i + 1 + break + + lines.insert(insert_idx, f"- {content}") + self.memory_md.write_text("\n".join(lines)) + + def _ensure_daily_note(self) -> None: + daily_path = self.memory_dir / f"{self.today}.md" + if not daily_path.exists(): + daily_path.write_text(f"# {self.today} Memory\n\n") + + def _append_to_daily_note( + self, date_str: str, category: MemoryCategory, content: str + ) -> None: + daily_path = self.memory_dir / f"{date_str}.md" + if not daily_path.exists(): + daily_path.write_text(f"# {date_str} Memory\n\n") + + section_header = f"## {category.value.capitalize()}\n" + text = daily_path.read_text() + + if section_header not in text: + daily_path.write_text(text.rstrip() + f"\n\n{section_header}") + text = daily_path.read_text() + + line = f"- {content}\n" + if content not in text: + lines = text.split("\n") + insert_idx = len(lines) + for i, line_text in enumerate(lines): + if line_text.strip() == section_header.strip(): + insert_idx = i + 1 + break + lines.insert(insert_idx, line) + daily_path.write_text("\n".join(lines)) + + # ------------------------------------------------------------------ + # Maintenance + # ------------------------------------------------------------------ + + def cleanup_old_notes(self, retention_days: int = 30) -> None: + cutoff = dt.date.today() - dt.timedelta(days=retention_days) + for path in self.memory_dir.glob("????-??-??.md"): + try: + note_date = dt.date.fromisoformat(path.stem) + if note_date < cutoff: + self._conn.execute( + "DELETE FROM memory_meta WHERE source = ?", + (f"daily:{path.stem}",), + ) + self._conn.commit() + path.unlink() + except (ValueError, OSError): + continue + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + def _init_db(self) -> sqlite3.Connection: + conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL;") + conn.execute("PRAGMA synchronous=NORMAL;") + conn.execute("PRAGMA busy_timeout=5000;") + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS memory_meta ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT NOT NULL, + category TEXT NOT NULL, + content TEXT NOT NULL, + tags TEXT DEFAULT '', + importance REAL DEFAULT 0.5, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_accessed_at TIMESTAMP, + access_count INTEGER DEFAULT 0 + ); + + CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5( + content, source, tags, category, + content='memory_meta', content_rowid='id' + ); + + CREATE TRIGGER IF NOT EXISTS memory_meta_ai AFTER INSERT ON memory_meta + BEGIN + INSERT INTO memory_fts(rowid, content, source, tags, category) + VALUES (new.id, new.content, new.source, new.tags, new.category); + END; + + CREATE TRIGGER IF NOT EXISTS memory_meta_ad AFTER DELETE ON memory_meta + BEGIN + INSERT INTO memory_fts(memory_fts, rowid, content, source, tags, category) + VALUES ('delete', old.id, old.content, old.source, old.tags, old.category); + END; + + CREATE TRIGGER IF NOT EXISTS memory_meta_au AFTER UPDATE ON memory_meta + BEGIN + INSERT INTO memory_fts(memory_fts, rowid, content, source, tags, category) + VALUES ('delete', old.id, old.content, old.source, old.tags, old.category); + INSERT INTO memory_fts(rowid, content, source, tags, category) + VALUES (new.id, new.content, new.source, new.tags, new.category); + END; + """ + ) + conn.commit() + return conn + + @staticmethod + def _build_fts_query(query: str) -> str: + import re + cleaned = re.sub(r'["*+\-:^(){}]', ' ', query) + tokens = cleaned.split() + if not tokens: + return "" + return " OR ".join(tokens) diff --git a/src/aish/memory/models.py b/src/aish/memory/models.py new file mode 100644 index 0000000..39791ed --- /dev/null +++ b/src/aish/memory/models.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +class MemoryCategory(str, Enum): + PREFERENCE = "preference" + ENVIRONMENT = "environment" + SOLUTION = "solution" + PATTERN = "pattern" + OTHER = "other" + + +@dataclass +class MemoryEntry: + id: int + source: str # 'daily:2026-04-03' or 'MEMORY.md' or 'explicit' + category: MemoryCategory + content: str + importance: float = 0.5 + tags: str = "" + created_at: Optional[str] = None + last_accessed_at: Optional[str] = None + access_count: int = 0 diff --git a/src/aish/shell/runtime/ai.py b/src/aish/shell/runtime/ai.py index c3ce15b..45007c0 100644 --- a/src/aish/shell/runtime/ai.py +++ b/src/aish/shell/runtime/ai.py @@ -9,7 +9,8 @@ import sys from typing import TYPE_CHECKING, Optional -from ...context_manager import ContextManager +from ...context_manager import ContextManager, MemoryType + from ...i18n import t from ...prompts import PromptManager @@ -174,6 +175,34 @@ def _inject_skill_prefix(self, text: str) -> str: prefix = " ".join([f"use {name} skill to do this." for name in refs]) return f"{prefix}\n\n{text}" + def _recall_memories(self, query: str) -> None: + """Inject relevant memories into context before AI interaction.""" + shell = getattr(self, "shell", None) + if not shell: + return + mem_mgr = getattr(shell, "memory_manager", None) + if not mem_mgr: + return + memory_config = getattr(shell.config, "memory", None) + if not memory_config or not getattr(memory_config, "auto_recall", False): + return + try: + results = mem_mgr.recall( + query, limit=getattr(memory_config, "recall_limit", 5) + ) + if results: + lines = [''] + for r in results: + lines.append(f"- [{r.category.value}] {r.content}") + lines.append("") + self.context_manager.add_memory( + MemoryType.KNOWLEDGE, + {"key": "memory_recall", "value": "\n".join(lines)}, + ) + except Exception: + pass # Memory recall is best-effort + + @staticmethod def _get_cancel_exceptions() -> tuple[type[BaseException], ...]: """Return cancellation exception types available in the current context.""" @@ -284,6 +313,7 @@ async def _fix(): system_message=system_message, stream=True, ) + return response shell = self._require_shell() @@ -328,12 +358,17 @@ async def _ask(): question_processed = self._inject_skill_prefix(question) + # Recall: inject relevant memories before AI call + self._recall_memories(question_processed) + response = await self.llm_session.process_input( question_processed, context_manager=self.context_manager, system_message=system_message, stream=True, ) + + # Retain: extract facts after AI call return response shell = self._require_shell() diff --git a/src/aish/shell/runtime/app.py b/src/aish/shell/runtime/app.py index f573c05..4a4464b 100644 --- a/src/aish/shell/runtime/app.py +++ b/src/aish/shell/runtime/app.py @@ -98,6 +98,23 @@ def __init__( model=config.model, enable_token_estimation=getattr(config, "enable_token_estimation", True), ) + + # Long-term memory must be initialized before LLM session so that + # MemoryTool can be registered during session construction. + self.memory_manager = None + memory_config = getattr(config, "memory", None) + if memory_config and getattr(memory_config, "enabled", False): + from aish.memory.manager import MemoryManager + + self.memory_manager = MemoryManager(config=memory_config) + # Inject session context from memory (always returns something useful) + mem_ctx = self.memory_manager.get_session_context() + if mem_ctx: + self.context_manager.add_memory( + MemoryType.KNOWLEDGE, + {"key": "long_term_memory", "value": mem_ctx}, + ) + self.llm_session: "LLMSession" = self._create_llm_session() self.uname_info, self.os_info, self.basic_env_info = get_or_fetch_static_env_info() @@ -246,6 +263,7 @@ def _create_llm_session(self) -> "LLMSession": interruption_manager=interruption_manager, is_command_approved=self._is_command_approved, history_manager=getattr(self, "history_manager", None), + memory_manager=getattr(self, "memory_manager", None), ) def init_litellm_in_background() -> None: @@ -1654,5 +1672,11 @@ def _cleanup(self) -> None: except Exception: pass + if self.memory_manager: + try: + self.memory_manager.close() + except Exception: + pass + self._restore_terminal() self.console.print(t("cli.startup.goodbye")) diff --git a/src/aish/tools/memory_tool.py b/src/aish/tools/memory_tool.py new file mode 100644 index 0000000..cf98d58 --- /dev/null +++ b/src/aish/tools/memory_tool.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from aish.tools.base import ToolBase +from aish.tools.result import ToolResult + +if TYPE_CHECKING: + from aish.memory.manager import MemoryManager + + +class MemoryTool(ToolBase): + """LLM tool for explicit memory management.""" + + def __init__(self, memory_manager: "MemoryManager") -> None: + super().__init__( + name="memory", + description=( + "Search, store, or manage long-term memories. " + "Use 'search' to find relevant past knowledge, " + "'store' to explicitly save important information, " + "'list' to see recent memories, " + "'forget' to remove outdated info." + ), + parameters={ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["search", "store", "forget", "list"], + "description": "Memory operation to perform", + }, + "query": { + "type": "string", + "description": "Search query (for 'search' action)", + }, + "content": { + "type": "string", + "description": "Content to store (for 'store' action)", + }, + "category": { + "type": "string", + "enum": [ + "preference", + "environment", + "solution", + "pattern", + "other", + ], + "description": "Category for stored memory (default: other)", + }, + "memory_id": { + "type": "integer", + "description": "Memory ID to forget (for 'forget' action)", + }, + }, + "required": ["action"], + }, + ) + self.memory_manager = memory_manager + + def __call__( + self, + action: str, + query: str | None = None, + content: str | None = None, + category: str | None = None, + memory_id: int | None = None, + **kwargs: Any, + ) -> ToolResult: + if action == "search": + if not query: + return ToolResult( + ok=False, output="Error: query is required for search" + ) + results = self.memory_manager.recall(query) + if not results: + return ToolResult(ok=True, output="No matching memories found.") + lines = [] + for r in results: + lines.append(f" [{r.category.value}] {r.content} (id={r.id})") + return ToolResult(ok=True, output="Found memories:\n" + "\n".join(lines)) + + elif action == "store": + if not content: + return ToolResult( + ok=False, output="Error: content is required for store" + ) + from aish.memory.models import MemoryCategory + + cat = MemoryCategory(category) if category else MemoryCategory.OTHER + entry_id = self.memory_manager.store( + content=content, + category=cat, + source="explicit", + importance=0.8, + ) + return ToolResult(ok=True, output=f"Stored as memory #{entry_id}.") + + elif action == "forget": + if memory_id is None: + return ToolResult( + ok=False, output="Error: memory_id is required for forget" + ) + self.memory_manager.delete(memory_id) + return ToolResult(ok=True, output=f"Forgot memory #{memory_id}.") + + elif action == "list": + entries = self.memory_manager.list_recent(limit=10) + if not entries: + return ToolResult(ok=True, output="No memories yet.") + lines = [] + for e in entries: + lines.append(f" #{e.id} [{e.category.value}] {e.content}") + return ToolResult( + ok=True, output="Recent memories:\n" + "\n".join(lines) + ) + + else: + return ToolResult( + ok=False, + output=f"Unknown action: {action}. Use search/store/forget/list.", + ) diff --git a/tests/test_memory_config.py b/tests/test_memory_config.py new file mode 100644 index 0000000..0468596 --- /dev/null +++ b/tests/test_memory_config.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from aish.memory.config import MemoryConfig + + +def test_memory_config_defaults(): + config = MemoryConfig() + assert config.enabled is True + assert config.recall_limit == 5 + assert config.recall_token_budget == 512 + assert config.daily_retention_days == 30 + assert config.auto_recall is True + assert "aish/memory" in config.data_dir or "memory" in config.data_dir + + +def test_memory_config_custom(): + config = MemoryConfig( + enabled=False, + recall_limit=10, + recall_token_budget=1024, + ) + assert config.enabled is False + assert config.recall_limit == 10 + assert config.recall_token_budget == 1024 + + +def test_config_model_has_memory_field(): + from aish.config import ConfigModel + + config = ConfigModel() + assert hasattr(config, "memory") + assert isinstance(config.memory, MemoryConfig) diff --git a/tests/test_memory_manager.py b/tests/test_memory_manager.py new file mode 100644 index 0000000..f64ab8b --- /dev/null +++ b/tests/test_memory_manager.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import pytest + +from aish.memory.config import MemoryConfig +from aish.memory.manager import MemoryManager +from aish.memory.models import MemoryCategory + + +@pytest.fixture +def memory_manager(tmp_path): + config = MemoryConfig(data_dir=str(tmp_path / "memory")) + mgr = MemoryManager(config=config) + yield mgr + mgr.close() + + +def test_init_creates_directories(tmp_path): + config = MemoryConfig(data_dir=str(tmp_path / "memory")) + mgr = MemoryManager(config=config) + assert (tmp_path / "memory").is_dir() + mgr.close() + + +def test_init_creates_database(memory_manager): + import sqlite3 + + conn = sqlite3.connect(str(memory_manager.db_path)) + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ).fetchall() + table_names = [t[0] for t in tables] + assert "memory_meta" in table_names + assert "indexed_files" not in table_names + conn.close() + + +def test_store_and_retrieve(memory_manager): + entry_id = memory_manager.store( + content="Production DB on port 5432", + category=MemoryCategory.ENVIRONMENT, + source="daily:2026-04-03", + ) + assert entry_id > 0 + + results = memory_manager.recall("production database", limit=5) + assert len(results) >= 1 + assert any("5432" in r.content for r in results) + + +def test_recall_returns_empty_for_no_match(memory_manager): + results = memory_manager.recall("nonexistent query xyz", limit=5) + assert len(results) == 0 + + +def test_recall_respects_limit(memory_manager): + for i in range(10): + memory_manager.store( + content=f"Test fact number {i} about servers", + category=MemoryCategory.ENVIRONMENT, + source="daily:2026-04-03", + ) + results = memory_manager.recall("servers", limit=3) + assert len(results) <= 3 + + +def test_store_creates_daily_note(memory_manager): + memory_manager.store( + content="Test fact for daily note", + category=MemoryCategory.OTHER, + source="daily:2026-04-03", + ) + daily_path = memory_manager.memory_dir / "2026-04-03.md" + assert daily_path.exists() + content = daily_path.read_text() + assert "Test fact for daily note" in content + + +def test_get_session_context_empty(memory_manager): + ctx = memory_manager.get_session_context() + assert isinstance(ctx, str) + # get_session_context returns MEMORY.md and daily note sections + assert "Long-term Memory" in ctx + + +def test_get_session_context_with_memory_md(memory_manager): + memory_file = memory_manager.memory_dir / "MEMORY.md" + memory_file.write_text("# Long-term Memory\n\n- User prefers vim\n") + ctx = memory_manager.get_session_context() + assert "User prefers vim" in ctx + + +def test_delete_memory(memory_manager): + entry_id = memory_manager.store( + content="Fact to delete", + category=MemoryCategory.OTHER, + source="explicit", + ) + memory_manager.delete(entry_id) + results = memory_manager.recall("Fact to delete", limit=5) + assert len(results) == 0 + + +def test_list_recent(memory_manager): + for i in range(5): + memory_manager.store( + content=f"Recent fact {i}", + category=MemoryCategory.PATTERN, + source="daily:2026-04-03", + ) + recent = memory_manager.list_recent(limit=3) + assert len(recent) <= 3 + + +def test_cleanup_old_notes(memory_manager): + old_note = memory_manager.memory_dir / "2025-01-01.md" + old_note.write_text("# Old note\n\n- Old fact\n") + memory_manager.store( + content="Old fact", + category=MemoryCategory.OTHER, + source="daily:2025-01-01", + ) + memory_manager.cleanup_old_notes(retention_days=30) + assert not old_note.exists() + + +def test_ensure_memory_md_created(memory_manager): + """MEMORY.md is auto-created during init.""" + assert memory_manager.memory_md.exists() + content = memory_manager.memory_md.read_text() + assert "Long-term Memory" in content + + +def test_ensure_daily_note_created(memory_manager): + """Today's daily note is auto-created during init.""" + import datetime as dt + today = dt.date.today().isoformat() + daily_path = memory_manager.memory_dir / f"{today}.md" + assert daily_path.exists() + + +def test_get_system_prompt_section(memory_manager): + section = memory_manager.get_system_prompt_section() + assert "Memory System" in section + assert "memory_search" in section + + +def test_store_permanent_goes_to_memory_md(memory_manager): + """store() with permanent category writes to MEMORY.md, not daily note.""" + import datetime as dt + today = dt.date.today().isoformat() + memory_manager.store( + content="Fact from explicit source", + category=MemoryCategory.SOLUTION, + source="explicit", + ) + # SOLUTION is a permanent category — goes to MEMORY.md + mem_text = memory_manager.memory_md.read_text() + assert "Fact from explicit source" in mem_text + # Daily note should NOT contain this entry + daily_path = memory_manager.memory_dir / f"{today}.md" + if daily_path.exists(): + assert "Fact from explicit source" not in daily_path.read_text() + + +def test_session_context_follows_links(memory_manager): + """get_session_context() loads files linked from MEMORY.md.""" + # Create a linked knowledge file + knowledge_dir = memory_manager.memory_dir / "knowledge" + knowledge_dir.mkdir(parents=True, exist_ok=True) + deploy_file = knowledge_dir / "deployment.md" + deploy_file.write_text("# Deployment\n\nDeploy via `make deploy` to prod.") + + # Add a link in MEMORY.md + memory_file = memory_manager.memory_md + text = memory_file.read_text() + text = text.replace( + "## Solutions\n", + "## Solutions\n\n- [Deployment Guide](knowledge/deployment.md)\n", + ) + memory_file.write_text(text) + + ctx = memory_manager.get_session_context() + assert "Long-term Memory" in ctx + assert "Deployment Guide" in ctx + assert "make deploy" in ctx + + +def test_session_context_ignores_links_outside_memory_dir(memory_manager): + """Links pointing outside memory_dir are skipped for security.""" + memory_file = memory_manager.memory_md + text = memory_file.read_text() + text = text.replace( + "## Solutions\n", + "## Solutions\n\n- [Escape](../../etc/passwd.md)\n", + ) + memory_file.write_text(text) + + ctx = memory_manager.get_session_context() + assert "Long-term Memory" in ctx + # Should NOT load anything from outside memory_dir + assert "Escape" in ctx # link text appears in MEMORY.md itself + assert "root:" not in ctx # would be in /etc/passwd content + + +def test_session_context_ignores_missing_linked_file(memory_manager): + """Broken links are silently skipped.""" + memory_file = memory_manager.memory_md + text = memory_file.read_text() + text = text.replace( + "## Solutions\n", + "## Solutions\n\n- [Missing](no-such-file.md)\n", + ) + memory_file.write_text(text) + + ctx = memory_manager.get_session_context() + assert "Long-term Memory" in ctx + assert "Missing" in ctx # link text in MEMORY.md + # No error, just missing content gracefully skipped diff --git a/tests/test_memory_models.py b/tests/test_memory_models.py new file mode 100644 index 0000000..f91d899 --- /dev/null +++ b/tests/test_memory_models.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from aish.memory.models import MemoryEntry, MemoryCategory + + +def test_memory_entry_creation(): + entry = MemoryEntry( + id=1, + source="daily:2026-04-03", + category=MemoryCategory.ENVIRONMENT, + content="Production DB on port 5432", + importance=0.8, + ) + assert entry.id == 1 + assert entry.category == MemoryCategory.ENVIRONMENT + assert entry.importance == 0.8 + + +def test_memory_category_values(): + assert MemoryCategory.PREFERENCE.value == "preference" + assert MemoryCategory.ENVIRONMENT.value == "environment" + assert MemoryCategory.SOLUTION.value == "solution" + assert MemoryCategory.PATTERN.value == "pattern" + assert MemoryCategory.OTHER.value == "other" diff --git a/tests/test_memory_tool.py b/tests/test_memory_tool.py new file mode 100644 index 0000000..32cd5ee --- /dev/null +++ b/tests/test_memory_tool.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import pytest + +from aish.memory.config import MemoryConfig +from aish.memory.manager import MemoryManager +from aish.memory.models import MemoryCategory +from aish.tools.memory_tool import MemoryTool + + +@pytest.fixture +def memory_tool(tmp_path): + config = MemoryConfig(data_dir=str(tmp_path / "memory")) + mgr = MemoryManager(config=config) + tool = MemoryTool(memory_manager=mgr) + yield tool + mgr.close() + + +def test_tool_spec(memory_tool): + spec = memory_tool.to_func_spec() + assert spec["type"] == "function" + assert spec["function"]["name"] == "memory" + assert "action" in spec["function"]["parameters"]["properties"] + + +def test_search_action(memory_tool): + memory_tool.memory_manager.store( + content="Redis runs on port 6379", + category=MemoryCategory.ENVIRONMENT, + source="explicit", + ) + result = memory_tool(action="search", query="Redis port") + assert "6379" in str(result) + + +def test_store_action(memory_tool): + result = memory_tool(action="store", content="User prefers dark theme") + assert "Stored" in str(result) or "stored" in str(result) + + +def test_list_action(memory_tool): + memory_tool.memory_manager.store( + content="Fact one", + category=MemoryCategory.PATTERN, + source="explicit", + ) + memory_tool.memory_manager.store( + content="Fact two", + category=MemoryCategory.SOLUTION, + source="explicit", + ) + result = memory_tool(action="list") + assert "Fact" in str(result) + + +def test_forget_action(memory_tool): + entry_id = memory_tool.memory_manager.store( + content="Temporary fact", + category=MemoryCategory.OTHER, + source="explicit", + ) + result = memory_tool(action="forget", memory_id=entry_id) + assert ( + "Forgot" in str(result) + or "forgot" in str(result) + or "removed" in str(result).lower() + ) + + +def test_invalid_action(memory_tool): + result = memory_tool(action="invalid_action") + assert "error" in str(result).lower() or "unknown" in str(result).lower()