diff --git a/api_types.py b/api_types.py index 6e8f445..03bb09d 100644 --- a/api_types.py +++ b/api_types.py @@ -469,6 +469,37 @@ class SearchHistoryCommitDetailResponse(TypedDict): ) +# --------------------------------------------------------------------------- +# find_dead_code Tool +# --------------------------------------------------------------------------- + + +class DeadCodeCandidate(TypedDict): + """A single dead-code candidate returned by find_dead_code.""" + + name: str + kind: str + file_path: str + line_start: int + line_end: int + confidence: float + reasons: list[str] + source_excerpt: str | None + + +class FindDeadCodeResponse(TypedDict): + """Response from the find_dead_code tool.""" + + status: Literal["ok"] + directory: str + candidates: list[DeadCodeCandidate] + count: int + scanned_symbols: int + total_symbols: int + limitations: list[str] + hint: NotRequired[str] + + # --------------------------------------------------------------------------- # Tool Response Union Types # --------------------------------------------------------------------------- @@ -481,5 +512,6 @@ class SearchHistoryCommitDetailResponse(TypedDict): | IndexCodebaseResponse | SearchDocsResponse | SearchHistoryResponse + | FindDeadCodeResponse | ErrorResponse ) diff --git a/pyproject.toml b/pyproject.toml index db69848..174fcfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "code-memory" -version = "1.0.29" +version = "1.0.30" description = "A deterministic, high-precision code intelligence MCP server" readme = "README.md" license = "MIT" diff --git a/queries.py b/queries.py index b998153..c8f8108 100644 --- a/queries.py +++ b/queries.py @@ -815,3 +815,350 @@ def _truncate_code(source_text: str, max_lines: int = 15, max_chars: int = 500) if len(truncated) > max_chars: truncated = truncated[:max_chars] return truncated + "\n// ... (truncated)" + + +# --------------------------------------------------------------------------- +# Dead-code detection +# --------------------------------------------------------------------------- + +# Symbol kinds eligible for dead-code analysis. +_DEAD_CODE_DEFAULT_KINDS: tuple[str, ...] = ("function", "method", "class") +_DEAD_CODE_ALLOWED_KINDS: frozenset[str] = frozenset(_DEAD_CODE_DEFAULT_KINDS) + +# Names that are almost always entry points or framework callbacks; never flag. +_DEAD_CODE_ENTRYPOINT_NAMES: frozenset[str] = frozenset({"main"}) + +# Languages where the AST reference extractor does NOT capture method calls +# made via member access (e.g. ``obj.method()``). Methods defined in files +# with these extensions get a confidence penalty and an explicit caveat. +_MEMBER_ACCESS_BLIND_EXTENSIONS: frozenset[str] = frozenset({ + ".js", ".jsx", ".ts", ".tsx", + ".go", ".rs", + ".kt", ".kts", + ".c", ".h", ".cpp", ".hpp", ".cc", ".cxx", +}) + +# Filenames that typically re-export public API (or list it via a string). +_REEXPORT_FILENAMES: frozenset[str] = frozenset({ + "__init__.py", "index.js", "index.jsx", "index.ts", "index.tsx", + "mod.rs", "lib.rs", +}) + + +def _is_test_path(path: str) -> bool: + """Return True if *path* looks like a test or fixture file.""" + import os + + norm = path.replace("\\", "/").lower() + basename = os.path.basename(norm) + + if basename == "conftest.py": + return True + if basename.startswith("test_"): + return True + if basename.endswith(( + "_test.py", + ".test.js", ".test.jsx", ".test.ts", ".test.tsx", + ".spec.js", ".spec.jsx", ".spec.ts", ".spec.tsx", + )): + return True + + parts = norm.split("/") + return any(seg in {"tests", "test", "__tests__", "spec", "specs"} for seg in parts) + + +def _is_excluded_from_dead_code( + name: str, kind: str, path: str, include_tests: bool +) -> tuple[bool, str | None]: + """Return ``(excluded, reason)``. + + Symbols matching exclusion rules are never reported as dead code, regardless + of reference count. Used to filter framework hooks, entry points, and + test fixtures that are invoked by mechanisms our reference extractor + can't observe. + """ + if not name or name.startswith(" 4 and name.startswith("__") and name.endswith("__"): + return True, "dunder method (Python protocol)" + + if name in _DEAD_CODE_ENTRYPOINT_NAMES: + return True, "common entry-point name" + + if kind == "file": + return True, "file-level fallback symbol" + + if not include_tests and _is_test_path(path): + return True, "in a test file (use include_tests=True to scan)" + + return False, None + + +def _has_decorator_above(path: str, line_start: int) -> bool: + """Best-effort check for a decorator on the line(s) immediately above. + + Returns True if a non-blank line above ``line_start`` begins with ``@``. + Falls back to False on read errors. This heuristic catches Python and + TypeScript decorators that aren't part of the symbol's source_text in + the index. + """ + try: + with open(path) as f: + lines = f.readlines() + except OSError: + return False + + # Walk upward past blank lines until we find content (or run out). + idx = line_start - 2 # convert to zero-indexed line above line_start + while idx >= 0: + stripped = lines[idx].strip() + if not stripped: + idx -= 1 + continue + return stripped.startswith("@") + return False + + +def _score_dead_code_candidate( + name: str, + kind: str, + path: str, + name_share_count: int, + has_decorator: bool, +) -> tuple[float, list[str]]: + """Return ``(confidence, reasons)`` for a candidate that has no external refs. + + Confidence is in [0.0, 0.99] — never claims absolute certainty since + dynamic dispatch, reflection, and string-based imports can always hide + usages. ``reasons`` is an ordered list of human-readable explanations. + """ + import os + + reasons: list[str] = ["No references found outside this symbol's own definition"] + confidence = 0.6 + + ext = os.path.splitext(path)[1].lower() + basename = os.path.basename(path) + + # ── Privacy ──────────────────────────────────────────────────────────── + if name.startswith("__") and not name.endswith("__"): + confidence += 0.25 + reasons.append(f"Name '{name}' is name-mangled (Python private)") + elif name.startswith("_"): + confidence += 0.2 + reasons.append(f"Underscore-prefixed name '{name}' suggests an internal helper") + else: + confidence -= 0.05 + reasons.append(f"Public name '{name}' may be part of an exported API; verify before removing") + + # ── Name uniqueness ──────────────────────────────────────────────────── + if name_share_count > 1: + confidence -= 0.3 + reasons.append( + f"Name '{name}' is shared by {name_share_count} symbols; " + "reference counts can't disambiguate which one is being used" + ) + + # ── Language / kind specific caveats ─────────────────────────────────── + if kind == "method": + if ext in _MEMBER_ACCESS_BLIND_EXTENSIONS: + confidence -= 0.3 + reasons.append( + f"Method in a {ext} file: calls via member access (obj.{name}()) " + "aren't captured by the reference index — verify manually" + ) + else: + reasons.append( + f"Method in a {ext or 'unknown-ext'} file: dynamic dispatch may hide some callers" + ) + elif kind == "class": + confidence -= 0.05 + reasons.append("Class: dynamic instantiation (reflection, string lookup) may hide usages") + + # ── Re-export files ──────────────────────────────────────────────────── + if basename in _REEXPORT_FILENAMES: + confidence -= 0.4 + reasons.append(f"Defined in {basename} — likely a re-export of a public API") + + # ── Decorators ───────────────────────────────────────────────────────── + if has_decorator: + confidence -= 0.25 + reasons.append("Decorated symbol — may be registered with a framework or DI system") + + confidence = max(0.0, min(0.99, confidence)) + return round(confidence, 3), reasons + + +def _source_excerpt(source_text: str | None, max_chars: int = 120) -> str | None: + """Return the first non-empty trimmed line of *source_text*, truncated.""" + if not source_text: + return None + for line in source_text.splitlines(): + trimmed = line.strip() + if trimmed: + if len(trimmed) > max_chars: + trimmed = trimmed[: max_chars - 3] + "..." + return trimmed + return None + + +def find_dead_code( + db, + *, + min_confidence: float = 0.5, + kinds: list[str] | None = None, + include_tests: bool = False, + top_k: int = 50, +) -> dict: + """Find symbols that look like dead code. + + Cross-references the ``symbols`` table against the ``references_`` table + to identify symbols with no reference outside their own body. Each + candidate is scored with a confidence in [0.0, 0.99] and a list of + human-readable reasons. + + Args: + db: Open ``sqlite3.Connection`` from ``db.get_db()``. + min_confidence: Lower bound on confidence (filters out low-signal hits). + kinds: Symbol kinds to consider; defaults to function/method/class. + include_tests: If True, also scan test files (default False). + top_k: Maximum candidates to return, sorted by confidence desc. + + Returns: + Dict with: + - candidates: list of result dicts with name, kind, file_path, + line_start, line_end, confidence, reasons, source_excerpt. + - scanned_symbols: count of symbols inspected after exclusions. + - total_symbols: total symbols of the requested kinds in the index. + - limitations: list of caveats explaining where false positives may + arise (e.g., languages where member access isn't tracked). + + This function is a heuristic — it cannot detect symbols invoked via + reflection, dynamic dispatch, string-based imports, or framework + registration. Treat results as candidates to investigate, not as a + definitive deletion list. + """ + import os + + requested_kinds = list(kinds) if kinds is not None else list(_DEAD_CODE_DEFAULT_KINDS) + if not requested_kinds: + return { + "candidates": [], + "scanned_symbols": 0, + "total_symbols": 0, + "limitations": [], + } + + placeholders = ",".join("?" * len(requested_kinds)) + rows = db.execute( + f""" + SELECT s.id, s.name, s.kind, s.file_id, f.path, + s.line_start, s.line_end, s.source_text + FROM symbols s + JOIN files f ON f.id = s.file_id + WHERE s.kind IN ({placeholders}) + """, + requested_kinds, + ).fetchall() + + total_symbols = len(rows) + if not rows: + return { + "candidates": [], + "scanned_symbols": 0, + "total_symbols": 0, + "limitations": [], + } + + # Count how many symbols share each name (across kinds in the index). + name_share: dict[str, int] = {} + for r in db.execute("SELECT name, COUNT(*) FROM symbols GROUP BY name").fetchall(): + name_share[r[0]] = r[1] + + # Prefetch all references and group by name; this avoids an N+1 query + # pattern across thousands of candidate symbols. Memory footprint is + # ~24 bytes per ref, well within reasonable limits even for large repos. + refs_by_name: dict[str, list[tuple[int, int]]] = {} + for r_name, f_id, ln in db.execute( + "SELECT symbol_name, file_id, line_number FROM references_" + ).fetchall(): + refs_by_name.setdefault(r_name, []).append((f_id, ln)) + + candidates: list[dict] = [] + scanned = 0 + seen_extensions: set[str] = set() + + for sid, name, kind, file_id, path, line_start, line_end, source_text in rows: + excluded, _exclusion_reason = _is_excluded_from_dead_code( + name, kind, path, include_tests + ) + if excluded: + continue + + scanned += 1 + seen_extensions.add(os.path.splitext(path)[1].lower()) + + # An "external" reference is any reference to this name that lives + # outside the symbol's own [line_start, line_end] body. References + # at the definition line and recursive self-calls are internal. + has_external = False + for ref_file_id, ref_line in refs_by_name.get(name, ()): + if not (ref_file_id == file_id and line_start <= ref_line <= line_end): + has_external = True + break + + if has_external: + continue + + confidence, reasons = _score_dead_code_candidate( + name=name, + kind=kind, + path=path, + name_share_count=name_share.get(name, 1), + has_decorator=_has_decorator_above(path, line_start), + ) + + if confidence < min_confidence: + continue + + candidates.append({ + "name": name, + "kind": kind, + "file_path": path, + "line_start": line_start, + "line_end": line_end, + "confidence": confidence, + "reasons": reasons, + "source_excerpt": _source_excerpt(source_text), + }) + + # Highest confidence first; break ties by file/line for stable output. + candidates.sort(key=lambda c: (-c["confidence"], c["file_path"], c["line_start"])) + + limitations: list[str] = [] + blind = sorted(seen_extensions & _MEMBER_ACCESS_BLIND_EXTENSIONS) + if blind: + limitations.append( + "Member-access calls (obj.method()) aren't tracked for: " + + ", ".join(blind) + + ". Methods in these files have lower confidence." + ) + shared = sum(1 for c in name_share.values() if c > 1) + if shared: + limitations.append( + f"{shared} symbol name(s) are reused across multiple definitions; " + "the reference index can't tell which definition a call resolves to." + ) + limitations.append( + "Dynamic dispatch, reflection, string-based imports, and " + "framework-registered callbacks may produce false positives." + ) + + return { + "candidates": candidates[:top_k], + "scanned_symbols": scanned, + "total_symbols": total_symbols, + "limitations": limitations, + } diff --git a/server.py b/server.py index a8608b9..f93f6f6 100644 --- a/server.py +++ b/server.py @@ -772,6 +772,158 @@ def search_history( return errors.format_error(e) +# ── Tool 5: find_dead_code ──────────────────────────────────────────────── +_DEAD_CODE_ALLOWED_KINDS = ("function", "method", "class") + + +@mcp.tool() +def find_dead_code( + directory: str, + min_confidence: float = 0.5, + kinds: list[str] | None = None, + include_tests: bool = False, + top_k: int = 50, +) -> api_types.FindDeadCodeResponse | api_types.ErrorResponse: + """USE THIS TOOL to find functions, methods, and classes that look like dead code (defined but never called). + + PREREQUISITE: This tool requires indexing. If results are empty or you haven't indexed this session, call index_codebase(directory) first. + + HOW IT WORKS: + Cross-references the indexed symbol table against the indexed reference table. + Any symbol with no reference outside its own definition body is flagged as a + candidate. Each candidate is scored with a confidence in [0.0, 0.99] and a + list of human-readable reasons explaining the verdict. + + TRIGGER - Call this tool when the user asks: + - "Find dead code / unused functions / unused classes" + - "What's not used in this codebase?" + - "Are there functions I can safely delete?" + - "Show me dead code in " + - "Find unreachable / orphaned code" + + HEURISTICS APPLIED: + - Excludes Python dunder methods (__init__, __call__, etc) — protocol methods + - Excludes 'main' — common entry point + - Excludes test files by default (override via include_tests=True) + - Excludes anonymous and file-level fallback symbols + - Lower confidence for methods in JS/TS/Go/Rust/C++/Kotlin (member-access + calls aren't captured by the reference index) + - Lower confidence for symbols defined in __init__.py / index.{js,ts} / + mod.rs (likely re-exports) + - Lower confidence for decorated symbols (likely framework-registered) + - Lower confidence when the name is shared across multiple symbols + + LIMITATIONS: + Cannot detect symbols invoked via reflection, dynamic dispatch, string-based + imports, or framework registration. Treat results as candidates to + investigate, NOT as a definitive deletion list. Always verify before + removing code. + + Do NOT use this tool for: + - Finding code definitions (use search_code with "definition") + - Finding where code is used (use search_code with "references") + - General code search (use search_code with "topic_discovery") + + Args: + directory: Path to the project directory to scan. + min_confidence: Minimum confidence (0.0-1.0) to include a candidate. + Default 0.5. Raise to filter aggressively. + kinds: Symbol kinds to scan. Default ['function', 'method', 'class']. + Allowed values: 'function', 'method', 'class'. + include_tests: If True, also scan symbols in test files. Default False. + top_k: Maximum candidates to return, sorted by confidence desc + (default 50, max 500). + + Returns: + Dict with: + - candidates: list, each containing name, kind, file_path, line_start, + line_end, confidence, reasons, source_excerpt. + - count: number of candidates returned. + - scanned_symbols: count of symbols inspected after exclusions. + - total_symbols: total symbols of the requested kinds in the index. + - limitations: list of caveats for interpreting the results. + """ + with logging_config.ToolLogger( + "find_dead_code", + directory=directory, + min_confidence=min_confidence, + kinds=kinds, + include_tests=include_tests, + top_k=top_k, + ) as log: + try: + directory_path = val.validate_directory(directory) + top_k_validated = val.validate_top_k(top_k, max_val=500, default=50) + + if not isinstance(min_confidence, (int, float)): + raise errors.ValidationError( + "min_confidence must be a number between 0.0 and 1.0", + {"provided_type": type(min_confidence).__name__}, + ) + if not (0.0 <= min_confidence <= 1.0): + raise errors.ValidationError( + "min_confidence must be between 0.0 and 1.0", + {"provided": min_confidence}, + ) + + kinds_validated: list[str] | None = None + if kinds is not None: + if not isinstance(kinds, list) or not all(isinstance(k, str) for k in kinds): + raise errors.ValidationError( + "kinds must be a list of strings", + {"allowed_values": list(_DEAD_CODE_ALLOWED_KINDS)}, + ) + if not kinds: + raise errors.ValidationError( + "kinds cannot be empty; omit the argument to use defaults", + {"allowed_values": list(_DEAD_CODE_ALLOWED_KINDS)}, + ) + invalid = [k for k in kinds if k not in _DEAD_CODE_ALLOWED_KINDS] + if invalid: + raise errors.ValidationError( + f"Invalid kind(s): {invalid}", + {"allowed_values": list(_DEAD_CODE_ALLOWED_KINDS)}, + ) + kinds_validated = list(dict.fromkeys(kinds)) # dedupe, preserve order + + database = db_mod.get_db(str(directory_path)) + + result = queries.find_dead_code( + database, + min_confidence=float(min_confidence), + kinds=kinds_validated, + include_tests=bool(include_tests), + top_k=top_k_validated, + ) + log.set_result_count(len(result["candidates"])) + + response = cast(api_types.FindDeadCodeResponse, { + "status": "ok", + "directory": str(directory_path), + "candidates": result["candidates"], + "count": len(result["candidates"]), + "scanned_symbols": result["scanned_symbols"], + "total_symbols": result["total_symbols"], + "limitations": result["limitations"], + }) + + if result["total_symbols"] == 0: + symbols_count = database.execute( + "SELECT COUNT(*) FROM symbols" + ).fetchone()[0] + if symbols_count == 0: + response["hint"] = ( # type: ignore[typeddict-unknown-key] + "Codebase may not be indexed. Call index_codebase(directory) first." + ) + + return response + + except errors.CodeMemoryError as e: + return e.to_dict() + except Exception as e: + return errors.format_error(e) + + # ── Entrypoint ──────────────────────────────────────────────────────────── def build_arg_parser() -> argparse.ArgumentParser: """Build and return the CLI argument parser for code-memory.""" diff --git a/tests/test_dead_code.py b/tests/test_dead_code.py new file mode 100644 index 0000000..693dd75 --- /dev/null +++ b/tests/test_dead_code.py @@ -0,0 +1,688 @@ +"""Tests for find_dead_code: query layer + server tool.""" + +from __future__ import annotations + +import sqlite3 +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import db as db_mod +import queries +from queries import ( + _has_decorator_above, + _is_excluded_from_dead_code, + _is_test_path, + _score_dead_code_candidate, + _source_excerpt, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_dead_code_db(tmp_path: Path) -> sqlite3.Connection: + """Create a SQLite DB with just the tables find_dead_code needs. + + Avoids loading sqlite-vec or the embedding model so unit tests stay fast. + """ + db_path = tmp_path / "test.db" + db = sqlite3.connect(db_path) + db.executescript( + """ + CREATE TABLE IF NOT EXISTS files ( + id INTEGER PRIMARY KEY, + path TEXT UNIQUE NOT NULL, + last_modified REAL NOT NULL, + file_hash TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS symbols ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + kind TEXT NOT NULL, + file_id INTEGER NOT NULL REFERENCES files(id), + line_start INTEGER NOT NULL, + line_end INTEGER NOT NULL, + parent_symbol_id INTEGER, + source_text TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS references_ ( + id INTEGER PRIMARY KEY, + symbol_name TEXT NOT NULL, + file_id INTEGER NOT NULL REFERENCES files(id), + line_number INTEGER NOT NULL + ); + """ + ) + db.commit() + return db + + +def _add_file(db: sqlite3.Connection, path: str) -> int: + cur = db.execute( + "INSERT INTO files (path, last_modified, file_hash) VALUES (?, ?, ?)", + (path, 0.0, "x"), + ) + db.commit() + return cur.lastrowid + + +def _add_symbol( + db: sqlite3.Connection, + file_id: int, + name: str, + kind: str, + line_start: int, + line_end: int, + source_text: str = "", + parent_id: int | None = None, +) -> int: + cur = db.execute( + """INSERT INTO symbols + (name, kind, file_id, line_start, line_end, parent_symbol_id, source_text) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (name, kind, file_id, line_start, line_end, parent_id, source_text), + ) + db.commit() + return cur.lastrowid + + +def _add_ref(db: sqlite3.Connection, file_id: int, name: str, line: int) -> None: + db.execute( + "INSERT INTO references_ (symbol_name, file_id, line_number) VALUES (?, ?, ?)", + (name, file_id, line), + ) + db.commit() + + +@pytest.fixture +def dc_db(temp_dir): + """Empty database with the schema find_dead_code needs.""" + db = _build_dead_code_db(temp_dir) + yield db + db.close() + + +# --------------------------------------------------------------------------- +# Helper: _is_test_path +# --------------------------------------------------------------------------- + + +class TestIsTestPath: + def test_test_prefix(self): + assert _is_test_path("/repo/tests/test_foo.py") is True + + def test_test_suffix(self): + assert _is_test_path("/repo/foo_test.py") is True + + def test_jest_spec(self): + assert _is_test_path("/repo/foo.spec.ts") is True + assert _is_test_path("/repo/foo.test.tsx") is True + + def test_conftest(self): + assert _is_test_path("/repo/conftest.py") is True + + def test_tests_directory(self): + assert _is_test_path("/repo/tests/sub/file.py") is True + assert _is_test_path("/repo/__tests__/file.js") is True + + def test_normal_file(self): + assert _is_test_path("/repo/src/foo.py") is False + + +# --------------------------------------------------------------------------- +# Helper: _is_excluded_from_dead_code +# --------------------------------------------------------------------------- + + +class TestIsExcluded: + def test_dunder_excluded(self): + excluded, reason = _is_excluded_from_dead_code( + "__init__", "method", "/x.py", False + ) + assert excluded is True + assert "dunder" in reason + + def test_short_double_underscore_not_dunder(self): + # __x is name-mangled, not a protocol method — keep it as a candidate + excluded, _ = _is_excluded_from_dead_code("__x", "function", "/x.py", False) + assert excluded is False + + def test_main_excluded(self): + excluded, reason = _is_excluded_from_dead_code( + "main", "function", "/x.py", False + ) + assert excluded is True + assert "entry-point" in reason + + def test_anonymous_excluded(self): + excluded, _ = _is_excluded_from_dead_code( + "", "function", "/x.py", False + ) + assert excluded is True + + def test_file_kind_excluded(self): + excluded, _ = _is_excluded_from_dead_code("foo", "file", "/x.py", False) + assert excluded is True + + def test_test_file_excluded_by_default(self): + excluded, _ = _is_excluded_from_dead_code( + "helper", "function", "/repo/tests/foo.py", False + ) + assert excluded is True + + def test_test_file_included_when_opted_in(self): + excluded, _ = _is_excluded_from_dead_code( + "helper", "function", "/repo/tests/foo.py", True + ) + assert excluded is False + + def test_normal_function_not_excluded(self): + excluded, _ = _is_excluded_from_dead_code( + "compute", "function", "/repo/src/x.py", False + ) + assert excluded is False + + +# --------------------------------------------------------------------------- +# Helper: _score_dead_code_candidate +# --------------------------------------------------------------------------- + + +class TestScoreDeadCodeCandidate: + def test_public_function_mentions_api_caveat(self): + conf, reasons = _score_dead_code_candidate( + "compute", "function", "/repo/src/x.py", 1, False, + ) + assert 0.0 < conf < 1.0 + assert any( + "public" in r.lower() or "exported api" in r.lower() for r in reasons + ) + + def test_private_higher_than_public(self): + public_conf, _ = _score_dead_code_candidate( + "compute", "function", "/repo/src/x.py", 1, False, + ) + private_conf, _ = _score_dead_code_candidate( + "_compute", "function", "/repo/src/x.py", 1, False, + ) + assert private_conf > public_conf + + def test_name_mangled_highest_privacy(self): + mangled_conf, reasons = _score_dead_code_candidate( + "__internal", "function", "/repo/src/x.py", 1, False, + ) + assert mangled_conf > 0.7 + assert any("name-mangled" in r.lower() for r in reasons) + + def test_shared_name_lowers_confidence(self): + unique_conf, _ = _score_dead_code_candidate( + "_helper", "function", "/repo/src/x.py", 1, False, + ) + shared_conf, reasons = _score_dead_code_candidate( + "_helper", "function", "/repo/src/x.py", 5, False, + ) + assert shared_conf < unique_conf + assert any("shared by 5" in r for r in reasons) + + def test_method_in_member_blind_lang_lower(self): + py_conf, _ = _score_dead_code_candidate( + "_helper", "method", "/repo/src/x.py", 1, False, + ) + js_conf, reasons = _score_dead_code_candidate( + "_helper", "method", "/repo/src/x.js", 1, False, + ) + assert js_conf < py_conf + assert any("member access" in r.lower() for r in reasons) + + def test_class_kind_slightly_lower(self): + fn_conf, _ = _score_dead_code_candidate( + "_X", "function", "/repo/src/x.py", 1, False, + ) + cls_conf, reasons = _score_dead_code_candidate( + "_X", "class", "/repo/src/x.py", 1, False, + ) + assert cls_conf < fn_conf + assert any("dynamic instantiation" in r.lower() for r in reasons) + + def test_init_py_lowers_confidence(self): + normal_conf, _ = _score_dead_code_candidate( + "compute", "function", "/repo/src/foo.py", 1, False, + ) + init_conf, reasons = _score_dead_code_candidate( + "compute", "function", "/repo/src/__init__.py", 1, False, + ) + assert init_conf < normal_conf + assert any("__init__.py" in r for r in reasons) + + def test_decorator_lowers_confidence(self): + plain_conf, _ = _score_dead_code_candidate( + "_compute", "function", "/repo/src/x.py", 1, False, + ) + deco_conf, reasons = _score_dead_code_candidate( + "_compute", "function", "/repo/src/x.py", 1, True, + ) + assert deco_conf < plain_conf + assert any("decorat" in r.lower() for r in reasons) + + def test_confidence_clamped_to_range(self): + # All penalties stacked: should still be in [0, 0.99] + conf, _ = _score_dead_code_candidate( + "Foo", "method", "/repo/src/__init__.py", 50, True, + ) + assert 0.0 <= conf <= 0.99 + + def test_confidence_never_exceeds_99(self): + # All boosts: should still be capped at 0.99 + conf, _ = _score_dead_code_candidate( + "_internal", "function", "/repo/src/x.py", 1, False, + ) + assert conf <= 0.99 + + +# --------------------------------------------------------------------------- +# Helper: _source_excerpt +# --------------------------------------------------------------------------- + + +class TestSourceExcerpt: + def test_first_nonempty_line(self): + assert _source_excerpt("\n\ndef foo():\n return 1") == "def foo():" + + def test_truncates_long_line(self): + long_line = "x" * 200 + out = _source_excerpt(long_line) + assert out is not None + assert len(out) <= 120 + + def test_none_for_empty(self): + assert _source_excerpt(None) is None + assert _source_excerpt("") is None + assert _source_excerpt("\n\n \n") is None + + +# --------------------------------------------------------------------------- +# Helper: _has_decorator_above +# --------------------------------------------------------------------------- + + +class TestHasDecoratorAbove: + def test_with_decorator(self, temp_dir): + f = temp_dir / "x.py" + f.write_text("@decorator\ndef foo():\n pass\n") + assert _has_decorator_above(str(f), 2) is True + + def test_no_decorator(self, temp_dir): + f = temp_dir / "x.py" + f.write_text("def foo():\n pass\n") + assert _has_decorator_above(str(f), 1) is False + + def test_blank_lines_skipped(self, temp_dir): + f = temp_dir / "x.py" + f.write_text("@decorator\n\n\ndef foo():\n pass\n") + assert _has_decorator_above(str(f), 4) is True + + def test_missing_file_returns_false(self, temp_dir): + assert _has_decorator_above(str(temp_dir / "missing.py"), 1) is False + + +# --------------------------------------------------------------------------- +# find_dead_code: core behavior +# --------------------------------------------------------------------------- + + +class TestFindDeadCodeBasics: + def test_dead_function_flagged(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "compute", "function", 1, 3, "def compute():\n return 1") + # Self-reference at the def line is internal, no external refs anywhere + _add_ref(dc_db, fid, "compute", 1) + + result = queries.find_dead_code(dc_db) + assert len(result["candidates"]) == 1 + c = result["candidates"][0] + assert c["name"] == "compute" + assert c["kind"] == "function" + assert 0.0 < c["confidence"] <= 0.99 + assert any("No references" in r for r in c["reasons"]) + + def test_alive_function_not_flagged(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "compute", "function", 1, 3, "def compute():\n return 1") + _add_ref(dc_db, fid, "compute", 1) + # External call at line 10 outside the function body + _add_ref(dc_db, fid, "compute", 10) + + result = queries.find_dead_code(dc_db) + assert result["candidates"] == [] + + def test_recursive_function_still_flagged(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "compute", "function", 1, 5, "def compute():\n compute()") + _add_ref(dc_db, fid, "compute", 1) # def line + _add_ref(dc_db, fid, "compute", 2) # recursion within body — still internal + + result = queries.find_dead_code(dc_db) + assert len(result["candidates"]) == 1 + assert result["candidates"][0]["name"] == "compute" + + def test_method_called_from_sibling_alive(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + cls_id = _add_symbol(dc_db, fid, "Foo", "class", 1, 10, "class Foo: ...") + _add_symbol(dc_db, fid, "bar", "method", 2, 4, "def bar(self): ...", parent_id=cls_id) + _add_symbol(dc_db, fid, "baz", "method", 5, 7, "def baz(self): ...", parent_id=cls_id) + # bar referenced from baz (line 6) — outside bar's [2,4] range + _add_ref(dc_db, fid, "bar", 2) + _add_ref(dc_db, fid, "bar", 6) + _add_ref(dc_db, fid, "baz", 5) + _add_ref(dc_db, fid, "Foo", 1) + + result = queries.find_dead_code(dc_db) + names = {c["name"] for c in result["candidates"]} + assert "bar" not in names + assert "baz" in names + assert "Foo" in names + + +# --------------------------------------------------------------------------- +# find_dead_code: exclusions +# --------------------------------------------------------------------------- + + +class TestFindDeadCodeExclusions: + def test_dunder_excluded(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "__init__", "method", 1, 3, "def __init__(self): ...") + _add_ref(dc_db, fid, "__init__", 1) + + assert queries.find_dead_code(dc_db)["candidates"] == [] + + def test_main_excluded(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "main", "function", 1, 3, "def main(): ...") + _add_ref(dc_db, fid, "main", 1) + + assert queries.find_dead_code(dc_db)["candidates"] == [] + + def test_test_files_excluded_by_default(self, dc_db): + fid = _add_file(dc_db, "/repo/tests/test_foo.py") + _add_symbol(dc_db, fid, "helper", "function", 1, 3, "def helper(): ...") + _add_ref(dc_db, fid, "helper", 1) + + assert queries.find_dead_code(dc_db)["candidates"] == [] + + def test_test_files_included_when_opted_in(self, dc_db): + fid = _add_file(dc_db, "/repo/tests/test_foo.py") + _add_symbol(dc_db, fid, "helper", "function", 1, 3, "def helper(): ...") + _add_ref(dc_db, fid, "helper", 1) + + result = queries.find_dead_code(dc_db, include_tests=True) + assert len(result["candidates"]) == 1 + + def test_anonymous_excluded(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.js") + _add_symbol(dc_db, fid, "", "function", 5, 7, "() => 1") + + assert queries.find_dead_code(dc_db)["candidates"] == [] + + def test_file_fallback_kind_excluded(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.unknown") + _add_symbol(dc_db, fid, "foo.unknown", "file", 1, 5, "...") + # 'file' isn't in the default kinds anyway, but the exclusion guards + # against an explicit kinds=['file'] request as well. + result = queries.find_dead_code(dc_db, kinds=["file"]) + assert result["candidates"] == [] + + +# --------------------------------------------------------------------------- +# find_dead_code: filters and shape +# --------------------------------------------------------------------------- + + +class TestFindDeadCodeFilters: + def test_min_confidence_filters(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "compute", "function", 1, 3, "def compute(): ...") + _add_ref(dc_db, fid, "compute", 1) + + assert len(queries.find_dead_code(dc_db, min_confidence=0.0)["candidates"]) == 1 + assert queries.find_dead_code(dc_db, min_confidence=0.99)["candidates"] == [] + + def test_kinds_filter(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + cls_id = _add_symbol(dc_db, fid, "Foo", "class", 1, 5, "class Foo: ...") + _add_symbol(dc_db, fid, "bar", "method", 2, 4, "def bar(self): ...", parent_id=cls_id) + _add_ref(dc_db, fid, "Foo", 1) + _add_ref(dc_db, fid, "bar", 2) + + method_only = queries.find_dead_code(dc_db, kinds=["method"]) + assert {c["name"] for c in method_only["candidates"]} == {"bar"} + + class_only = queries.find_dead_code(dc_db, kinds=["class"]) + assert {c["name"] for c in class_only["candidates"]} == {"Foo"} + + def test_top_k_caps_results(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + for i in range(20): + name = f"_dead_{i}" + _add_symbol(dc_db, fid, name, "function", i * 5 + 1, i * 5 + 3, f"def {name}(): ...") + _add_ref(dc_db, fid, name, i * 5 + 1) + + result = queries.find_dead_code(dc_db, top_k=5) + assert len(result["candidates"]) == 5 + + def test_empty_kinds_returns_empty(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "compute", "function", 1, 3, "def compute(): ...") + _add_ref(dc_db, fid, "compute", 1) + + result = queries.find_dead_code(dc_db, kinds=[]) + assert result["candidates"] == [] + assert result["total_symbols"] == 0 + + +class TestFindDeadCodeShape: + def test_response_shape(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "_internal", "function", 1, 3, "def _internal(): ...") + _add_ref(dc_db, fid, "_internal", 1) + + result = queries.find_dead_code(dc_db) + for top_key in ("candidates", "scanned_symbols", "total_symbols", "limitations"): + assert top_key in result + + c = result["candidates"][0] + for key in ( + "name", "kind", "file_path", "line_start", "line_end", + "confidence", "reasons", "source_excerpt", + ): + assert key in c + assert isinstance(c["reasons"], list) + assert all(isinstance(r, str) for r in c["reasons"]) + + def test_limitations_includes_member_access_caveat_for_js(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.js") + _add_symbol(dc_db, fid, "_helper", "method", 1, 3, "function _helper() {}") + _add_ref(dc_db, fid, "_helper", 1) + + result = queries.find_dead_code(dc_db) + assert any("member-access" in lim.lower() for lim in result["limitations"]) + + def test_sorted_by_confidence_desc(self, dc_db): + fid = _add_file(dc_db, "/repo/foo.py") + _add_symbol(dc_db, fid, "_priv_a", "function", 1, 3, "def _priv_a(): ...") + _add_ref(dc_db, fid, "_priv_a", 1) + _add_symbol(dc_db, fid, "public_b", "function", 5, 7, "def public_b(): ...") + _add_ref(dc_db, fid, "public_b", 5) + + result = queries.find_dead_code(dc_db, min_confidence=0.0) + confidences = [c["confidence"] for c in result["candidates"]] + assert confidences == sorted(confidences, reverse=True) + + def test_no_candidates_returns_empty_list_and_counts(self, dc_db): + # Empty DB: nothing to scan + result = queries.find_dead_code(dc_db) + assert result["candidates"] == [] + assert result["total_symbols"] == 0 + assert result["scanned_symbols"] == 0 + + +# --------------------------------------------------------------------------- +# find_dead_code: cross-file behavior +# --------------------------------------------------------------------------- + + +class TestFindDeadCodeCrossFile: + def test_external_ref_in_different_file(self, dc_db): + f1 = _add_file(dc_db, "/repo/a.py") + f2 = _add_file(dc_db, "/repo/b.py") + _add_symbol(dc_db, f1, "compute", "function", 1, 3, "def compute(): ...") + _add_ref(dc_db, f1, "compute", 1) # def line in a.py + _add_ref(dc_db, f2, "compute", 5) # used in b.py — external + + assert queries.find_dead_code(dc_db)["candidates"] == [] + + def test_shared_name_alive_via_either_caller(self, dc_db): + # Two same-named definitions in different files; if either has any + # non-self reference, both end up alive (the reference index can't + # disambiguate by signature). + f1 = _add_file(dc_db, "/repo/a.py") + f2 = _add_file(dc_db, "/repo/b.py") + _add_symbol(dc_db, f1, "process", "function", 1, 3, "def process(): ...") + _add_symbol(dc_db, f2, "process", "function", 1, 3, "def process(): ...") + _add_ref(dc_db, f1, "process", 1) + _add_ref(dc_db, f2, "process", 1) + _add_ref(dc_db, f2, "process", 8) # call in b.py + + result = queries.find_dead_code(dc_db) + assert all(c["name"] != "process" for c in result["candidates"]) + + +# --------------------------------------------------------------------------- +# server.find_dead_code: input validation +# --------------------------------------------------------------------------- + + +class TestFindDeadCodeServerValidation: + def test_nonexistent_directory_returns_error(self): + import server + + result = server.find_dead_code("/nonexistent/path") + assert result.get("error") is True + assert "ValidationError" in result.get("error_type", "") + + def test_min_confidence_above_one_returns_error(self, temp_dir): + import server + + result = server.find_dead_code(str(temp_dir), min_confidence=1.5) + assert result.get("error") is True + assert "ValidationError" in result.get("error_type", "") + + def test_negative_min_confidence_returns_error(self, temp_dir): + import server + + result = server.find_dead_code(str(temp_dir), min_confidence=-0.1) + assert result.get("error") is True + + def test_invalid_kind_returns_error(self, temp_dir): + import server + + result = server.find_dead_code(str(temp_dir), kinds=["function", "variable"]) + assert result.get("error") is True + assert "ValidationError" in result.get("error_type", "") + + def test_empty_kinds_list_returns_error(self, temp_dir): + import server + + result = server.find_dead_code(str(temp_dir), kinds=[]) + assert result.get("error") is True + + def test_top_k_too_large_returns_error(self, temp_dir): + import server + + result = server.find_dead_code(str(temp_dir), top_k=10000) + assert result.get("error") is True + + +# --------------------------------------------------------------------------- +# server.find_dead_code: end-to-end via real db_mod.get_db +# --------------------------------------------------------------------------- + + +@pytest.fixture +def prepopulated_directory(temp_dir): + """Directory with a code_memory.db that bypasses embedding-model loading. + + Pre-creates the schema and an index_metadata row matching the configured + ``EMBEDDING_MODEL_NAME`` so ``db_mod.get_db()`` short-circuits the model + load on the next open. We pick a tiny embedding dimension (8) — find_dead_code + never reads the embedding tables, so the value is irrelevant beyond + making the schema valid. + """ + import sqlite_vec + + db_path = temp_dir / "code_memory.db" + conn = sqlite3.connect(db_path) + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + conn.executescript(db_mod._SCHEMA_SQL) + db_mod._create_embedding_tables(conn, 8) + conn.execute( + "INSERT INTO index_metadata (key, value) VALUES ('embedding_model', ?)", + (db_mod.EMBEDDING_MODEL_NAME,), + ) + conn.execute( + "INSERT INTO index_metadata (key, value) VALUES ('embedding_dim', ?)", + ("8",), + ) + conn.commit() + yield temp_dir, conn + conn.close() + + +class TestFindDeadCodeServerEndToEnd: + def test_returns_candidates(self, prepopulated_directory): + import server + + directory, conn = prepopulated_directory + src_path = str(directory / "src.py") + conn.execute( + "INSERT INTO files (path, last_modified, file_hash) VALUES (?, ?, ?)", + (src_path, 0.0, "h"), + ) + fid = conn.execute( + "SELECT id FROM files WHERE path = ?", (src_path,) + ).fetchone()[0] + conn.execute( + "INSERT INTO symbols (name, kind, file_id, line_start, line_end, " + "parent_symbol_id, source_text) VALUES (?, ?, ?, ?, ?, ?, ?)", + ("_dead_function", "function", fid, 1, 3, None, + "def _dead_function():\n pass"), + ) + conn.execute( + "INSERT INTO references_ (symbol_name, file_id, line_number) VALUES (?, ?, ?)", + ("_dead_function", fid, 1), + ) + conn.commit() + + result = server.find_dead_code(str(directory)) + assert result.get("status") == "ok" + assert result.get("count") == 1 + assert result["candidates"][0]["name"] == "_dead_function" + assert isinstance(result["limitations"], list) + assert "directory" in result + + def test_empty_index_returns_hint(self, prepopulated_directory): + import server + + directory, _ = prepopulated_directory + result = server.find_dead_code(str(directory)) + assert result.get("status") == "ok" + assert result.get("count") == 0 + assert "hint" in result + assert "index_codebase" in result["hint"] diff --git a/uv.lock b/uv.lock index 6c9723c..c50e728 100644 --- a/uv.lock +++ b/uv.lock @@ -109,7 +109,7 @@ wheels = [ [[package]] name = "code-memory" -version = "1.0.29" +version = "1.0.30" source = { editable = "." } dependencies = [ { name = "einops" },