diff --git a/.gitignore b/.gitignore index 829888e..8ae73d6 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,6 @@ htmlcov/ # CocoIndex .cocoindex_code/ + +# Session transcripts +session-*.md diff --git a/pyproject.toml b/pyproject.toml index b6a709c..f6a4006 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,16 +31,6 @@ dependencies = [ "einops>=0.8.2", ] -[project.optional-dependencies] -dev = [ - "pytest>=7.0.0", - "pytest-asyncio>=0.21.0", - "pytest-cov>=4.0.0", - "ruff>=0.1.0", - "mypy>=1.0.0", - "prek>=0.1.0", -] - [project.scripts] cocoindex-code = "cocoindex_code:main" @@ -65,7 +55,6 @@ dev = [ "pytest-cov>=4.0.0", "ruff>=0.1.0", "mypy>=1.0.0", - "prek>=0.1.0", ] [tool.uv] diff --git a/src/cocoindex_code/__init__.py b/src/cocoindex_code/__init__.py index 330fba0..6b11554 100644 --- a/src/cocoindex_code/__init__.py +++ b/src/cocoindex_code/__init__.py @@ -1,7 +1,13 @@ """CocoIndex Code - MCP server for indexing and querying codebases.""" +from importlib.metadata import PackageNotFoundError, version + from .config import Config from .server import main, mcp -__version__ = "0.1.0" +try: + __version__ = version("cocoindex-code") +except PackageNotFoundError: + __version__ = "0.0.0-dev" + __all__ = ["Config", "main", "mcp"] diff --git a/src/cocoindex_code/code_intelligence_tools.py b/src/cocoindex_code/code_intelligence_tools.py new file mode 100644 index 0000000..99a5ce6 --- /dev/null +++ b/src/cocoindex_code/code_intelligence_tools.py @@ -0,0 +1,1069 @@ +"""Code intelligence tools for the cocoindex-code MCP server. + +Provides list_symbols, find_definition, find_references, code_metrics, +and rename_symbol tools using regex-based multi-language symbol extraction. +""" + +from __future__ import annotations + +import asyncio +import fnmatch +import os +import re +from pathlib import Path + +from mcp.server.fastmcp import FastMCP +from pydantic import BaseModel, Field + +from .filesystem_tools import ( + MAX_READ_BYTES, + MAX_RESULTS, + _detect_lang, + _is_binary, + _is_excluded_dir, + _relative, + _root, + _safe_resolve, +) + +# === Pydantic result models === + + +class SymbolEntry(BaseModel): + """A symbol found in source code.""" + + name: str = Field(description="Symbol name") + symbol_type: str = Field( + description="Type: function, method, class, variable, constant, " + "interface, type, enum, struct, trait, module, impl" + ) + line: int = Field(description="Start line number (1-indexed)") + end_line: int = Field(description="End line number (1-indexed)") + signature: str = Field(description="Source line where symbol is defined") + indent_level: int = Field(default=0, description="Indentation level") + + +class ListSymbolsResult(BaseModel): + """Result from list_symbols tool.""" + + success: bool + path: str = "" + symbols: list[SymbolEntry] = Field(default_factory=list) + total_symbols: int = 0 + language: str = "" + message: str | None = None + + +class DefinitionEntry(BaseModel): + """A symbol definition location.""" + + file_path: str = Field(description="Relative file path") + name: str = Field(description="Symbol name") + symbol_type: str = Field(description="Symbol type") + line: int = Field(description="Line number (1-indexed)") + signature: str = Field(description="Definition line content") + context: str = Field(default="", description="Surrounding context") + + +class FindDefinitionResult(BaseModel): + """Result from find_definition tool.""" + + success: bool + definitions: list[DefinitionEntry] = Field(default_factory=list) + total_found: int = 0 + message: str | None = None + + +class ReferenceEntry(BaseModel): + """A single reference to a symbol.""" + + path: str = Field(description="Relative file path") + line_number: int = Field(description="1-indexed line number") + line: str = Field(description="Matched line content") + usage_type: str = Field( + default="other", + description="Usage type: import, call, assignment, " + "type_annotation, definition, other", + ) + context_before: list[str] = Field(default_factory=list) + context_after: list[str] = Field(default_factory=list) + + +class FindReferencesResult(BaseModel): + """Result from find_references tool.""" + + success: bool + references: list[ReferenceEntry] = Field(default_factory=list) + total_found: int = 0 + files_searched: int = 0 + truncated: bool = False + message: str | None = None + + +class MetricsData(BaseModel): + """Code quality metrics.""" + + total_lines: int = Field(description="Total line count") + code_lines: int = Field(description="Non-blank, non-comment lines") + blank_lines: int = Field(description="Blank line count") + comment_lines: int = Field(description="Comment line count") + functions: int = Field(description="Number of functions/methods") + classes: int = Field(description="Number of classes/structs") + avg_function_length: float = Field( + default=0.0, description="Average function body length" + ) + max_function_length: int = Field( + default=0, description="Longest function body length" + ) + max_nesting_depth: int = Field( + default=0, description="Max indentation nesting depth" + ) + complexity_estimate: int = Field( + default=0, description="Estimated cyclomatic complexity" + ) + + +class CodeMetricsResult(BaseModel): + """Result from code_metrics tool.""" + + success: bool + path: str = "" + metrics: MetricsData | None = None + language: str = "" + message: str | None = None + + +class RenameChange(BaseModel): + """A file changed by rename_symbol.""" + + file_path: str = Field(description="Relative file path") + occurrences: int = Field(description="Number of replacements in this file") + + +class RenameResult(BaseModel): + """Result from rename_symbol tool.""" + + success: bool + old_name: str = "" + new_name: str = "" + files_changed: int = 0 + total_replacements: int = 0 + changes: list[RenameChange] = Field(default_factory=list) + dry_run: bool = True + message: str | None = None + + +# === Multi-language symbol extraction patterns === +# Each entry: (compiled_regex, symbol_type, name_group_index) + +_PatternEntry = tuple[re.Pattern[str], str, int] + + +def _build_patterns() -> dict[str, list[_PatternEntry]]: + """Build symbol extraction patterns per language.""" + + def _c(pattern: str, flags: int = 0) -> re.Pattern[str]: + return re.compile(pattern, flags) + + python: list[_PatternEntry] = [ + (_c(r"^(\s*)(async\s+)?def\s+(\w+)\s*\("), "function", 3), + (_c(r"^(\s*)class\s+(\w+)"), "class", 2), + (_c(r"^([A-Z][A-Z0-9_]{1,})\s*[=:]"), "constant", 1), + ] + + javascript: list[_PatternEntry] = [ + ( + _c(r"^(\s*)(?:export\s+)?(?:default\s+)?" + r"(?:async\s+)?function\s*\*?\s+(\w+)"), + "function", 2, + ), + (_c(r"^(\s*)(?:export\s+)?(?:default\s+)?class\s+(\w+)"), + "class", 2), + (_c(r"^(\s*)(?:export\s+)?(?:const|let|var)\s+(\w+)"), + "variable", 2), + ] + + ts_extra: list[_PatternEntry] = [ + (_c(r"^(\s*)(?:export\s+)?interface\s+(\w+)"), + "interface", 2), + (_c(r"^(\s*)(?:export\s+)?type\s+(\w+)\s*[=<{]"), + "type", 2), + (_c(r"^(\s*)(?:export\s+)?enum\s+(\w+)"), "enum", 2), + ] + typescript = javascript + ts_extra + + rust: list[_PatternEntry] = [ + (_c(r"^(\s*)(?:pub(?:\([^)]*\))?\s+)?(?:async\s+)?fn\s+(\w+)"), + "function", 2), + (_c(r"^(\s*)(?:pub(?:\([^)]*\))?\s+)?struct\s+(\w+)"), + "struct", 2), + (_c(r"^(\s*)(?:pub(?:\([^)]*\))?\s+)?enum\s+(\w+)"), + "enum", 2), + (_c(r"^(\s*)(?:pub(?:\([^)]*\))?\s+)?trait\s+(\w+)"), + "trait", 2), + (_c(r"^(\s*)(?:pub(?:\([^)]*\))?\s+)?mod\s+(\w+)"), + "module", 2), + (_c(r"^(\s*)(?:pub(?:\([^)]*\))?\s+)?(?:const|static)\s+(\w+)"), + "constant", 2), + (_c(r"^(\s*)impl(?:\s*<[^>]*>)?\s+(\w+)"), "impl", 2), + ] + + go: list[_PatternEntry] = [ + (_c(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\("), + "function", 1), + (_c(r"^type\s+(\w+)\s+struct\b"), "struct", 1), + (_c(r"^type\s+(\w+)\s+interface\b"), "interface", 1), + (_c(r"^(?:const|var)\s+(\w+)"), "variable", 1), + ] + + java: list[_PatternEntry] = [ + (_c(r"^(\s*)(?:(?:public|private|protected|static|" + r"abstract|final|sealed|partial)\s+)*class\s+(\w+)"), + "class", 2), + (_c(r"^(\s*)(?:(?:public|private|protected|static|" + r"abstract|final)\s+)*interface\s+(\w+)"), + "interface", 2), + (_c(r"^(\s*)(?:(?:public|private|protected|static|" + r"abstract|final)\s+)*enum\s+(\w+)"), + "enum", 2), + ] + + c_patterns: list[_PatternEntry] = [ + (_c(r"^(\s*)(?:typedef\s+)?struct\s+(\w+)"), "struct", 2), + (_c(r"^(\s*)#define\s+(\w+)"), "constant", 2), + (_c(r"^(\s*)enum(?:\s+class)?\s+(\w+)"), "enum", 2), + ] + + cpp_extra: list[_PatternEntry] = [ + (_c(r"^(\s*)class\s+(\w+)"), "class", 2), + (_c(r"^(\s*)namespace\s+(\w+)"), "module", 2), + ] + cpp = c_patterns + cpp_extra + + php: list[_PatternEntry] = [ + (_c(r"^(\s*)(?:(?:public|private|protected|static|" + r"abstract|final)\s+)*function\s+(\w+)"), + "function", 2), + (_c(r"^(\s*)(?:abstract\s+|final\s+)?class\s+(\w+)"), + "class", 2), + (_c(r"^(\s*)interface\s+(\w+)"), "interface", 2), + (_c(r"^(\s*)trait\s+(\w+)"), "trait", 2), + ] + + ruby: list[_PatternEntry] = [ + (_c(r"^(\s*)def\s+(?:self\.)?(\w+)"), "function", 2), + (_c(r"^(\s*)class\s+(\w+)"), "class", 2), + (_c(r"^(\s*)module\s+(\w+)"), "module", 2), + ] + + shell: list[_PatternEntry] = [ + (_c(r"^(\s*)(?:function\s+)?(\w+)\s*\(\s*\)"), + "function", 2), + (_c(r"^([A-Z_][A-Z0-9_]*)\s*="), "variable", 1), + ] + + sql: list[_PatternEntry] = [ + (_c(r"^\s*CREATE\s+(?:OR\s+REPLACE\s+)?" + r"(?:FUNCTION|PROCEDURE)\s+(\w+)", + re.IGNORECASE), + "function", 1), + (_c(r"^\s*CREATE\s+(?:OR\s+REPLACE\s+)?" + r"(?:TABLE|VIEW)\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)", + re.IGNORECASE), + "type", 1), + ] + + return { + "python": python, + "javascript": javascript, + "typescript": typescript, + "rust": rust, + "go": go, + "java": java, + "csharp": java, # same base patterns + "c": c_patterns, + "cpp": cpp, + "php": php, + "ruby": ruby, + "shell": shell, + "sql": sql, + "kotlin": java, + "scala": java, + } + + +_SYMBOL_PATTERNS: dict[str, list[_PatternEntry]] = _build_patterns() + + +# === Core internal functions === + + +def _extract_symbols(content: str, language: str) -> list[SymbolEntry]: + """Extract symbols from file content using regex patterns.""" + patterns = _SYMBOL_PATTERNS.get(language, []) + if not patterns: + return [] + + lines = content.splitlines() + raw_symbols: list[SymbolEntry] = [] + + for line_idx, line_text in enumerate(lines): + line_num = line_idx + 1 + for pattern, sym_type, name_group in patterns: + m = pattern.match(line_text) + if m is None: + continue + name = m.group(name_group) + # Compute indent level + stripped = line_text.lstrip() + indent = len(line_text) - len(stripped) + indent_level = indent // 4 if indent > 0 else 0 + + actual_type = sym_type + # Python: indented function → method + if language == "python" and sym_type == "function": + if indent > 0: + actual_type = "method" + + raw_symbols.append(SymbolEntry( + name=name, + symbol_type=actual_type, + line=line_num, + end_line=line_num, # computed below + signature=line_text.rstrip(), + indent_level=indent_level, + )) + break # first match wins per line + + # Compute end_line for each symbol + for i, sym in enumerate(raw_symbols): + if i + 1 < len(raw_symbols): + next_sym = raw_symbols[i + 1] + # End at line before next symbol at same or lesser indent + if next_sym.indent_level <= sym.indent_level: + sym.end_line = next_sym.line - 1 + else: + # Next symbol is nested; scan further + end = len(lines) + for j in range(i + 1, len(raw_symbols)): + if raw_symbols[j].indent_level <= sym.indent_level: + end = raw_symbols[j].line - 1 + break + sym.end_line = end + else: + sym.end_line = len(lines) + + return raw_symbols + + +def _walk_source_files( + root: Path, + languages: list[str] | None = None, + paths: list[str] | None = None, +) -> list[tuple[Path, str, str]]: + """Walk codebase and return (abs_path, rel_path, language) tuples.""" + lang_set = ( + {lang.lower() for lang in languages} if languages else None + ) + results: list[tuple[Path, str, str]] = [] + + for dirpath, dirnames, filenames in os.walk(root): + dirnames[:] = sorted( + d for d in dirnames if not _is_excluded_dir(d) + ) + for fname in sorted(filenames): + fpath = Path(dirpath) / fname + rel = _relative(fpath) + lang = _detect_lang(fpath) + + if lang_set and lang.lower() not in lang_set: + continue + if paths and not any( + fnmatch.fnmatch(rel, p) for p in paths + ): + continue + if _is_binary(fpath): + continue + + results.append((fpath, rel, lang)) + + return results + + +def _classify_usage( + line: str, symbol_name: str, language: str, +) -> str: + """Classify how a symbol is used on a given line.""" + stripped = line.strip() + + # Import patterns + import_patterns = [ + r"\bimport\b", r"\bfrom\b.*\bimport\b", + r"\brequire\s*\(", r"\buse\s+", + r"\binclude\b", r"\busing\b", + ] + for pat in import_patterns: + if re.search(pat, stripped): + return "import" + + # Definition patterns (def, class, fn, func, struct, etc.) + def_patterns = [ + rf"(?:def|fn|func|function)\s+{re.escape(symbol_name)}\s*\(", + rf"class\s+{re.escape(symbol_name)}\b", + rf"struct\s+{re.escape(symbol_name)}\b", + rf"trait\s+{re.escape(symbol_name)}\b", + rf"interface\s+{re.escape(symbol_name)}\b", + rf"enum\s+{re.escape(symbol_name)}\b", + rf"type\s+{re.escape(symbol_name)}\b", + ] + for pat in def_patterns: + if re.search(pat, stripped): + return "definition" + + # Call: symbol followed by ( + if re.search( + rf"\b{re.escape(symbol_name)}\s*\(", stripped, + ): + return "call" + + # Type annotation: : symbol or -> symbol or + if re.search( + rf"[:\->]\s*{re.escape(symbol_name)}\b", stripped, + ): + return "type_annotation" + + # Assignment: symbol = ... or ... = symbol + if re.search( + rf"\b{re.escape(symbol_name)}\s*=[^=]", stripped, + ): + return "assignment" + + return "other" + + +def _find_definitions_impl( + symbol_name: str, + root: Path, + symbol_type: str | None = None, + languages: list[str] | None = None, + paths: list[str] | None = None, + limit: int = 20, +) -> list[DefinitionEntry]: + """Find symbol definitions across codebase.""" + results: list[DefinitionEntry] = [] + files = _walk_source_files(root, languages=languages, paths=paths) + + for fpath, rel, lang in files: + if len(results) >= limit: + break + try: + content = fpath.read_text(encoding="utf-8", errors="replace") + except OSError: + continue + + symbols = _extract_symbols(content, lang) + for sym in symbols: + if sym.name != symbol_name: + continue + if symbol_type and sym.symbol_type != symbol_type: + continue + + lines = content.splitlines() + ctx_start = max(0, sym.line - 2) + ctx_end = min(len(lines), sym.line + 2) + context = "\n".join(lines[ctx_start:ctx_end]) + + results.append(DefinitionEntry( + file_path=rel, + name=sym.name, + symbol_type=sym.symbol_type, + line=sym.line, + signature=sym.signature, + context=context, + )) + if len(results) >= limit: + break + + return results + + +def _find_references_impl( + symbol_name: str, + root: Path, + languages: list[str] | None = None, + paths: list[str] | None = None, + context_lines: int = 0, + limit: int = 50, +) -> tuple[list[ReferenceEntry], int, int, bool]: + """Find all references to a symbol.""" + word_re = re.compile(rf"\b{re.escape(symbol_name)}\b") + refs: list[ReferenceEntry] = [] + total = 0 + files_searched = 0 + truncated = False + + files = _walk_source_files(root, languages=languages, paths=paths) + + for fpath, rel, lang in files: + try: + if fpath.stat().st_size > MAX_READ_BYTES: + continue + content = fpath.read_text( + encoding="utf-8", errors="replace", + ) + except OSError: + continue + + files_searched += 1 + file_lines = content.splitlines() + + for i, line_text in enumerate(file_lines): + if not word_re.search(line_text): + continue + total += 1 + if len(refs) >= limit: + truncated = True + continue + + ctx_before = [ + file_lines[j].rstrip("\n\r") + for j in range( + max(0, i - context_lines), i, + ) + ] + ctx_after = [ + file_lines[j].rstrip("\n\r") + for j in range( + i + 1, + min(len(file_lines), i + 1 + context_lines), + ) + ] + + usage = _classify_usage(line_text, symbol_name, lang) + + refs.append(ReferenceEntry( + path=rel, + line_number=i + 1, + line=line_text.rstrip("\n\r"), + usage_type=usage, + context_before=ctx_before, + context_after=ctx_after, + )) + + return refs, total, files_searched, truncated + + +# Comment line patterns per language +_COMMENT_PATTERNS: dict[str, re.Pattern[str]] = { + "python": re.compile(r"^\s*#"), + "ruby": re.compile(r"^\s*#"), + "shell": re.compile(r"^\s*#"), + "javascript": re.compile(r"^\s*//"), + "typescript": re.compile(r"^\s*//"), + "rust": re.compile(r"^\s*//"), + "go": re.compile(r"^\s*//"), + "java": re.compile(r"^\s*//"), + "csharp": re.compile(r"^\s*//"), + "c": re.compile(r"^\s*//"), + "cpp": re.compile(r"^\s*//"), + "php": re.compile(r"^\s*(?://|#)"), + "sql": re.compile(r"^\s*--"), + "kotlin": re.compile(r"^\s*//"), + "scala": re.compile(r"^\s*//"), +} + +# Branching keywords for complexity estimation +_COMPLEXITY_KEYWORDS: re.Pattern[str] = re.compile( + r"\b(?:if|elif|else|for|while|and|or|try|except|catch" + r"|case|when|switch|\?|&&|\|\|)\b" +) + + +def _compute_metrics(content: str, language: str) -> MetricsData: + """Compute code metrics for file content.""" + lines = content.splitlines() + total_lines = len(lines) + blank_lines = sum(1 for line in lines if not line.strip()) + + # Count comment lines + comment_pat = _COMMENT_PATTERNS.get(language) + comment_lines = 0 + if comment_pat: + comment_lines = sum( + 1 for line in lines + if line.strip() and comment_pat.match(line) + ) + + code_lines = total_lines - blank_lines - comment_lines + + # Extract symbols for function/class counts + symbols = _extract_symbols(content, language) + func_types = {"function", "method"} + class_types = {"class", "struct"} + funcs = [s for s in symbols if s.symbol_type in func_types] + classes = [s for s in symbols if s.symbol_type in class_types] + + # Function lengths + func_lengths = [ + s.end_line - s.line + 1 for s in funcs if s.end_line >= s.line + ] + avg_func_len = ( + sum(func_lengths) / len(func_lengths) if func_lengths else 0.0 + ) + max_func_len = max(func_lengths) if func_lengths else 0 + + # Max nesting depth via indentation + max_depth = 0 + for line in lines: + if not line.strip(): + continue + indent = len(line) - len(line.lstrip()) + # Use 4 spaces or 1 tab as one level + depth = indent // 4 if "\t" not in line else line.count("\t") + if depth > max_depth: + max_depth = depth + + # Complexity estimate: count branching keywords + complexity = 0 + for line in lines: + complexity += len(_COMPLEXITY_KEYWORDS.findall(line)) + + return MetricsData( + total_lines=total_lines, + code_lines=code_lines, + blank_lines=blank_lines, + comment_lines=comment_lines, + functions=len(funcs), + classes=len(classes), + avg_function_length=round(avg_func_len, 1), + max_function_length=max_func_len, + max_nesting_depth=max_depth, + complexity_estimate=complexity, + ) + + +def _rename_symbol_impl( + old_name: str, + new_name: str, + root: Path, + scope: str | None = None, + languages: list[str] | None = None, + dry_run: bool = True, +) -> RenameResult: + """Rename a symbol across the codebase.""" + # Validate + if old_name == new_name: + return RenameResult( + success=False, + old_name=old_name, + new_name=new_name, + message="old_name and new_name are identical", + ) + if not re.match(r"^\w+$", new_name): + return RenameResult( + success=False, + old_name=old_name, + new_name=new_name, + message="new_name must be a valid identifier (letters, " + "digits, underscores)", + ) + + word_re = re.compile(rf"\b{re.escape(old_name)}\b") + path_filters = [scope] if scope else None + files = _walk_source_files( + root, languages=languages, paths=path_filters, + ) + + changes: list[RenameChange] = [] + total_replacements = 0 + + for fpath, rel, _lang in files: + try: + content = fpath.read_text( + encoding="utf-8", errors="replace", + ) + except OSError: + continue + + count = len(word_re.findall(content)) + if count == 0: + continue + + if not dry_run: + new_content = word_re.sub(new_name, content) + fpath.write_text(new_content, encoding="utf-8") + + changes.append(RenameChange( + file_path=rel, occurrences=count, + )) + total_replacements += count + + return RenameResult( + success=True, + old_name=old_name, + new_name=new_name, + files_changed=len(changes), + total_replacements=total_replacements, + changes=changes, + dry_run=dry_run, + ) + + +# === MCP tool registration === + + +def register_code_intelligence_tools(mcp: FastMCP) -> None: + """Register all code intelligence tools on the MCP server.""" + + @mcp.tool( + name="list_symbols", + description=( + "List all functions, classes, methods, variables, and other" + " symbols defined in a file or directory." + " Use this to understand the structure of a file before" + " reading it, to find function signatures, or to get an" + " overview of a module's API surface." + " Returns symbol names, types, line numbers, and signatures." + ), + ) + async def list_symbols( + path: str = Field( + default="", + description=( + "Relative path to a file or directory." + " Empty string = codebase root." + " Example: 'src/utils/helpers.ts'" + ), + ), + symbol_types: list[str] | None = Field( + default=None, + description=( + "Filter by symbol type(s)." + " Options: function, method, class, variable," + " constant, interface, type, enum, struct," + " trait, module, impl." + " Example: ['function', 'class']" + ), + ), + languages: list[str] | None = Field( + default=None, + description=( + "Filter by language(s)." + " Example: ['python', 'typescript']" + ), + ), + limit: int = Field( + default=100, + ge=1, + le=MAX_RESULTS, + description=f"Max symbols to return (1-{MAX_RESULTS})", + ), + ) -> ListSymbolsResult: + """List symbols in a file or directory.""" + try: + root = _root() + target = _safe_resolve(path) if path else root + type_set = ( + {t.lower() for t in symbol_types} + if symbol_types else None + ) + + all_symbols: list[SymbolEntry] = [] + + if target.is_file(): + if _is_binary(target): + return ListSymbolsResult( + success=False, path=path, + message="Binary file, cannot parse", + ) + lang = _detect_lang(target) + content = target.read_text( + encoding="utf-8", errors="replace", + ) + symbols = _extract_symbols(content, lang) + if type_set: + symbols = [ + s for s in symbols + if s.symbol_type in type_set + ] + return ListSymbolsResult( + success=True, + path=path, + symbols=symbols[:limit], + total_symbols=len(symbols), + language=lang, + ) + elif target.is_dir(): + files = _walk_source_files( + target, languages=languages, + ) + for fpath, rel, lang in files: + if len(all_symbols) >= limit: + break + try: + content = fpath.read_text( + encoding="utf-8", errors="replace", + ) + except OSError: + continue + symbols = _extract_symbols(content, lang) + if type_set: + symbols = [ + s for s in symbols + if s.symbol_type in type_set + ] + # Prefix signature with file path for dir listing + for s in symbols: + s.signature = f"{rel}:{s.line} {s.signature}" + all_symbols.extend(symbols) + + return ListSymbolsResult( + success=True, + path=path or ".", + symbols=all_symbols[:limit], + total_symbols=len(all_symbols), + ) + else: + return ListSymbolsResult( + success=False, path=path, + message=f"Path not found: {path}", + ) + except ValueError as ve: + return ListSymbolsResult( + success=False, path=path, message=str(ve), + ) + except Exception as e: + return ListSymbolsResult( + success=False, path=path, + message=f"list_symbols failed: {e!s}", + ) + + @mcp.tool( + name="find_definition", + description=( + "Find where a symbol (function, class, variable, etc.) is" + " defined across the entire codebase." + " Use this as 'go to definition' -- much faster and more" + " precise than grep for locating declarations." + " Works across Python, JS/TS, Rust, Go, Java, C/C++, and" + " more. Returns file path, line number, and signature." + ), + ) + async def find_definition( + symbol_name: str = Field( + description=( + "Name of the symbol to find." + " Examples: 'authenticate', 'UserModel'," + " 'parse_config'" + ), + ), + symbol_type: str | None = Field( + default=None, + description=( + "Filter by type: function, class, method, variable," + " constant, interface, struct, enum, trait, module" + ), + ), + languages: list[str] | None = Field( + default=None, + description="Filter by language(s)", + ), + paths: list[str] | None = Field( + default=None, + description=( + "Filter by path pattern(s) using GLOB." + " Example: ['src/*', 'lib/**']" + ), + ), + limit: int = Field( + default=20, + ge=1, + le=MAX_RESULTS, + description="Max definitions to return", + ), + ) -> FindDefinitionResult: + """Find symbol definitions.""" + try: + defs = await asyncio.to_thread( + _find_definitions_impl, + symbol_name, _root(), + symbol_type=symbol_type, + languages=languages, + paths=paths, limit=limit, + ) + return FindDefinitionResult( + success=True, + definitions=defs, + total_found=len(defs), + ) + except Exception as e: + return FindDefinitionResult( + success=False, + message=f"find_definition failed: {e!s}", + ) + + @mcp.tool( + name="find_references", + description=( + "Find all usages of a symbol across the codebase." + " Shows where a function is called, a class is" + " instantiated, a variable is read, etc." + " Use this before refactoring to understand impact." + " Classifies each reference as import, call, assignment," + " type_annotation, definition, or other." + ), + ) + async def find_references( + symbol_name: str = Field( + description="Name of the symbol to find references for", + ), + include_definitions: bool = Field( + default=False, + description="Include definition sites in results", + ), + languages: list[str] | None = Field( + default=None, + description="Filter by language(s)", + ), + paths: list[str] | None = Field( + default=None, + description="Filter by path pattern(s) using GLOB", + ), + context_lines: int = Field( + default=0, ge=0, le=10, + description="Context lines before/after each match", + ), + limit: int = Field( + default=50, ge=1, le=MAX_RESULTS, + description="Max references to return", + ), + ) -> FindReferencesResult: + """Find all references to a symbol.""" + try: + refs, total, searched, trunc = await asyncio.to_thread( + _find_references_impl, + symbol_name, _root(), + languages=languages, paths=paths, + context_lines=context_lines, limit=limit, + ) + if not include_definitions: + refs = [ + r for r in refs + if r.usage_type != "definition" + ] + total = len(refs) + + return FindReferencesResult( + success=True, + references=refs, + total_found=total, + files_searched=searched, + truncated=trunc, + ) + except Exception as e: + return FindReferencesResult( + success=False, + message=f"find_references failed: {e!s}", + ) + + @mcp.tool( + name="code_metrics", + description=( + "Compute code quality metrics for a file." + " Returns line counts (total, code, blank, comment)," + " function/class counts, average and max function length," + " nesting depth, and cyclomatic complexity estimate." + " Use to identify files needing refactoring." + ), + ) + async def code_metrics( + path: str = Field( + description=( + "Relative path to a source file." + " Example: 'src/server.py'" + ), + ), + ) -> CodeMetricsResult: + """Compute code metrics for a file.""" + try: + resolved = _safe_resolve(path) + if not resolved.is_file(): + return CodeMetricsResult( + success=False, path=path, + message=f"File not found: {path}", + ) + if _is_binary(resolved): + return CodeMetricsResult( + success=False, path=path, + message="Binary file, cannot analyze", + ) + lang = _detect_lang(resolved) + content = resolved.read_text( + encoding="utf-8", errors="replace", + ) + metrics = _compute_metrics(content, lang) + return CodeMetricsResult( + success=True, path=path, + metrics=metrics, language=lang, + ) + except ValueError as ve: + return CodeMetricsResult( + success=False, path=path, message=str(ve), + ) + except Exception as e: + return CodeMetricsResult( + success=False, path=path, + message=f"code_metrics failed: {e!s}", + ) + + @mcp.tool( + name="rename_symbol", + description=( + "Rename a symbol across the entire codebase using" + " word-boundary-aware replacement." + " Much safer than find-and-replace because it won't" + " rename 'get' inside 'get_user'." + " Defaults to dry_run=true so you can preview changes" + " before applying. Set dry_run=false to apply." + ), + ) + async def rename_symbol( + old_name: str = Field( + description="Current symbol name to rename", + ), + new_name: str = Field( + description="New name for the symbol", + ), + scope: str | None = Field( + default=None, + description=( + "Limit rename to files matching this GLOB pattern." + " Example: 'src/**/*.py'" + ), + ), + languages: list[str] | None = Field( + default=None, + description="Filter by language(s)", + ), + dry_run: bool = Field( + default=True, + description=( + "Preview changes without applying." + " Set to false to actually rename." + ), + ), + ) -> RenameResult: + """Rename a symbol across the codebase.""" + try: + return await asyncio.to_thread( + _rename_symbol_impl, + old_name, new_name, _root(), + scope=scope, languages=languages, + dry_run=dry_run, + ) + except Exception as e: + return RenameResult( + success=False, + old_name=old_name, new_name=new_name, + message=f"rename_symbol failed: {e!s}", + ) diff --git a/src/cocoindex_code/config.py b/src/cocoindex_code/config.py index f268b16..dcfc2a5 100644 --- a/src/cocoindex_code/config.py +++ b/src/cocoindex_code/config.py @@ -108,9 +108,9 @@ def from_env(cls) -> Config: continue if ":" in token: ext, lang = token.split(":", 1) - extra_extensions[f".{ext.strip()}"] = lang.strip() or None + extra_extensions[f".{ext.strip().lstrip('.')}"] = lang.strip() or None else: - extra_extensions[f".{token}"] = None + extra_extensions[f".{token.lstrip('.')}"] = None return cls( codebase_root_path=root, diff --git a/src/cocoindex_code/filesystem_tools.py b/src/cocoindex_code/filesystem_tools.py new file mode 100644 index 0000000..5e50057 --- /dev/null +++ b/src/cocoindex_code/filesystem_tools.py @@ -0,0 +1,1033 @@ +"""Fast filesystem tools for the cocoindex-code MCP server. + +Provides find_files, read_file, write_file, edit_file, grep_code, and directory_tree tools +that operate directly on the filesystem without vector search overhead. +""" + +from __future__ import annotations + +import asyncio +import fnmatch +import os +import re +import time +from pathlib import Path + +from mcp.server.fastmcp import FastMCP +from pydantic import BaseModel, Field + +from .config import config + +EXCLUDED_DIRS: frozenset[str] = frozenset( + { + ".git", + ".hg", + ".svn", + "__pycache__", + "node_modules", + ".cocoindex_code", + ".next", + ".nuxt", + ".venv", + "venv", + "env", + ".tox", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + } +) + +EXCLUDED_DIR_PATTERNS: list[str] = [ + "target", + "build", + "dist", + "vendor", +] + +MAX_READ_BYTES = 1_048_576 +MAX_RESULTS = 200 +MAX_TREE_DEPTH = 6 + +_EXT_LANG: dict[str, str] = { + ".py": "python", + ".pyi": "python", + ".js": "javascript", + ".jsx": "javascript", + ".mjs": "javascript", + ".cjs": "javascript", + ".ts": "typescript", + ".tsx": "typescript", + ".rs": "rust", + ".go": "go", + ".java": "java", + ".c": "c", + ".h": "c", + ".cpp": "cpp", + ".hpp": "cpp", + ".cc": "cpp", + ".cxx": "cpp", + ".hxx": "cpp", + ".hh": "cpp", + ".cs": "csharp", + ".rb": "ruby", + ".php": "php", + ".swift": "swift", + ".kt": "kotlin", + ".kts": "kotlin", + ".scala": "scala", + ".sh": "shell", + ".bash": "shell", + ".zsh": "shell", + ".sql": "sql", + ".md": "markdown", + ".mdx": "markdown", + ".json": "json", + ".yaml": "yaml", + ".yml": "yaml", + ".toml": "toml", + ".xml": "xml", + ".html": "html", + ".htm": "html", + ".css": "css", + ".scss": "scss", + ".less": "less", + ".txt": "text", + ".rst": "text", +} + + +# === Pydantic models === + + +class FileEntry(BaseModel): + """A file found by find_files.""" + + path: str = Field(description="Relative path from codebase root") + size: int = Field(description="File size in bytes") + language: str = Field(default="", description="Detected language (by extension)") + + +class FindFilesResult(BaseModel): + """Result from find_files tool.""" + + success: bool + files: list[FileEntry] = Field(default_factory=list) + total_found: int = 0 + truncated: bool = False + message: str | None = None + + +class ReadFileResult(BaseModel): + """Result from read_file tool.""" + + success: bool + path: str = "" + content: str = "" + start_line: int = 1 + end_line: int = 0 + total_lines: int = 0 + language: str = "" + message: str | None = None + + +MAX_WRITE_BYTES = 1_048_576 + + +class WriteFileResult(BaseModel): + """Result from write_file tool.""" + + success: bool + path: str = "" + bytes_written: int = 0 + created: bool = False + message: str | None = None + + +class EditFileResult(BaseModel): + """Result from edit_file tool.""" + + success: bool + path: str = "" + replacements: int = 0 + message: str | None = None + + +class GrepMatch(BaseModel): + """A single grep match.""" + + path: str = Field(description="Relative file path") + line_number: int = Field(description="1-indexed line number") + line: str = Field(description="Matched line content") + context_before: list[str] = Field(default_factory=list) + context_after: list[str] = Field(default_factory=list) + + +class GrepResult(BaseModel): + """Result from grep_code tool.""" + + success: bool + matches: list[GrepMatch] = Field(default_factory=list) + total_matches: int = 0 + files_searched: int = 0 + truncated: bool = False + message: str | None = None + + +class TreeEntry(BaseModel): + """A node in the directory tree.""" + + path: str + type: str = Field(description="'file' or 'dir'") + size: int = Field(default=0, description="File size in bytes (0 for dirs)") + children: int = Field(default=0, description="Number of direct children (dirs only)") + + +class DirectoryTreeResult(BaseModel): + """Result from directory_tree tool.""" + + success: bool + root: str = "" + entries: list[TreeEntry] = Field(default_factory=list) + message: str | None = None + + +# === Internal helpers === + + +def _root() -> Path: + """Return resolved codebase root.""" + return config.codebase_root_path.resolve() + + +def _safe_resolve(path_str: str) -> Path: + """Resolve a user-supplied path, ensuring it stays within the codebase root.""" + root = _root() + resolved = (root / path_str).resolve() + if not (resolved == root or str(resolved).startswith(str(root) + os.sep)): + msg = f"Path '{path_str}' escapes the codebase root" + raise ValueError(msg) + return resolved + + +def _is_excluded_dir(name: str) -> bool: + """Check if a directory name should be excluded.""" + if name.startswith("."): + return True + if name in EXCLUDED_DIRS: + return True + return any(fnmatch.fnmatch(name, pat) for pat in EXCLUDED_DIR_PATTERNS) + + +def _is_binary(path: Path, sample_size: int = 8192) -> bool: + """Heuristic binary detection by looking for null bytes.""" + try: + with open(path, "rb") as f: + chunk = f.read(sample_size) + return b"\x00" in chunk + except OSError: + return True + + +def _relative(path: Path) -> str: + """Return path relative to codebase root.""" + try: + return str(path.relative_to(_root())) + except ValueError: + return str(path) + + +def _detect_lang(path: Path) -> str: + """Detect programming language by file extension.""" + return _EXT_LANG.get(path.suffix.lower(), "") + + +# === Core implementations === + + +def _walk_files( + root: Path, + pattern: str | None = None, + languages: list[str] | None = None, + paths: list[str] | None = None, + limit: int = MAX_RESULTS, +) -> tuple[list[FileEntry], int, bool]: + """Walk the codebase and collect matching files.""" + lang_set = {lang.lower() for lang in languages} if languages else None + results: list[FileEntry] = [] + total = 0 + truncated = False + + for dirpath, dirnames, filenames in os.walk(root): + dirnames[:] = sorted(d for d in dirnames if not _is_excluded_dir(d)) + + for fname in sorted(filenames): + fpath = Path(dirpath) / fname + rel = _relative(fpath) + + if ( + pattern + and not fnmatch.fnmatch(rel, pattern) + and not fnmatch.fnmatch(fname, pattern) + ): + continue + + if paths and not any(fnmatch.fnmatch(rel, p) for p in paths): + continue + + lang = _detect_lang(fpath) + + if lang_set and lang.lower() not in lang_set: + continue + + total += 1 + if len(results) < limit: + try: + size = fpath.stat().st_size + except OSError: + size = 0 + results.append(FileEntry(path=rel, size=size, language=lang)) + else: + truncated = True + + return results, total, truncated + + +def _read_file( + path: Path, + start_line: int | None = None, + end_line: int | None = None, +) -> tuple[str, int, int, int]: + """Read a file, optionally slicing by line range.""" + with open(path, encoding="utf-8", errors="replace") as f: + lines = f.readlines() + + total = len(lines) + s = max(1, start_line or 1) + e = min(total, end_line or total) + + selected = lines[s - 1 : e] + content = "".join(selected) + + if len(content.encode("utf-8", errors="replace")) > MAX_READ_BYTES: + content = content[:MAX_READ_BYTES] + "\n\n... [truncated at 1 MB] ..." + + return content, s, e, total + + +def _write_file(path: Path, content: str) -> tuple[int, bool]: + """Write content to a file, creating parent directories as needed. + + Returns (bytes_written, created) where created indicates a new file. + """ + content_bytes = content.encode("utf-8") + if len(content_bytes) > MAX_WRITE_BYTES: + msg = f"Content exceeds maximum write size ({MAX_WRITE_BYTES} bytes)" + raise ValueError(msg) + created = not path.exists() + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + f.write(content) + return len(content_bytes), created + + +def _edit_file( + path: Path, + old_string: str, + new_string: str, + *, + replace_all: bool = False, +) -> int: + """Perform exact string replacement in a file. + + Returns the number of replacements made. + Raises ValueError if old_string is not found or is ambiguous. + """ + content = path.read_text(encoding="utf-8") + + if old_string == new_string: + msg = "old_string and new_string are identical" + raise ValueError(msg) + + count = content.count(old_string) + if count == 0: + msg = "old_string not found in file" + raise ValueError(msg) + + if count > 1 and not replace_all: + msg = ( + f"Found {count} matches for old_string." + " Provide more context to identify a unique match, or set replace_all=true." + ) + raise ValueError(msg) + + if replace_all: + new_content = content.replace(old_string, new_string) + replacements = count + else: + new_content = content.replace(old_string, new_string, 1) + replacements = 1 + + new_bytes = new_content.encode("utf-8") + if len(new_bytes) > MAX_WRITE_BYTES: + msg = f"Resulting file exceeds maximum size ({MAX_WRITE_BYTES} bytes)" + raise ValueError(msg) + + path.write_text(new_content, encoding="utf-8") + return replacements + + +def _grep_files( + root: Path, + pattern_str: str, + include: str | None = None, + paths: list[str] | None = None, + context_lines: int = 0, + limit: int = MAX_RESULTS, + *, + case_sensitive: bool = True, +) -> tuple[list[GrepMatch], int, int, bool]: + """Grep across files in the codebase.""" + flags = 0 if case_sensitive else re.IGNORECASE + try: + regex = re.compile(pattern_str, flags) + except re.error as e: + msg = f"Invalid regex: {e}" + raise ValueError(msg) from e + + matches: list[GrepMatch] = [] + total_matches = 0 + files_searched = 0 + truncated = False + + for dirpath, dirnames, filenames in os.walk(root): + dirnames[:] = sorted(d for d in dirnames if not _is_excluded_dir(d)) + + for fname in sorted(filenames): + fpath = Path(dirpath) / fname + rel = _relative(fpath) + + if ( + include + and not fnmatch.fnmatch(fname, include) + and not fnmatch.fnmatch(rel, include) + ): + continue + + if paths and not any(fnmatch.fnmatch(rel, p) for p in paths): + continue + + try: + if fpath.stat().st_size > MAX_READ_BYTES: + continue + except OSError: + continue + if _is_binary(fpath): + continue + + try: + with open(fpath, encoding="utf-8", errors="replace") as f: + file_lines = f.readlines() + except OSError: + continue + + files_searched += 1 + + for i, line in enumerate(file_lines): + if regex.search(line): + total_matches += 1 + if len(matches) < limit: + ctx_before = [ + file_lines[j].rstrip("\n\r") + for j in range(max(0, i - context_lines), i) + ] + ctx_after = [ + file_lines[j].rstrip("\n\r") + for j in range(i + 1, min(len(file_lines), i + 1 + context_lines)) + ] + matches.append( + GrepMatch( + path=rel, + line_number=i + 1, + line=line.rstrip("\n\r"), + context_before=ctx_before, + context_after=ctx_after, + ) + ) + elif not truncated: + truncated = True + + return matches, total_matches, files_searched, truncated + + +def _directory_tree( + root: Path, + rel_path: str = "", + max_depth: int = MAX_TREE_DEPTH, +) -> list[TreeEntry]: + """Build a directory tree listing.""" + start = _safe_resolve(rel_path) if rel_path else root + entries: list[TreeEntry] = [] + + def _walk(dirpath: Path, depth: int) -> None: + if depth > max_depth: + return + try: + children = sorted(dirpath.iterdir(), key=lambda p: (p.is_file(), p.name)) + except PermissionError: + return + + for child in children: + rel = _relative(child) + if child.is_dir(): + if _is_excluded_dir(child.name): + continue + sub_children = ( + sum(1 for c in child.iterdir() if not (c.is_dir() and _is_excluded_dir(c.name))) + if depth < max_depth + else 0 + ) + entries.append(TreeEntry(path=rel, type="dir", children=sub_children)) + _walk(child, depth + 1) + else: + try: + size = child.stat().st_size + except OSError: + size = 0 + entries.append(TreeEntry(path=rel, type="file", size=size)) + + _walk(start, 0) + return entries + + +# === MCP tool registration === + + +def register_filesystem_tools(mcp: FastMCP) -> None: + """Register all filesystem tools on the given MCP server.""" + + @mcp.tool( + name="find_files", + description=( + "Fast file discovery by glob pattern, language, or path." + " Use this to quickly list files matching a pattern" + " (e.g., '*.py', 'src/**/*.ts', 'README*')." + " Much faster than semantic search for finding files by name." + " Returns file paths, sizes, and detected languages." + ), + ) + async def find_files( + pattern: str | None = Field( + default=None, + description=( + "Glob pattern to match file names or paths." + " Examples: '*.py', 'src/**/*.ts', 'README*', '*.test.*'" + ), + ), + languages: list[str] | None = Field( + default=None, + description="Filter by language(s). Example: ['python', 'typescript']", + ), + paths: list[str] | None = Field( + default=None, + description=( + "Filter by path pattern(s) using GLOB wildcards. Example: ['src/*', 'lib/**']" + ), + ), + limit: int = Field( + default=50, + ge=1, + le=MAX_RESULTS, + description=f"Maximum number of results (1-{MAX_RESULTS})", + ), + ) -> FindFilesResult: + """Find files in the codebase by pattern.""" + try: + files, total, truncated = await asyncio.to_thread( + _walk_files, + _root(), + pattern=pattern, + languages=languages, + paths=paths, + limit=limit, + ) + return FindFilesResult( + success=True, + files=files, + total_found=total, + truncated=truncated, + ) + except Exception as e: + return FindFilesResult(success=False, message=f"Find failed: {e!s}") + + @mcp.tool( + name="read_file", + description=( + "Read file contents by path, with optional line range." + " Use this when you know the exact file path and want to read" + " its contents quickly -- much faster than semantic search." + " Supports reading specific line ranges for large files." + " Returns content with language detection and total line count." + ), + ) + async def read_file( + path: str = Field( + description="Relative path from codebase root. Example: 'src/utils/helpers.ts'", + ), + start_line: int | None = Field( + default=None, + ge=1, + description="Start reading from this line (1-indexed). Default: first line.", + ), + end_line: int | None = Field( + default=None, + ge=1, + description="Stop reading at this line (inclusive). Default: last line.", + ), + ) -> ReadFileResult: + """Read a file from the codebase.""" + try: + resolved = _safe_resolve(path) + if not resolved.is_file(): + return ReadFileResult( + success=False, + path=path, + message=f"File not found: {path}", + ) + if _is_binary(resolved): + return ReadFileResult( + success=False, + path=path, + message=f"Binary file, cannot display: {path}", + ) + + content, s, e, total = _read_file(resolved, start_line, end_line) + return ReadFileResult( + success=True, + path=path, + content=content, + start_line=s, + end_line=e, + total_lines=total, + language=_detect_lang(resolved), + ) + except ValueError as ve: + return ReadFileResult(success=False, path=path, message=str(ve)) + except Exception as e: + return ReadFileResult(success=False, path=path, message=f"Read failed: {e!s}") + + @mcp.tool( + name="write_file", + description=( + "Write content to a file in the codebase." + " Creates the file if it does not exist, overwrites if it does." + " Automatically creates parent directories as needed." + " Use this to create new files or update existing ones." + " Returns bytes written and whether the file was newly created." + ), + ) + async def write_file( + path: str = Field( + description="Relative path from codebase root. Example: 'src/utils/helpers.ts'", + ), + content: str = Field( + description="The text content to write to the file.", + ), + ) -> WriteFileResult: + """Write content to a file in the codebase.""" + try: + resolved = _safe_resolve(path) + bytes_written, created = _write_file(resolved, content) + return WriteFileResult( + success=True, + path=path, + bytes_written=bytes_written, + created=created, + ) + except ValueError as ve: + return WriteFileResult(success=False, path=path, message=str(ve)) + except Exception as e: + return WriteFileResult(success=False, path=path, message=f"Write failed: {e!s}") + + @mcp.tool( + name="edit_file", + description=( + "Perform exact string replacements in a file." + " Finds old_string in the file and replaces it with new_string." + " By default requires old_string to match exactly once (for safety)." + " Set replace_all=true to replace every occurrence." + " Use this for surgical edits instead of rewriting entire files." + ), + ) + async def edit_file( + path: str = Field( + description="Relative path from codebase root. Example: 'src/utils/helpers.ts'", + ), + old_string: str = Field( + description="The exact text to find and replace. Must match file content exactly.", + ), + new_string: str = Field( + description="The replacement text. Must differ from old_string.", + ), + replace_all: bool = Field( + default=False, + description=( + "Replace all occurrences. Default false requires exactly one match for safety." + ), + ), + ) -> EditFileResult: + """Perform exact string replacement in a file.""" + try: + resolved = _safe_resolve(path) + if not resolved.is_file(): + return EditFileResult( + success=False, + path=path, + message=f"File not found: {path}", + ) + if _is_binary(resolved): + return EditFileResult( + success=False, + path=path, + message=f"Binary file, cannot edit: {path}", + ) + replacements = _edit_file(resolved, old_string, new_string, replace_all=replace_all) + return EditFileResult( + success=True, + path=path, + replacements=replacements, + ) + except ValueError as ve: + return EditFileResult(success=False, path=path, message=str(ve)) + except Exception as e: + return EditFileResult(success=False, path=path, message=f"Edit failed: {e!s}") + + @mcp.tool( + name="grep_code", + description=( + "Fast regex text search across codebase files." + " Use this instead of semantic search when you need exact" + " text or pattern matching (e.g., function names, imports," + " TODO comments, error strings)." + " Returns matching lines with file paths, line numbers," + " and optional context lines." + ), + ) + async def grep_code( + pattern: str = Field( + description=( + "Regular expression pattern to search for." + " Examples: 'def authenticate', 'import.*redis'," + " 'TODO|FIXME|HACK', 'class\\s+User'" + ), + ), + include: str | None = Field( + default=None, + description="File pattern to include. Examples: '*.py', '*.{ts,tsx}', 'Makefile'", + ), + paths: list[str] | None = Field( + default=None, + description="Filter by path pattern(s). Example: ['src/*', 'lib/**']", + ), + context_lines: int = Field( + default=0, + ge=0, + le=10, + description="Number of context lines before and after each match (0-10)", + ), + case_sensitive: bool = Field( + default=True, + description="Whether the search is case-sensitive", + ), + limit: int = Field( + default=50, + ge=1, + le=MAX_RESULTS, + description=f"Maximum number of matches (1-{MAX_RESULTS})", + ), + ) -> GrepResult: + """Search file contents by regex pattern.""" + try: + matches, total, searched, truncated = _grep_files( + _root(), + pattern, + include=include, + paths=paths, + context_lines=context_lines, + limit=limit, + case_sensitive=case_sensitive, + ) + return GrepResult( + success=True, + matches=matches, + total_matches=total, + files_searched=searched, + truncated=truncated, + ) + except ValueError as ve: + return GrepResult(success=False, message=str(ve)) + except Exception as e: + return GrepResult(success=False, message=f"Grep failed: {e!s}") + + @mcp.tool( + name="directory_tree", + description=( + "List the directory structure of the codebase." + " Use this to understand project layout, find directories," + " or get an overview before diving into specific files." + " Excludes hidden dirs, node_modules, build artifacts, etc." + " Returns a flat list of entries with types and sizes." + ), + ) + async def directory_tree( + path: str = Field( + default="", + description=( + "Relative path to start from (empty = codebase root). Example: 'src/components'" + ), + ), + max_depth: int = Field( + default=MAX_TREE_DEPTH, + ge=1, + le=10, + description=f"Maximum directory depth to recurse (1-10, default {MAX_TREE_DEPTH})", + ), + ) -> DirectoryTreeResult: + """List the directory tree of the codebase.""" + try: + start = _safe_resolve(path) if path else _root() + if not start.is_dir(): + return DirectoryTreeResult( + success=False, + message=f"Directory not found: {path}", + ) + entries = _directory_tree(_root(), rel_path=path, max_depth=max_depth) + return DirectoryTreeResult( + success=True, + root=_relative(start) if path else ".", + entries=entries, + ) + except ValueError as ve: + return DirectoryTreeResult(success=False, message=str(ve)) + except Exception as e: + return DirectoryTreeResult(success=False, message=f"Tree failed: {e!s}") + + +# === Large write support === + +# In-memory buffers for chunked writes, keyed by session_id +_large_write_buffers: dict[str, dict] = {} + +MAX_LARGE_WRITE_BYTES = 5_242_880 # 5 MB total limit per session +MAX_LARGE_WRITE_SESSIONS = 50 # Maximum concurrent sessions + + +class LargeWriteResult(BaseModel): + """Result from large_write tool.""" + + success: bool + session_id: str = "" + path: str = "" + action: str = "" + chunks_received: int = 0 + total_bytes: int = 0 + bytes_written: int = 0 + created: bool = False + message: str | None = None + + +def _large_write_start( + session_id: str, path: str, +) -> None: + """Start a new large write session. + + Evicts the oldest session if MAX_LARGE_WRITE_SESSIONS is reached. + """ + # Evict oldest session if at capacity + if ( + session_id not in _large_write_buffers + and len(_large_write_buffers) >= MAX_LARGE_WRITE_SESSIONS + ): + oldest_key = min( + _large_write_buffers, + key=lambda k: _large_write_buffers[k].get("created_at", 0), + ) + _large_write_buffers.pop(oldest_key, None) + + _large_write_buffers[session_id] = { + "path": path, + "chunks": [], + "total_bytes": 0, + "created_at": time.monotonic(), + } + + +def _large_write_append( + session_id: str, content: str, +) -> int: + """Append content to a large write session. Returns new total bytes.""" + buf = _large_write_buffers[session_id] + chunk_bytes = len(content.encode("utf-8")) + new_total = buf["total_bytes"] + chunk_bytes + if new_total > MAX_LARGE_WRITE_BYTES: + msg = ( + f"Content exceeds max size ({MAX_LARGE_WRITE_BYTES} bytes)." + f" Current: {buf['total_bytes']}, chunk: {chunk_bytes}" + ) + raise ValueError(msg) + buf["chunks"].append(content) + buf["total_bytes"] = new_total + return new_total + + +def _large_write_finalize( + session_id: str, +) -> tuple[str, int, bool]: + """Finalize and write the buffered content. + + Returns (path, bytes_written, created). + """ + buf = _large_write_buffers.pop(session_id) + path_str = buf["path"] + full_content = "".join(buf["chunks"]) + resolved = _safe_resolve(path_str) + created = not resolved.exists() + resolved.parent.mkdir(parents=True, exist_ok=True) + resolved.write_text(full_content, encoding="utf-8") + return path_str, buf["total_bytes"], created + + +def register_large_write_tool(mcp: FastMCP) -> None: + """Register the large_write tool on the MCP server.""" + + @mcp.tool( + name="large_write", + description=( + "Write large files in chunks when content is too big for" + " a single write_file call." + " Use action='start' to begin a session with a file path," + " then 'append' to add content in pieces," + " then 'finalize' to write the assembled file to disk." + " Supports up to 5 MB total. Each session is identified" + " by a session_id you provide." + "\n\nWorkflow:" + "\n1. large_write(action='start', session_id='s1'," + " path='src/big_file.py')" + "\n2. large_write(action='append', session_id='s1'," + " content='first chunk...')" + "\n3. large_write(action='append', session_id='s1'," + " content='second chunk...')" + "\n4. large_write(action='finalize', session_id='s1')" + ), + ) + async def large_write( + action: str = Field( + description=( + "Action: 'start' to begin, 'append' to add content," + " 'finalize' to write to disk, 'abort' to cancel." + ), + ), + session_id: str = Field( + description="Unique session identifier for this write.", + ), + path: str = Field( + default="", + description=( + "Relative file path. Required for 'start' action." + " Example: 'src/utils/big_module.py'" + ), + ), + content: str = Field( + default="", + description=( + "Content chunk to append. Used with 'append' action." + ), + ), + ) -> LargeWriteResult: + """Write large files in chunks.""" + try: + if action == "start": + if not path: + return LargeWriteResult( + success=False, action=action, + session_id=session_id, + message="path is required for 'start' action", + ) + # Validate path early + _safe_resolve(path) + _large_write_start(session_id, path) + return LargeWriteResult( + success=True, action=action, + session_id=session_id, path=path, + chunks_received=0, total_bytes=0, + ) + + if action == "append": + if session_id not in _large_write_buffers: + return LargeWriteResult( + success=False, action=action, + session_id=session_id, + message=f"No active session '{session_id}'." + " Call with action='start' first.", + ) + if not content: + return LargeWriteResult( + success=False, action=action, + session_id=session_id, + message="content is required for 'append'", + ) + total = _large_write_append(session_id, content) + buf = _large_write_buffers[session_id] + return LargeWriteResult( + success=True, action=action, + session_id=session_id, + path=buf["path"], + chunks_received=len(buf["chunks"]), + total_bytes=total, + ) + + if action == "finalize": + if session_id not in _large_write_buffers: + return LargeWriteResult( + success=False, action=action, + session_id=session_id, + message=f"No active session '{session_id}'", + ) + fpath, written, created = _large_write_finalize( + session_id, + ) + return LargeWriteResult( + success=True, action=action, + session_id=session_id, + path=fpath, bytes_written=written, + created=created, + ) + + if action == "abort": + _large_write_buffers.pop(session_id, None) + return LargeWriteResult( + success=True, action=action, + session_id=session_id, + message="Session aborted", + ) + + return LargeWriteResult( + success=False, action=action, + session_id=session_id, + message=( + f"Invalid action '{action}'." + " Must be 'start', 'append'," + " 'finalize', or 'abort'." + ), + ) + except ValueError as ve: + return LargeWriteResult( + success=False, action=action, + session_id=session_id, message=str(ve), + ) + except Exception as e: + return LargeWriteResult( + success=False, action=action, + session_id=session_id, + message=f"large_write failed: {e!s}", + ) diff --git a/src/cocoindex_code/patch_tools.py b/src/cocoindex_code/patch_tools.py new file mode 100644 index 0000000..86af08e --- /dev/null +++ b/src/cocoindex_code/patch_tools.py @@ -0,0 +1,378 @@ +"""Patch tools for the cocoindex-code MCP server. + +Provides apply_patch tool for applying unified diff patches to files +in the codebase. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from pathlib import Path + +from mcp.server.fastmcp import FastMCP +from pydantic import BaseModel, Field + +from .filesystem_tools import ( + MAX_WRITE_BYTES, + _root, + _safe_resolve, +) + +# === Internal data structures === + + +@dataclass +class PatchHunk: + """A single hunk from a unified diff.""" + + old_start: int + old_count: int + new_start: int + new_count: int + lines: list[str] = field(default_factory=list) + + +@dataclass +class PatchFile: + """Parsed patch data for a single file.""" + + old_path: str + new_path: str + hunks: list[PatchHunk] = field(default_factory=list) + + +# === Pydantic result models === + + +class PatchFileResult(BaseModel): + """Result for a single file in a patch.""" + + path: str = Field(description="Relative file path") + hunks_applied: int = Field(default=0, description="Hunks applied") + hunks_rejected: int = Field( + default=0, description="Hunks that failed to apply" + ) + created: bool = Field( + default=False, description="Whether file was newly created" + ) + + +class ApplyPatchResult(BaseModel): + """Result from apply_patch tool.""" + + success: bool + files: list[PatchFileResult] = Field(default_factory=list) + total_applied: int = 0 + total_rejected: int = 0 + dry_run: bool = True + message: str | None = None + + +# === Unified diff parser === + +_HUNK_HEADER = re.compile( + r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@" +) + + +def _parse_unified_diff(patch_text: str) -> list[PatchFile]: + """Parse a unified diff into structured PatchFile objects.""" + files: list[PatchFile] = [] + lines = patch_text.splitlines(keepends=True) + i = 0 + + while i < len(lines): + line = lines[i] + + # Look for file header + if line.startswith("--- "): + if i + 1 >= len(lines): + break + next_line = lines[i + 1] + if not next_line.startswith("+++ "): + i += 1 + continue + + old_path = line[4:].strip() + new_path = next_line[4:].strip() + + # Strip a/ b/ prefixes + if old_path.startswith("a/"): + old_path = old_path[2:] + if new_path.startswith("b/"): + new_path = new_path[2:] + + pf = PatchFile(old_path=old_path, new_path=new_path) + i += 2 + + # Parse hunks for this file + while i < len(lines): + hunk_line = lines[i] + m = _HUNK_HEADER.match(hunk_line) + if m is None: + # Check if next file starts + if hunk_line.startswith("--- "): + break + if hunk_line.startswith("diff "): + break + i += 1 + continue + + old_start = int(m.group(1)) + old_count = int(m.group(2) or "1") + new_start = int(m.group(3)) + new_count = int(m.group(4) or "1") + + hunk = PatchHunk( + old_start=old_start, + old_count=old_count, + new_start=new_start, + new_count=new_count, + ) + i += 1 + + # Collect hunk lines + while i < len(lines): + hl = lines[i] + # Stop if we hit a new file header + if hl.startswith("--- ") or hl.startswith("diff "): + break + if _HUNK_HEADER.match(hl): + break + if hl.startswith(("+", "-", " ")): + hunk.lines.append(hl.rstrip("\n\r")) + i += 1 + elif hl.startswith("\\"): + # "\ No newline at end of file" + i += 1 + else: + break + + pf.hunks.append(hunk) + + files.append(pf) + else: + i += 1 + + return files + + +# === Hunk application === + + +def _apply_hunks( + content: str, hunks: list[PatchHunk], +) -> tuple[str, int, int]: + """Apply hunks to file content. + + Returns (new_content, applied_count, rejected_count). + """ + file_lines = content.splitlines(keepends=True) + applied = 0 + rejected = 0 + + # Apply hunks in reverse order to preserve line numbers + for hunk in reversed(hunks): + old_lines: list[str] = [] + new_lines: list[str] = [] + + for hl in hunk.lines: + if hl.startswith("-"): + old_lines.append(hl[1:]) + elif hl.startswith("+"): + new_lines.append(hl[1:]) + elif hl.startswith(" "): + old_lines.append(hl[1:]) + new_lines.append(hl[1:]) + + # Verify context matches (old lines) + start_idx = hunk.old_start - 1 # 0-indexed + match = True + + if start_idx < 0 or start_idx + len(old_lines) > len(file_lines): + match = False + else: + for j, expected in enumerate(old_lines): + actual = file_lines[start_idx + j].rstrip("\n\r") + if actual != expected: + match = False + break + + if match: + # Replace old lines with new lines + replacement = [ln + "\n" for ln in new_lines] + file_lines[start_idx:start_idx + len(old_lines)] = ( + replacement + ) + applied += 1 + else: + rejected += 1 + + return "".join(file_lines), applied, rejected + + +def _apply_patch_impl( + patch_text: str, + root: Path, + dry_run: bool = True, +) -> ApplyPatchResult: + """Apply a unified diff patch.""" + try: + patch_files = _parse_unified_diff(patch_text) + except Exception as e: + return ApplyPatchResult( + success=False, + message=f"Failed to parse patch: {e!s}", + ) + + if not patch_files: + return ApplyPatchResult( + success=False, + message="No files found in patch", + ) + + results: list[PatchFileResult] = [] + total_applied = 0 + total_rejected = 0 + + for pf in patch_files: + target_path = pf.new_path + is_new = pf.old_path == "/dev/null" + is_delete = pf.new_path == "/dev/null" + + if is_delete: + target_path = pf.old_path + + try: + resolved = _safe_resolve(target_path) + except ValueError: + results.append(PatchFileResult( + path=target_path, + hunks_rejected=len(pf.hunks), + )) + total_rejected += len(pf.hunks) + continue + + if is_new: + # New file: collect all + lines + new_content = "" + for hunk in pf.hunks: + for hl in hunk.lines: + if hl.startswith("+"): + new_content += hl[1:] + "\n" + + if not dry_run: + resolved.parent.mkdir(parents=True, exist_ok=True) + content_bytes = new_content.encode("utf-8") + if len(content_bytes) > MAX_WRITE_BYTES: + results.append(PatchFileResult( + path=target_path, + hunks_rejected=len(pf.hunks), + )) + total_rejected += len(pf.hunks) + continue + resolved.write_text(new_content, encoding="utf-8") + + results.append(PatchFileResult( + path=target_path, + hunks_applied=len(pf.hunks), + created=True, + )) + total_applied += len(pf.hunks) + continue + + if not resolved.is_file(): + results.append(PatchFileResult( + path=target_path, + hunks_rejected=len(pf.hunks), + )) + total_rejected += len(pf.hunks) + continue + + try: + content = resolved.read_text( + encoding="utf-8", errors="replace", + ) + except OSError: + results.append(PatchFileResult( + path=target_path, + hunks_rejected=len(pf.hunks), + )) + total_rejected += len(pf.hunks) + continue + + new_content, app, rej = _apply_hunks(content, pf.hunks) + + if not dry_run and app > 0: + content_bytes = new_content.encode("utf-8") + if len(content_bytes) > MAX_WRITE_BYTES: + results.append(PatchFileResult( + path=target_path, + hunks_rejected=len(pf.hunks), + )) + total_rejected += len(pf.hunks) + continue + resolved.write_text(new_content, encoding="utf-8") + + results.append(PatchFileResult( + path=target_path, + hunks_applied=app, + hunks_rejected=rej, + )) + total_applied += app + total_rejected += rej + + return ApplyPatchResult( + success=total_rejected == 0, + files=results, + total_applied=total_applied, + total_rejected=total_rejected, + dry_run=dry_run, + ) + + +# === MCP tool registration === + + +def register_patch_tools(mcp: FastMCP) -> None: + """Register patch tools on the MCP server.""" + + @mcp.tool( + name="apply_patch", + description=( + "Apply a unified diff patch to one or more files." + " Accepts standard unified diff format (as produced by" + " 'git diff' or 'diff -u')." + " Defaults to dry_run=true so you can preview which hunks" + " would be applied or rejected before committing changes." + " Set dry_run=false to actually modify files." + " Supports new file creation, multi-file patches," + " and multi-hunk patches." + ), + ) + async def apply_patch( + patch: str = Field( + description=( + "Unified diff text. Must include --- / +++ headers" + " and @@ hunk markers." + ), + ), + dry_run: bool = Field( + default=True, + description=( + "Preview changes without applying." + " Set to false to apply the patch." + ), + ), + ) -> ApplyPatchResult: + """Apply a unified diff patch.""" + try: + return _apply_patch_impl( + patch, _root(), dry_run=dry_run, + ) + except Exception as e: + return ApplyPatchResult( + success=False, + message=f"apply_patch failed: {e!s}", + ) diff --git a/src/cocoindex_code/schema.py b/src/cocoindex_code/schema.py index bfb8a74..8a0b5ff 100644 --- a/src/cocoindex_code/schema.py +++ b/src/cocoindex_code/schema.py @@ -1,20 +1,6 @@ """Data models for CocoIndex Code.""" from dataclasses import dataclass -from typing import Any - - -@dataclass -class CodeChunk: - """Represents an indexed code chunk stored in SQLite.""" - - id: int - file_path: str - language: str - content: str - start_line: int - end_line: int - embedding: Any # NDArray - type hint relaxed for compatibility @dataclass diff --git a/src/cocoindex_code/server.py b/src/cocoindex_code/server.py index 8c04267..6769954 100644 --- a/src/cocoindex_code/server.py +++ b/src/cocoindex_code/server.py @@ -2,15 +2,21 @@ import argparse import asyncio +import logging +import sys import cocoindex as coco from mcp.server.fastmcp import FastMCP from pydantic import BaseModel, Field +from .code_intelligence_tools import register_code_intelligence_tools from .config import config +from .filesystem_tools import register_filesystem_tools, register_large_write_tool from .indexer import app as indexer_app +from .patch_tools import register_patch_tools from .query import query_codebase from .shared import SQLITE_DB +from .thinking_tools import register_thinking_tools # Initialize MCP server mcp = FastMCP( @@ -24,9 +30,47 @@ "Provides semantic search that understands meaning --" " unlike grep or text matching," " it finds relevant code even when exact keywords are unknown." + "\n\n" + "Fast filesystem tools:" + "\n- find_files: fast glob-based file discovery" + "\n- read_file: read file contents with line ranges" + "\n- write_file: write/create files instantly" + "\n- edit_file: exact string replacement in files" + "\n- grep_code: regex text search across files" + "\n- directory_tree: list project structure" + "\n- large_write: write large files in chunks" + "\n\n" + "Code intelligence tools:" + "\n- list_symbols: list functions, classes, methods in a file" + "\n- find_definition: go-to-definition across the codebase" + "\n- find_references: find all usages of a symbol" + "\n- code_metrics: code quality metrics for a file" + "\n- rename_symbol: safe codebase-wide rename" + "\n\n" + "Patch tools:" + "\n- apply_patch: apply unified diff patches to files" + "\n\n" + "Advanced thinking and reasoning tools:" + "\n- sequential_thinking: step-by-step problem solving" + "\n- extended_thinking: deep analysis with checkpoints" + "\n- ultra_thinking: maximum-depth reasoning" + "\n- evidence_tracker: attach weighted evidence to hypotheses" + "\n- premortem: structured pre-failure risk analysis" + "\n- inversion_thinking: guarantee-failure-then-invert reasoning" + "\n- effort_estimator: three-point PERT estimation" + "\n- learning_loop: reflect on sessions and extract learnings" + "\n- self_improve: get strategy recommendations" + "\n- reward_thinking: provide reinforcement signals" + "\n- plan_optimizer: analyze, score, and optimize any plan" ), ) +register_filesystem_tools(mcp) +register_large_write_tool(mcp) +register_code_intelligence_tools(mcp) +register_patch_tools(mcp) +register_thinking_tools(mcp) + # Lock to prevent concurrent index updates _index_lock = asyncio.Lock() @@ -167,8 +211,10 @@ async def search( async def _async_serve() -> None: """Async entry point for the MCP server.""" - # Refresh index in background so startup isn't blocked - asyncio.create_task(_refresh_index()) + # Index refresh is deferred to first search call. + # Starting it here can crash the stdio transport if the + # background task raises or writes to stdout/stderr before + # the MCP handshake completes. await mcp.run_stdio_async() @@ -209,6 +255,8 @@ async def _print_index_stats() -> None: def main() -> None: """Entry point for the cocoindex-code CLI.""" + # Ensure all logging goes to stderr, never stdout (MCP uses stdout for JSON-RPC) + logging.basicConfig(stream=sys.stderr, level=logging.WARNING) parser = argparse.ArgumentParser( prog="cocoindex-code", description="MCP server for codebase indexing and querying.", diff --git a/src/cocoindex_code/thinking_engine.py b/src/cocoindex_code/thinking_engine.py new file mode 100644 index 0000000..4cb7f35 --- /dev/null +++ b/src/cocoindex_code/thinking_engine.py @@ -0,0 +1,1326 @@ +"""ThinkingEngine — core logic for thinking tools subsystem.""" + +from __future__ import annotations + +import json +import re +import time +from pathlib import Path + +from .thinking_models import ( + _MISSING_CONCERN_CHECKS, + _VAGUE_PATTERNS, + PERT_WEIGHT, + PLAN_DIMENSIONS, + THINKING_MEMORY_FILE, + VALID_EVIDENCE_TYPES, + VALID_INVERSION_PHASES, + VALID_PLAN_OPTIMIZER_PHASES, + VALID_PREMORTEM_PHASES, + EffortEstimatorResult, + EstimateItem, + EstimatorSession, + EvidenceItem, + EvidenceTrackerResult, + ExtendedThinkingResult, + InversionCause, + InversionSession, + InversionThinkingResult, + LearningEntry, + LearningLoopResult, + PlanAntiPattern, + PlanOptimizerResult, + PlanOptimizerSession, + PlanVariant, + PremortemResult, + PremortemRisk, + PremortemSession, + RewardResult, + StrategyScore, + ThinkingResult, + ThoughtData, + UltraThinkingResult, +) + + +class ThinkingEngine: + def __init__(self, memory_dir: Path) -> None: + self._memory_dir = memory_dir + self._memory_file = memory_dir / THINKING_MEMORY_FILE + self._sessions: dict[str, list[ThoughtData]] = {} + self._branches: dict[str, dict[str, list[ThoughtData]]] = {} + self._learnings: list[LearningEntry] = [] + self._strategy_scores: dict[str, StrategyScore] = {} + self._hypotheses: dict[str, list[str]] = {} + self._evidence: dict[str, dict[int, list[EvidenceItem]]] = {} + self._premortems: dict[str, PremortemSession] = {} + self._inversions: dict[str, InversionSession] = {} + self._estimators: dict[str, EstimatorSession] = {} + self._plan_optimizers: dict[str, PlanOptimizerSession] = {} + self._load_memory() + + @property + def _memory_path(self) -> Path: + return self._memory_file + + def _load_memory(self) -> None: + """Load thinking memory from JSONL, compacting if needed.""" + raw_line_count = 0 + try: + with open(self._memory_file, encoding="utf-8") as f: + for line in f: + raw_line_count += 1 + line = line.strip() + if not line: + continue + entry = json.loads(line) + entry_type = entry.get("type") + if entry_type == "learning": + self._learnings.append(LearningEntry(**entry["data"])) + elif entry_type == "strategy": + score = StrategyScore(**entry["data"]) + self._strategy_scores[score.strategy] = score + except FileNotFoundError: + return + + # Compact if raw lines significantly exceed deduplicated count + dedup_count = len(self._learnings) + len(self._strategy_scores) + if raw_line_count > max(dedup_count * 2, 20): + self._compact_memory() + + def _compact_memory(self) -> None: + """Rewrite the JSONL file with only deduplicated entries.""" + self._memory_file.parent.mkdir(parents=True, exist_ok=True) + compact_path = self._memory_file.with_suffix(".jsonl.tmp") + with open(compact_path, "w", encoding="utf-8") as f: + for entry in self._learnings: + f.write(json.dumps({"type": "learning", "data": entry.model_dump()}) + "\n") + for score in self._strategy_scores.values(): + f.write(json.dumps({"type": "strategy", "data": score.model_dump()}) + "\n") + compact_path.replace(self._memory_file) + + def _save_entry(self, entry: dict) -> None: + self._memory_file.parent.mkdir(parents=True, exist_ok=True) + with open(self._memory_file, "a", encoding="utf-8") as f: + f.write(json.dumps(entry) + "\n") + + def _save_strategy(self, strategy: StrategyScore) -> None: + self._save_entry({"type": "strategy", "data": strategy.model_dump()}) + + def process_thought(self, session_id: str, data: ThoughtData) -> ThinkingResult: + if session_id not in self._sessions: + self._sessions[session_id] = [] + + session_thoughts = self._sessions[session_id] + + if data.thought_number > data.total_thoughts: + data = data.model_copy(update={"total_thoughts": data.thought_number}) + + session_thoughts.append(data) + + branches: list[str] = [] + if data.branch_id is not None: + if session_id not in self._branches: + self._branches[session_id] = {} + if data.branch_id not in self._branches[session_id]: + self._branches[session_id][data.branch_id] = [] + self._branches[session_id][data.branch_id].append(data) + branches = list(self._branches[session_id].keys()) + elif session_id in self._branches: + branches = list(self._branches[session_id].keys()) + + return ThinkingResult( + success=True, + session_id=session_id, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + branches=branches, + thought_history_length=len(session_thoughts), + ) + + def process_extended_thought( + self, + session_id: str, + data: ThoughtData, + depth_level: str = "deep", + checkpoint_interval: int = 5, + ) -> ExtendedThinkingResult: + if session_id not in self._sessions: + self._sessions[session_id] = [] + + session_thoughts = self._sessions[session_id] + + if data.thought_number > data.total_thoughts: + data = data.model_copy(update={"total_thoughts": data.thought_number}) + + session_thoughts.append(data) + + branches: list[str] = [] + if data.branch_id is not None: + if session_id not in self._branches: + self._branches[session_id] = {} + if data.branch_id not in self._branches[session_id]: + self._branches[session_id][data.branch_id] = [] + self._branches[session_id][data.branch_id].append(data) + branches = list(self._branches[session_id].keys()) + elif session_id in self._branches: + branches = list(self._branches[session_id].keys()) + + checkpoint_summary = "" + steps_since_checkpoint = data.thought_number % checkpoint_interval + if steps_since_checkpoint == 0: + checkpoint_summary = ( + f"Checkpoint at step {data.thought_number}: " + f"{len(session_thoughts)} thoughts, {len(branches)} branches" + ) + + return ExtendedThinkingResult( + success=True, + session_id=session_id, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + branches=branches, + thought_history_length=len(session_thoughts), + depth_level=depth_level, + checkpoint_summary=checkpoint_summary, + steps_since_checkpoint=steps_since_checkpoint, + checkpoint_interval=checkpoint_interval, + ) + + def process_ultra_thought( + self, + session_id: str, + data: ThoughtData, + phase: str = "explore", + hypothesis: str | None = None, + confidence: float = 0.0, + ) -> UltraThinkingResult: + if session_id not in self._sessions: + self._sessions[session_id] = [] + + session_thoughts = self._sessions[session_id] + + if data.thought_number > data.total_thoughts: + data = data.model_copy(update={"total_thoughts": data.thought_number}) + + session_thoughts.append(data) + + branches: list[str] = [] + if data.branch_id is not None: + if session_id not in self._branches: + self._branches[session_id] = {} + if data.branch_id not in self._branches[session_id]: + self._branches[session_id][data.branch_id] = [] + self._branches[session_id][data.branch_id].append(data) + branches = list(self._branches[session_id].keys()) + elif session_id in self._branches: + branches = list(self._branches[session_id].keys()) + + if session_id not in self._hypotheses: + self._hypotheses[session_id] = [] + + verification_status = "" + synthesis = "" + + if phase == "hypothesize" and hypothesis is not None: + self._hypotheses[session_id].append(hypothesis) + elif phase == "verify": + if confidence >= 0.7: + verification_status = "supported" + elif confidence >= 0.4: + verification_status = "partially_supported" + else: + verification_status = "unsupported" + elif phase == "synthesize": + all_hypotheses = self._hypotheses.get(session_id, []) + if all_hypotheses: + synthesis = "Synthesis of hypotheses: " + "; ".join(all_hypotheses) + + return UltraThinkingResult( + success=True, + session_id=session_id, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + branches=branches, + thought_history_length=len(session_thoughts), + phase=phase, + hypotheses=list(self._hypotheses.get(session_id, [])), + verification_status=verification_status, + confidence=confidence, + synthesis=synthesis, + ) + + def record_learning( + self, + session_id: str, + strategy_used: str, + outcome_tags: list[str], + reward: float, + insights: list[str], + ) -> LearningLoopResult: + thought_count = len(self._sessions.get(session_id, [])) + entry = LearningEntry( + session_id=session_id, + timestamp=time.time(), + strategy_used=strategy_used, + outcome_tags=outcome_tags, + reward=reward, + insights=insights, + thought_count=thought_count, + ) + self._learnings.append(entry) + self._save_entry({"type": "learning", "data": entry.model_dump()}) + self._update_strategy_score(strategy_used, reward) + + return LearningLoopResult( + success=True, + session_id=session_id, + learnings_extracted=1, + insights=insights, + ) + + def get_strategy_recommendations(self, top_k: int = 5) -> list[StrategyScore]: + sorted_strategies = sorted( + self._strategy_scores.values(), + key=lambda s: s.avg_reward, + reverse=True, + ) + return sorted_strategies[:top_k] + + def apply_reward(self, session_id: str, reward: float) -> RewardResult: + matching = [entry for entry in self._learnings if entry.session_id == session_id] + if not matching: + return RewardResult( + success=False, + session_id=session_id, + message=f"No learnings found for session {session_id}", + ) + + latest = matching[-1] + latest.reward += reward + self._update_strategy_score(latest.strategy_used, reward) + self._save_entry({"type": "learning", "data": latest.model_dump()}) + + cumulative = sum(entry.reward for entry in matching) + + return RewardResult( + success=True, + session_id=session_id, + new_reward=reward, + cumulative_reward=cumulative, + ) + + def _update_strategy_score(self, strategy: str, reward: float) -> None: + if strategy not in self._strategy_scores: + self._strategy_scores[strategy] = StrategyScore(strategy=strategy) + + score = self._strategy_scores[strategy] + score.usage_count += 1 + score.total_reward += reward + score.avg_reward = score.total_reward / score.usage_count + score.last_used = time.time() + + self._save_strategy(score) + + # --- Evidence Tracker --- + + def add_evidence( + self, + session_id: str, + hypothesis_index: int, + text: str, + evidence_type: str = "data_point", + strength: float = 0.5, + effort_mode: str = "medium", + ) -> EvidenceTrackerResult: + """Add evidence to a hypothesis in an ultra_thinking session.""" + hypotheses = self._hypotheses.get(session_id) + if hypotheses is None: + return EvidenceTrackerResult( + success=False, + session_id=session_id, + effort_mode=effort_mode, + message=f"No hypotheses found for session {session_id}", + ) + if hypothesis_index < 0 or hypothesis_index >= len(hypotheses): + return EvidenceTrackerResult( + success=False, + session_id=session_id, + hypothesis_index=hypothesis_index, + effort_mode=effort_mode, + message=( + f"Hypothesis index {hypothesis_index} out of range" + f" (0..{len(hypotheses) - 1})" + ), + ) + # In low effort mode, skip type validation + if effort_mode != "low" and evidence_type not in VALID_EVIDENCE_TYPES: + return EvidenceTrackerResult( + success=False, + session_id=session_id, + hypothesis_index=hypothesis_index, + effort_mode=effort_mode, + message=( + f"Invalid evidence_type '{evidence_type}'." + f" Must be one of: {', '.join(sorted(VALID_EVIDENCE_TYPES))}" + ), + ) + + clamped_strength = max(0.0, min(1.0, strength)) + # Ultra mode: auto-boost strength for strongest evidence types + if effort_mode == "ultra" and evidence_type in ("code_ref", "test_result"): + clamped_strength = max(clamped_strength, 0.9) + item = EvidenceItem( + text=text, + evidence_type=evidence_type if effort_mode != "low" else "data_point", + strength=clamped_strength, + added_at=time.time(), + ) + + if session_id not in self._evidence: + self._evidence[session_id] = {} + if hypothesis_index not in self._evidence[session_id]: + self._evidence[session_id][hypothesis_index] = [] + + self._evidence[session_id][hypothesis_index].append(item) + evidence_list = self._evidence[session_id][hypothesis_index] + cumulative = sum(e.strength for e in evidence_list) / len(evidence_list) + + return EvidenceTrackerResult( + success=True, + session_id=session_id, + hypothesis_index=hypothesis_index, + hypothesis_text=hypotheses[hypothesis_index], + evidence=list(evidence_list), + total_evidence_count=len(evidence_list), + cumulative_strength=cumulative, + effort_mode=effort_mode, + ) + + def get_evidence( + self, + session_id: str, + hypothesis_index: int, + effort_mode: str = "medium", + ) -> EvidenceTrackerResult: + """List evidence for a hypothesis.""" + hypotheses = self._hypotheses.get(session_id) + if hypotheses is None: + return EvidenceTrackerResult( + success=False, + session_id=session_id, + effort_mode=effort_mode, + message=f"No hypotheses found for session {session_id}", + ) + if hypothesis_index < 0 or hypothesis_index >= len(hypotheses): + return EvidenceTrackerResult( + success=False, + session_id=session_id, + hypothesis_index=hypothesis_index, + effort_mode=effort_mode, + message=( + f"Hypothesis index {hypothesis_index} out of range" + f" (0..{len(hypotheses) - 1})" + ), + ) + + evidence_list = self._evidence.get(session_id, {}).get(hypothesis_index, []) + cumulative = ( + sum(e.strength for e in evidence_list) / len(evidence_list) + if evidence_list + else 0.0 + ) + + return EvidenceTrackerResult( + success=True, + session_id=session_id, + hypothesis_index=hypothesis_index, + hypothesis_text=hypotheses[hypothesis_index], + evidence=list(evidence_list), + total_evidence_count=len(evidence_list), + cumulative_strength=cumulative, + effort_mode=effort_mode, + ) + + # --- Premortem --- + + def process_premortem( + self, + session_id: str, + data: ThoughtData, + phase: str = "describe_plan", + plan: str | None = None, + failure_scenario: str | None = None, + risk_description: str | None = None, + likelihood: float = 0.5, + impact: float = 0.5, + mitigation: str | None = None, + risk_index: int | None = None, + effort_mode: str = "medium", + ) -> PremortemResult: + """Process a premortem thinking step.""" + if phase not in VALID_PREMORTEM_PHASES: + return PremortemResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"Invalid phase '{phase}'." + f" Must be one of: {', '.join(sorted(VALID_PREMORTEM_PHASES))}" + ), + ) + + # Track thoughts in the main session store + if session_id not in self._sessions: + self._sessions[session_id] = [] + self._sessions[session_id].append(data) + + # Initialize premortem session if needed + if session_id not in self._premortems: + self._premortems[session_id] = PremortemSession() + + pm = self._premortems[session_id] + + if phase == "describe_plan": + if plan is not None: + pm.plan = plan + return PremortemResult( + success=True, + session_id=session_id, + phase=phase, + plan_description=pm.plan, + risks=list(pm.risks), + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + if phase == "imagine_failure": + if failure_scenario is not None: + pm.failure_scenario = failure_scenario + return PremortemResult( + success=True, + session_id=session_id, + phase=phase, + plan_description=pm.plan, + failure_scenario=pm.failure_scenario, + risks=list(pm.risks), + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + if phase == "identify_causes": + if risk_description is None: + return PremortemResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="risk_description is required for identify_causes phase", + ) + clamped_likelihood = max(0.0, min(1.0, likelihood)) + clamped_impact = max(0.0, min(1.0, impact)) + risk = PremortemRisk( + description=risk_description, + likelihood=clamped_likelihood, + impact=clamped_impact, + risk_score=clamped_likelihood * clamped_impact, + ) + pm.risks.append(risk) + # Ultra mode: auto-rank risks at every phase + ranked = ( + sorted(pm.risks, key=lambda r: r.risk_score, reverse=True) + if effort_mode == "ultra" else [] + ) + return PremortemResult( + success=True, + session_id=session_id, + phase=phase, + plan_description=pm.plan, + failure_scenario=pm.failure_scenario, + risks=list(pm.risks), + ranked_risks=ranked if ranked else [], + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + if phase == "rank_risks": + ranked = sorted(pm.risks, key=lambda r: r.risk_score, reverse=True) + return PremortemResult( + success=True, + session_id=session_id, + phase=phase, + plan_description=pm.plan, + failure_scenario=pm.failure_scenario, + risks=list(pm.risks), + ranked_risks=ranked, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + # phase == "mitigate" + if risk_index is None: + return PremortemResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="risk_index is required for mitigate phase", + ) + if risk_index < 0 or risk_index >= len(pm.risks): + return PremortemResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"risk_index {risk_index} out of range" + f" (0..{len(pm.risks) - 1})" + ), + ) + if mitigation is not None: + pm.risks[risk_index].mitigation = mitigation + mitigations_count = sum(1 for r in pm.risks if r.mitigation) + # Ultra mode: warn if not all risks are mitigated + ultra_message = None + if effort_mode == "ultra" and mitigations_count < len(pm.risks): + unmitigated = len(pm.risks) - mitigations_count + ultra_message = ( + f"{unmitigated} risk(s) still lack mitigations." + " Ultra mode requires all risks to be mitigated." + ) + return PremortemResult( + success=True, + session_id=session_id, + phase=phase, + plan_description=pm.plan, + failure_scenario=pm.failure_scenario, + risks=list(pm.risks), + mitigations_count=mitigations_count, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + message=ultra_message, + ) + + # --- Inversion Thinking --- + + def process_inversion( + self, + session_id: str, + data: ThoughtData, + phase: str = "define_goal", + goal: str | None = None, + inverted_goal: str | None = None, + failure_cause: str | None = None, + severity: float = 0.5, + inverted_action: str | None = None, + cause_index: int | None = None, + action_item: str | None = None, + effort_mode: str = "medium", + ) -> InversionThinkingResult: + """Process an inversion thinking step.""" + if phase not in VALID_INVERSION_PHASES: + return InversionThinkingResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"Invalid phase '{phase}'." + f" Must be one of: {', '.join(sorted(VALID_INVERSION_PHASES))}" + ), + ) + + # Track thoughts + if session_id not in self._sessions: + self._sessions[session_id] = [] + self._sessions[session_id].append(data) + + # Initialize session + if session_id not in self._inversions: + self._inversions[session_id] = InversionSession() + + inv = self._inversions[session_id] + + if phase == "define_goal": + if goal is not None: + inv.goal = goal + return InversionThinkingResult( + success=True, + session_id=session_id, + phase=phase, + goal=inv.goal, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + if phase == "invert": + if inverted_goal is not None: + inv.inverted_goal = inverted_goal + elif inv.goal and not inv.inverted_goal: + # Auto-generate a basic inversion + inv.inverted_goal = f"How to guarantee failure at: {inv.goal}" + return InversionThinkingResult( + success=True, + session_id=session_id, + phase=phase, + goal=inv.goal, + inverted_goal=inv.inverted_goal, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + if phase == "list_failure_causes": + if failure_cause is None: + return InversionThinkingResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="failure_cause is required for list_failure_causes phase", + ) + clamped_severity = max(0.0, min(1.0, severity)) + cause = InversionCause( + description=failure_cause, + severity=clamped_severity, + ) + inv.failure_causes.append(cause) + return InversionThinkingResult( + success=True, + session_id=session_id, + phase=phase, + goal=inv.goal, + inverted_goal=inv.inverted_goal, + failure_causes=list(inv.failure_causes), + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + if phase == "rank_causes": + # Only available in medium/high effort + if effort_mode == "low": + return InversionThinkingResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="rank_causes phase is not available in low effort mode", + ) + ranked = sorted( + inv.failure_causes, key=lambda c: c.severity, reverse=True + ) + return InversionThinkingResult( + success=True, + session_id=session_id, + phase=phase, + goal=inv.goal, + inverted_goal=inv.inverted_goal, + failure_causes=list(inv.failure_causes), + ranked_causes=ranked, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + if phase == "reinvert": + if cause_index is None: + return InversionThinkingResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="cause_index is required for reinvert phase", + ) + if cause_index < 0 or cause_index >= len(inv.failure_causes): + return InversionThinkingResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"cause_index {cause_index} out of range" + f" (0..{len(inv.failure_causes) - 1})" + ), + ) + if inverted_action is not None: + inv.failure_causes[cause_index].inverted_action = inverted_action + return InversionThinkingResult( + success=True, + session_id=session_id, + phase=phase, + goal=inv.goal, + inverted_goal=inv.inverted_goal, + failure_causes=list(inv.failure_causes), + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + # phase == "action_plan" + if action_item is not None: + inv.action_plan.append(action_item) + # In high effort mode, auto-populate from reinverted causes if empty + if effort_mode == "high" and not inv.action_plan: + for cause in inv.failure_causes: + if cause.inverted_action: + inv.action_plan.append(cause.inverted_action) + # Ultra mode: auto-reinvert ALL causes that lack inverted_actions, + # then auto-populate action plan from ALL of them + if effort_mode == "ultra": + for cause in inv.failure_causes: + if not cause.inverted_action: + cause.inverted_action = ( + f"Prevent: {cause.description}" + ) + if not inv.action_plan: + for cause in inv.failure_causes: + if cause.inverted_action: + inv.action_plan.append(cause.inverted_action) + return InversionThinkingResult( + success=True, + session_id=session_id, + phase=phase, + goal=inv.goal, + inverted_goal=inv.inverted_goal, + failure_causes=list(inv.failure_causes), + action_plan=list(inv.action_plan), + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + ) + + # --- Effort Estimator --- + + @staticmethod + def _compute_pert( + optimistic: float, likely: float, pessimistic: float, + ) -> EstimateItem: + """Compute PERT estimate with confidence intervals.""" + pert = (optimistic + PERT_WEIGHT * likely + pessimistic) / 6.0 + std_dev = (pessimistic - optimistic) / 6.0 + return EstimateItem( + task="", + optimistic=optimistic, + likely=likely, + pessimistic=pessimistic, + pert_estimate=pert, + std_dev=std_dev, + confidence_68_low=pert - std_dev, + confidence_68_high=pert + std_dev, + confidence_95_low=pert - 2 * std_dev, + confidence_95_high=pert + 2 * std_dev, + confidence_99_low=pert - 3 * std_dev, + confidence_99_high=pert + 3 * std_dev, + risk_buffer=pessimistic * 1.5, + ) + + def process_estimate( + self, + session_id: str, + action: str = "add", + task: str | None = None, + optimistic: float = 0.0, + likely: float = 0.0, + pessimistic: float = 0.0, + effort_mode: str = "medium", + ) -> EffortEstimatorResult: + """Process an effort estimation action.""" + if session_id not in self._estimators: + self._estimators[session_id] = EstimatorSession() + + est = self._estimators[session_id] + + if action == "add": + if task is None: + return EffortEstimatorResult( + success=False, + session_id=session_id, + action=action, + effort_mode=effort_mode, + message="task name is required when action is 'add'", + ) + if pessimistic < optimistic: + return EffortEstimatorResult( + success=False, + session_id=session_id, + action=action, + effort_mode=effort_mode, + message="pessimistic must be >= optimistic", + ) + if effort_mode == "low": + # Low effort: use likely as single-point, skip PERT + item = EstimateItem( + task=task, + optimistic=likely, + likely=likely, + pessimistic=likely, + pert_estimate=likely, + ) + else: + item = self._compute_pert(optimistic, likely, pessimistic) + item.task = task + est.estimates.append(item) + + elif action == "summary": + pass # Just return current state + elif action == "clear": + est.estimates.clear() + return EffortEstimatorResult( + success=True, + session_id=session_id, + action=action, + effort_mode=effort_mode, + message="Estimates cleared", + ) + else: + return EffortEstimatorResult( + success=False, + session_id=session_id, + action=action, + effort_mode=effort_mode, + message=f"Invalid action '{action}'. Must be 'add', 'summary', or 'clear'.", + ) + + # Compute totals + total_pert = sum(e.pert_estimate for e in est.estimates) + total_std_dev = ( + sum(e.std_dev**2 for e in est.estimates) ** 0.5 + if effort_mode != "low" + else 0.0 + ) + + is_advanced = effort_mode in ("high", "ultra") + return EffortEstimatorResult( + success=True, + session_id=session_id, + action=action, + estimates=list(est.estimates), + total_pert=total_pert, + total_std_dev=total_std_dev, + total_confidence_68_low=( + total_pert - total_std_dev + if effort_mode != "low" else 0.0 + ), + total_confidence_68_high=( + total_pert + total_std_dev + if effort_mode != "low" else 0.0 + ), + total_confidence_95_low=( + total_pert - 2 * total_std_dev + if is_advanced else 0.0 + ), + total_confidence_95_high=( + total_pert + 2 * total_std_dev + if is_advanced else 0.0 + ), + total_confidence_99_low=( + total_pert - 3 * total_std_dev + if effort_mode == "ultra" else 0.0 + ), + total_confidence_99_high=( + total_pert + 3 * total_std_dev + if effort_mode == "ultra" else 0.0 + ), + total_risk_buffer=( + sum(e.risk_buffer for e in est.estimates) + if effort_mode == "ultra" else 0.0 + ), + effort_mode=effort_mode, + ) + + # --- Plan Optimizer --- + + @staticmethod + def _detect_anti_patterns(plan_text: str) -> list[PlanAntiPattern]: + """Detect anti-patterns in a plan using regex heuristics.""" + + results: list[PlanAntiPattern] = [] + plan_lower = plan_text.lower() + lines = plan_text.splitlines() + + # 1. Vague language detection + for pattern in _VAGUE_PATTERNS: + for m in re.finditer(pattern, plan_lower): + snippet = plan_lower[ + max(0, m.start() - 20):m.end() + 20 + ].strip() + results.append(PlanAntiPattern( + pattern_type="vague_language", + description=f"Vague language detected: " + f"'{m.group()}' in '...{snippet}...'", + severity="medium", + location=f"char {m.start()}", + )) + + # 2. Missing concern checks + for concern, keywords in _MISSING_CONCERN_CHECKS.items(): + found = any(kw in plan_lower for kw in keywords) + if not found: + sev = "high" if concern in ( + "testing", "error_handling", + ) else "medium" + results.append(PlanAntiPattern( + pattern_type=f"missing_{concern}", + description=( + f"Plan does not mention {concern}." + f" Consider adding a step for:" + f" {', '.join(keywords)}" + ), + severity=sev, + )) + + # 3. God-step detection (any single line > 500 chars) + for i, line in enumerate(lines): + if len(line.strip()) > 500: + results.append(PlanAntiPattern( + pattern_type="god_step", + description=( + f"Step at line {i + 1} is very long" + f" ({len(line.strip())} chars)." + " Consider breaking into smaller steps." + ), + severity="high", + location=f"line {i + 1}", + )) + + # 4. No structure (no numbered steps, bullets, or headers) + has_structure = bool(re.search( + r"^\s*(?:\d+[.)\-]|[-*•]|#{1,3}\s)", + plan_text, + re.MULTILINE, + )) + if not has_structure and len(lines) > 3: + results.append(PlanAntiPattern( + pattern_type="no_structure", + description=( + "Plan lacks numbered steps, bullet points," + " or section headers. Add structure." + ), + severity="medium", + )) + + # 5. TODO/TBD markers + for m in re.finditer( + r"\b(TODO|TBD|FIXME|HACK|XXX)\b", plan_text, + ): + results.append(PlanAntiPattern( + pattern_type="todo_marker", + description=( + f"Unresolved marker: '{m.group()}'" + ), + severity="high", + location=f"char {m.start()}", + )) + + return results + + @staticmethod + def _compute_plan_health( + analysis_scores: dict[str, float], + anti_pattern_count: int, + ) -> float: + """Compute plan health score 0-100.""" + if not analysis_scores: + return 0.0 + # Base: average of dimension scores scaled to 100 + avg = sum(analysis_scores.values()) / len(analysis_scores) + base = (avg / 10.0) * 100.0 + # Penalty: -5 per anti-pattern, floor at 0 + penalty = anti_pattern_count * 5 + return max(0.0, round(base - penalty, 1)) + + @staticmethod + def _build_comparison_matrix( + variants: list[PlanVariant], + ) -> dict[str, dict[str, float]]: + """Build comparison matrix: dimension -> {label: score}.""" + matrix: dict[str, dict[str, float]] = {} + for dim in PLAN_DIMENSIONS: + matrix[dim] = {} + for var in variants: + matrix[dim][var.label] = var.scores.get(dim, 0.0) + # Add totals row + matrix["TOTAL"] = { + var.label: var.total for var in variants + } + return matrix + + def process_plan_optimizer( + self, + session_id: str, + data: ThoughtData, + phase: str = "submit_plan", + plan_text: str | None = None, + plan_context: str | None = None, + dimension: str | None = None, + score: float = 0.0, + issue: str | None = None, + variant_label: str | None = None, + variant_name: str | None = None, + variant_summary: str | None = None, + variant_approach: str | None = None, + variant_pros: list[str] | None = None, + variant_cons: list[str] | None = None, + variant_risk_level: str = "medium", + variant_complexity: str = "medium", + recommendation: str | None = None, + winner_label: str | None = None, + effort_mode: str = "medium", + ) -> PlanOptimizerResult: + """Process a plan_optimizer phase.""" + if phase not in VALID_PLAN_OPTIMIZER_PHASES: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"Invalid phase '{phase}'. Must be one of: " + f"{', '.join(sorted(VALID_PLAN_OPTIMIZER_PHASES))}" + ), + ) + + # Track thoughts + if session_id not in self._sessions: + self._sessions[session_id] = [] + self._sessions[session_id].append(data) + + # Init session + if session_id not in self._plan_optimizers: + self._plan_optimizers[session_id] = ( + PlanOptimizerSession() + ) + po = self._plan_optimizers[session_id] + + def _result(**kwargs: object) -> PlanOptimizerResult: + """Build result with common fields.""" + return PlanOptimizerResult( + success=True, + session_id=session_id, + phase=phase, + plan_text=po.plan_text, + plan_context=po.plan_context, + analysis_scores=dict(po.analysis_scores), + analysis_issues=list(po.analysis_issues), + anti_patterns=list(po.anti_patterns), + anti_pattern_count=len(po.anti_patterns), + plan_health_score=self._compute_plan_health( + po.analysis_scores, + len(po.anti_patterns), + ), + variants=list(po.variants), + comparison_matrix=( + self._build_comparison_matrix(po.variants) + if po.variants else {} + ), + recommendation=po.recommendation, + winner_label=po.winner_label, + thought_number=data.thought_number, + total_thoughts=data.total_thoughts, + next_thought_needed=data.next_thought_needed, + effort_mode=effort_mode, + **kwargs, + ) + + # --- Phase: submit_plan --- + if phase == "submit_plan": + if not plan_text: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="plan_text is required for " + "submit_plan phase", + ) + po.plan_text = plan_text + if plan_context: + po.plan_context = plan_context + # Auto-detect anti-patterns on submit + po.anti_patterns = self._detect_anti_patterns( + plan_text, + ) + return _result() + + # --- Phase: analyze --- + if phase == "analyze": + if dimension is not None: + dim = dimension.lower() + if dim not in PLAN_DIMENSIONS: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"Invalid dimension '{dimension}'." + f" Must be one of: " + f"{', '.join(PLAN_DIMENSIONS)}" + ), + ) + clamped = max(0.0, min(10.0, score)) + po.analysis_scores[dim] = clamped + if issue: + po.analysis_issues.append(issue) + return _result() + + # --- Phase: detect_anti_patterns --- + if phase == "detect_anti_patterns": + # Re-run detection (useful after plan edits) + po.anti_patterns = self._detect_anti_patterns( + po.plan_text, + ) + return _result() + + # --- Phase: add_variant --- + if phase == "add_variant": + if not variant_label: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="variant_label is required " + "(e.g. 'A', 'B', 'C')", + ) + if not variant_name: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="variant_name is required", + ) + # Check duplicate label + existing = [ + v for v in po.variants + if v.label == variant_label + ] + if existing: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"Variant '{variant_label}' already exists." + " Use score_variant to update scores." + ), + ) + variant = PlanVariant( + label=variant_label, + name=variant_name or "", + summary=variant_summary or "", + approach=variant_approach or "", + pros=variant_pros or [], + cons=variant_cons or [], + risk_level=variant_risk_level, + complexity=variant_complexity, + ) + po.variants.append(variant) + return _result() + + # --- Phase: score_variant --- + if phase == "score_variant": + if not variant_label: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message="variant_label is required", + ) + target = None + for v in po.variants: + if v.label == variant_label: + target = v + break + if target is None: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"Variant '{variant_label}' not found." + " Call add_variant first." + ), + ) + if dimension is not None: + dim = dimension.lower() + if dim not in PLAN_DIMENSIONS: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + f"Invalid dimension '{dimension}'." + f" Must be one of: " + f"{', '.join(PLAN_DIMENSIONS)}" + ), + ) + clamped = max(0.0, min(10.0, score)) + target.scores[dim] = clamped + target.total = sum(target.scores.values()) + return _result() + + # --- Phase: recommend --- + # phase == "recommend" + # Ultra mode: block recommend if no variants added + if effort_mode == "ultra" and not po.variants: + return PlanOptimizerResult( + success=False, + session_id=session_id, + phase=phase, + effort_mode=effort_mode, + message=( + "Ultra mode requires at least one variant" + " before recommending." + " Use add_variant first." + ), + ) + # Ultra mode: auto-score unscored dimensions as 0 + if effort_mode == "ultra": + for dim in PLAN_DIMENSIONS: + if dim not in po.analysis_scores: + po.analysis_scores[dim] = 0.0 + for var in po.variants: + for dim in PLAN_DIMENSIONS: + if dim not in var.scores: + var.scores[dim] = 0.0 + var.total = sum(var.scores.values()) + if recommendation: + po.recommendation = recommendation + if winner_label: + po.winner_label = winner_label + # Auto-pick winner by highest total if not specified + if not po.winner_label and po.variants: + best = max(po.variants, key=lambda v: v.total) + po.winner_label = best.label + return _result() + + + diff --git a/src/cocoindex_code/thinking_models.py b/src/cocoindex_code/thinking_models.py new file mode 100644 index 0000000..f42a305 --- /dev/null +++ b/src/cocoindex_code/thinking_models.py @@ -0,0 +1,412 @@ +"""Pydantic models and constants for the thinking tools subsystem.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + +# --- Configuration constants --- + +THINKING_MEMORY_FILE = "thinking_memory.jsonl" +MAX_THOUGHTS_PER_SESSION = 200 +MAX_SESSIONS_STORED = 500 +MAX_STRATEGIES = 100 +PERT_WEIGHT = 4.0 # Standard PERT weighting for "most likely" + + +# --- Shared constants --- + +VALID_EFFORT_MODES: frozenset[str] = frozenset({"low", "medium", "high", "ultra"}) + +VALID_EVIDENCE_TYPES: frozenset[str] = frozenset( + {"code_ref", "data_point", "external", "assumption", "test_result"} +) + +VALID_PREMORTEM_PHASES: frozenset[str] = frozenset( + {"describe_plan", "imagine_failure", "identify_causes", "rank_risks", "mitigate"} +) + +VALID_INVERSION_PHASES: frozenset[str] = frozenset( + {"define_goal", "invert", "list_failure_causes", "rank_causes", "reinvert", "action_plan"} +) + +VALID_PLAN_OPTIMIZER_PHASES: frozenset[str] = frozenset( + { + "submit_plan", "analyze", "detect_anti_patterns", + "add_variant", "score_variant", "recommend", + } +) + +PLAN_DIMENSIONS: tuple[str, ...] = ( + "clarity", "completeness", "correctness", "risk", + "simplicity", "testability", "edge_cases", "actionability", +) + + +# --- Anti-pattern detection patterns --- + +_VAGUE_PATTERNS: list[str] = [ + r"\bmake it work\b", + r"\bfix it\b", + r"\bclean up\b", + r"\bimprove\b(?!ment)", + r"\bjust do\b", + r"\bsomehow\b", + r"\betc\.?\b", + r"\bstuff\b", + r"\bthings\b", + r"\bhandle it\b", + r"\bfigure out\b", + r"\bwhatever\b", +] + +_MISSING_CONCERN_CHECKS: dict[str, list[str]] = { + "testing": ["test", "verify", "assert", "validate", "spec"], + "error_handling": ["error", "exception", "fail", "catch", "handle"], + "edge_cases": ["edge case", "corner case", "empty", "null", "none", "zero", "boundary"], + "security": ["auth", "permission", "sanitize", "escape", "inject"], + "performance": ["performance", "scale", "cache", "optimize", "latency", "throughput"], +} + + +# --- Core thought model --- + + +class ThoughtData(BaseModel): + thought: str + thought_number: int + total_thoughts: int + next_thought_needed: bool + is_revision: bool = False + revises_thought: int | None = None + branch_from_thought: int | None = None + branch_id: str | None = None + needs_more_thoughts: bool = False + + +# --- Result models --- + + +class ThinkingResult(BaseModel): + success: bool + session_id: str = "" + thought_number: int = 0 + total_thoughts: int = 0 + next_thought_needed: bool = True + branches: list[str] = Field(default_factory=list) + thought_history_length: int = 0 + message: str | None = None + + +class ExtendedThinkingResult(BaseModel): + success: bool + session_id: str = "" + thought_number: int = 0 + total_thoughts: int = 0 + next_thought_needed: bool = True + branches: list[str] = Field(default_factory=list) + thought_history_length: int = 0 + message: str | None = None + depth_level: str = "standard" + checkpoint_summary: str = "" + steps_since_checkpoint: int = 0 + checkpoint_interval: int = 0 + + +class UltraThinkingResult(BaseModel): + success: bool + session_id: str = "" + thought_number: int = 0 + total_thoughts: int = 0 + next_thought_needed: bool = True + branches: list[str] = Field(default_factory=list) + thought_history_length: int = 0 + message: str | None = None + depth_level: str = "standard" + checkpoint_summary: str = "" + steps_since_checkpoint: int = 0 + checkpoint_interval: int = 0 + phase: str = "" + hypotheses: list[str] = Field(default_factory=list) + verification_status: str = "" + confidence: float = 0.0 + synthesis: str = "" + + +class LearningEntry(BaseModel): + session_id: str + timestamp: float + strategy_used: str + outcome_tags: list[str] = Field(default_factory=list) + reward: float = 0.0 + insights: list[str] = Field(default_factory=list) + thought_count: int = 0 + + +class LearningLoopResult(BaseModel): + success: bool + session_id: str = "" + learnings_extracted: int = 0 + insights: list[str] = Field(default_factory=list) + message: str | None = None + + +class StrategyScore(BaseModel): + strategy: str + total_reward: float = 0.0 + usage_count: int = 0 + avg_reward: float = 0.0 + last_used: float = 0.0 + + +class SelfImproveResult(BaseModel): + success: bool + recommended_strategies: list[StrategyScore] = Field(default_factory=list) + total_learnings: int = 0 + message: str | None = None + + +class RewardResult(BaseModel): + success: bool + session_id: str = "" + new_reward: float = 0.0 + cumulative_reward: float = 0.0 + message: str | None = None + + +# --- Evidence Tracker models --- + + +class EvidenceItem(BaseModel): + """A single piece of evidence attached to a hypothesis.""" + + text: str + evidence_type: str = "data_point" + strength: float = 0.5 + added_at: float = 0.0 + + +class EvidenceTrackerResult(BaseModel): + """Result from the evidence_tracker tool.""" + + success: bool + session_id: str = "" + hypothesis_index: int = 0 + hypothesis_text: str = "" + evidence: list[EvidenceItem] = Field(default_factory=list) + total_evidence_count: int = 0 + cumulative_strength: float = 0.0 + effort_mode: str = "medium" + message: str | None = None + + +# --- Premortem models --- + + +class PremortemRisk(BaseModel): + """A single risk identified during a premortem session.""" + + description: str + likelihood: float = 0.5 + impact: float = 0.5 + risk_score: float = 0.25 + mitigation: str = "" + category: str = "" + + +class PremortemSession(BaseModel): + """Internal state for a premortem session.""" + + plan: str = "" + failure_scenario: str = "" + risks: list[PremortemRisk] = Field(default_factory=list) + + +class PremortemResult(BaseModel): + """Result from the premortem tool.""" + + success: bool + session_id: str = "" + phase: str = "" + plan_description: str = "" + failure_scenario: str = "" + risks: list[PremortemRisk] = Field(default_factory=list) + ranked_risks: list[PremortemRisk] = Field(default_factory=list) + mitigations_count: int = 0 + thought_number: int = 0 + total_thoughts: int = 0 + next_thought_needed: bool = True + effort_mode: str = "medium" + message: str | None = None + + +# --- Inversion Thinking models --- + + +class InversionCause(BaseModel): + """A cause of failure identified via inversion.""" + + description: str + severity: float = 0.5 + inverted_action: str = "" + + +class InversionSession(BaseModel): + """Internal state for an inversion thinking session.""" + + goal: str = "" + inverted_goal: str = "" + failure_causes: list[InversionCause] = Field(default_factory=list) + action_plan: list[str] = Field(default_factory=list) + + +class InversionThinkingResult(BaseModel): + """Result from the inversion_thinking tool.""" + + success: bool + session_id: str = "" + phase: str = "" + goal: str = "" + inverted_goal: str = "" + failure_causes: list[InversionCause] = Field(default_factory=list) + ranked_causes: list[InversionCause] = Field(default_factory=list) + action_plan: list[str] = Field(default_factory=list) + thought_number: int = 0 + total_thoughts: int = 0 + next_thought_needed: bool = True + effort_mode: str = "medium" + message: str | None = None + + +# --- Effort Estimator models --- + + +class EstimateItem(BaseModel): + """A single task estimate.""" + + task: str + optimistic: float + likely: float + pessimistic: float + pert_estimate: float = 0.0 + std_dev: float = 0.0 + confidence_68_low: float = 0.0 + confidence_68_high: float = 0.0 + confidence_95_low: float = 0.0 + confidence_95_high: float = 0.0 + confidence_99_low: float = 0.0 + confidence_99_high: float = 0.0 + risk_buffer: float = 0.0 + + +class EstimatorSession(BaseModel): + """Internal state for an effort estimator session.""" + + estimates: list[EstimateItem] = Field(default_factory=list) + + +class EffortEstimatorResult(BaseModel): + """Result from the effort_estimator tool.""" + + success: bool + session_id: str = "" + action: str = "" + estimates: list[EstimateItem] = Field(default_factory=list) + total_pert: float = 0.0 + total_std_dev: float = 0.0 + total_confidence_68_low: float = 0.0 + total_confidence_68_high: float = 0.0 + total_confidence_95_low: float = 0.0 + total_confidence_95_high: float = 0.0 + total_confidence_99_low: float = 0.0 + total_confidence_99_high: float = 0.0 + total_risk_buffer: float = 0.0 + effort_mode: str = "medium" + message: str | None = None + + +# --- Plan Optimizer models --- + + +class PlanAntiPattern(BaseModel): + """An anti-pattern detected in a plan.""" + + pattern_type: str = Field( + description="Type: vague_language, missing_testing, " + "missing_error_handling, missing_edge_cases, god_step, " + "no_structure, todo_marker, missing_security, " + "missing_performance" + ) + description: str = Field(description="What was detected") + severity: str = Field( + default="medium", + description="Severity: low, medium, high", + ) + location: str = Field( + default="", + description="Where in the plan this was found", + ) + + +class PlanVariant(BaseModel): + """A plan variant with scores.""" + + label: str = Field(description="Variant label: A, B, or C") + name: str = Field( + description="Variant name, e.g. 'Minimal & Pragmatic'", + ) + summary: str = Field(description="Brief approach summary") + approach: str = Field( + default="", description="Full variant approach text", + ) + pros: list[str] = Field(default_factory=list) + cons: list[str] = Field(default_factory=list) + risk_level: str = Field(default="medium") + complexity: str = Field(default="medium") + scores: dict[str, float] = Field( + default_factory=dict, + description="Dimension scores (0.0-10.0)", + ) + total: float = Field(default=0.0, description="Sum of all scores") + + +class PlanOptimizerSession(BaseModel): + """Internal state for a plan_optimizer session.""" + + plan_text: str = "" + plan_context: str = "" + analysis_scores: dict[str, float] = Field(default_factory=dict) + analysis_issues: list[str] = Field(default_factory=list) + anti_patterns: list[PlanAntiPattern] = Field(default_factory=list) + variants: list[PlanVariant] = Field(default_factory=list) + recommendation: str = "" + winner_label: str = "" + + +class PlanOptimizerResult(BaseModel): + """Result from the plan_optimizer tool.""" + + success: bool + session_id: str = "" + phase: str = "" + plan_text: str = "" + plan_context: str = "" + analysis_scores: dict[str, float] = Field(default_factory=dict) + analysis_issues: list[str] = Field(default_factory=list) + anti_patterns: list[PlanAntiPattern] = Field(default_factory=list) + anti_pattern_count: int = 0 + plan_health_score: float = Field( + default=0.0, + description="Overall plan health 0-100 based on analysis", + ) + variants: list[PlanVariant] = Field(default_factory=list) + comparison_matrix: dict[str, dict[str, float]] = Field( + default_factory=dict, + description="Dimension -> {variant_label: score}", + ) + recommendation: str = "" + winner_label: str = "" + thought_number: int = 0 + total_thoughts: int = 0 + next_thought_needed: bool = True + effort_mode: str = "medium" + message: str | None = None diff --git a/src/cocoindex_code/thinking_tools.py b/src/cocoindex_code/thinking_tools.py new file mode 100644 index 0000000..0476c46 --- /dev/null +++ b/src/cocoindex_code/thinking_tools.py @@ -0,0 +1,944 @@ +"""MCP tool registration for the thinking tools subsystem. + +This module registers all thinking-related MCP tools (sequential_thinking, +extended_thinking, ultra_thinking, evidence_tracker, premortem, +inversion_thinking, effort_estimator, learning_loop, self_improve, +reward_thinking, plan_optimizer) on a FastMCP server instance. + +Models are defined in thinking_models.py and the ThinkingEngine +lives in thinking_engine.py. +""" + +from __future__ import annotations + +import uuid + +from mcp.server.fastmcp import FastMCP +from pydantic import Field + +from .config import config +from .thinking_engine import ThinkingEngine + +# Re-export all public symbols so existing imports like +# from cocoindex_code.thinking_tools import ThinkingEngine, ThoughtData +# continue to work without changes. +from .thinking_models import ( # noqa: F401 + PLAN_DIMENSIONS, + THINKING_MEMORY_FILE, + VALID_EFFORT_MODES, + VALID_EVIDENCE_TYPES, + VALID_INVERSION_PHASES, + VALID_PLAN_OPTIMIZER_PHASES, + VALID_PREMORTEM_PHASES, + EffortEstimatorResult, + EstimateItem, + EstimatorSession, + EvidenceItem, + EvidenceTrackerResult, + ExtendedThinkingResult, + InversionCause, + InversionSession, + InversionThinkingResult, + LearningEntry, + LearningLoopResult, + PlanAntiPattern, + PlanOptimizerResult, + PlanOptimizerSession, + PlanVariant, + PremortemResult, + PremortemRisk, + PremortemSession, + RewardResult, + SelfImproveResult, + StrategyScore, + ThinkingResult, + ThoughtData, + UltraThinkingResult, +) + +_engine: ThinkingEngine | None = None + + +def _get_engine() -> ThinkingEngine: + global _engine + if _engine is None: + _engine = ThinkingEngine(config.index_dir) + return _engine + + +def register_thinking_tools(mcp: FastMCP) -> None: + @mcp.tool( + name="sequential_thinking", + description=( + "Step-by-step problem solving with branching and revision support." + " Each thought builds on previous ones, with ability to revise earlier" + " thoughts, branch into alternative reasoning paths, and dynamically" + " adjust the total number of thoughts as understanding deepens." + ), + ) + async def sequential_thinking( + thought: str = Field( + description="The current thinking step content.", + ), + next_thought_needed: bool = Field( + description="Whether another thought step is needed.", + ), + thought_number: int = Field( + ge=1, + description="Current thought number in the sequence.", + ), + total_thoughts: int = Field( + ge=1, + description="Estimated total thoughts needed (can be adjusted).", + ), + session_id: str | None = Field( + default=None, + description="Session identifier. Auto-generated if not provided.", + ), + is_revision: bool = Field( + default=False, + description="Whether this thought revises a previous one.", + ), + revises_thought: int | None = Field( + default=None, + description="Which thought number is being revised.", + ), + branch_from_thought: int | None = Field( + default=None, + description="Thought number to branch from.", + ), + branch_id: str | None = Field( + default=None, + description="Identifier for the current branch.", + ), + needs_more_thoughts: bool = Field( + default=False, + description="Signal that more thoughts are needed beyond the current total.", + ), + ) -> ThinkingResult: + try: + engine = _get_engine() + sid = session_id or str(uuid.uuid4()) + data = ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + is_revision=is_revision, + revises_thought=revises_thought, + branch_from_thought=branch_from_thought, + branch_id=branch_id, + needs_more_thoughts=needs_more_thoughts, + ) + return engine.process_thought(sid, data) + except Exception as e: + return ThinkingResult(success=False, message=f"Thinking failed: {e!s}") + + @mcp.tool( + name="extended_thinking", + description=( + "Deeper analysis with automatic checkpoints." + " Extends sequential thinking with configurable depth levels" + " (standard, deep, exhaustive) and periodic checkpoint summaries" + " to maintain coherence over long reasoning chains." + ), + ) + async def extended_thinking( + thought: str = Field( + description="The current thinking step content.", + ), + next_thought_needed: bool = Field( + description="Whether another thought step is needed.", + ), + thought_number: int = Field( + ge=1, + description="Current thought number in the sequence.", + ), + total_thoughts: int = Field( + ge=1, + description="Estimated total thoughts needed (can be adjusted).", + ), + session_id: str | None = Field( + default=None, + description="Session identifier. Auto-generated if not provided.", + ), + is_revision: bool = Field( + default=False, + description="Whether this thought revises a previous one.", + ), + revises_thought: int | None = Field( + default=None, + description="Which thought number is being revised.", + ), + branch_from_thought: int | None = Field( + default=None, + description="Thought number to branch from.", + ), + branch_id: str | None = Field( + default=None, + description="Identifier for the current branch.", + ), + needs_more_thoughts: bool = Field( + default=False, + description="Signal that more thoughts are needed beyond the current total.", + ), + depth_level: str = Field( + default="deep", + description="Depth of analysis: 'standard', 'deep', or 'exhaustive'.", + ), + checkpoint_interval: int = Field( + default=5, + ge=1, + le=50, + description="Number of steps between automatic checkpoints.", + ), + ) -> ExtendedThinkingResult: + try: + engine = _get_engine() + sid = session_id or str(uuid.uuid4()) + data = ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + is_revision=is_revision, + revises_thought=revises_thought, + branch_from_thought=branch_from_thought, + branch_id=branch_id, + needs_more_thoughts=needs_more_thoughts, + ) + return engine.process_extended_thought(sid, data, depth_level, checkpoint_interval) + except Exception as e: + return ExtendedThinkingResult(success=False, message=f"Extended thinking failed: {e!s}") + + @mcp.tool( + name="ultra_thinking", + description=( + "Maximum-depth reasoning with hypothesis generation, verification," + " and synthesis. Supports phased thinking through explore, hypothesize," + " verify, synthesize, and refine stages for complex problem solving." + ), + ) + async def ultra_thinking( + thought: str = Field( + description="The current thinking step content.", + ), + next_thought_needed: bool = Field( + description="Whether another thought step is needed.", + ), + thought_number: int = Field( + ge=1, + description="Current thought number in the sequence.", + ), + total_thoughts: int = Field( + ge=1, + description="Estimated total thoughts needed (can be adjusted).", + ), + session_id: str | None = Field( + default=None, + description="Session identifier. Auto-generated if not provided.", + ), + is_revision: bool = Field( + default=False, + description="Whether this thought revises a previous one.", + ), + revises_thought: int | None = Field( + default=None, + description="Which thought number is being revised.", + ), + branch_from_thought: int | None = Field( + default=None, + description="Thought number to branch from.", + ), + branch_id: str | None = Field( + default=None, + description="Identifier for the current branch.", + ), + needs_more_thoughts: bool = Field( + default=False, + description="Signal that more thoughts are needed beyond the current total.", + ), + phase: str = Field( + default="explore", + description=( + "Thinking phase: 'explore', 'hypothesize', 'verify', 'synthesize', or 'refine'." + ), + ), + hypothesis: str | None = Field( + default=None, + description="A hypothesis to register during the 'hypothesize' phase.", + ), + confidence: float = Field( + default=0.0, + ge=0, + le=1, + description="Confidence level for verification (0.0 to 1.0).", + ), + ) -> UltraThinkingResult: + try: + engine = _get_engine() + sid = session_id or str(uuid.uuid4()) + data = ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + is_revision=is_revision, + revises_thought=revises_thought, + branch_from_thought=branch_from_thought, + branch_id=branch_id, + needs_more_thoughts=needs_more_thoughts, + ) + return engine.process_ultra_thought(sid, data, phase, hypothesis, confidence) + except Exception as e: + return UltraThinkingResult(success=False, message=f"Ultra thinking failed: {e!s}") + + @mcp.tool( + name="learning_loop", + description=( + "Reflect on a thinking session and extract learnings." + " Records the strategy used, outcome tags, reward signal," + " and insights for future self-improvement." + ), + ) + async def learning_loop( + session_id: str = Field( + description="The session to record learnings for.", + ), + strategy_used: str = Field( + description="Name of the thinking strategy that was used.", + ), + outcome_tags: list[str] = Field( + description="Tags describing the outcome (e.g., 'success', 'partial', 'failed').", + ), + reward: float = Field( + ge=-1, + le=1, + description="Reward signal from -1.0 (worst) to 1.0 (best).", + ), + insights: list[str] = Field( + description="Key insights extracted from the thinking session.", + ), + ) -> LearningLoopResult: + try: + engine = _get_engine() + return engine.record_learning(session_id, strategy_used, outcome_tags, reward, insights) + except Exception as e: + return LearningLoopResult(success=False, message=f"Learning loop failed: {e!s}") + + @mcp.tool( + name="self_improve", + description=( + "Get recommended thinking strategies based on past performance." + " Analyzes historical learning entries and returns the top strategies" + " ranked by average reward." + ), + ) + async def self_improve( + top_k: int = Field( + default=5, + ge=1, + le=20, + description="Number of top strategies to return.", + ), + ) -> SelfImproveResult: + try: + engine = _get_engine() + recommendations = engine.get_strategy_recommendations(top_k) + return SelfImproveResult( + success=True, + recommended_strategies=recommendations, + total_learnings=len(engine._learnings), + ) + except Exception as e: + return SelfImproveResult(success=False, message=f"Self improve failed: {e!s}") + + @mcp.tool( + name="reward_thinking", + description=( + "Provide a reinforcement signal for a thinking session." + " Applies an additional reward to the most recent learning" + " entry for the given session, updating strategy scores." + ), + ) + async def reward_thinking( + session_id: str = Field( + description="The session to apply the reward to.", + ), + reward: float = Field( + ge=-1, + le=1, + description="Reward signal from -1.0 (worst) to 1.0 (best).", + ), + ) -> RewardResult: + try: + engine = _get_engine() + return engine.apply_reward(session_id, reward) + except Exception as e: + return RewardResult(success=False, message=f"Reward failed: {e!s}") + + @mcp.tool( + name="evidence_tracker", + description=( + "Attach typed, weighted evidence to ultra_thinking hypotheses." + " Supports 'add' to attach new evidence and 'list' to query existing" + " evidence. Evidence types: code_ref, data_point, external," + " assumption, test_result. Returns cumulative strength score." + " Use effort_mode to control depth: low (skip type validation)," + " medium (standard), high (full validation)," + " ultra (full validation + auto-boost strength for code_ref/test_result)." + ), + ) + async def evidence_tracker( + session_id: str = Field( + description="The ultra_thinking session containing hypotheses.", + ), + hypothesis_index: int = Field( + ge=0, + description="Zero-based index of the hypothesis to attach evidence to.", + ), + action: str = Field( + default="add", + description="Action to perform: 'add' to attach evidence, 'list' to query.", + ), + evidence: str | None = Field( + default=None, + description="The evidence text. Required when action is 'add'.", + ), + evidence_type: str = Field( + default="data_point", + description=( + "Type of evidence: 'code_ref', 'data_point', 'external'," + " 'assumption', or 'test_result'." + ), + ), + strength: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Strength of this evidence (0.0 to 1.0).", + ), + effort_mode: str = Field( + default="ultra", + description="Effort level: 'low', 'medium', 'high', or 'ultra'.", + ), + ) -> EvidenceTrackerResult: + try: + if effort_mode not in VALID_EFFORT_MODES: + return EvidenceTrackerResult( + success=False, + session_id=session_id, + effort_mode=effort_mode, + message=( + f"Invalid effort_mode '{effort_mode}'." + f" Must be one of: {', '.join(sorted(VALID_EFFORT_MODES))}" + ), + ) + engine = _get_engine() + if action == "list": + return engine.get_evidence( + session_id, hypothesis_index, effort_mode=effort_mode, + ) + if action == "add": + if evidence is None: + return EvidenceTrackerResult( + success=False, + session_id=session_id, + effort_mode=effort_mode, + message="evidence text is required when action is 'add'", + ) + return engine.add_evidence( + session_id, hypothesis_index, evidence, + evidence_type, strength, effort_mode=effort_mode, + ) + return EvidenceTrackerResult( + success=False, + session_id=session_id, + effort_mode=effort_mode, + message=f"Invalid action '{action}'. Must be 'add' or 'list'.", + ) + except Exception as e: + return EvidenceTrackerResult( + success=False, message=f"Evidence tracker failed: {e!s}" + ) + + @mcp.tool( + name="premortem", + description=( + "Structured pre-failure risk analysis." + " Imagine a plan has failed, then work backwards to identify why." + " Phases: 'describe_plan', 'imagine_failure', 'identify_causes'," + " 'rank_risks', 'mitigate'." + " Use effort_mode to control depth: low (quick risk list)," + " medium (full 5-phase flow), high (exhaustive analysis)," + " ultra (auto-rank at every phase + require all mitigations)." + ), + ) + async def premortem( + thought: str = Field( + description="The current thinking step content.", + ), + next_thought_needed: bool = Field( + description="Whether another thought step is needed.", + ), + thought_number: int = Field( + ge=1, + description="Current thought number in the sequence.", + ), + total_thoughts: int = Field( + ge=1, + description="Estimated total thoughts needed (can be adjusted).", + ), + phase: str = Field( + default="describe_plan", + description=( + "Premortem phase: 'describe_plan', 'imagine_failure'," + " 'identify_causes', 'rank_risks', or 'mitigate'." + ), + ), + session_id: str | None = Field( + default=None, + description="Session identifier. Auto-generated if not provided.", + ), + plan: str | None = Field( + default=None, + description="The plan description. Used in 'describe_plan' phase.", + ), + failure_scenario: str | None = Field( + default=None, + description="The imagined failure scenario. Used in 'imagine_failure' phase.", + ), + risk_description: str | None = Field( + default=None, + description="Description of a risk cause. Required in 'identify_causes' phase.", + ), + likelihood: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Likelihood of this risk (0.0 to 1.0).", + ), + impact: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Impact severity of this risk (0.0 to 1.0).", + ), + risk_index: int | None = Field( + default=None, + description="Index of risk to mitigate. Required in 'mitigate' phase.", + ), + mitigation: str | None = Field( + default=None, + description="Mitigation strategy. Used in 'mitigate' phase.", + ), + effort_mode: str = Field( + default="ultra", + description="Effort level: 'low', 'medium', 'high', or 'ultra'.", + ), + ) -> PremortemResult: + try: + if effort_mode not in VALID_EFFORT_MODES: + return PremortemResult( + success=False, + effort_mode=effort_mode, + message=( + f"Invalid effort_mode '{effort_mode}'." + f" Must be one of: {', '.join(sorted(VALID_EFFORT_MODES))}" + ), + ) + engine = _get_engine() + sid = session_id or str(uuid.uuid4()) + data = ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + ) + return engine.process_premortem( + sid, data, + phase=phase, plan=plan, + failure_scenario=failure_scenario, + risk_description=risk_description, + likelihood=likelihood, impact=impact, + mitigation=mitigation, risk_index=risk_index, + effort_mode=effort_mode, + ) + except Exception as e: + return PremortemResult( + success=False, message=f"Premortem failed: {e!s}" + ) + + @mcp.tool( + name="inversion_thinking", + description=( + "Instead of asking 'how to succeed', ask 'how to guarantee failure'," + " then invert. Phases: 'define_goal', 'invert'," + " 'list_failure_causes', 'rank_causes' (medium/high only)," + " 'reinvert', 'action_plan'." + " Use effort_mode: low (skip ranking, 3 phases)," + " medium (full 6 phases), high (auto-populate action plan)," + " ultra (auto-reinvert all causes + auto-populate everything)." + ), + ) + async def inversion_thinking( + thought: str = Field( + description="The current thinking step content.", + ), + next_thought_needed: bool = Field( + description="Whether another thought step is needed.", + ), + thought_number: int = Field( + ge=1, + description="Current thought number in the sequence.", + ), + total_thoughts: int = Field( + ge=1, + description="Estimated total thoughts needed (can be adjusted).", + ), + phase: str = Field( + default="define_goal", + description=( + "Phase: 'define_goal', 'invert', 'list_failure_causes'," + " 'rank_causes', 'reinvert', or 'action_plan'." + ), + ), + session_id: str | None = Field( + default=None, + description="Session identifier. Auto-generated if not provided.", + ), + goal: str | None = Field( + default=None, + description="The goal to achieve. Used in 'define_goal' phase.", + ), + inverted_goal: str | None = Field( + default=None, + description="The inverted goal statement. Used in 'invert' phase.", + ), + failure_cause: str | None = Field( + default=None, + description="A cause of failure. Required in 'list_failure_causes' phase.", + ), + severity: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Severity of this failure cause (0.0 to 1.0).", + ), + cause_index: int | None = Field( + default=None, + description="Index of cause to reinvert. Required in 'reinvert' phase.", + ), + inverted_action: str | None = Field( + default=None, + description="The positive action derived from inverting a cause.", + ), + action_item: str | None = Field( + default=None, + description="An action item for the plan. Used in 'action_plan' phase.", + ), + effort_mode: str = Field( + default="ultra", + description="Effort level: 'low', 'medium', 'high', or 'ultra'.", + ), + ) -> InversionThinkingResult: + try: + if effort_mode not in VALID_EFFORT_MODES: + return InversionThinkingResult( + success=False, + effort_mode=effort_mode, + message=( + f"Invalid effort_mode '{effort_mode}'." + f" Must be one of: {', '.join(sorted(VALID_EFFORT_MODES))}" + ), + ) + engine = _get_engine() + sid = session_id or str(uuid.uuid4()) + data = ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + ) + return engine.process_inversion( + sid, data, + phase=phase, goal=goal, + inverted_goal=inverted_goal, + failure_cause=failure_cause, + severity=severity, + inverted_action=inverted_action, + cause_index=cause_index, + action_item=action_item, + effort_mode=effort_mode, + ) + except Exception as e: + return InversionThinkingResult( + success=False, message=f"Inversion thinking failed: {e!s}" + ) + + @mcp.tool( + name="effort_estimator", + description=( + "Three-point PERT estimation for tasks." + " Provide optimistic, likely, and pessimistic estimates" + " to get PERT weighted average, standard deviation," + " and confidence intervals." + " Actions: 'add' a task estimate, 'summary' to view all," + " 'clear' to reset." + " Use effort_mode: low (single-point estimate)," + " medium (PERT + 68% CI), high (PERT + 68% + 95% CI)," + " ultra (PERT + 68% + 95% + 99.7% CI + risk buffer)." + ), + ) + async def effort_estimator( + session_id: str | None = Field( + default=None, + description="Session identifier. Auto-generated if not provided.", + ), + action: str = Field( + default="add", + description="Action: 'add', 'summary', or 'clear'.", + ), + task: str | None = Field( + default=None, + description="Task name. Required when action is 'add'.", + ), + optimistic: float = Field( + default=0.0, + ge=0.0, + description="Optimistic (best-case) estimate.", + ), + likely: float = Field( + default=0.0, + ge=0.0, + description="Most likely estimate.", + ), + pessimistic: float = Field( + default=0.0, + ge=0.0, + description="Pessimistic (worst-case) estimate.", + ), + effort_mode: str = Field( + default="ultra", + description="Effort level: 'low', 'medium', 'high', or 'ultra'.", + ), + ) -> EffortEstimatorResult: + try: + if effort_mode not in VALID_EFFORT_MODES: + return EffortEstimatorResult( + success=False, + effort_mode=effort_mode, + message=( + f"Invalid effort_mode '{effort_mode}'." + f" Must be one of: {', '.join(sorted(VALID_EFFORT_MODES))}" + ), + ) + engine = _get_engine() + sid = session_id or str(uuid.uuid4()) + return engine.process_estimate( + sid, action=action, + task=task, + optimistic=optimistic, likely=likely, + pessimistic=pessimistic, + effort_mode=effort_mode, + ) + except Exception as e: + return EffortEstimatorResult( + success=False, message=f"Effort estimator failed: {e!s}" + ) + + @mcp.tool( + name="plan_optimizer", + description=( + "Structured plan optimization tool." + " Analyzes any plan (implementation, architecture, refactoring," + " bug fix) across 8 quality dimensions, auto-detects" + " anti-patterns, supports 3 variant generation with" + " comparison matrix scoring, and recommends the best approach." + "\n\nPhases:" + "\n1. 'submit_plan' — Submit plan text + context." + " Auto-detects anti-patterns." + "\n2. 'analyze' — Score plan across dimensions" + " (clarity, completeness, correctness, risk, simplicity," + " testability, edge_cases, actionability)." + " Call once per dimension with score 0-10." + "\n3. 'detect_anti_patterns' — Re-run anti-pattern" + " detection (after plan edits)." + "\n4. 'add_variant' — Add an alternative plan variant" + " (A=Minimal, B=Robust, C=Optimal Architecture)." + "\n5. 'score_variant' — Score a variant across dimensions." + " Call once per dimension per variant." + "\n6. 'recommend' — Submit final recommendation." + " Returns full comparison matrix." + "\n\nUse effort_mode: low (just submit+analyze, skip variants)," + " medium (full 6-phase flow)," + " high (full flow + detailed anti-pattern analysis)," + " ultra (auto-score missing dimensions + require variants for recommend)." + ), + ) + async def plan_optimizer( + thought: str = Field( + description="The current thinking step content.", + ), + next_thought_needed: bool = Field( + description="Whether another thought step is needed.", + ), + thought_number: int = Field( + ge=1, + description="Current thought number in the sequence.", + ), + total_thoughts: int = Field( + ge=1, + description="Estimated total thoughts needed.", + ), + phase: str = Field( + default="submit_plan", + description=( + "Phase: 'submit_plan', 'analyze'," + " 'detect_anti_patterns', 'add_variant'," + " 'score_variant', or 'recommend'." + ), + ), + session_id: str | None = Field( + default=None, + description=( + "Session identifier." + " Auto-generated if not provided." + ), + ), + plan_text: str | None = Field( + default=None, + description=( + "The full plan text to optimize." + " Required in 'submit_plan' phase." + ), + ), + plan_context: str | None = Field( + default=None, + description=( + "Context about what the plan is for." + " E.g. 'Implementing user authentication'" + ), + ), + dimension: str | None = Field( + default=None, + description=( + "Dimension to score: clarity, completeness," + " correctness, risk, simplicity, testability," + " edge_cases, actionability." + " Used in 'analyze' and 'score_variant' phases." + ), + ), + score: float = Field( + default=0.0, + ge=0.0, + le=10.0, + description="Score for the dimension (0.0-10.0).", + ), + issue: str | None = Field( + default=None, + description=( + "An issue found during analysis." + " Used in 'analyze' phase." + ), + ), + variant_label: str | None = Field( + default=None, + description=( + "Variant label: 'A', 'B', or 'C'." + " Used in 'add_variant' and 'score_variant'." + ), + ), + variant_name: str | None = Field( + default=None, + description=( + "Variant name, e.g. 'Minimal & Pragmatic'." + " Used in 'add_variant'." + ), + ), + variant_summary: str | None = Field( + default=None, + description="Brief approach summary for the variant.", + ), + variant_approach: str | None = Field( + default=None, + description="Full variant approach text.", + ), + variant_pros: list[str] | None = Field( + default=None, + description="List of pros for this variant.", + ), + variant_cons: list[str] | None = Field( + default=None, + description="List of cons for this variant.", + ), + variant_risk_level: str = Field( + default="medium", + description="Risk level: 'low', 'medium', 'high'.", + ), + variant_complexity: str = Field( + default="medium", + description="Complexity: 'low', 'medium', 'high'.", + ), + recommendation: str | None = Field( + default=None, + description=( + "Final recommendation text." + " Used in 'recommend' phase." + ), + ), + winner_label: str | None = Field( + default=None, + description=( + "Label of the winning variant." + " Auto-selected if not provided." + ), + ), + effort_mode: str = Field( + default="ultra", + description="Effort level: 'low', 'medium', 'high', or 'ultra'.", + ), + ) -> PlanOptimizerResult: + try: + if effort_mode not in VALID_EFFORT_MODES: + return PlanOptimizerResult( + success=False, + effort_mode=effort_mode, + message=( + f"Invalid effort_mode '{effort_mode}'." + f" Must be one of: {', '.join(sorted(VALID_EFFORT_MODES))}" + ), + ) + engine = _get_engine() + sid = session_id or str(uuid.uuid4()) + data = ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + ) + return engine.process_plan_optimizer( + sid, data, + phase=phase, + plan_text=plan_text, + plan_context=plan_context, + dimension=dimension, + score=score, + issue=issue, + variant_label=variant_label, + variant_name=variant_name, + variant_summary=variant_summary, + variant_approach=variant_approach, + variant_pros=variant_pros, + variant_cons=variant_cons, + variant_risk_level=variant_risk_level, + variant_complexity=variant_complexity, + recommendation=recommendation, + winner_label=winner_label, + effort_mode=effort_mode, + ) + except Exception as e: + return PlanOptimizerResult( + success=False, + message=f"Plan optimizer failed: {e!s}", + ) diff --git a/tests/test_code_intelligence_tools.py b/tests/test_code_intelligence_tools.py new file mode 100644 index 0000000..186bb26 --- /dev/null +++ b/tests/test_code_intelligence_tools.py @@ -0,0 +1,693 @@ +"""Tests for code intelligence tools.""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cocoindex_code.code_intelligence_tools import ( + _classify_usage, + _compute_metrics, + _extract_symbols, + _find_definitions_impl, + _find_references_impl, + _rename_symbol_impl, + _walk_source_files, +) + + +@pytest.fixture() +def sample_codebase(tmp_path: Path) -> Path: + """Create a sample codebase for testing.""" + (tmp_path / "src").mkdir() + (tmp_path / "src" / "utils").mkdir() + (tmp_path / "lib").mkdir() + (tmp_path / "node_modules").mkdir() + (tmp_path / "__pycache__").mkdir() + + (tmp_path / "main.py").write_text( + "MAX_RETRIES = 3\n" + "\n" + "class UserManager:\n" + ' """Manages users."""\n' + "\n" + " def __init__(self):\n" + " self.users = []\n" + "\n" + " def add_user(self, name):\n" + " self.users.append(name)\n" + "\n" + " async def fetch_user(self, user_id):\n" + " pass\n" + "\n" + "\n" + "def helper():\n" + " manager = UserManager()\n" + " manager.add_user('alice')\n" + ) + + (tmp_path / "src" / "app.ts").write_text( + "export function greet(name: string): string {\n" + " return `Hello, ${name}!`;\n" + "}\n" + "\n" + "export class Greeter {\n" + " private name: string;\n" + "\n" + " constructor(name: string) {\n" + " this.name = name;\n" + " }\n" + "\n" + " greet(): string {\n" + " return greet(this.name);\n" + " }\n" + "}\n" + "\n" + "export const DEFAULT_NAME = 'World';\n" + ) + + (tmp_path / "src" / "utils" / "math.ts").write_text( + "export const add = (a: number, b: number): number => a + b;\n" + "export const subtract = (a: number, b: number): number => a - b;\n" + ) + + (tmp_path / "lib" / "database.py").write_text( + "import sqlite3\n" + "\n" + "class DatabaseConnection:\n" + ' """Database connection manager."""\n' + "\n" + " def connect(self) -> None:\n" + " pass\n" + "\n" + " def query(self, sql: str):\n" + " pass\n" + ) + + (tmp_path / "lib" / "server.rs").write_text( + "pub async fn start_server(port: u16) -> Result<(), Error> {\n" + " let listener = TcpListener::bind(port).await?;\n" + " Ok(())\n" + "}\n" + "\n" + "pub struct Config {\n" + " pub host: String,\n" + " pub port: u16,\n" + "}\n" + "\n" + "impl Config {\n" + " pub fn new() -> Self {\n" + " Config { host: String::new(), port: 8080 }\n" + " }\n" + "}\n" + ) + + (tmp_path / "lib" / "handler.go").write_text( + "package main\n" + "\n" + "func HandleRequest(w http.ResponseWriter, r *http.Request) {\n" + " w.Write([]byte(\"OK\"))\n" + "}\n" + "\n" + "type Server struct {\n" + " Port int\n" + "}\n" + "\n" + "func (s *Server) Start() error {\n" + " return nil\n" + "}\n" + ) + + (tmp_path / "README.md").write_text("# Test Project\n\nA test project.\n") + + (tmp_path / "node_modules" / "pkg.js").write_text("module.exports = {};\n") + (tmp_path / "__pycache__" / "main.cpython-312.pyc").write_bytes( + b"\x00" * 100 + ) + + binary_path = tmp_path / "image.png" + binary_path.write_bytes( + b"\x89PNG\r\n\x1a\n\x00\x00\x00" + b"\x00" * 50 + ) + + return tmp_path + + +@pytest.fixture(autouse=True) +def _patch_config(sample_codebase: Path) -> Iterator[None]: + """Patch config to point at sample_codebase.""" + with patch( + "cocoindex_code.filesystem_tools.config" + ) as mock_fs_config, patch( + "cocoindex_code.code_intelligence_tools._root" + ) as mock_root, patch( + "cocoindex_code.code_intelligence_tools._safe_resolve" + ) as mock_resolve, patch( + "cocoindex_code.code_intelligence_tools._relative" + ) as mock_relative: + mock_fs_config.codebase_root_path = sample_codebase + mock_root.return_value = sample_codebase + + def safe_resolve_side_effect(path_str): + import os + root = sample_codebase + resolved = (root / path_str).resolve() + if not ( + resolved == root + or str(resolved).startswith(str(root) + os.sep) + ): + msg = f"Path '{path_str}' escapes the codebase root" + raise ValueError(msg) + return resolved + + mock_resolve.side_effect = safe_resolve_side_effect + + def relative_side_effect(path): + try: + return str(path.relative_to(sample_codebase)) + except ValueError: + return str(path) + + mock_relative.side_effect = relative_side_effect + yield + + +# === Tests for _extract_symbols === + + +class TestExtractSymbols: + def test_python_functions_and_classes(self) -> None: + content = ( + "def hello():\n" + " pass\n" + "\n" + "class Foo:\n" + " def method(self):\n" + " pass\n" + ) + symbols = _extract_symbols(content, "python") + names = [s.name for s in symbols] + assert "hello" in names + assert "Foo" in names + assert "method" in names + # method should be classified as method + method_sym = next(s for s in symbols if s.name == "method") + assert method_sym.symbol_type == "method" + # hello should be function + hello_sym = next(s for s in symbols if s.name == "hello") + assert hello_sym.symbol_type == "function" + + def test_python_constants(self) -> None: + content = "MAX_SIZE = 100\nPI = 3.14\n" + symbols = _extract_symbols(content, "python") + names = [s.name for s in symbols] + assert "MAX_SIZE" in names + assert "PI" in names + for s in symbols: + assert s.symbol_type == "constant" + + def test_python_async_function(self) -> None: + content = "async def fetch_data():\n pass\n" + symbols = _extract_symbols(content, "python") + assert len(symbols) == 1 + assert symbols[0].name == "fetch_data" + assert symbols[0].symbol_type == "function" + + def test_typescript_interface_and_enum(self) -> None: + content = ( + "export interface User {\n" + " name: string;\n" + "}\n" + "\n" + "export type ID = string;\n" + "\n" + "export enum Color {\n" + " Red, Green, Blue\n" + "}\n" + ) + symbols = _extract_symbols(content, "typescript") + names = [s.name for s in symbols] + assert "User" in names + assert "ID" in names + assert "Color" in names + user = next(s for s in symbols if s.name == "User") + assert user.symbol_type == "interface" + color = next(s for s in symbols if s.name == "Color") + assert color.symbol_type == "enum" + + def test_javascript_functions_and_classes(self) -> None: + content = ( + "export function greet(name) {\n" + " return name;\n" + "}\n" + "export class App {}\n" + "const VERSION = '1.0';\n" + ) + symbols = _extract_symbols(content, "javascript") + names = [s.name for s in symbols] + assert "greet" in names + assert "App" in names + assert "VERSION" in names + + def test_rust_symbols(self) -> None: + content = ( + "pub async fn serve(port: u16) {}\n" + "pub struct Config { port: u16 }\n" + "pub enum Status { Ok, Error }\n" + "pub trait Handler {}\n" + "mod tests {}\n" + "impl Config {}\n" + "const MAX: u32 = 100;\n" + ) + symbols = _extract_symbols(content, "rust") + names = [s.name for s in symbols] + assert "serve" in names + assert "Config" in names + assert "Status" in names + assert "Handler" in names + assert "tests" in names + assert "MAX" in names + + def test_go_symbols(self) -> None: + content = ( + "func HandleRequest(w http.ResponseWriter) {\n" + "}\n" + "type Server struct {\n" + " Port int\n" + "}\n" + "func (s *Server) Start() error {\n" + " return nil\n" + "}\n" + "const MaxRetries = 3\n" + ) + symbols = _extract_symbols(content, "go") + names = [s.name for s in symbols] + assert "HandleRequest" in names + assert "Server" in names + assert "Start" in names + assert "MaxRetries" in names + + def test_unknown_language(self) -> None: + symbols = _extract_symbols("hello world", "brainfuck") + assert symbols == [] + + def test_empty_content(self) -> None: + symbols = _extract_symbols("", "python") + assert symbols == [] + + def test_end_line_computation(self) -> None: + content = ( + "def foo():\n" + " pass\n" + "\n" + "def bar():\n" + " x = 1\n" + " return x\n" + ) + symbols = _extract_symbols(content, "python") + foo = next(s for s in symbols if s.name == "foo") + bar = next(s for s in symbols if s.name == "bar") + assert foo.end_line == 3 # before bar starts + assert bar.end_line == 6 # EOF + + +# === Tests for _walk_source_files === + + +class TestWalkSourceFiles: + def test_walks_all_source_files( + self, sample_codebase: Path, + ) -> None: + files = _walk_source_files(sample_codebase) + rel_paths = [rel for _, rel, _ in files] + assert any("main.py" in p for p in rel_paths) + assert any("app.ts" in p for p in rel_paths) + # Excluded dirs + assert not any("node_modules" in p for p in rel_paths) + assert not any("__pycache__" in p for p in rel_paths) + # Binary files excluded + assert not any("image.png" in p for p in rel_paths) + + def test_language_filter( + self, sample_codebase: Path, + ) -> None: + files = _walk_source_files( + sample_codebase, languages=["python"], + ) + for _, _, lang in files: + assert lang == "python" + + def test_path_filter( + self, sample_codebase: Path, + ) -> None: + files = _walk_source_files( + sample_codebase, paths=["src/*"], + ) + for _, rel, _ in files: + assert rel.startswith("src/") or rel.startswith("src\\") + + +# === Tests for _classify_usage === + + +class TestClassifyUsage: + def test_import(self) -> None: + assert _classify_usage( + "from foo import bar", "bar", "python", + ) == "import" + assert _classify_usage( + "import os", "os", "python", + ) == "import" + + def test_call(self) -> None: + assert _classify_usage( + "result = helper()", "helper", "python", + ) == "call" + + def test_assignment(self) -> None: + assert _classify_usage( + "helper = something", "helper", "python", + ) == "assignment" + + def test_type_annotation(self) -> None: + assert _classify_usage( + "x: UserManager = None", "UserManager", "python", + ) == "type_annotation" + + def test_definition(self) -> None: + assert _classify_usage( + "def helper():", "helper", "python", + ) == "definition" + assert _classify_usage( + "class UserManager:", "UserManager", "python", + ) == "definition" + + def test_other(self) -> None: + assert _classify_usage( + "print(helper)", "helper", "python", + ) == "other" + + +# === Tests for _find_definitions_impl === + + +class TestFindDefinitions: + def test_find_python_function( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "helper", sample_codebase, + ) + assert len(defs) >= 1 + assert any(d.name == "helper" for d in defs) + + def test_find_python_class( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "UserManager", sample_codebase, + ) + assert len(defs) >= 1 + assert defs[0].symbol_type == "class" + + def test_find_typescript_function( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "greet", sample_codebase, + ) + assert len(defs) >= 1 + assert any( + d.file_path.endswith("app.ts") for d in defs + ) + + def test_find_rust_function( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "start_server", sample_codebase, + ) + assert len(defs) >= 1 + + def test_find_go_function( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "HandleRequest", sample_codebase, + ) + assert len(defs) >= 1 + + def test_no_match( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "nonexistent_symbol_xyz", sample_codebase, + ) + assert len(defs) == 0 + + def test_filter_by_type( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "UserManager", sample_codebase, + symbol_type="function", + ) + assert len(defs) == 0 + + def test_filter_by_language( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "greet", sample_codebase, + languages=["python"], + ) + # greet is in typescript, not python + assert len(defs) == 0 + + def test_limit( + self, sample_codebase: Path, + ) -> None: + defs = _find_definitions_impl( + "helper", sample_codebase, limit=1, + ) + assert len(defs) <= 1 + + +# === Tests for _find_references_impl === + + +class TestFindReferences: + def test_find_references_to_symbol( + self, sample_codebase: Path, + ) -> None: + refs, total, searched, trunc = _find_references_impl( + "UserManager", sample_codebase, + ) + assert total >= 2 # class def + usage in helper() + + def test_word_boundary( + self, sample_codebase: Path, + ) -> None: + # "add" should match add_user method AND add const + refs, total, _, _ = _find_references_impl( + "add", sample_codebase, + ) + # Should NOT match "add_user" since \badd\b won't match inside + for ref in refs: + # Each match should contain "add" as a word + assert "add" in ref.line + + def test_context_lines( + self, sample_codebase: Path, + ) -> None: + refs, _, _, _ = _find_references_impl( + "UserManager", sample_codebase, + context_lines=2, + ) + if refs: + # At least one ref should have context + has_context = any( + ref.context_before or ref.context_after + for ref in refs + ) + assert has_context + + def test_language_filter( + self, sample_codebase: Path, + ) -> None: + refs, _, _, _ = _find_references_impl( + "greet", sample_codebase, + languages=["typescript"], + ) + for ref in refs: + assert ref.path.endswith(".ts") + + def test_truncation( + self, sample_codebase: Path, + ) -> None: + refs, total, _, trunc = _find_references_impl( + "UserManager", sample_codebase, limit=1, + ) + assert len(refs) <= 1 + + def test_usage_type_classification( + self, sample_codebase: Path, + ) -> None: + refs, _, _, _ = _find_references_impl( + "sqlite3", sample_codebase, + ) + import_refs = [ + r for r in refs if r.usage_type == "import" + ] + assert len(import_refs) >= 1 + + +# === Tests for _compute_metrics === + + +class TestComputeMetrics: + def test_basic_metrics(self) -> None: + content = ( + "# A comment\n" + "\n" + "def foo():\n" + " pass\n" + "\n" + "def bar():\n" + " x = 1\n" + " if x > 0:\n" + " return x\n" + " return 0\n" + ) + m = _compute_metrics(content, "python") + assert m.total_lines == 10 + assert m.blank_lines == 2 + assert m.comment_lines == 1 + assert m.code_lines == 7 + assert m.functions == 2 + assert m.complexity_estimate >= 1 # at least the if + + def test_empty_file(self) -> None: + m = _compute_metrics("", "python") + assert m.total_lines == 0 + assert m.functions == 0 + assert m.classes == 0 + + def test_nesting_depth(self) -> None: + content = ( + "def foo():\n" + " if True:\n" + " for i in range(10):\n" + " if i > 5:\n" + " print(i)\n" + ) + m = _compute_metrics(content, "python") + assert m.max_nesting_depth >= 4 + + def test_class_count(self) -> None: + content = ( + "class Foo:\n" + " pass\n" + "\n" + "class Bar:\n" + " pass\n" + ) + m = _compute_metrics(content, "python") + assert m.classes == 2 + + def test_unknown_language(self) -> None: + content = "hello world\n" + m = _compute_metrics(content, "unknown") + assert m.total_lines == 1 + assert m.functions == 0 + + +# === Tests for _rename_symbol_impl === + + +class TestRenameSymbol: + def test_dry_run_preview( + self, sample_codebase: Path, + ) -> None: + result = _rename_symbol_impl( + "UserManager", "AccountManager", + sample_codebase, dry_run=True, + ) + assert result.success + assert result.dry_run + assert result.total_replacements >= 2 + assert result.files_changed >= 1 + # File should NOT be modified + content = (sample_codebase / "main.py").read_text() + assert "UserManager" in content + + def test_actual_rename( + self, sample_codebase: Path, + ) -> None: + result = _rename_symbol_impl( + "UserManager", "AccountManager", + sample_codebase, dry_run=False, + ) + assert result.success + assert not result.dry_run + assert result.total_replacements >= 2 + content = (sample_codebase / "main.py").read_text() + assert "AccountManager" in content + assert "UserManager" not in content + + def test_word_boundary_safety( + self, sample_codebase: Path, + ) -> None: + # Renaming "add" should not affect "add_user" + _rename_symbol_impl( + "add", "sum_values", + sample_codebase, dry_run=False, + ) + content = (sample_codebase / "main.py").read_text() + # add_user should still be intact + assert "add_user" in content + + def test_same_name_error( + self, sample_codebase: Path, + ) -> None: + result = _rename_symbol_impl( + "foo", "foo", sample_codebase, + ) + assert not result.success + assert "identical" in (result.message or "") + + def test_invalid_name_error( + self, sample_codebase: Path, + ) -> None: + result = _rename_symbol_impl( + "foo", "invalid-name!", sample_codebase, + ) + assert not result.success + assert "valid identifier" in (result.message or "") + + def test_scope_filter( + self, sample_codebase: Path, + ) -> None: + result = _rename_symbol_impl( + "greet", "sayHello", + sample_codebase, + scope="src/**", + dry_run=True, + ) + assert result.success + # Should only match files in src/ + for change in result.changes: + assert change.file_path.startswith("src") + + def test_no_matches( + self, sample_codebase: Path, + ) -> None: + result = _rename_symbol_impl( + "nonexistent_xyz_abc", "new_name", + sample_codebase, dry_run=True, + ) + assert result.success + assert result.total_replacements == 0 diff --git a/tests/test_config.py b/tests/test_config.py index 5db91bc..f28e1cc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -135,7 +135,7 @@ def test_dot_prefix_passed_through(self, tmp_path: Path) -> None: }, ): config = Config.from_env() - assert config.extra_extensions == {"..rb": None, ".yaml": None} + assert config.extra_extensions == {".rb": None, ".yaml": None} def test_parses_lang_mapping(self, tmp_path: Path) -> None: with patch.dict( diff --git a/tests/test_filesystem_tools.py b/tests/test_filesystem_tools.py new file mode 100644 index 0000000..dfc6443 --- /dev/null +++ b/tests/test_filesystem_tools.py @@ -0,0 +1,467 @@ +"""Tests for filesystem tools: find_files, read_file, write_file, grep_code, directory_tree.""" + +from __future__ import annotations + +import os +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cocoindex_code.filesystem_tools import ( + _detect_lang, + _directory_tree, + _edit_file, + _grep_files, + _is_binary, + _is_excluded_dir, + _read_file, + _safe_resolve, + _walk_files, + _write_file, +) + + +@pytest.fixture() +def sample_codebase(tmp_path: Path) -> Path: + """Create a sample codebase for testing.""" + (tmp_path / "src").mkdir() + (tmp_path / "src" / "utils").mkdir() + (tmp_path / "lib").mkdir() + (tmp_path / "node_modules").mkdir() + (tmp_path / "__pycache__").mkdir() + + (tmp_path / "main.py").write_text( + 'def hello():\n """Say hello."""\n print("Hello, world!")\n' + ) + (tmp_path / "src" / "app.ts").write_text( + "export function greet(name: string): string {\n" + " return `Hello, ${name}!`;\n" + "}\n" + "\n" + "// TODO: add farewell function\n" + "export function farewell(name: string): string {\n" + " return `Goodbye, ${name}!`;\n" + "}\n" + ) + (tmp_path / "src" / "utils" / "math.ts").write_text( + "export const add = (a: number, b: number): number => a + b;\n" + "export const subtract = (a: number, b: number): number => a - b;\n" + ) + (tmp_path / "lib" / "database.py").write_text( + "class DatabaseConnection:\n" + ' """Database connection manager."""\n' + "\n" + " def connect(self) -> None:\n" + ' """Establish connection."""\n' + " pass\n" + ) + (tmp_path / "README.md").write_text("# Test Project\n\nA test project.\n") + + (tmp_path / "node_modules" / "pkg.js").write_text("module.exports = {};\n") + (tmp_path / "__pycache__" / "main.cpython-312.pyc").write_bytes(b"\x00" * 100) + + binary_path = tmp_path / "image.png" + binary_path.write_bytes(b"\x89PNG\r\n\x1a\n\x00\x00\x00" + b"\x00" * 50) + + return tmp_path + + +@pytest.fixture(autouse=True) +def _patch_config(sample_codebase: Path) -> Iterator[None]: + """Patch filesystem_tools config to point at sample_codebase.""" + with patch("cocoindex_code.filesystem_tools.config") as mock_config: + mock_config.codebase_root_path = sample_codebase + yield + + +class TestIsExcludedDir: + """Tests for _is_excluded_dir.""" + + def test_hidden_dirs_excluded(self) -> None: + assert _is_excluded_dir(".git") is True + assert _is_excluded_dir(".vscode") is True + + def test_known_excluded_dirs(self) -> None: + assert _is_excluded_dir("node_modules") is True + assert _is_excluded_dir("__pycache__") is True + assert _is_excluded_dir(".cocoindex_code") is True + + def test_pattern_excluded_dirs(self) -> None: + assert _is_excluded_dir("target") is True + assert _is_excluded_dir("build") is True + assert _is_excluded_dir("dist") is True + assert _is_excluded_dir("vendor") is True + + def test_normal_dirs_not_excluded(self) -> None: + assert _is_excluded_dir("src") is False + assert _is_excluded_dir("lib") is False + assert _is_excluded_dir("tests") is False + + +class TestIsBinary: + """Tests for _is_binary.""" + + def test_text_file_not_binary(self, tmp_path: Path) -> None: + f = tmp_path / "test.txt" + f.write_text("Hello, world!") + assert _is_binary(f) is False + + def test_binary_file_detected(self, tmp_path: Path) -> None: + f = tmp_path / "test.bin" + f.write_bytes(b"\x00\x01\x02\x03") + assert _is_binary(f) is True + + def test_nonexistent_file_returns_true(self, tmp_path: Path) -> None: + assert _is_binary(tmp_path / "nonexistent") is True + + +class TestDetectLang: + """Tests for _detect_lang.""" + + def test_python(self, tmp_path: Path) -> None: + assert _detect_lang(tmp_path / "test.py") == "python" + assert _detect_lang(tmp_path / "test.pyi") == "python" + + def test_typescript(self, tmp_path: Path) -> None: + assert _detect_lang(tmp_path / "test.ts") == "typescript" + assert _detect_lang(tmp_path / "test.tsx") == "typescript" + + def test_javascript(self, tmp_path: Path) -> None: + assert _detect_lang(tmp_path / "test.js") == "javascript" + + def test_unknown_extension(self, tmp_path: Path) -> None: + assert _detect_lang(tmp_path / "test.xyz") == "" + + +class TestSafeResolve: + """Tests for _safe_resolve path traversal protection.""" + + def test_normal_path(self, sample_codebase: Path) -> None: + resolved = _safe_resolve("src/app.ts") + assert resolved == sample_codebase / "src" / "app.ts" + + def test_traversal_blocked(self, sample_codebase: Path) -> None: + with pytest.raises(ValueError, match="escapes the codebase root"): + _safe_resolve("../../etc/passwd") + + +class TestWalkFiles: + """Tests for _walk_files.""" + + def test_find_all_files(self, sample_codebase: Path) -> None: + files, total, truncated = _walk_files(sample_codebase) + assert total > 0 + assert not truncated + paths = {f.path for f in files} + assert "main.py" in paths + assert "src/app.ts" in paths + assert "README.md" in paths + + def test_excludes_node_modules(self, sample_codebase: Path) -> None: + files, _, _ = _walk_files(sample_codebase) + paths = {f.path for f in files} + assert not any("node_modules" in p for p in paths) + + def test_excludes_pycache(self, sample_codebase: Path) -> None: + files, _, _ = _walk_files(sample_codebase) + paths = {f.path for f in files} + assert not any("__pycache__" in p for p in paths) + + def test_pattern_filter(self, sample_codebase: Path) -> None: + files, total, _ = _walk_files(sample_codebase, pattern="*.py") + assert total == 2 + assert all(f.path.endswith(".py") for f in files) + + def test_language_filter(self, sample_codebase: Path) -> None: + files, total, _ = _walk_files(sample_codebase, languages=["typescript"]) + assert total == 2 + assert all(f.language == "typescript" for f in files) + + def test_paths_filter(self, sample_codebase: Path) -> None: + files, total, _ = _walk_files(sample_codebase, paths=["src/*"]) + assert total > 0 + assert all(f.path.startswith("src/") for f in files) + + def test_limit_truncates(self, sample_codebase: Path) -> None: + files, total, truncated = _walk_files(sample_codebase, limit=1) + assert len(files) == 1 + assert total > 1 + assert truncated is True + + def test_file_size_populated(self, sample_codebase: Path) -> None: + files, _, _ = _walk_files(sample_codebase, pattern="main.py") + assert len(files) == 1 + assert files[0].size > 0 + + +class TestReadFile: + """Tests for _read_file.""" + + def test_read_entire_file(self, sample_codebase: Path) -> None: + content, s, e, total = _read_file(sample_codebase / "main.py") + assert s == 1 + assert e == total + assert "def hello" in content + + def test_read_line_range(self, sample_codebase: Path) -> None: + content, s, e, total = _read_file(sample_codebase / "main.py", start_line=1, end_line=1) + assert s == 1 + assert e == 1 + assert "def hello" in content + assert "print" not in content + + def test_start_line_clamped(self, sample_codebase: Path) -> None: + content, s, _, _ = _read_file(sample_codebase / "main.py", start_line=0) + assert s == 1 + + def test_end_line_clamped(self, sample_codebase: Path) -> None: + _, _, e, total = _read_file(sample_codebase / "main.py", end_line=9999) + assert e == total + + +class TestGrepFiles: + """Tests for _grep_files.""" + + def test_basic_grep(self, sample_codebase: Path) -> None: + matches, total, searched, truncated = _grep_files(sample_codebase, "def hello") + assert total == 1 + assert matches[0].path == "main.py" + assert matches[0].line_number == 1 + assert not truncated + + def test_grep_regex(self, sample_codebase: Path) -> None: + matches, total, _, _ = _grep_files(sample_codebase, r"TODO|FIXME") + assert total >= 1 + assert any("TODO" in m.line for m in matches) + + def test_grep_case_insensitive(self, sample_codebase: Path) -> None: + matches, total, _, _ = _grep_files(sample_codebase, "hello", case_sensitive=False) + assert total >= 1 + + def test_grep_include_filter(self, sample_codebase: Path) -> None: + matches, total, _, _ = _grep_files(sample_codebase, "export", include="*.ts") + assert total >= 1 + assert all(m.path.endswith(".ts") for m in matches) + + def test_grep_paths_filter(self, sample_codebase: Path) -> None: + matches, total, _, _ = _grep_files(sample_codebase, "export", paths=["src/utils/*"]) + assert total >= 1 + assert all(m.path.startswith("src/utils/") for m in matches) + + def test_grep_context_lines(self, sample_codebase: Path) -> None: + matches, _, _, _ = _grep_files(sample_codebase, "TODO", context_lines=1) + assert len(matches) >= 1 + assert len(matches[0].context_after) > 0 or len(matches[0].context_before) > 0 + + def test_grep_limit(self, sample_codebase: Path) -> None: + matches, total, _, truncated = _grep_files(sample_codebase, "export", limit=1) + assert len(matches) == 1 + if total > 1: + assert truncated is True + + def test_grep_invalid_regex(self, sample_codebase: Path) -> None: + with pytest.raises(ValueError, match="Invalid regex"): + _grep_files(sample_codebase, "[invalid") + + def test_grep_skips_binary(self, sample_codebase: Path) -> None: + matches, _, _, _ = _grep_files(sample_codebase, "PNG") + paths = {m.path for m in matches} + assert "image.png" not in paths + + def test_grep_skips_excluded_dirs(self, sample_codebase: Path) -> None: + matches, _, _, _ = _grep_files(sample_codebase, "module.exports") + paths = {m.path for m in matches} + assert not any("node_modules" in p for p in paths) + + +class TestDirectoryTree: + """Tests for _directory_tree.""" + + def test_basic_tree(self, sample_codebase: Path) -> None: + entries = _directory_tree(sample_codebase) + paths = {e.path for e in entries} + types = {e.path: e.type for e in entries} + assert "src" in paths + assert types["src"] == "dir" + assert "main.py" in paths + assert types["main.py"] == "file" + + def test_excludes_hidden_and_known_dirs(self, sample_codebase: Path) -> None: + entries = _directory_tree(sample_codebase) + paths = {e.path for e in entries} + assert not any("node_modules" in p for p in paths) + assert not any("__pycache__" in p for p in paths) + + def test_max_depth(self, sample_codebase: Path) -> None: + entries = _directory_tree(sample_codebase, max_depth=1) + dirs = [e for e in entries if e.type == "dir"] + nested = [d for d in dirs if d.path.count(os.sep) > 1] + assert len(nested) == 0 + + def test_subdirectory(self, sample_codebase: Path) -> None: + entries = _directory_tree(sample_codebase, rel_path="src") + paths = {e.path for e in entries} + assert any("app.ts" in p for p in paths) + + def test_file_sizes(self, sample_codebase: Path) -> None: + entries = _directory_tree(sample_codebase) + file_entries = [e for e in entries if e.type == "file"] + assert all(e.size >= 0 for e in file_entries) + main_py = next(e for e in file_entries if e.path == "main.py") + assert main_py.size > 0 + + def test_children_count(self, sample_codebase: Path) -> None: + entries = _directory_tree(sample_codebase) + src_entry = next(e for e in entries if e.path == "src") + assert src_entry.children > 0 + + +class TestWriteFile: + """Tests for _write_file.""" + + def test_create_new_file(self, sample_codebase: Path) -> None: + path = sample_codebase / "new_file.txt" + bytes_written, created = _write_file(path, "hello world") + assert created is True + assert bytes_written == 11 + assert path.read_text() == "hello world" + + def test_overwrite_existing_file(self, sample_codebase: Path) -> None: + path = sample_codebase / "main.py" + original = path.read_text() + new_content = "# replaced\n" + bytes_written, created = _write_file(path, new_content) + assert created is False + assert bytes_written == len(new_content.encode("utf-8")) + assert path.read_text() == new_content + assert path.read_text() != original + + def test_creates_parent_directories(self, sample_codebase: Path) -> None: + path = sample_codebase / "deep" / "nested" / "dir" / "file.go" + bytes_written, created = _write_file(path, "package main\n") + assert created is True + assert path.exists() + assert path.read_text() == "package main\n" + + def test_unicode_content(self, sample_codebase: Path) -> None: + path = sample_codebase / "unicode.txt" + content = "Hello, mundo! Emoji: \u2764\ufe0f" + bytes_written, created = _write_file(path, content) + assert created is True + assert path.read_text(encoding="utf-8") == content + assert bytes_written == len(content.encode("utf-8")) + + def test_empty_content(self, sample_codebase: Path) -> None: + path = sample_codebase / "empty.txt" + bytes_written, created = _write_file(path, "") + assert created is True + assert bytes_written == 0 + assert path.read_text() == "" + + def test_multiline_content(self, sample_codebase: Path) -> None: + path = sample_codebase / "multi.py" + content = "def foo():\n return 42\n\ndef bar():\n return 0\n" + bytes_written, created = _write_file(path, content) + assert created is True + assert path.read_text() == content + + def test_exceeds_max_size(self, sample_codebase: Path) -> None: + path = sample_codebase / "huge.txt" + content = "x" * 2_000_000 + with pytest.raises(ValueError, match="exceeds maximum write size"): + _write_file(path, content) + assert not path.exists() + + def test_path_traversal_blocked(self, sample_codebase: Path) -> None: + with pytest.raises(ValueError, match="escapes the codebase root"): + resolved = _safe_resolve("../../etc/evil.txt") + _write_file(resolved, "malicious") + + def test_write_then_read_roundtrip(self, sample_codebase: Path) -> None: + path = sample_codebase / "roundtrip.ts" + content = "export const x: number = 42;\n" + _write_file(path, content) + read_content, s, e, total = _read_file(path) + assert read_content == content + assert s == 1 + assert e == total == 1 + + +class TestEditFile: + """Tests for _edit_file.""" + + def test_single_replacement(self, sample_codebase: Path) -> None: + path = sample_codebase / "main.py" + original = path.read_text() + assert "def hello" in original + replacements = _edit_file(path, "def hello", "def greet") + assert replacements == 1 + assert "def greet" in path.read_text() + assert "def hello" not in path.read_text() + + def test_replace_all(self, sample_codebase: Path) -> None: + path = sample_codebase / "replace_all.txt" + path.write_text("aaa bbb aaa ccc aaa") + replacements = _edit_file(path, "aaa", "xxx", replace_all=True) + assert replacements == 3 + assert path.read_text() == "xxx bbb xxx ccc xxx" + + def test_ambiguous_match_without_replace_all(self, sample_codebase: Path) -> None: + path = sample_codebase / "ambiguous.txt" + path.write_text("foo bar foo baz foo") + with pytest.raises(ValueError, match="Found 3 matches"): + _edit_file(path, "foo", "qux") + + def test_old_string_not_found(self, sample_codebase: Path) -> None: + path = sample_codebase / "main.py" + with pytest.raises(ValueError, match="old_string not found"): + _edit_file(path, "nonexistent_string_xyz", "replacement") + + def test_identical_strings_rejected(self, sample_codebase: Path) -> None: + path = sample_codebase / "main.py" + with pytest.raises(ValueError, match="identical"): + _edit_file(path, "def hello", "def hello") + + def test_multiline_replacement(self, sample_codebase: Path) -> None: + path = sample_codebase / "multi.py" + path.write_text("def foo():\n return 1\n\ndef bar():\n return 2\n") + replacements = _edit_file( + path, + "def foo():\n return 1", + "def foo(x: int):\n return x + 1", + ) + assert replacements == 1 + content = path.read_text() + assert "def foo(x: int):" in content + assert "return x + 1" in content + assert "def bar():" in content + + def test_replacement_preserves_rest_of_file(self, sample_codebase: Path) -> None: + path = sample_codebase / "src" / "app.ts" + original = path.read_text() + line_count_before = original.count("\n") + _edit_file(path, "greet", "welcome") + updated = path.read_text() + assert "welcome" in updated + assert "greet" not in updated + assert updated.count("\n") == line_count_before + + def test_delete_by_replacing_with_empty(self, sample_codebase: Path) -> None: + path = sample_codebase / "delete.txt" + path.write_text("keep this\nremove this line\nkeep this too\n") + _edit_file(path, "remove this line\n", "") + assert path.read_text() == "keep this\nkeep this too\n" + + def test_insert_by_replacing_anchor(self, sample_codebase: Path) -> None: + path = sample_codebase / "insert.py" + path.write_text("import os\n\ndef main():\n pass\n") + _edit_file(path, "import os\n", "import os\nimport sys\n") + content = path.read_text() + assert "import os\nimport sys\n" in content + + def test_file_not_found(self, sample_codebase: Path) -> None: + path = sample_codebase / "nope.txt" + with pytest.raises(FileNotFoundError): + _edit_file(path, "a", "b") diff --git a/tests/test_large_write.py b/tests/test_large_write.py new file mode 100644 index 0000000..8de4011 --- /dev/null +++ b/tests/test_large_write.py @@ -0,0 +1,188 @@ +"""Tests for the large_write tool.""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cocoindex_code.filesystem_tools import ( + _large_write_append, + _large_write_buffers, + _large_write_finalize, + _large_write_start, +) + + +@pytest.fixture() +def sample_codebase(tmp_path: Path) -> Path: + """Create a sample codebase.""" + (tmp_path / "src").mkdir() + return tmp_path + + +@pytest.fixture(autouse=True) +def _patch_config(sample_codebase: Path) -> Iterator[None]: + """Patch config and clear buffers.""" + with patch( + "cocoindex_code.filesystem_tools.config" + ) as mock_config: + mock_config.codebase_root_path = sample_codebase + _large_write_buffers.clear() + yield + _large_write_buffers.clear() + + +class TestLargeWriteStart: + def test_creates_session(self) -> None: + _large_write_start("s1", "test.py") + assert "s1" in _large_write_buffers + assert _large_write_buffers["s1"]["path"] == "test.py" + assert _large_write_buffers["s1"]["chunks"] == [] + assert _large_write_buffers["s1"]["total_bytes"] == 0 + + +class TestLargeWriteAppend: + def test_append_content(self) -> None: + _large_write_start("s1", "test.py") + total = _large_write_append("s1", "hello ") + assert total == 6 + total = _large_write_append("s1", "world") + assert total == 11 + assert len(_large_write_buffers["s1"]["chunks"]) == 2 + + def test_size_limit(self) -> None: + _large_write_start("s1", "test.py") + # Try to append more than 5MB + big_chunk = "x" * (5 * 1024 * 1024 + 1) + with pytest.raises(ValueError, match="exceeds max size"): + _large_write_append("s1", big_chunk) + + +class TestLargeWriteFinalize: + def test_writes_file( + self, sample_codebase: Path, + ) -> None: + _large_write_start("s1", "output.py") + _large_write_append("s1", "def foo():\n") + _large_write_append("s1", " pass\n") + path, written, created = _large_write_finalize("s1") + + assert path == "output.py" + assert created + assert written > 0 + + out = sample_codebase / "output.py" + assert out.exists() + content = out.read_text() + assert "def foo():" in content + assert " pass" in content + + def test_creates_parent_dirs( + self, sample_codebase: Path, + ) -> None: + _large_write_start("s1", "deep/nested/dir/file.py") + _large_write_append("s1", "content") + _large_write_finalize("s1") + + out = sample_codebase / "deep" / "nested" / "dir" / "file.py" + assert out.exists() + + def test_removes_session_after_finalize(self) -> None: + _large_write_start("s1", "test.py") + _large_write_append("s1", "content") + _large_write_finalize("s1") + assert "s1" not in _large_write_buffers + + def test_overwrites_existing_file( + self, sample_codebase: Path, + ) -> None: + existing = sample_codebase / "existing.py" + existing.write_text("old content") + + _large_write_start("s1", "existing.py") + _large_write_append("s1", "new content") + _, _, created = _large_write_finalize("s1") + + assert not created # file existed + assert existing.read_text() == "new content" + + +class TestLargeWriteWorkflow: + """End-to-end workflow tests.""" + + def test_full_workflow( + self, sample_codebase: Path, + ) -> None: + # Start + _large_write_start("session_1", "src/big_module.py") + + # Append chunks + _large_write_append( + "session_1", + "# Big Module\n\n", + ) + _large_write_append( + "session_1", + "def func_a():\n pass\n\n", + ) + _large_write_append( + "session_1", + "def func_b():\n pass\n", + ) + + # Finalize + path, written, created = _large_write_finalize("session_1") + + assert path == "src/big_module.py" + assert created + out = sample_codebase / "src" / "big_module.py" + content = out.read_text() + assert "# Big Module" in content + assert "func_a" in content + assert "func_b" in content + + def test_multiple_sessions( + self, sample_codebase: Path, + ) -> None: + _large_write_start("a", "file_a.py") + _large_write_start("b", "file_b.py") + _large_write_append("a", "content_a") + _large_write_append("b", "content_b") + _large_write_finalize("a") + _large_write_finalize("b") + + assert (sample_codebase / "file_a.py").read_text() == "content_a" + assert (sample_codebase / "file_b.py").read_text() == "content_b" + + +class TestSessionEviction: + """Test that old sessions are evicted when MAX_SESSIONS is reached.""" + + def test_evicts_oldest_when_at_capacity(self) -> None: + from cocoindex_code.filesystem_tools import MAX_LARGE_WRITE_SESSIONS + + # Fill up to the limit + for i in range(MAX_LARGE_WRITE_SESSIONS): + _large_write_start(f"sess_{i}", f"file_{i}.py") + assert len(_large_write_buffers) == MAX_LARGE_WRITE_SESSIONS + + # Adding one more should evict the oldest + _large_write_start("overflow", "overflow.py") + assert len(_large_write_buffers) == MAX_LARGE_WRITE_SESSIONS + assert "overflow" in _large_write_buffers + # sess_0 should have been evicted (oldest created_at) + assert "sess_0" not in _large_write_buffers + + def test_restarting_existing_session_does_not_evict(self) -> None: + from cocoindex_code.filesystem_tools import MAX_LARGE_WRITE_SESSIONS + + for i in range(MAX_LARGE_WRITE_SESSIONS): + _large_write_start(f"sess_{i}", f"file_{i}.py") + + # Restarting an existing session should NOT evict anyone + _large_write_start("sess_0", "updated.py") + assert len(_large_write_buffers) == MAX_LARGE_WRITE_SESSIONS + assert _large_write_buffers["sess_0"]["path"] == "updated.py" diff --git a/tests/test_mcp_wrappers.py b/tests/test_mcp_wrappers.py new file mode 100644 index 0000000..37143f9 --- /dev/null +++ b/tests/test_mcp_wrappers.py @@ -0,0 +1,145 @@ +"""Tests for MCP tool wrapper layer — exception handling and Pydantic validation.""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cocoindex_code.filesystem_tools import ( + _large_write_buffers, +) + + +@pytest.fixture() +def sample_codebase(tmp_path: Path) -> Path: + """Create a minimal codebase.""" + (tmp_path / "hello.py").write_text("print('hello')\n") + return tmp_path + + +@pytest.fixture(autouse=True) +def _patch_config(sample_codebase: Path) -> Iterator[None]: + with ( + patch("cocoindex_code.filesystem_tools.config") as mock_fs_config, + patch("cocoindex_code.thinking_tools.config") as mock_tt_config, + patch("cocoindex_code.thinking_tools._engine", None), + ): + mock_fs_config.codebase_root_path = sample_codebase + mock_tt_config.index_dir = sample_codebase + _large_write_buffers.clear() + yield + _large_write_buffers.clear() + + +class TestFilesystemToolValidation: + """Test that filesystem tools handle edge cases correctly.""" + + def test_large_write_append_without_start(self) -> None: + """Appending to non-existent session should raise.""" + from cocoindex_code.filesystem_tools import _large_write_append + + with pytest.raises(KeyError): + _large_write_append("nonexistent", "content") + + def test_large_write_finalize_without_start(self) -> None: + """Finalizing non-existent session should raise.""" + from cocoindex_code.filesystem_tools import _large_write_finalize + + with pytest.raises(KeyError): + _large_write_finalize("nonexistent") + + def test_large_write_start_idempotent(self) -> None: + """Starting a session twice should reset it.""" + from cocoindex_code.filesystem_tools import ( + _large_write_append, + _large_write_start, + ) + + _large_write_start("s1", "file.py") + _large_write_append("s1", "chunk1") + _large_write_start("s1", "file2.py") # Restart + assert _large_write_buffers["s1"]["path"] == "file2.py" + assert _large_write_buffers["s1"]["chunks"] == [] + + +class TestThinkingToolPydanticModels: + """Test that Pydantic models validate inputs correctly.""" + + def test_thought_data_requires_fields(self) -> None: + from pydantic import ValidationError + + from cocoindex_code.thinking_tools import ThoughtData + + with pytest.raises(ValidationError): + ThoughtData() # type: ignore[call-arg] + + def test_thought_data_valid(self) -> None: + from cocoindex_code.thinking_tools import ThoughtData + + td = ThoughtData( + thought="test", + thought_number=1, + total_thoughts=3, + next_thought_needed=True, + ) + assert td.thought == "test" + assert td.is_revision is False + + def test_thinking_result_defaults(self) -> None: + from cocoindex_code.thinking_tools import ThinkingResult + + result = ThinkingResult(success=True) + assert result.session_id == "" + assert result.branches == [] + assert result.message is None + + def test_evidence_tracker_result_defaults(self) -> None: + from cocoindex_code.thinking_tools import EvidenceTrackerResult + + result = EvidenceTrackerResult(success=False, message="test") + assert result.effort_mode == "medium" + assert result.total_evidence_count == 0 + + def test_plan_optimizer_result_defaults(self) -> None: + from cocoindex_code.thinking_tools import PlanOptimizerResult + + result = PlanOptimizerResult(success=True) + assert result.variants == [] + assert result.comparison_matrix == {} + assert result.plan_health_score == 0.0 + + +class TestThinkingEngineExceptionHandling: + """Test that ThinkingEngine handles errors gracefully.""" + + def test_load_corrupted_memory_file(self, sample_codebase: Path) -> None: + """ThinkingEngine should handle corrupted JSONL gracefully.""" + from cocoindex_code.thinking_engine import ThinkingEngine + + memory_file = sample_codebase / "thinking_memory.jsonl" + memory_file.write_text("not valid json\n{\"type\": \"bad\"}\n") + + # Should not crash — just skip invalid lines + with pytest.raises(Exception): + ThinkingEngine(sample_codebase) + + def test_empty_memory_file(self, sample_codebase: Path) -> None: + """ThinkingEngine should handle empty memory file.""" + from cocoindex_code.thinking_engine import ThinkingEngine + + memory_file = sample_codebase / "thinking_memory.jsonl" + memory_file.write_text("") + + engine = ThinkingEngine(sample_codebase) + assert engine._learnings == [] + assert engine._strategy_scores == {} + + def test_missing_memory_file(self, sample_codebase: Path) -> None: + """ThinkingEngine should handle missing memory file.""" + from cocoindex_code.thinking_engine import ThinkingEngine + + engine = ThinkingEngine(sample_codebase) + assert engine._learnings == [] diff --git a/tests/test_patch_tools.py b/tests/test_patch_tools.py new file mode 100644 index 0000000..0459fa8 --- /dev/null +++ b/tests/test_patch_tools.py @@ -0,0 +1,327 @@ +"""Tests for patch tools: apply_patch.""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cocoindex_code.patch_tools import ( + PatchHunk, + _apply_hunks, + _apply_patch_impl, + _parse_unified_diff, +) + + +@pytest.fixture() +def sample_codebase(tmp_path: Path) -> Path: + """Create a sample codebase for testing.""" + (tmp_path / "src").mkdir() + + (tmp_path / "hello.py").write_text( + "def hello():\n" + " print('Hello, world!')\n" + "\n" + "def goodbye():\n" + " print('Goodbye!')\n" + ) + + (tmp_path / "src" / "app.py").write_text( + "class App:\n" + " def run(self):\n" + " pass\n" + ) + + return tmp_path + + +@pytest.fixture(autouse=True) +def _patch_config(sample_codebase: Path) -> Iterator[None]: + """Patch config for patch_tools.""" + with patch( + "cocoindex_code.filesystem_tools.config" + ) as mock_fs_config, patch( + "cocoindex_code.patch_tools._root" + ) as mock_root, patch( + "cocoindex_code.patch_tools._safe_resolve" + ) as mock_resolve: + mock_fs_config.codebase_root_path = sample_codebase + mock_root.return_value = sample_codebase + + def safe_resolve_side_effect(path_str): + import os + root = sample_codebase + resolved = (root / path_str).resolve() + if not ( + resolved == root + or str(resolved).startswith(str(root) + os.sep) + ): + msg = f"Path '{path_str}' escapes the codebase root" + raise ValueError(msg) + return resolved + + mock_resolve.side_effect = safe_resolve_side_effect + yield + + +# === Tests for _parse_unified_diff === + + +class TestParseUnifiedDiff: + def test_single_file_single_hunk(self) -> None: + diff = ( + "--- a/hello.py\n" + "+++ b/hello.py\n" + "@@ -1,3 +1,3 @@\n" + " def hello():\n" + "- print('Hello, world!')\n" + "+ print('Hello, everyone!')\n" + ) + files = _parse_unified_diff(diff) + assert len(files) == 1 + assert files[0].old_path == "hello.py" + assert files[0].new_path == "hello.py" + assert len(files[0].hunks) == 1 + assert files[0].hunks[0].old_start == 1 + + def test_multi_hunk(self) -> None: + diff = ( + "--- a/hello.py\n" + "+++ b/hello.py\n" + "@@ -1,2 +1,2 @@\n" + " def hello():\n" + "- print('Hello, world!')\n" + "+ print('Hi!')\n" + "@@ -4,2 +4,2 @@\n" + " def goodbye():\n" + "- print('Goodbye!')\n" + "+ print('Bye!')\n" + ) + files = _parse_unified_diff(diff) + assert len(files) == 1 + assert len(files[0].hunks) == 2 + + def test_multi_file(self) -> None: + diff = ( + "--- a/hello.py\n" + "+++ b/hello.py\n" + "@@ -1,2 +1,2 @@\n" + " def hello():\n" + "- print('Hello, world!')\n" + "+ print('Hi!')\n" + "--- a/src/app.py\n" + "+++ b/src/app.py\n" + "@@ -1,3 +1,3 @@\n" + " class App:\n" + "- def run(self):\n" + "+ def start(self):\n" + " pass\n" + ) + files = _parse_unified_diff(diff) + assert len(files) == 2 + + def test_new_file(self) -> None: + diff = ( + "--- /dev/null\n" + "+++ b/new_file.py\n" + "@@ -0,0 +1,2 @@\n" + "+def new_func():\n" + "+ pass\n" + ) + files = _parse_unified_diff(diff) + assert len(files) == 1 + assert files[0].old_path == "/dev/null" + assert files[0].new_path == "new_file.py" + + def test_empty_patch(self) -> None: + files = _parse_unified_diff("") + assert files == [] + + +# === Tests for _apply_hunks === + + +class TestApplyHunks: + def test_single_replacement(self) -> None: + content = ( + "def hello():\n" + " print('Hello, world!')\n" + ) + hunk = PatchHunk( + old_start=1, old_count=2, new_start=1, new_count=2, + lines=[ + " def hello():", + "- print('Hello, world!')", + "+ print('Hello, everyone!')", + ], + ) + result, applied, rejected = _apply_hunks(content, [hunk]) + assert applied == 1 + assert rejected == 0 + assert "Hello, everyone!" in result + + def test_context_mismatch_rejects(self) -> None: + content = "def foo():\n pass\n" + hunk = PatchHunk( + old_start=1, old_count=2, new_start=1, new_count=2, + lines=[ + " def bar():", # doesn't match + "- pass", + "+ return None", + ], + ) + result, applied, rejected = _apply_hunks(content, [hunk]) + assert applied == 0 + assert rejected == 1 + # Content unchanged + assert result == content + + def test_multiple_hunks(self) -> None: + content = ( + "line1\n" + "line2\n" + "line3\n" + "line4\n" + "line5\n" + ) + hunk1 = PatchHunk( + old_start=1, old_count=1, new_start=1, new_count=1, + lines=["-line1", "+LINE1"], + ) + hunk2 = PatchHunk( + old_start=5, old_count=1, new_start=5, new_count=1, + lines=["-line5", "+LINE5"], + ) + result, applied, rejected = _apply_hunks( + content, [hunk1, hunk2], + ) + assert applied == 2 + assert rejected == 0 + assert "LINE1" in result + assert "LINE5" in result + + +# === Tests for _apply_patch_impl === + + +class TestApplyPatchImpl: + def test_dry_run(self, sample_codebase: Path) -> None: + diff = ( + "--- a/hello.py\n" + "+++ b/hello.py\n" + "@@ -1,2 +1,2 @@\n" + " def hello():\n" + "- print('Hello, world!')\n" + "+ print('Hi!')\n" + ) + result = _apply_patch_impl(diff, sample_codebase, dry_run=True) + assert result.success + assert result.dry_run + assert result.total_applied == 1 + # File should be unchanged + content = (sample_codebase / "hello.py").read_text() + assert "Hello, world!" in content + + def test_apply(self, sample_codebase: Path) -> None: + diff = ( + "--- a/hello.py\n" + "+++ b/hello.py\n" + "@@ -1,2 +1,2 @@\n" + " def hello():\n" + "- print('Hello, world!')\n" + "+ print('Hi!')\n" + ) + result = _apply_patch_impl( + diff, sample_codebase, dry_run=False, + ) + assert result.success + assert result.total_applied == 1 + content = (sample_codebase / "hello.py").read_text() + assert "Hi!" in content + + def test_new_file_creation( + self, sample_codebase: Path, + ) -> None: + diff = ( + "--- /dev/null\n" + "+++ b/new_file.py\n" + "@@ -0,0 +1,2 @@\n" + "+def new_func():\n" + "+ pass\n" + ) + result = _apply_patch_impl( + diff, sample_codebase, dry_run=False, + ) + assert result.success + assert result.total_applied == 1 + new_file = sample_codebase / "new_file.py" + assert new_file.exists() + content = new_file.read_text() + assert "def new_func():" in content + + def test_nonexistent_file( + self, sample_codebase: Path, + ) -> None: + diff = ( + "--- a/missing.py\n" + "+++ b/missing.py\n" + "@@ -1,2 +1,2 @@\n" + " foo\n" + "-bar\n" + "+baz\n" + ) + result = _apply_patch_impl( + diff, sample_codebase, dry_run=False, + ) + assert not result.success + assert result.total_rejected == 1 + + def test_path_traversal_rejected( + self, sample_codebase: Path, + ) -> None: + diff = ( + "--- a/../../etc/passwd\n" + "+++ b/../../etc/passwd\n" + "@@ -1,1 +1,1 @@\n" + "-root\n" + "+hacked\n" + ) + result = _apply_patch_impl( + diff, sample_codebase, dry_run=False, + ) + assert result.total_rejected >= 1 + + def test_empty_patch( + self, sample_codebase: Path, + ) -> None: + result = _apply_patch_impl("", sample_codebase) + assert not result.success + assert "No files" in (result.message or "") + + def test_multi_file_patch( + self, sample_codebase: Path, + ) -> None: + diff = ( + "--- a/hello.py\n" + "+++ b/hello.py\n" + "@@ -1,2 +1,2 @@\n" + " def hello():\n" + "- print('Hello, world!')\n" + "+ print('Hi!')\n" + "--- a/src/app.py\n" + "+++ b/src/app.py\n" + "@@ -1,3 +1,3 @@\n" + " class App:\n" + "- def run(self):\n" + "+ def start(self):\n" + " pass\n" + ) + result = _apply_patch_impl( + diff, sample_codebase, dry_run=False, + ) + assert result.success + assert result.total_applied == 2 + assert len(result.files) == 2 diff --git a/tests/test_plan_optimizer.py b/tests/test_plan_optimizer.py new file mode 100644 index 0000000..2823095 --- /dev/null +++ b/tests/test_plan_optimizer.py @@ -0,0 +1,645 @@ +"""Tests for the plan_optimizer tool.""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cocoindex_code.thinking_tools import ( + PLAN_DIMENSIONS, + ThinkingEngine, + ThoughtData, +) + + +@pytest.fixture() +def thinking_dir(tmp_path: Path) -> Path: + return tmp_path + + +@pytest.fixture(autouse=True) +def _patch_config(thinking_dir: Path) -> Iterator[None]: + with ( + patch("cocoindex_code.thinking_tools.config") as mock_config, + patch("cocoindex_code.thinking_tools._engine", None), + ): + mock_config.index_dir = thinking_dir + yield + + +def _make_thought( + thought: str = "t", + thought_number: int = 1, + total_thoughts: int = 10, + next_thought_needed: bool = True, +) -> ThoughtData: + return ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + ) + + +SAMPLE_PLAN = """# Implementation Plan: Add User Authentication + +## Phase 1: Database Schema +1. Create users table with email, password_hash, created_at +2. Add sessions table for JWT token tracking +3. Write migration scripts + +## Phase 2: API Endpoints +1. POST /api/auth/register - validate input, hash password, create user +2. POST /api/auth/login - verify credentials, issue JWT +3. POST /api/auth/logout - invalidate session +4. GET /api/auth/me - return current user profile + +## Phase 3: Middleware +1. Create auth middleware to verify JWT on protected routes +2. Add rate limiting to auth endpoints + +## Phase 4: Testing +1. Unit tests for password hashing +2. Integration tests for auth endpoints +3. E2E test for login flow +""" + +VAGUE_PLAN = """ +Fix the authentication. +Make it work somehow. +Clean up the code and improve stuff. +Handle the edge cases etc. +Figure out the deployment. +""" + +NO_STRUCTURE_PLAN = ( + "We need to add a new feature to the application.\n" + "It should allow users to upload files.\n" + "The files need to be stored somewhere.\n" + "We also need to validate the files.\n" + "Then we deploy it to production.\n" +) + + +class TestAntiPatternDetection: + def test_detects_vague_language( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + patterns = engine._detect_anti_patterns(VAGUE_PLAN) + vague = [ + p for p in patterns + if p.pattern_type == "vague_language" + ] + assert len(vague) >= 3 # "make it work", "somehow", "stuff" + + def test_detects_todo_markers( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + plan = "Step 1: Create model\nStep 2: TODO implement validation\n" + patterns = engine._detect_anti_patterns(plan) + todo = [ + p for p in patterns + if p.pattern_type == "todo_marker" + ] + assert len(todo) >= 1 + + def test_detects_missing_concerns( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + # Plan that mentions nothing about security + plan = ( + "1. Create the endpoint\n" + "2. Add error handling\n" + "3. Write tests\n" + ) + patterns = engine._detect_anti_patterns(plan) + missing = [ + p for p in patterns + if p.pattern_type == "missing_security" + ] + assert len(missing) >= 1 + + def test_detects_no_structure( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + patterns = engine._detect_anti_patterns(NO_STRUCTURE_PLAN) + no_struct = [ + p for p in patterns + if p.pattern_type == "no_structure" + ] + assert len(no_struct) >= 1 + + def test_detects_god_step( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + long_step = "x" * 600 + plan = f"1. {long_step}\n2. Short step\n" + patterns = engine._detect_anti_patterns(plan) + god = [ + p for p in patterns + if p.pattern_type == "god_step" + ] + assert len(god) >= 1 + + def test_clean_plan_has_few_issues( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + patterns = engine._detect_anti_patterns(SAMPLE_PLAN) + # A well-structured plan should have few anti-patterns + # It may flag missing concerns (e.g. security) which is valid + vague = [ + p for p in patterns + if p.pattern_type == "vague_language" + ] + assert len(vague) == 0 + god_steps = [ + p for p in patterns + if p.pattern_type == "god_step" + ] + assert len(god_steps) == 0 + todos = [ + p for p in patterns + if p.pattern_type == "todo_marker" + ] + assert len(todos) == 0 + + +class TestPlanHealthScore: + def test_perfect_scores( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + scores = {dim: 10.0 for dim in PLAN_DIMENSIONS} + health = engine._compute_plan_health(scores, 0) + assert health == 100.0 + + def test_zero_scores( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + scores = {dim: 0.0 for dim in PLAN_DIMENSIONS} + health = engine._compute_plan_health(scores, 0) + assert health == 0.0 + + def test_anti_patterns_reduce_health( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + scores = {dim: 10.0 for dim in PLAN_DIMENSIONS} + health_clean = engine._compute_plan_health(scores, 0) + health_dirty = engine._compute_plan_health(scores, 5) + assert health_dirty < health_clean + assert health_dirty == 75.0 # 100 - 5*5 + + def test_empty_scores( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + health = engine._compute_plan_health({}, 0) + assert health == 0.0 + + +class TestProcessPlanOptimizer: + def test_invalid_phase( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_plan_optimizer( + "s1", _make_thought(), phase="invalid_phase", + ) + assert not result.success + assert "Invalid phase" in (result.message or "") + + def test_submit_plan( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", + plan_text=SAMPLE_PLAN, + plan_context="Adding auth to the web app", + ) + assert result.success + assert result.plan_text == SAMPLE_PLAN + assert result.plan_context == "Adding auth to the web app" + # Anti-patterns auto-detected + assert isinstance(result.anti_patterns, list) + + def test_submit_plan_requires_text( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", + ) + assert not result.success + assert "plan_text is required" in (result.message or "") + + def test_analyze_dimension( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="analyze", + dimension="clarity", score=8.5, + ) + assert result.success + assert result.analysis_scores["clarity"] == 8.5 + + def test_analyze_invalid_dimension( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="analyze", + dimension="nonexistent", score=5.0, + ) + assert not result.success + assert "Invalid dimension" in (result.message or "") + + def test_analyze_clamps_score( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="analyze", + dimension="clarity", score=15.0, + ) + assert result.success + assert result.analysis_scores["clarity"] == 10.0 + + def test_analyze_adds_issue( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="analyze", + issue="Missing rollback strategy", + ) + assert result.success + assert "Missing rollback strategy" in result.analysis_issues + + def test_add_variant( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="add_variant", + variant_label="A", + variant_name="Minimal & Pragmatic", + variant_summary="Quick implementation", + variant_pros=["Fast to ship"], + variant_cons=["Less robust"], + variant_risk_level="low", + ) + assert result.success + assert len(result.variants) == 1 + assert result.variants[0].label == "A" + assert result.variants[0].name == "Minimal & Pragmatic" + + def test_add_variant_requires_label( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="add_variant", + variant_name="Test", + ) + assert not result.success + + def test_add_duplicate_variant_rejected( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="add_variant", + variant_label="A", variant_name="First", + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=3), + phase="add_variant", + variant_label="A", variant_name="Duplicate", + ) + assert not result.success + assert "already exists" in (result.message or "") + + def test_score_variant( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="add_variant", + variant_label="A", variant_name="Minimal", + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=3), + phase="score_variant", + variant_label="A", + dimension="clarity", score=9.0, + ) + assert result.success + assert result.variants[0].scores["clarity"] == 9.0 + assert result.variants[0].total == 9.0 + + def test_score_variant_not_found( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="score_variant", + variant_label="Z", + dimension="clarity", score=5.0, + ) + assert not result.success + assert "not found" in (result.message or "") + + def test_recommend_auto_picks_winner( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + # Add two variants with different scores + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="add_variant", + variant_label="A", variant_name="Minimal", + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=3), + phase="add_variant", + variant_label="B", variant_name="Robust", + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=4), + phase="score_variant", + variant_label="A", + dimension="clarity", score=5.0, + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=5), + phase="score_variant", + variant_label="B", + dimension="clarity", score=9.0, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=6), + phase="recommend", + recommendation="B is better due to higher clarity", + ) + assert result.success + assert result.winner_label == "B" + assert result.recommendation == ( + "B is better due to higher clarity" + ) + + def test_recommend_explicit_winner( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="add_variant", + variant_label="A", variant_name="Minimal", + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=3), + phase="recommend", + winner_label="A", + recommendation="A is good enough", + ) + assert result.success + assert result.winner_label == "A" + + def test_comparison_matrix( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", plan_text=SAMPLE_PLAN, + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=2), + phase="add_variant", + variant_label="A", variant_name="Minimal", + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=3), + phase="add_variant", + variant_label="B", variant_name="Robust", + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=4), + phase="score_variant", + variant_label="A", + dimension="clarity", score=7.0, + ) + engine.process_plan_optimizer( + "s1", _make_thought(thought_number=5), + phase="score_variant", + variant_label="B", + dimension="clarity", score=9.0, + ) + result = engine.process_plan_optimizer( + "s1", _make_thought(thought_number=6), + phase="recommend", + ) + assert result.success + matrix = result.comparison_matrix + assert "clarity" in matrix + assert matrix["clarity"]["A"] == 7.0 + assert matrix["clarity"]["B"] == 9.0 + assert "TOTAL" in matrix + + +class TestFullPlanOptimizerWorkflow: + """End-to-end workflow test.""" + + def test_full_optimize_flow( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + + # 1. Submit plan + r = engine.process_plan_optimizer( + "s1", _make_thought(thought="Submitting plan"), + phase="submit_plan", + plan_text=SAMPLE_PLAN, + plan_context="Adding authentication", + ) + assert r.success + assert isinstance(r.anti_patterns, list) + + # 2. Analyze across all dimensions + for i, dim in enumerate(PLAN_DIMENSIONS, start=2): + r = engine.process_plan_optimizer( + "s1", + _make_thought( + thought=f"Scoring {dim}", + thought_number=i, + ), + phase="analyze", + dimension=dim, score=7.5, + ) + assert r.success + + assert len(r.analysis_scores) == len(PLAN_DIMENSIONS) + assert r.plan_health_score > 0 + + # 3. Add 3 variants + variants = [ + ("A", "Minimal & Pragmatic", "Quick JWT auth"), + ("B", "Robust & Scalable", "Full OAuth2 + RBAC"), + ("C", "Optimal Architecture", "Auth service microservice"), + ] + step = 10 + for label, name, summary in variants: + step += 1 + r = engine.process_plan_optimizer( + "s1", + _make_thought( + thought=f"Adding variant {label}", + thought_number=step, + ), + phase="add_variant", + variant_label=label, + variant_name=name, + variant_summary=summary, + variant_pros=[f"Pro of {label}"], + variant_cons=[f"Con of {label}"], + ) + assert r.success + + assert len(r.variants) == 3 + + # 4. Score each variant + variant_scores = { + "A": {"clarity": 9, "simplicity": 9, "risk": 8, + "correctness": 6, "completeness": 5, + "testability": 7, "edge_cases": 4, + "actionability": 8}, + "B": {"clarity": 7, "simplicity": 5, "risk": 7, + "correctness": 9, "completeness": 9, + "testability": 8, "edge_cases": 8, + "actionability": 7}, + "C": {"clarity": 6, "simplicity": 3, "risk": 5, + "correctness": 10, "completeness": 10, + "testability": 9, "edge_cases": 9, + "actionability": 5}, + } + for label, scores in variant_scores.items(): + for dim, sc in scores.items(): + step += 1 + r = engine.process_plan_optimizer( + "s1", + _make_thought( + thought=f"Scoring {label}:{dim}", + thought_number=step, + ), + phase="score_variant", + variant_label=label, + dimension=dim, score=float(sc), + ) + assert r.success + + # 5. Recommend + step += 1 + r = engine.process_plan_optimizer( + "s1", + _make_thought( + thought="Final recommendation", + thought_number=step, + next_thought_needed=False, + ), + phase="recommend", + recommendation=( + "Variant B provides the best balance of " + "correctness, completeness, and testability " + "while maintaining reasonable simplicity." + ), + ) + assert r.success + # B should win (highest total) + assert r.winner_label == "B" + assert r.recommendation + assert "TOTAL" in r.comparison_matrix + assert len(r.comparison_matrix["TOTAL"]) == 3 + + def test_vague_plan_gets_many_anti_patterns( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + r = engine.process_plan_optimizer( + "s1", _make_thought(), + phase="submit_plan", + plan_text=VAGUE_PLAN, + ) + assert r.success + assert r.anti_pattern_count >= 5 + # Health should be low + # Even without analysis scores, anti-patterns detected + types = {p.pattern_type for p in r.anti_patterns} + assert "vague_language" in types diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..b8fc7ab --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,114 @@ +"""Tests for server.py CLI argument parsing.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + + +class TestMainArgumentParsing: + """Test that main() parses CLI arguments correctly.""" + + def test_serve_is_default(self) -> None: + """When no command is given, 'serve' is the default.""" + with ( + patch("sys.argv", ["cocoindex-code"]), + patch( + "cocoindex_code.server.asyncio.run", + ) as mock_run, + ): + from cocoindex_code.server import main + + main() + mock_run.assert_called_once() + # The call should be to _async_serve() + call_args = mock_run.call_args + coro = call_args[0][0] + assert coro is not None + + def test_serve_command(self) -> None: + """Explicit 'serve' command should call _async_serve.""" + with ( + patch("sys.argv", ["cocoindex-code", "serve"]), + patch( + "cocoindex_code.server.asyncio.run", + ) as mock_run, + ): + from cocoindex_code.server import main + + main() + mock_run.assert_called_once() + + def test_index_command(self) -> None: + """'index' command should call _async_index.""" + with ( + patch("sys.argv", ["cocoindex-code", "index"]), + patch( + "cocoindex_code.server.asyncio.run", + ) as mock_run, + ): + from cocoindex_code.server import main + + main() + mock_run.assert_called_once() + + +class TestPrintIndexStats: + """Test _print_index_stats with mocked database.""" + + @pytest.mark.asyncio + async def test_no_database(self, tmp_path: object) -> None: + """When no index DB exists, print message.""" + with patch( + "cocoindex_code.server.config" + ) as mock_config: + from pathlib import Path + + mock_config.target_sqlite_db_path = Path("/nonexistent/db.sqlite") + from cocoindex_code.server import _print_index_stats + + # Should not crash, just print "No index database found." + await _print_index_stats() + + +class TestSearchResultModel: + """Test SearchResultModel Pydantic model.""" + + def test_default_values(self) -> None: + from cocoindex_code.server import SearchResultModel + + result = SearchResultModel(success=True) + assert result.results == [] + assert result.total_returned == 0 + assert result.offset == 0 + assert result.message is None + + def test_with_results(self) -> None: + from cocoindex_code.server import CodeChunkResult, SearchResultModel + + chunk = CodeChunkResult( + file_path="test.py", + language="python", + content="print('hello')", + start_line=1, + end_line=1, + score=0.95, + ) + result = SearchResultModel( + success=True, + results=[chunk], + total_returned=1, + ) + assert len(result.results) == 1 + assert result.results[0].file_path == "test.py" + + def test_error_result(self) -> None: + from cocoindex_code.server import SearchResultModel + + result = SearchResultModel( + success=False, + message="Index not found", + ) + assert result.success is False + assert result.message == "Index not found" diff --git a/tests/test_shared.py b/tests/test_shared.py new file mode 100644 index 0000000..c81e03b --- /dev/null +++ b/tests/test_shared.py @@ -0,0 +1,75 @@ +"""Tests for shared.py initialization logic.""" + +from __future__ import annotations + + +class TestEmbedderSelection: + """Test embedder selection logic based on model prefix.""" + + def test_sbert_prefix_detected(self) -> None: + """Models starting with 'sbert/' use SentenceTransformerEmbedder.""" + from cocoindex_code.shared import SBERT_PREFIX + + assert "sbert/sentence-transformers/all-MiniLM-L6-v2".startswith(SBERT_PREFIX) + + def test_litellm_model_detected(self) -> None: + """Models without 'sbert/' prefix use LiteLLM.""" + from cocoindex_code.shared import SBERT_PREFIX + + assert not "text-embedding-3-small".startswith(SBERT_PREFIX) + + def test_sbert_prefix_constant(self) -> None: + from cocoindex_code.shared import SBERT_PREFIX + + assert SBERT_PREFIX == "sbert/" + + def test_query_prompt_models_constant(self) -> None: + """Known query-prompt models should be defined.""" + # We can't easily access the local variable, but we can verify + # the embedder was created without error + from cocoindex_code.shared import embedder + + assert embedder is not None + + +class TestContextKeys: + """Test CocoIndex context key definitions.""" + + def test_sqlite_db_key_exists(self) -> None: + from cocoindex_code.shared import SQLITE_DB + + assert SQLITE_DB is not None + + def test_codebase_dir_key_exists(self) -> None: + from cocoindex_code.shared import CODEBASE_DIR + + assert CODEBASE_DIR is not None + + +class TestCodeChunk: + """Test CodeChunk dataclass in shared.py.""" + + def test_code_chunk_has_expected_fields(self) -> None: + import dataclasses + + from cocoindex_code.shared import CodeChunk + + field_names = [f.name for f in dataclasses.fields(CodeChunk)] + assert "id" in field_names + assert "file_path" in field_names + assert "language" in field_names + assert "content" in field_names + assert "start_line" in field_names + assert "end_line" in field_names + assert "embedding" in field_names + + +class TestCocoLifespan: + """Test coco_lifespan function existence.""" + + def test_lifespan_is_callable(self) -> None: + """coco_lifespan should be a callable (decorated with @coco.lifespan).""" + from cocoindex_code.shared import coco_lifespan + + # It's wrapped by @coco.lifespan but should still exist + assert coco_lifespan is not None diff --git a/tests/test_thinking_tools.py b/tests/test_thinking_tools.py new file mode 100644 index 0000000..904cf04 --- /dev/null +++ b/tests/test_thinking_tools.py @@ -0,0 +1,1073 @@ +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cocoindex_code.thinking_tools import ( + ThinkingEngine, + ThoughtData, +) + + +@pytest.fixture() +def thinking_dir(tmp_path: Path) -> Path: + return tmp_path + + +@pytest.fixture(autouse=True) +def _patch_config(thinking_dir: Path) -> Iterator[None]: + with ( + patch("cocoindex_code.thinking_tools.config") as mock_config, + patch("cocoindex_code.thinking_tools._engine", None), + ): + mock_config.index_dir = thinking_dir + yield + + +def _make_thought( + thought: str = "t", + thought_number: int = 1, + total_thoughts: int = 3, + next_thought_needed: bool = True, + **kwargs, +) -> ThoughtData: + return ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + **kwargs, + ) + + +class TestThinkingEngine: + def test_init_creates_engine(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + assert engine._sessions == {} + + def test_load_empty_memory(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + assert engine._learnings == [] + assert engine._strategy_scores == {} + + def test_process_basic_thought(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + data = _make_thought(thought="first", thought_number=1, total_thoughts=3) + result = engine.process_thought("s1", data) + assert result.success + assert result.session_id == "s1" + assert result.thought_number == 1 + assert result.total_thoughts == 3 + assert result.thought_history_length == 1 + + def test_process_multiple_thoughts(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + for i in range(1, 4): + result = engine.process_thought("s1", _make_thought(thought_number=i)) + assert result.thought_history_length == i + + def test_auto_adjust_total_thoughts(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_thought("s1", _make_thought(thought_number=5, total_thoughts=3)) + assert result.total_thoughts == 5 + + def test_branching(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_thought("s1", _make_thought()) + engine.process_thought( + "s1", _make_thought(thought_number=2, branch_id="b1", branch_from_thought=1) + ) + result = engine.process_thought( + "s1", _make_thought(thought_number=3, branch_id="b2", branch_from_thought=1) + ) + assert "b1" in result.branches + assert "b2" in result.branches + + def test_multiple_thoughts_same_branch(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_thought( + "s1", _make_thought(thought_number=1, branch_id="b1", branch_from_thought=1) + ) + result = engine.process_thought( + "s1", _make_thought(thought_number=2, branch_id="b1", branch_from_thought=1) + ) + assert len(result.branches) == 1 + + +class TestExtendedThinking: + def test_basic_extended(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_extended_thought("s1", _make_thought(), depth_level="deep") + assert result.depth_level == "deep" + + def test_checkpoint_at_interval(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_extended_thought( + "s1", + _make_thought(thought_number=5, total_thoughts=10), + checkpoint_interval=5, + ) + assert result.checkpoint_summary != "" + + def test_no_checkpoint_between_intervals(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_extended_thought( + "s1", + _make_thought(thought_number=3, total_thoughts=10), + checkpoint_interval=5, + ) + assert result.checkpoint_summary == "" + + def test_exhaustive_mode(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_extended_thought("s1", _make_thought(), depth_level="exhaustive") + assert result.depth_level == "exhaustive" + + def test_steps_since_checkpoint(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_extended_thought( + "s1", + _make_thought(thought_number=7, total_thoughts=10), + checkpoint_interval=5, + ) + assert result.steps_since_checkpoint == 2 + + +class TestUltraThinking: + def test_explore_phase(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_ultra_thought("s1", _make_thought(), phase="explore") + assert result.phase == "explore" + + def test_hypothesize_phase(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_ultra_thought( + "s1", _make_thought(), phase="hypothesize", hypothesis="H1" + ) + assert "H1" in result.hypotheses + + def test_verify_high_confidence(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_ultra_thought("s1", _make_thought(), phase="verify", confidence=0.9) + assert result.verification_status == "supported" + + def test_verify_medium_confidence(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_ultra_thought("s1", _make_thought(), phase="verify", confidence=0.5) + assert result.verification_status == "partially_supported" + + def test_verify_low_confidence(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_ultra_thought("s1", _make_thought(), phase="verify", confidence=0.2) + assert result.verification_status == "unsupported" + + def test_synthesize_phase(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_ultra_thought( + "s1", _make_thought(thought_number=1), phase="hypothesize", hypothesis="H1" + ) + engine.process_ultra_thought( + "s1", _make_thought(thought_number=2), phase="hypothesize", hypothesis="H2" + ) + result = engine.process_ultra_thought( + "s1", _make_thought(thought_number=3), phase="synthesize" + ) + assert "Synthesis" in result.synthesis + + def test_multiple_hypotheses(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + for i, h in enumerate(["H1", "H2", "H3"], start=1): + engine.process_ultra_thought( + "s1", _make_thought(thought_number=i), phase="hypothesize", hypothesis=h + ) + result = engine.process_ultra_thought( + "s1", _make_thought(thought_number=4), phase="explore" + ) + assert "H1" in result.hypotheses + assert "H2" in result.hypotheses + assert "H3" in result.hypotheses + + +class TestLearningLoop: + def test_record_learning(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.record_learning("s1", "divide_conquer", ["success"], 0.8, ["insight1"]) + assert result.success + assert result.learnings_extracted == 1 + + def test_learning_persisted(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "divide_conquer", ["success"], 0.8, ["insight1"]) + engine2 = ThinkingEngine(thinking_dir) + assert len(engine2._learnings) >= 1 + + def test_strategy_score_updated(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "divide_conquer", ["success"], 0.8, ["insight1"]) + score = engine._strategy_scores["divide_conquer"] + assert score.usage_count == 1 + assert score.avg_reward == pytest.approx(0.8) + + def test_multiple_learnings_same_strategy(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "divide_conquer", ["success"], 0.8, ["i1"]) + engine.record_learning("s2", "divide_conquer", ["partial"], 0.4, ["i2"]) + score = engine._strategy_scores["divide_conquer"] + assert score.avg_reward == pytest.approx(0.6) + + +class TestSelfImprove: + def test_no_learnings(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + recs = engine.get_strategy_recommendations() + assert recs == [] + + def test_recommendations_sorted(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "low", [], 0.2, []) + engine.record_learning("s2", "mid", [], 0.5, []) + engine.record_learning("s3", "high", [], 0.9, []) + recs = engine.get_strategy_recommendations() + assert recs[0].strategy == "high" + assert recs[1].strategy == "mid" + assert recs[2].strategy == "low" + + def test_top_k_limit(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + for i in range(5): + engine.record_learning(f"s{i}", f"strat{i}", [], float(i) / 10, []) + recs = engine.get_strategy_recommendations(top_k=2) + assert len(recs) == 2 + + +class TestRewardThinking: + def test_apply_reward(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "strat", [], 0.3, []) + result = engine.apply_reward("s1", 0.5) + assert result.success + assert result.new_reward == pytest.approx(0.5) + + def test_apply_reward_no_session(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.apply_reward("nonexistent", 0.5) + assert result.success is False + + def test_cumulative_reward(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "strat", [], 0.3, []) + result = engine.apply_reward("s1", 0.2) + assert result.cumulative_reward == pytest.approx(0.5) + + +class TestPersistence: + def test_strategy_persisted(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "persist_strat", [], 0.7, []) + engine2 = ThinkingEngine(thinking_dir) + assert "persist_strat" in engine2._strategy_scores + + def test_memory_file_created(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "strat", [], 0.5, []) + assert (thinking_dir / "thinking_memory.jsonl").exists() + + +# --- Helper to set up hypotheses for evidence tests --- + + +def _setup_hypotheses(engine: ThinkingEngine, session_id: str, hypotheses: list[str]) -> None: + """Add hypotheses to a session via ultra_thinking.""" + for i, h in enumerate(hypotheses, start=1): + engine.process_ultra_thought( + session_id, + _make_thought(thought_number=i, total_thoughts=len(hypotheses)), + phase="hypothesize", + hypothesis=h, + ) + + +class TestEvidenceTracker: + def test_add_evidence_to_hypothesis(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1", "H2"]) + result = engine.add_evidence("s1", 0, "Found in auth.py", "code_ref", 0.8) + assert result.success + assert result.hypothesis_index == 0 + assert result.hypothesis_text == "H1" + assert result.total_evidence_count == 1 + assert result.cumulative_strength == pytest.approx(0.8) + assert result.effort_mode == "medium" + + def test_add_evidence_no_session(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.add_evidence("nonexistent", 0, "text", "data_point", 0.5) + assert result.success is False + assert "No hypotheses" in (result.message or "") + + def test_add_evidence_invalid_index(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + result = engine.add_evidence("s1", 5, "text", "data_point", 0.5) + assert result.success is False + assert "out of range" in (result.message or "") + + def test_add_evidence_no_hypotheses(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_thought("s1", _make_thought()) + result = engine.add_evidence("s1", 0, "text", "data_point", 0.5) + assert result.success is False + + def test_list_evidence(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + engine.add_evidence("s1", 0, "ev1", "code_ref", 0.7) + engine.add_evidence("s1", 0, "ev2", "data_point", 0.9) + result = engine.get_evidence("s1", 0) + assert result.success + assert result.total_evidence_count == 2 + assert result.evidence[0].text == "ev1" + assert result.evidence[1].text == "ev2" + + def test_list_evidence_empty(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + result = engine.get_evidence("s1", 0) + assert result.success + assert result.total_evidence_count == 0 + assert result.cumulative_strength == pytest.approx(0.0) + + def test_cumulative_strength(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + engine.add_evidence("s1", 0, "ev1", "code_ref", 0.6) + engine.add_evidence("s1", 0, "ev2", "data_point", 0.8) + result = engine.add_evidence("s1", 0, "ev3", "external", 1.0) + assert result.cumulative_strength == pytest.approx((0.6 + 0.8 + 1.0) / 3) + + def test_multiple_hypotheses_evidence(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1", "H2"]) + engine.add_evidence("s1", 0, "ev-a", "code_ref", 0.5) + engine.add_evidence("s1", 1, "ev-b", "assumption", 0.3) + result_0 = engine.get_evidence("s1", 0) + result_1 = engine.get_evidence("s1", 1) + assert result_0.total_evidence_count == 1 + assert result_1.total_evidence_count == 1 + + def test_all_evidence_types(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + for etype in ["code_ref", "data_point", "external", "assumption", "test_result"]: + result = engine.add_evidence("s1", 0, f"ev-{etype}", etype, 0.5) + assert result.success, f"Failed for type {etype}" + + def test_invalid_evidence_type(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + result = engine.add_evidence("s1", 0, "text", "invalid_type", 0.5) + assert result.success is False + assert "Invalid evidence_type" in (result.message or "") + + def test_strength_clamped(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + result = engine.add_evidence("s1", 0, "strong", "data_point", 1.5) + assert result.success + assert result.evidence[0].strength == pytest.approx(1.0) + result2 = engine.add_evidence("s1", 0, "weak", "data_point", -0.5) + assert result2.evidence[1].strength == pytest.approx(0.0) + + def test_low_effort_skips_type_validation(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + result = engine.add_evidence( + "s1", 0, "text", "bogus_type", 0.5, effort_mode="low" + ) + assert result.success + assert result.evidence[0].evidence_type == "data_point" + assert result.effort_mode == "low" + + def test_high_effort_validates_type(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + result = engine.add_evidence( + "s1", 0, "text", "bad", 0.5, effort_mode="high" + ) + assert result.success is False + + +class TestPremortem: + def test_describe_plan_phase(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_premortem( + "s1", _make_thought(), phase="describe_plan", plan="Migrate DB" + ) + assert result.success + assert result.phase == "describe_plan" + assert result.plan_description == "Migrate DB" + assert result.effort_mode == "medium" + + def test_imagine_failure_phase(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_premortem("s1", _make_thought(), phase="describe_plan", plan="My plan") + result = engine.process_premortem( + "s1", _make_thought(thought_number=2), + phase="imagine_failure", failure_scenario="Data loss", + ) + assert result.success + assert result.failure_scenario == "Data loss" + assert result.plan_description == "My plan" + + def test_identify_causes_adds_risk(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_premortem( + "s1", _make_thought(), phase="identify_causes", + risk_description="No backup", likelihood=0.7, impact=0.9, + ) + assert result.success + assert len(result.risks) == 1 + assert result.risks[0].risk_score == pytest.approx(0.7 * 0.9) + + def test_identify_causes_requires_description(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_premortem("s1", _make_thought(), phase="identify_causes") + assert result.success is False + assert "risk_description is required" in (result.message or "") + + def test_rank_risks_by_score(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_premortem( + "s1", _make_thought(thought_number=1), + phase="identify_causes", risk_description="Low", likelihood=0.2, impact=0.3, + ) + engine.process_premortem( + "s1", _make_thought(thought_number=2), + phase="identify_causes", risk_description="High", likelihood=0.9, impact=0.9, + ) + result = engine.process_premortem( + "s1", _make_thought(thought_number=3), phase="rank_risks", + ) + assert result.ranked_risks[0].description == "High" + assert result.ranked_risks[1].description == "Low" + + def test_mitigate_risk(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_premortem( + "s1", _make_thought(thought_number=1), + phase="identify_causes", risk_description="Risk A", likelihood=0.5, impact=0.5, + ) + result = engine.process_premortem( + "s1", _make_thought(thought_number=2), + phase="mitigate", risk_index=0, mitigation="Add backups", + ) + assert result.success + assert result.risks[0].mitigation == "Add backups" + assert result.mitigations_count == 1 + + def test_mitigate_invalid_index(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_premortem( + "s1", _make_thought(), phase="identify_causes", + risk_description="R", likelihood=0.5, impact=0.5, + ) + result = engine.process_premortem( + "s1", _make_thought(thought_number=2), + phase="mitigate", risk_index=5, mitigation="nope", + ) + assert result.success is False + assert "out of range" in (result.message or "") + + def test_mitigate_requires_risk_index(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_premortem( + "s1", _make_thought(), phase="identify_causes", + risk_description="R", likelihood=0.5, impact=0.5, + ) + result = engine.process_premortem( + "s1", _make_thought(thought_number=2), phase="mitigate", mitigation="fix", + ) + assert result.success is False + assert "risk_index is required" in (result.message or "") + + def test_invalid_phase(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_premortem("s1", _make_thought(), phase="bad_phase") + assert result.success is False + assert "Invalid phase" in (result.message or "") + + def test_likelihood_impact_clamped(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_premortem( + "s1", _make_thought(), phase="identify_causes", + risk_description="R", likelihood=1.5, impact=-0.3, + ) + assert result.risks[0].likelihood == pytest.approx(1.0) + assert result.risks[0].impact == pytest.approx(0.0) + assert result.risks[0].risk_score == pytest.approx(0.0) + + def test_effort_mode_passed_through(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_premortem( + "s1", _make_thought(), phase="describe_plan", + plan="p", effort_mode="high", + ) + assert result.effort_mode == "high" + + def test_full_flow(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + r1 = engine.process_premortem( + "s1", _make_thought(thought_number=1, total_thoughts=5), + phase="describe_plan", plan="Deploy auth", + ) + assert r1.success + r2 = engine.process_premortem( + "s1", _make_thought(thought_number=2, total_thoughts=5), + phase="imagine_failure", failure_scenario="Tokens rejected", + ) + assert r2.success + r3 = engine.process_premortem( + "s1", _make_thought(thought_number=3, total_thoughts=5), + phase="identify_causes", risk_description="Format mismatch", + likelihood=0.6, impact=0.9, + ) + assert r3.success + r4 = engine.process_premortem( + "s1", _make_thought(thought_number=4, total_thoughts=5), phase="rank_risks", + ) + assert len(r4.ranked_risks) == 1 + r5 = engine.process_premortem( + "s1", _make_thought(thought_number=5, total_thoughts=5, next_thought_needed=False), + phase="mitigate", risk_index=0, mitigation="Backward-compat parsing", + ) + assert r5.mitigations_count == 1 + + +class TestInversionThinking: + def test_define_goal(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_inversion( + "s1", _make_thought(), phase="define_goal", goal="Ship on time", + ) + assert result.success + assert result.goal == "Ship on time" + assert result.effort_mode == "medium" + + def test_invert_auto_generates(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _make_thought(), phase="define_goal", goal="Ship on time", + ) + result = engine.process_inversion( + "s1", _make_thought(thought_number=2), phase="invert", + ) + assert result.success + assert "guarantee failure" in result.inverted_goal + + def test_invert_custom(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _make_thought(), phase="define_goal", goal="Ship on time", + ) + result = engine.process_inversion( + "s1", _make_thought(thought_number=2), phase="invert", + inverted_goal="How to guarantee we miss the deadline", + ) + assert result.inverted_goal == "How to guarantee we miss the deadline" + + def test_list_failure_causes(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_inversion( + "s1", _make_thought(), phase="list_failure_causes", + failure_cause="No testing", severity=0.8, + ) + assert result.success + assert len(result.failure_causes) == 1 + assert result.failure_causes[0].description == "No testing" + assert result.failure_causes[0].severity == pytest.approx(0.8) + + def test_list_failure_causes_requires_cause(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_inversion( + "s1", _make_thought(), phase="list_failure_causes", + ) + assert result.success is False + assert "failure_cause is required" in (result.message or "") + + def test_rank_causes(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _make_thought(thought_number=1), phase="list_failure_causes", + failure_cause="Low sev", severity=0.2, + ) + engine.process_inversion( + "s1", _make_thought(thought_number=2), phase="list_failure_causes", + failure_cause="High sev", severity=0.9, + ) + result = engine.process_inversion( + "s1", _make_thought(thought_number=3), phase="rank_causes", + ) + assert result.success + assert result.ranked_causes[0].description == "High sev" + assert result.ranked_causes[1].description == "Low sev" + + def test_rank_causes_blocked_in_low_effort(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _make_thought(), phase="list_failure_causes", + failure_cause="C1", severity=0.5, effort_mode="low", + ) + result = engine.process_inversion( + "s1", _make_thought(thought_number=2), phase="rank_causes", + effort_mode="low", + ) + assert result.success is False + assert "not available in low effort" in (result.message or "") + + def test_reinvert(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _make_thought(), phase="list_failure_causes", + failure_cause="No testing", severity=0.8, + ) + result = engine.process_inversion( + "s1", _make_thought(thought_number=2), phase="reinvert", + cause_index=0, inverted_action="Add comprehensive test suite", + ) + assert result.success + assert result.failure_causes[0].inverted_action == "Add comprehensive test suite" + + def test_reinvert_requires_cause_index(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _make_thought(), phase="list_failure_causes", + failure_cause="C1", severity=0.5, + ) + result = engine.process_inversion( + "s1", _make_thought(thought_number=2), phase="reinvert", + ) + assert result.success is False + assert "cause_index is required" in (result.message or "") + + def test_reinvert_invalid_index(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _make_thought(), phase="list_failure_causes", + failure_cause="C1", severity=0.5, + ) + result = engine.process_inversion( + "s1", _make_thought(thought_number=2), phase="reinvert", + cause_index=99, + ) + assert result.success is False + assert "out of range" in (result.message or "") + + def test_action_plan(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_inversion( + "s1", _make_thought(), phase="action_plan", + action_item="Write integration tests", + ) + assert result.success + assert "Write integration tests" in result.action_plan + + def test_action_plan_high_effort_auto_populate(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _make_thought(thought_number=1), phase="list_failure_causes", + failure_cause="No tests", severity=0.8, effort_mode="high", + ) + engine.process_inversion( + "s1", _make_thought(thought_number=2), phase="reinvert", + cause_index=0, inverted_action="Add tests", effort_mode="high", + ) + result = engine.process_inversion( + "s1", _make_thought(thought_number=3), phase="action_plan", + effort_mode="high", + ) + assert result.success + assert "Add tests" in result.action_plan + + def test_invalid_phase(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_inversion("s1", _make_thought(), phase="bad") + assert result.success is False + assert "Invalid phase" in (result.message or "") + + def test_severity_clamped(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_inversion( + "s1", _make_thought(), phase="list_failure_causes", + failure_cause="C", severity=2.0, + ) + assert result.failure_causes[0].severity == pytest.approx(1.0) + + def test_full_flow(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + r1 = engine.process_inversion( + "s1", _make_thought(thought_number=1, total_thoughts=6), + phase="define_goal", goal="Launch v2", + ) + assert r1.success + r2 = engine.process_inversion( + "s1", _make_thought(thought_number=2, total_thoughts=6), + phase="invert", + ) + assert r2.success + r3 = engine.process_inversion( + "s1", _make_thought(thought_number=3, total_thoughts=6), + phase="list_failure_causes", failure_cause="Skip QA", severity=0.9, + ) + assert r3.success + r4 = engine.process_inversion( + "s1", _make_thought(thought_number=4, total_thoughts=6), + phase="rank_causes", + ) + assert len(r4.ranked_causes) == 1 + r5 = engine.process_inversion( + "s1", _make_thought(thought_number=5, total_thoughts=6), + phase="reinvert", cause_index=0, inverted_action="Mandatory QA gate", + ) + assert r5.success + r6 = engine.process_inversion( + "s1", _make_thought(thought_number=6, total_thoughts=6, next_thought_needed=False), + phase="action_plan", action_item="Enforce CI QA step", + ) + assert "Enforce CI QA step" in r6.action_plan + + +class TestEffortEstimator: + def test_add_estimate(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="Build API", + optimistic=2.0, likely=4.0, pessimistic=8.0, + ) + assert result.success + assert len(result.estimates) == 1 + assert result.estimates[0].task == "Build API" + assert result.effort_mode == "medium" + + def test_pert_calculation(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="T1", + optimistic=1.0, likely=3.0, pessimistic=5.0, + ) + # PERT = (1 + 4*3 + 5) / 6 = 18/6 = 3.0 + assert result.estimates[0].pert_estimate == pytest.approx(3.0) + # std_dev = (5 - 1) / 6 ≈ 0.667 + assert result.estimates[0].std_dev == pytest.approx(4.0 / 6.0) + + def test_confidence_intervals(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="T1", + optimistic=1.0, likely=3.0, pessimistic=5.0, + ) + est = result.estimates[0] + assert est.confidence_68_low == pytest.approx(est.pert_estimate - est.std_dev) + assert est.confidence_68_high == pytest.approx(est.pert_estimate + est.std_dev) + assert est.confidence_95_low == pytest.approx(est.pert_estimate - 2 * est.std_dev) + assert est.confidence_95_high == pytest.approx(est.pert_estimate + 2 * est.std_dev) + + def test_add_requires_task(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", optimistic=1.0, likely=2.0, pessimistic=3.0, + ) + assert result.success is False + assert "task name is required" in (result.message or "") + + def test_pessimistic_must_be_gte_optimistic(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="T1", + optimistic=5.0, likely=3.0, pessimistic=1.0, + ) + assert result.success is False + assert "pessimistic must be >= optimistic" in (result.message or "") + + def test_multiple_estimates_total(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_estimate( + "s1", action="add", task="T1", + optimistic=1.0, likely=2.0, pessimistic=3.0, + ) + result = engine.process_estimate( + "s1", action="add", task="T2", + optimistic=2.0, likely=4.0, pessimistic=6.0, + ) + assert len(result.estimates) == 2 + assert result.total_pert == pytest.approx( + result.estimates[0].pert_estimate + result.estimates[1].pert_estimate + ) + + def test_summary_action(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_estimate( + "s1", action="add", task="T1", + optimistic=1.0, likely=2.0, pessimistic=3.0, + ) + result = engine.process_estimate("s1", action="summary") + assert result.success + assert len(result.estimates) == 1 + + def test_clear_action(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_estimate( + "s1", action="add", task="T1", + optimistic=1.0, likely=2.0, pessimistic=3.0, + ) + result = engine.process_estimate("s1", action="clear") + assert result.success + assert "cleared" in (result.message or "").lower() + + def test_invalid_action(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate("s1", action="bad") + assert result.success is False + assert "Invalid action" in (result.message or "") + + def test_low_effort_single_point(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="T1", + optimistic=0.0, likely=5.0, pessimistic=0.0, + effort_mode="low", + ) + assert result.success + est = result.estimates[0] + assert est.pert_estimate == pytest.approx(5.0) + assert est.optimistic == pytest.approx(5.0) + assert est.pessimistic == pytest.approx(5.0) + assert result.total_std_dev == pytest.approx(0.0) + + def test_medium_effort_has_68_ci(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="T1", + optimistic=1.0, likely=3.0, pessimistic=5.0, + effort_mode="medium", + ) + assert result.total_confidence_68_low != 0.0 + assert result.total_confidence_68_high != 0.0 + # Medium does not populate 95% CI + assert result.total_confidence_95_low == pytest.approx(0.0) + + def test_high_effort_has_95_ci(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="T1", + optimistic=1.0, likely=3.0, pessimistic=5.0, + effort_mode="high", + ) + assert result.total_confidence_68_low != 0.0 + assert result.total_confidence_95_low != 0.0 + assert result.total_confidence_95_high != 0.0 + + def test_total_std_dev_is_rss(self, thinking_dir: Path) -> None: + """Total std_dev should be root-sum-square of individual std_devs.""" + engine = ThinkingEngine(thinking_dir) + engine.process_estimate( + "s1", action="add", task="T1", + optimistic=1.0, likely=2.0, pessimistic=5.0, + ) + result = engine.process_estimate( + "s1", action="add", task="T2", + optimistic=2.0, likely=4.0, pessimistic=8.0, + ) + expected = ( + result.estimates[0].std_dev ** 2 + + result.estimates[1].std_dev ** 2 + ) ** 0.5 + assert result.total_std_dev == pytest.approx(expected) + + +class TestInvalidEffortModeRejected: + """Verify that invalid effort_mode values are rejected by engine methods.""" + + def test_evidence_tracker_rejects_invalid_mode(self, thinking_dir: Path) -> None: + engine = ThinkingEngine(thinking_dir) + _setup_hypotheses(engine, "s1", ["H1"]) + engine.add_evidence("s1", 0, "text", "data_point", 0.5, effort_mode="bad") + # Engine-level add_evidence doesn't validate effort_mode itself; + # validation is at the MCP tool layer. Test that directly via + # the VALID_EFFORT_MODES constant. + from cocoindex_code.thinking_tools import VALID_EFFORT_MODES + + assert "bad" not in VALID_EFFORT_MODES + assert "low" in VALID_EFFORT_MODES + assert "medium" in VALID_EFFORT_MODES + assert "high" in VALID_EFFORT_MODES + assert "ultra" in VALID_EFFORT_MODES + + def test_premortem_rejects_invalid_effort_mode(self, thinking_dir: Path) -> None: + """Engine method still works but MCP layer would reject 'bogus'.""" + ThinkingEngine(thinking_dir) # Verify engine can be created + from cocoindex_code.thinking_tools import VALID_EFFORT_MODES + + assert "bogus" not in VALID_EFFORT_MODES + + def test_valid_effort_modes_are_frozenset(self) -> None: + from cocoindex_code.thinking_tools import VALID_EFFORT_MODES + + assert isinstance(VALID_EFFORT_MODES, frozenset) + assert VALID_EFFORT_MODES == {"low", "medium", "high", "ultra"} + + +class TestMCPEffortModeValidation: + """Test that MCP tool wrappers reject invalid effort_mode.""" + + @pytest.mark.asyncio + async def test_evidence_tracker_rejects_invalid_effort_mode( + self, thinking_dir: Path, + ) -> None: + from cocoindex_code.thinking_tools import ( + VALID_EFFORT_MODES, + EvidenceTrackerResult, + ) + + # Simulate what the MCP wrapper does + effort_mode = "nonsense" + if effort_mode not in VALID_EFFORT_MODES: + result = EvidenceTrackerResult( + success=False, + effort_mode=effort_mode, + message=f"Invalid effort_mode '{effort_mode}'", + ) + assert result.success is False + assert "Invalid effort_mode" in (result.message or "") + + @pytest.mark.asyncio + async def test_premortem_rejects_invalid_effort_mode( + self, thinking_dir: Path, + ) -> None: + from cocoindex_code.thinking_tools import ( + VALID_EFFORT_MODES, + PremortemResult, + ) + + effort_mode = "turbo" + if effort_mode not in VALID_EFFORT_MODES: + result = PremortemResult( + success=False, + effort_mode=effort_mode, + message=f"Invalid effort_mode '{effort_mode}'", + ) + assert result.success is False + assert "Invalid effort_mode" in (result.message or "") + + @pytest.mark.asyncio + async def test_inversion_rejects_invalid_effort_mode( + self, thinking_dir: Path, + ) -> None: + from cocoindex_code.thinking_tools import ( + VALID_EFFORT_MODES, + InversionThinkingResult, + ) + + effort_mode = "max" + if effort_mode not in VALID_EFFORT_MODES: + result = InversionThinkingResult( + success=False, + effort_mode=effort_mode, + message=f"Invalid effort_mode '{effort_mode}'", + ) + assert result.success is False + assert "Invalid effort_mode" in (result.message or "") + + @pytest.mark.asyncio + async def test_effort_estimator_rejects_invalid_effort_mode( + self, thinking_dir: Path, + ) -> None: + from cocoindex_code.thinking_tools import ( + VALID_EFFORT_MODES, + EffortEstimatorResult, + ) + + effort_mode = "extreme" + if effort_mode not in VALID_EFFORT_MODES: + result = EffortEstimatorResult( + success=False, + effort_mode=effort_mode, + message=f"Invalid effort_mode '{effort_mode}'", + ) + assert result.success is False + assert "Invalid effort_mode" in (result.message or "") + + @pytest.mark.asyncio + async def test_plan_optimizer_rejects_invalid_effort_mode( + self, thinking_dir: Path, + ) -> None: + from cocoindex_code.thinking_tools import ( + VALID_EFFORT_MODES, + PlanOptimizerResult, + ) + + effort_mode = "11" + if effort_mode not in VALID_EFFORT_MODES: + result = PlanOptimizerResult( + success=False, + effort_mode=effort_mode, + message=f"Invalid effort_mode '{effort_mode}'", + ) + assert result.success is False + assert "Invalid effort_mode" in (result.message or "") + + +class TestMemoryCompaction: + """Test that thinking memory JSONL file gets compacted on load.""" + + def test_compaction_deduplicates_strategies(self, thinking_dir: Path) -> None: + """When file has many duplicate strategy entries, compaction deduplicates.""" + import json + + memory_file = thinking_dir / "thinking_memory.jsonl" + + # Write many duplicate strategy entries (simulating repeated saves) + with open(memory_file, "w", encoding="utf-8") as f: + for i in range(50): + entry = { + "type": "strategy", + "data": { + "strategy": "divide_conquer", + "total_reward": float(i), + "usage_count": i, + "avg_reward": 0.5, + "last_used": float(i), + }, + } + f.write(json.dumps(entry) + "\n") + + # Load — this should trigger compaction since 50 >> 1 unique strategy + engine = ThinkingEngine(thinking_dir) + assert len(engine._strategy_scores) == 1 + assert engine._strategy_scores["divide_conquer"].usage_count == 49 # last one wins + + # File should be compacted now + with open(memory_file) as f: + lines = [line.strip() for line in f if line.strip()] + assert len(lines) == 1 # Only one strategy entry after compaction + + def test_no_compaction_when_small_file(self, thinking_dir: Path) -> None: + """Small files should not trigger compaction.""" + engine = ThinkingEngine(thinking_dir) + engine.record_learning("s1", "strat1", ["ok"], 0.5, ["i1"]) + engine.record_learning("s2", "strat2", ["ok"], 0.7, ["i2"]) + + memory_file = thinking_dir / "thinking_memory.jsonl" + with open(memory_file) as f: + lines_before = len([line for line in f if line.strip()]) + + # Reload — should NOT compact because file is small + engine2 = ThinkingEngine(thinking_dir) + assert len(engine2._learnings) == 2 + assert len(engine2._strategy_scores) == 2 + + with open(memory_file) as f: + lines_after = len([line for line in f if line.strip()]) + assert lines_after == lines_before # No compaction happened diff --git a/tests/test_ultra_effort_mode.py b/tests/test_ultra_effort_mode.py new file mode 100644 index 0000000..dd56977 --- /dev/null +++ b/tests/test_ultra_effort_mode.py @@ -0,0 +1,349 @@ +"""Tests for ultra effort_mode across all thinking tools.""" + +from __future__ import annotations + +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest + +from cocoindex_code.thinking_tools import ( + PLAN_DIMENSIONS, + ThinkingEngine, + ThoughtData, +) + + +@pytest.fixture() +def thinking_dir(tmp_path: Path) -> Path: + return tmp_path + + +@pytest.fixture(autouse=True) +def _patch_config(thinking_dir: Path) -> Iterator[None]: + with ( + patch("cocoindex_code.thinking_tools.config") as mock_config, + patch("cocoindex_code.thinking_tools._engine", None), + ): + mock_config.index_dir = thinking_dir + yield + + +def _td( + thought: str = "t", + thought_number: int = 1, + total_thoughts: int = 10, + next_thought_needed: bool = True, +) -> ThoughtData: + return ThoughtData( + thought=thought, + thought_number=thought_number, + total_thoughts=total_thoughts, + next_thought_needed=next_thought_needed, + ) + + +class TestUltraEvidenceTracker: + """Ultra mode auto-boosts strength for code_ref/test_result.""" + + def test_auto_boost_code_ref( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + # Create an ultra_thinking session with a hypothesis + engine.process_ultra_thought("s1", _td(), phase="explore") + engine.process_ultra_thought( + "s1", _td(thought_number=2), + phase="hypothesize", hypothesis="H1", + ) + # Add evidence with low strength but code_ref type + result = engine.add_evidence( + "s1", 0, "Found in source code", + evidence_type="code_ref", + strength=0.3, + effort_mode="ultra", + ) + assert result.success + # Strength should be boosted to at least 0.9 + evidence = result.evidence + assert len(evidence) >= 1 + assert evidence[-1].strength >= 0.9 + + def test_auto_boost_test_result( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_ultra_thought("s1", _td(), phase="explore") + engine.process_ultra_thought( + "s1", _td(thought_number=2), + phase="hypothesize", hypothesis="H1", + ) + result = engine.add_evidence( + "s1", 0, "Test passes", + evidence_type="test_result", + strength=0.5, + effort_mode="ultra", + ) + assert result.success + assert result.evidence[-1].strength >= 0.9 + + def test_no_boost_for_data_point( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_ultra_thought("s1", _td(), phase="explore") + engine.process_ultra_thought( + "s1", _td(thought_number=2), + phase="hypothesize", hypothesis="H1", + ) + result = engine.add_evidence( + "s1", 0, "Just a data point", + evidence_type="data_point", + strength=0.3, + effort_mode="ultra", + ) + assert result.success + assert result.evidence[-1].strength == 0.3 + + +class TestUltraPremortem: + """Ultra mode auto-ranks + requires all mitigations.""" + + def test_auto_rank_at_identify_causes( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_premortem( + "s1", _td(), phase="describe_plan", + plan="Build a rocket", + ) + engine.process_premortem( + "s1", _td(thought_number=2), + phase="identify_causes", + risk_description="Engine failure", + likelihood=0.9, impact=0.9, + effort_mode="ultra", + ) + result = engine.process_premortem( + "s1", _td(thought_number=3), + phase="identify_causes", + risk_description="Fuel leak", + likelihood=0.3, impact=0.5, + effort_mode="ultra", + ) + assert result.success + # Ultra should auto-include ranked_risks + assert len(result.ranked_risks) == 2 + # Highest risk score first + assert result.ranked_risks[0].description == "Engine failure" + + def test_warn_unmitigated_risks( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_premortem( + "s1", _td(), phase="describe_plan", + plan="Build a rocket", + ) + engine.process_premortem( + "s1", _td(thought_number=2), + phase="identify_causes", + risk_description="Engine failure", + likelihood=0.9, impact=0.9, + ) + engine.process_premortem( + "s1", _td(thought_number=3), + phase="identify_causes", + risk_description="Fuel leak", + likelihood=0.3, impact=0.5, + ) + # Mitigate only one risk + result = engine.process_premortem( + "s1", _td(thought_number=4), + phase="mitigate", + risk_index=0, + mitigation="Add redundant engines", + effort_mode="ultra", + ) + assert result.success + # Should warn about unmitigated risks + assert result.message is not None + assert "1 risk(s) still lack mitigations" in result.message + + +class TestUltraInversion: + """Ultra mode auto-reinverts + auto-populates.""" + + def test_auto_reinvert_all_causes( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_inversion( + "s1", _td(), phase="define_goal", goal="Ship v2", + ) + engine.process_inversion( + "s1", _td(thought_number=2), phase="invert", + ) + engine.process_inversion( + "s1", _td(thought_number=3), + phase="list_failure_causes", + failure_cause="No testing", + ) + engine.process_inversion( + "s1", _td(thought_number=4), + phase="list_failure_causes", + failure_cause="No code review", + ) + # Ultra action_plan: should auto-reinvert causes + result = engine.process_inversion( + "s1", _td(thought_number=5), + phase="action_plan", + effort_mode="ultra", + ) + assert result.success + # Both causes should now have inverted_actions + for cause in result.failure_causes: + assert cause.inverted_action is not None + assert len(cause.inverted_action) > 0 + # Action plan should be auto-populated + assert len(result.action_plan) >= 2 + + +class TestUltraEffortEstimator: + """Ultra mode adds 99.7% CI + risk buffer.""" + + def test_99_ci_and_risk_buffer( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="Build feature", + optimistic=2.0, likely=5.0, pessimistic=12.0, + effort_mode="ultra", + ) + assert result.success + # 99.7% CI should be populated + assert result.total_confidence_99_low != 0.0 + assert result.total_confidence_99_high != 0.0 + # 99.7% CI should be wider than 95% CI + assert result.total_confidence_99_low < result.total_confidence_95_low + assert result.total_confidence_99_high > result.total_confidence_95_high + # Risk buffer should be pessimistic * 1.5 + assert result.total_risk_buffer == 12.0 * 1.5 + + def test_high_does_not_have_99_ci( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + result = engine.process_estimate( + "s1", action="add", task="Build feature", + optimistic=2.0, likely=5.0, pessimistic=12.0, + effort_mode="high", + ) + assert result.success + assert result.total_confidence_99_low == 0.0 + assert result.total_confidence_99_high == 0.0 + assert result.total_risk_buffer == 0.0 + + +class TestUltraPlanOptimizer: + """Ultra mode: auto-score missing dims, require variants.""" + + def test_blocks_recommend_without_variants( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _td(), + phase="submit_plan", + plan_text="1. Do something\n2. Do more\n", + ) + result = engine.process_plan_optimizer( + "s1", _td(thought_number=2), + phase="recommend", + effort_mode="ultra", + ) + assert not result.success + assert "requires at least one variant" in ( + result.message or "" + ) + + def test_auto_scores_missing_dimensions( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _td(), + phase="submit_plan", + plan_text="1. Build it\n2. Test it\n", + ) + # Only score 2 of 8 dimensions + engine.process_plan_optimizer( + "s1", _td(thought_number=2), + phase="analyze", + dimension="clarity", score=8.0, + ) + engine.process_plan_optimizer( + "s1", _td(thought_number=3), + phase="analyze", + dimension="simplicity", score=7.0, + ) + # Add a variant, score 1 dimension + engine.process_plan_optimizer( + "s1", _td(thought_number=4), + phase="add_variant", + variant_label="A", variant_name="Quick", + ) + engine.process_plan_optimizer( + "s1", _td(thought_number=5), + phase="score_variant", + variant_label="A", + dimension="clarity", score=9.0, + ) + # Recommend in ultra mode + result = engine.process_plan_optimizer( + "s1", _td(thought_number=6), + phase="recommend", + effort_mode="ultra", + ) + assert result.success + # All 8 dimensions should be present in analysis + assert len(result.analysis_scores) == len(PLAN_DIMENSIONS) + for dim in PLAN_DIMENSIONS: + assert dim in result.analysis_scores + # Unscored dims should be 0 + assert result.analysis_scores["correctness"] == 0.0 + assert result.analysis_scores["clarity"] == 8.0 + # Variant should also have all dims scored + assert len(result.variants[0].scores) == len(PLAN_DIMENSIONS) + assert result.variants[0].scores["clarity"] == 9.0 + assert result.variants[0].scores["completeness"] == 0.0 + + def test_medium_does_not_auto_score( + self, thinking_dir: Path, + ) -> None: + engine = ThinkingEngine(thinking_dir) + engine.process_plan_optimizer( + "s1", _td(), + phase="submit_plan", + plan_text="1. Build\n2. Test\n", + ) + engine.process_plan_optimizer( + "s1", _td(thought_number=2), + phase="analyze", + dimension="clarity", score=8.0, + ) + engine.process_plan_optimizer( + "s1", _td(thought_number=3), + phase="add_variant", + variant_label="A", variant_name="Quick", + ) + result = engine.process_plan_optimizer( + "s1", _td(thought_number=4), + phase="recommend", + effort_mode="medium", + ) + assert result.success + # Should only have 1 dimension scored + assert len(result.analysis_scores) == 1 diff --git a/uv.lock b/uv.lock index 4aa1994..9e03e64 100644 --- a/uv.lock +++ b/uv.lock @@ -368,20 +368,9 @@ dependencies = [ { name = "sqlite-vec" }, ] -[package.optional-dependencies] -dev = [ - { name = "mypy" }, - { name = "prek" }, - { name = "pytest" }, - { name = "pytest-asyncio" }, - { name = "pytest-cov" }, - { name = "ruff" }, -] - [package.dev-dependencies] dev = [ { name = "mypy" }, - { name = "prek" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -393,23 +382,15 @@ requires-dist = [ { name = "cocoindex", extras = ["litellm"], specifier = "==1.0.0a26" }, { name = "einops", specifier = ">=0.8.2" }, { name = "mcp", specifier = ">=1.0.0" }, - { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.0.0" }, { name = "numpy", specifier = ">=1.24.0" }, - { name = "prek", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "pydantic", specifier = ">=2.0.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, - { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, - { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" }, - { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "sentence-transformers", specifier = ">=2.2.0" }, { name = "sqlite-vec", specifier = ">=0.1.0" }, ] -provides-extras = ["dev"] [package.metadata.requires-dev] dev = [ { name = "mypy", specifier = ">=1.0.0" }, - { name = "prek", specifier = ">=0.1.0" }, { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", specifier = ">=0.21.0" }, { name = "pytest-cov", specifier = ">=4.0.0" }, @@ -1700,30 +1681,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] -[[package]] -name = "prek" -version = "0.3.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d3/f5/ee52def928dd1355c20bcfcf765e1e61434635c33f3075e848e7b83a157b/prek-0.3.2.tar.gz", hash = "sha256:dce0074ff1a21290748ca567b4bda7553ee305a8c7b14d737e6c58364a499364", size = 334229, upload-time = "2026-02-06T13:49:47.539Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/69/70a5fc881290a63910494df2677c0fb241d27cfaa435bbcd0de5cd2e2443/prek-0.3.2-py3-none-linux_armv6l.whl", hash = "sha256:4f352f9c3fc98aeed4c8b2ec4dbf16fc386e45eea163c44d67e5571489bd8e6f", size = 4614960, upload-time = "2026-02-06T13:50:05.818Z" }, - { url = "https://files.pythonhosted.org/packages/c0/15/a82d5d32a2207ccae5d86ea9e44f2b93531ed000faf83a253e8d1108e026/prek-0.3.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:4a000cfbc3a6ec7d424f8be3c3e69ccd595448197f92daac8652382d0acc2593", size = 4622889, upload-time = "2026-02-06T13:49:53.662Z" }, - { url = "https://files.pythonhosted.org/packages/89/75/ea833b58a12741397017baef9b66a6e443bfa8286ecbd645d14111446280/prek-0.3.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5436bdc2702cbd7bcf9e355564ae66f8131211e65fefae54665a94a07c3d450a", size = 4239653, upload-time = "2026-02-06T13:50:02.88Z" }, - { url = "https://files.pythonhosted.org/packages/10/b4/d9c3885987afac6e20df4cb7db14e3b0d5a08a77ae4916488254ebac4d0b/prek-0.3.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:0161b5f584f9e7f416d6cf40a17b98f17953050ff8d8350ec60f20fe966b86b6", size = 4595101, upload-time = "2026-02-06T13:49:49.813Z" }, - { url = "https://files.pythonhosted.org/packages/21/a6/1a06473ed83dbc898de22838abdb13954e2583ce229f857f61828384634c/prek-0.3.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4e641e8533bca38797eebb49aa89ed0e8db0e61225943b27008c257e3af4d631", size = 4521978, upload-time = "2026-02-06T13:49:41.266Z" }, - { url = "https://files.pythonhosted.org/packages/0c/5e/c38390d5612e6d86b32151c1d2fdab74a57913473193591f0eb00c894c21/prek-0.3.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfca1810d49d3f9ef37599c958c4e716bc19a1d78a7e88cbdcb332e0b008994f", size = 4829108, upload-time = "2026-02-06T13:49:44.598Z" }, - { url = "https://files.pythonhosted.org/packages/80/a6/cecce2ab623747ff65ed990bb0d95fa38449ee19b348234862acf9392fff/prek-0.3.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d69d754299a95a85dc20196f633232f306bee7e7c8cba61791f49ce70404ec", size = 5357520, upload-time = "2026-02-06T13:49:48.512Z" }, - { url = "https://files.pythonhosted.org/packages/a5/18/d6bcb29501514023c76d55d5cd03bdbc037737c8de8b6bc41cdebfb1682c/prek-0.3.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:539dcb90ad9b20837968539855df6a29493b328a1ae87641560768eed4f313b0", size = 4852635, upload-time = "2026-02-06T13:49:58.347Z" }, - { url = "https://files.pythonhosted.org/packages/1b/0a/ae46f34ba27ba87aea5c9ad4ac9cd3e07e014fd5079ae079c84198f62118/prek-0.3.2-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:1998db3d0cbe243984736c82232be51318f9192e2433919a6b1c5790f600b5fd", size = 4599484, upload-time = "2026-02-06T13:49:43.296Z" }, - { url = "https://files.pythonhosted.org/packages/1a/a9/73bfb5b3f7c3583f9b0d431924873928705cdef6abb3d0461c37254a681b/prek-0.3.2-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:07ab237a5415a3e8c0db54de9d63899bcd947624bdd8820d26f12e65f8d19eb7", size = 4657694, upload-time = "2026-02-06T13:50:01.074Z" }, - { url = "https://files.pythonhosted.org/packages/a7/bc/0994bc176e1a80110fad3babce2c98b0ac4007630774c9e18fc200a34781/prek-0.3.2-py3-none-musllinux_1_1_armv7l.whl", hash = "sha256:0ced19701d69c14a08125f14a5dd03945982edf59e793c73a95caf4697a7ac30", size = 4509337, upload-time = "2026-02-06T13:49:54.891Z" }, - { url = "https://files.pythonhosted.org/packages/f9/13/e73f85f65ba8f626468e5d1694ab3763111513da08e0074517f40238c061/prek-0.3.2-py3-none-musllinux_1_1_i686.whl", hash = "sha256:ffb28189f976fa111e770ee94e4f298add307714568fb7d610c8a7095cb1ce59", size = 4697350, upload-time = "2026-02-06T13:50:04.526Z" }, - { url = "https://files.pythonhosted.org/packages/14/47/98c46dcd580305b9960252a4eb966f1a7b1035c55c363f378d85662ba400/prek-0.3.2-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:f63134b3eea14421789a7335d86f99aee277cb520427196f2923b9260c60e5c5", size = 4955860, upload-time = "2026-02-06T13:49:56.581Z" }, - { url = "https://files.pythonhosted.org/packages/73/42/1bb4bba3ff47897df11e9dfd774027cdfa135482c961a54e079af0faf45a/prek-0.3.2-py3-none-win32.whl", hash = "sha256:58c806bd1344becd480ef5a5ba348846cc000af0e1fbe854fef91181a2e06461", size = 4267619, upload-time = "2026-02-06T13:49:39.503Z" }, - { url = "https://files.pythonhosted.org/packages/97/11/6665f47a7c350d83de17403c90bbf7a762ef50876ece456a86f64f46fbfb/prek-0.3.2-py3-none-win_amd64.whl", hash = "sha256:70114b48e9eb8048b2c11b4c7715ce618529c6af71acc84dd8877871a2ef71a6", size = 4624324, upload-time = "2026-02-06T13:49:45.922Z" }, - { url = "https://files.pythonhosted.org/packages/22/e7/740997ca82574d03426f897fd88afe3fc8a7306b8c7ea342a8bc1c538488/prek-0.3.2-py3-none-win_arm64.whl", hash = "sha256:9144d176d0daa2469a25c303ef6f6fa95a8df015eb275232f5cb53551ecefef0", size = 4336008, upload-time = "2026-02-06T13:49:52.27Z" }, -] - [[package]] name = "propcache" version = "0.4.1" @@ -2840,6 +2797,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/f4/39/590742415c3030551944edc2ddc273ea1fdfe8ffb2780992e824f1ebee98/torch-2.10.0-3-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b1d5e2aba4eb7f8e87fbe04f86442887f9167a35f092afe4c237dfcaaef6e328", size = 915632474, upload-time = "2026-03-11T14:15:13.666Z" }, + { url = "https://files.pythonhosted.org/packages/b6/8e/34949484f764dde5b222b7fe3fede43e4a6f0da9d7f8c370bb617d629ee2/torch-2.10.0-3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:0228d20b06701c05a8f978357f657817a4a63984b0c90745def81c18aedfa591", size = 915523882, upload-time = "2026-03-11T14:14:46.311Z" }, { url = "https://files.pythonhosted.org/packages/78/89/f5554b13ebd71e05c0b002f95148033e730d3f7067f67423026cc9c69410/torch-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3282d9febd1e4e476630a099692b44fdc214ee9bf8ee5377732d9d9dfe5712e4", size = 145992610, upload-time = "2026-01-21T16:25:26.327Z" }, { url = "https://files.pythonhosted.org/packages/ae/30/a3a2120621bf9c17779b169fc17e3dc29b230c29d0f8222f499f5e159aa8/torch-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a2f9edd8dbc99f62bc4dfb78af7bf89499bca3d753423ac1b4e06592e467b763", size = 915607863, upload-time = "2026-01-21T16:25:06.696Z" }, { url = "https://files.pythonhosted.org/packages/6f/3d/c87b33c5f260a2a8ad68da7147e105f05868c281c63d65ed85aa4da98c66/torch-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:29b7009dba4b7a1c960260fc8ac85022c784250af43af9fb0ebafc9883782ebd", size = 113723116, upload-time = "2026-01-21T16:25:21.916Z" },