diff --git a/.gitignore b/.gitignore index 8e979e9..3b7edae 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ __pycache__/ *.pyc .coverage .pytest_cache/ -.env \ No newline at end of file +.env +.DS_Store \ No newline at end of file diff --git a/plumb/cli.py b/plumb/cli.py index 253f012..c9fcd99 100644 --- a/plumb/cli.py +++ b/plumb/cli.py @@ -115,7 +115,7 @@ def _init_clone_setup(repo_root: Path, cfg: PlumbConfig) -> None: hooks_dir = repo_root / ".git" / "hooks" hooks_dir.mkdir(exist_ok=True) hook_path = hooks_dir / "pre-commit" - hook_path.write_text("#!/bin/sh\nplumb hook\nexit $?\n") + hook_path.write_text('#!/bin/sh\n[ "$PLUMB_SKIP" = "1" ] && exit 0\nplumb hook\nexit $?\n') hook_path.chmod(0o755) post_commit_path = hooks_dir / "post-commit" post_commit_path.write_text("#!/bin/sh\nplumb post-commit\n") @@ -249,7 +249,7 @@ def init(): hooks_dir = repo_root / ".git" / "hooks" hooks_dir.mkdir(exist_ok=True) hook_path = hooks_dir / "pre-commit" - hook_path.write_text("#!/bin/sh\nplumb hook\nexit $?\n") + hook_path.write_text('#!/bin/sh\n[ "$PLUMB_SKIP" = "1" ] && exit 0\nplumb hook\nexit $?\n') hook_path.chmod(0o755) post_commit_path = hooks_dir / "post-commit" post_commit_path.write_text("#!/bin/sh\nplumb post-commit\n") @@ -806,11 +806,14 @@ def map_tests(dry_run): console.print(f"Found {len(test_summaries)} test functions and {len(requirements)} requirements.") console.print("Running LLM mapping...") - from plumb.programs import configure_dspy, run_chunked_mapper + from plumb.programs import configure_dspy, run_chunked_mapper, get_program_lm, get_program_config from plumb.programs.test_mapper import TestMapper configure_dspy() mapper = TestMapper() + override_lm = get_program_lm("test_mapper") + prog_cfg = get_program_config("test_mapper") or {} + budget = prog_cfg.get("budget", 60000) req_json = json.dumps([{"id": r["id"], "text": r["text"]} for r in requirements]) items = [(s["name"], json.dumps(s)) for s in test_summaries] @@ -819,9 +822,16 @@ def _combine(chunk): return json.dumps([json.loads(t) for _, t in chunk]) try: - mappings = run_chunked_mapper( - mapper, req_json, items, budget=60000, combine_fn=_combine, - ) + if override_lm: + import dspy + with dspy.context(lm=override_lm): + mappings = run_chunked_mapper( + mapper, req_json, items, budget=budget, combine_fn=_combine, + ) + else: + mappings = run_chunked_mapper( + mapper, req_json, items, budget=budget, combine_fn=_combine, + ) except Exception as e: console.print(f"[red]Mapping failed: {e}[/red]") raise SystemExit(1) diff --git a/plumb/coverage_reporter.py b/plumb/coverage_reporter.py index 6b1beb8..3da9993 100644 --- a/plumb/coverage_reporter.py +++ b/plumb/coverage_reporter.py @@ -11,6 +11,7 @@ from rich.table import Table from plumb.config import load_config +from plumb.ignore import is_ignored, parse_plumbignore PLUMB_MARKER_RE = re.compile(r'#\s*plumb:(req-[a-f0-9]+)') FUNC_NAME_RE = re.compile(r'def test_req_([a-f0-9]+)_') @@ -29,6 +30,7 @@ def run_pytest_coverage(repo_root: str | Path) -> dict | None: result = subprocess.run( [ sys.executable, "-m", "pytest", + "-m", "not slow", "--cov=.", f"--cov-report=json:{cov_json}", "--cov-report=", @@ -118,11 +120,14 @@ def _collect_source_summaries(repo_root: Path) -> dict[str, str]: """ import ast + ignore_patterns = parse_plumbignore(repo_root) per_file: dict[str, str] = {} for item in sorted(repo_root.rglob("*.py")): rel = str(item.relative_to(repo_root)) if ".plumb" in rel or "test_" in item.name or rel.startswith("tests/"): continue + if is_ignored(rel, ignore_patterns): + continue try: content = item.read_text() except Exception: @@ -324,11 +329,14 @@ def check_spec_to_code_coverage( return (0, len(requirements)) # --- LLM mapping --- - from plumb.programs import configure_dspy, run_chunked_mapper + from plumb.programs import configure_dspy, run_chunked_mapper, get_program_lm, get_program_config from plumb.programs.code_coverage_mapper import CodeCoverageMapper configure_dspy() mapper = CodeCoverageMapper() + override_lm = get_program_lm("code_coverage_mapper", repo_root) + prog_cfg = get_program_config("code_coverage_mapper", repo_root) or {} + budget = prog_cfg.get("budget", 60000) if full_remap: dirty_reqs = requirements @@ -346,10 +354,18 @@ def check_spec_to_code_coverage( def _combine(chunk): return "\n\n".join(text for _, text in chunk) - results = run_chunked_mapper( - mapper, req_json, items, budget=60000, - combine_fn=_combine, merge_fn=merge_coverage_results, - ) + if override_lm: + import dspy + with dspy.context(lm=override_lm): + results = run_chunked_mapper( + mapper, req_json, items, budget=budget, + combine_fn=_combine, merge_fn=merge_coverage_results, + ) + else: + results = run_chunked_mapper( + mapper, req_json, items, budget=budget, + combine_fn=_combine, merge_fn=merge_coverage_results, + ) # Build fresh results dict from LLM output fresh_results: dict[str, dict] = {} diff --git a/plumb/ignore.py b/plumb/ignore.py index 4f21137..16e84e2 100644 --- a/plumb/ignore.py +++ b/plumb/ignore.py @@ -50,12 +50,17 @@ def is_ignored(filepath: str, patterns: list[str]) -> bool: - Exact match: ``README.md`` - Glob matched against the basename: ``*.txt`` - Directory prefix (pattern ends with ``/``): ``docs/`` matches ``docs/foo`` + - Glob directory prefix: ``.venv*/`` matches ``.venv3.10/foo`` """ basename = Path(filepath).name + top_dir = filepath.split("/")[0] for pat in patterns: if pat.endswith("/"): - # Directory prefix — match if filepath starts with the prefix - if filepath == pat.rstrip("/") or filepath.startswith(pat): + prefix = pat.rstrip("/") + # Directory prefix — exact startswith or fnmatch on top directory + if filepath == prefix or filepath.startswith(pat): + return True + if fnmatch(top_dir, prefix): return True else: # Exact full-path match or fnmatch against basename diff --git a/plumb/programs/__init__.py b/plumb/programs/__init__.py index 2cb7334..43f80f2 100644 --- a/plumb/programs/__init__.py +++ b/plumb/programs/__init__.py @@ -5,24 +5,42 @@ import dspy from dspy.adapters import XMLAdapter +from dspy.clients.base_lm import BaseLM from plumb import PlumbAuthError, PlumbInferenceError _configured = False +_NO_BACKEND_MSG = ( + "No LLM backend available.\n" + "Option 1: Set ANTHROPIC_API_KEY in .env or environment (direct API, fastest)\n" + "Option 2: Install Claude Code CLI — https://claude.ai/code (uses your subscription)" +) -def get_lm() -> dspy.LM: - return dspy.LM("anthropic/claude-sonnet-4-20250514", max_tokens=28000) + +def get_lm() -> BaseLM: + """Return the best available LM: direct API if ANTHROPIC_API_KEY is set, + otherwise Claude Code CLI if available.""" + if os.environ.get("ANTHROPIC_API_KEY"): + return dspy.LM("anthropic/claude-sonnet-4-20250514", max_tokens=28000) + + from plumb.programs.claude_code_lm import ClaudeCodeLM, find_claude_cli + + if find_claude_cli(): + return ClaudeCodeLM(model="sonnet", max_tokens=28000) + + raise PlumbAuthError(_NO_BACKEND_MSG) def configure_dspy() -> None: """Lazy DSPy configuration. No-op if already configured. - Never call at import time — ANTHROPIC_API_KEY absence would break + Never call at import time — missing auth would break non-LLM commands like plumb status.""" global _configured if _configured: return from dotenv import load_dotenv + load_dotenv(override=False) lm = get_lm() dspy.configure(lm=lm, adapter=XMLAdapter()) @@ -30,38 +48,35 @@ def configure_dspy() -> None: def validate_api_access() -> None: - """Check that ANTHROPIC_API_KEY is set and works. Loads .env first, then - falls back to exported environment variables. Performs a smoke test to - verify the key is valid. Raises PlumbAuthError if not found or invalid.""" + """Check that an LLM backend is available and working. + + Tries ANTHROPIC_API_KEY first (direct API), then falls back to the + Claude Code CLI. Performs a smoke test to verify the backend works. + Raises PlumbAuthError if neither is available or working. + """ from dotenv import load_dotenv load_dotenv(override=False) - if not os.environ.get("ANTHROPIC_API_KEY"): - raise PlumbAuthError( - "ANTHROPIC_API_KEY is not set. " - "Plumb requires a valid Anthropic API key to analyze commits.\n" - "Set it in a .env file or export it: export ANTHROPIC_API_KEY=your-key-here" - ) - - # Smoke test: verify the key actually works - lm = get_lm() + + lm = get_lm() # raises PlumbAuthError if no backend available + try: response = lm("Reply with only the word: hello") if not response: - raise PlumbAuthError("API returned empty response - key may be invalid") + raise PlumbAuthError("LLM returned empty response - backend may be misconfigured") + except PlumbAuthError: + raise except Exception as e: err_str = str(e).lower() if "auth" in err_str or "api key" in err_str or "401" in err_str: raise PlumbAuthError( f"ANTHROPIC_API_KEY is invalid or rejected: {e}" ) from e - raise PlumbAuthError( - f"Failed to verify API access: {e}" - ) from e + raise PlumbAuthError(f"Failed to verify LLM access: {e}") from e -def get_program_lm(program_name: str, repo_root: str | Path | None = None) -> dspy.LM | None: - """Return a per-program LM override from config, or None for the default.""" +def get_program_config(program_name: str, repo_root: str | Path | None = None) -> dict | None: + """Return the raw program_models entry for a program, or None.""" from plumb.config import find_repo_root, load_config if repo_root is None: @@ -71,14 +86,29 @@ def get_program_lm(program_name: str, repo_root: str | Path | None = None) -> ds cfg = load_config(repo_root) if cfg is None: return None - entry = cfg.program_models.get(program_name) + return cfg.program_models.get(program_name) + + +def get_program_lm(program_name: str, repo_root: str | Path | None = None) -> BaseLM | None: + """Return a per-program LM override from config, or None for the default.""" + entry = get_program_config(program_name, repo_root) if entry is None: return None model = entry.get("model") if not model: return None max_tokens = entry.get("max_tokens", 8192) - return dspy.LM(model, max_tokens=max_tokens) + + if os.environ.get("ANTHROPIC_API_KEY"): + return dspy.LM(model, max_tokens=max_tokens) + + from plumb.programs.claude_code_lm import ClaudeCodeLM, find_claude_cli + + if find_claude_cli(): + cli_model = model.removeprefix("anthropic/") + return ClaudeCodeLM(model=cli_model, max_tokens=max_tokens) + + return None def run_with_retries(fn, *args, max_retries: int = 2, **kwargs): @@ -94,6 +124,7 @@ def run_with_retries(fn, *args, max_retries: int = 2, **kwargs): raise PlumbAuthError( f"API key is invalid or rejected: {e}" ) from e + print(f"[retry {attempt+1}/{max_retries+1}] {type(e).__name__}: {e}") last_error = e raise PlumbInferenceError( f"LLM inference failed after {max_retries + 1} attempts: {last_error}" diff --git a/plumb/programs/claude_code_lm.py b/plumb/programs/claude_code_lm.py new file mode 100644 index 0000000..03d06d6 --- /dev/null +++ b/plumb/programs/claude_code_lm.py @@ -0,0 +1,154 @@ +"""DSPy BaseLM subclass that routes completions through the claude CLI. + +Uses ``claude -p`` (non-interactive print mode) so that users with a Claude +Code subscription can run plumb without a separate ANTHROPIC_API_KEY. + +Pattern adapted from tinaudio/skills@b0cbd3d. +""" + +from __future__ import annotations + +import os +import shutil +import subprocess +from types import SimpleNamespace +from typing import Any + +from dspy.clients.base_lm import BaseLM + +from plumb import PlumbInferenceError + + +def find_claude_cli() -> str | None: + """Return the path to the ``claude`` CLI binary, or None if not found.""" + return shutil.which("claude") + + +# GIT_* env vars that are safe to pass through to claude -p. +# Everything else starting with GIT_ is stripped to prevent claude -p's +# plugin init from corrupting a worktree's git index during pre-commit hooks. +# Pattern from pre-commit framework: +# https://github.com/pre-commit/pre-commit/blob/ec1928f37e8abd7bab0b7ed29a031e5fd8875be7/pre_commit/git.py#L27 +_GIT_ENV_WHITELIST = { + "GIT_EXEC_PATH", + "GIT_SSH", + "GIT_SSH_COMMAND", + "GIT_SSL_CAINFO", + "GIT_SSL_NO_VERIFY", + "GIT_CONFIG_COUNT", + "GIT_HTTP_PROXY_AUTHMETHOD", + "GIT_ALLOW_PROTOCOL", + "GIT_ASKPASS", +} + + +def _call_claude(prompt: str, model: str | None = None, timeout: int = 300) -> str: + """Run ``claude -p`` with *prompt* on stdin and return the text response. + + Strips ``CLAUDECODE`` and repo-local ``GIT_*`` env vars so that claude -p's + plugin init cannot corrupt a worktree's git index during pre-commit hooks. + See https://github.com/ktinubu/plumb/issues/1. + """ + cmd = ["claude", "-p", "--output-format", "text"] + if model: + cmd.extend(["--model", model]) + + env = { + k: v for k, v in os.environ.items() + if k != "CLAUDECODE" and ( + not k.startswith("GIT_") + or k.startswith(("GIT_CONFIG_KEY_", "GIT_CONFIG_VALUE_")) + or k in _GIT_ENV_WHITELIST + ) + } + + result = subprocess.run( + cmd, + input=prompt, + capture_output=True, + text=True, + env=env, + timeout=timeout, + ) + if result.returncode != 0: + raise RuntimeError( + f"claude -p exited {result.returncode}\nstderr: {result.stderr}" + ) + return result.stdout + + +def _serialize_messages( + prompt: str | None = None, + messages: list[dict[str, str]] | None = None, +) -> str: + """Convert a DSPy messages list into a single text prompt for the CLI. + + System messages get ```` tags, multi-turn conversations get + ``[role]`` prefixes. Single user messages are passed through as-is. + """ + if not messages: + return prompt or "" + + # Single user message — pass through without decoration + if len(messages) == 1 and messages[0].get("role") == "user": + return messages[0]["content"] + + parts: list[str] = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "system": + parts.append(f"\n{content}\n") + else: + parts.append(f"[{role}]\n{content}") + return "\n\n".join(parts) + + +def _make_response(text: str, model: str) -> SimpleNamespace: + """Build a minimal OpenAI-compatible response object for BaseLM.""" + return SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace(content=text, role="assistant"), + finish_reason="stop", + ) + ], + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + model=model, + ) + + +class ClaudeCodeLM(BaseLM): + """DSPy LM that routes completions through the ``claude`` CLI.""" + + def __init__( + self, + model: str = "sonnet", + max_tokens: int = 28000, + timeout: int = 300, + **kwargs: Any, + ): + super().__init__( + model=f"claude-code/{model}", + model_type="chat", + temperature=0.0, + max_tokens=max_tokens, + **kwargs, + ) + self.cli_model = model + self.timeout = timeout + + def forward( + self, + prompt: str | None = None, + messages: list[dict[str, Any]] | None = None, + **kwargs: Any, + ) -> SimpleNamespace: + import sys + + text_input = _serialize_messages(prompt, messages) + input_len = len(text_input) + print(f"[ClaudeCodeLM] Calling claude -p --model {self.cli_model} ({input_len} chars)...", file=sys.stderr) + response_text = _call_claude(text_input, model=self.cli_model, timeout=self.timeout) + print(f"[ClaudeCodeLM] Got response ({len(response_text)} chars)", file=sys.stderr) + return _make_response(response_text, self.model) diff --git a/plumb/programs/code_modifier.py b/plumb/programs/code_modifier.py index 53b844a..d569683 100644 --- a/plumb/programs/code_modifier.py +++ b/plumb/programs/code_modifier.py @@ -1,20 +1,34 @@ from __future__ import annotations import json +import os import re import anthropic from dotenv import load_dotenv +from plumb.programs.claude_code_lm import _call_claude, find_claude_cli + class CodeModifier: """Modify staged code to satisfy a rejected decision. Uses Anthropic API directly (not DSPy) because code modification - is inherently open-ended.""" + is inherently open-ended. Falls back to claude CLI when no API key.""" def __init__(self, client: anthropic.Anthropic | None = None): load_dotenv(override=False) - self.client = client or anthropic.Anthropic() + if client is not None: + self.client = client + self._use_cli = False + elif os.environ.get("ANTHROPIC_API_KEY"): + self.client = anthropic.Anthropic() + self._use_cli = False + elif find_claude_cli(): + self.client = None + self._use_cli = True + else: + self.client = anthropic.Anthropic() + self._use_cli = False def modify( self, @@ -50,6 +64,10 @@ def modify( }} ```""" + if self._use_cli: + text = _call_claude(prompt) + return self._parse_response(text) + response = self.client.messages.create( model="claude-sonnet-4-20250514", max_tokens=4096, diff --git a/plumb/sync.py b/plumb/sync.py index 076081e..0065440 100644 --- a/plumb/sync.py +++ b/plumb/sync.py @@ -146,7 +146,7 @@ def insert_new_sections( def parse_spec_files(repo_root: str | Path) -> list[dict]: """Read markdown spec files, run RequirementParser, assign stable IDs, write requirements.json.""" - from plumb.programs import configure_dspy, run_with_retries + from plumb.programs import configure_dspy, run_with_retries, get_program_lm from plumb.programs.requirement_parser import RequirementParser repo_root = Path(repo_root) @@ -169,6 +169,7 @@ def parse_spec_files(repo_root: str | Path) -> list[dict]: configure_dspy() parser = RequirementParser() + override_lm = get_program_lm("requirement_parser", repo_root) for spec_path_str in config.spec_paths: spec_path = repo_root / spec_path_str @@ -182,7 +183,12 @@ def parse_spec_files(repo_root: str | Path) -> list[dict]: for md_file in md_files: content = md_file.read_text() try: - parsed = run_with_retries(parser, content) + if override_lm: + import dspy + with dspy.context(lm=override_lm): + parsed = run_with_retries(parser, content) + else: + parsed = run_with_retries(parser, content) except Exception: continue diff --git a/pyproject.toml b/pyproject.toml index a36323d..185c682 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,3 +29,5 @@ packages = ["plumb"] [tool.pytest.ini_options] testpaths = ["tests"] +addopts = "-m 'not slow'" +markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/tests/test_claude_code_integration.py b/tests/test_claude_code_integration.py new file mode 100644 index 0000000..078a935 --- /dev/null +++ b/tests/test_claude_code_integration.py @@ -0,0 +1,59 @@ +"""Integration test for ClaudeCodeLM with a real claude -p call. + +Marked slow — skipped by default. Run with: pytest -m slow +Requires the claude CLI to be installed and authenticated. +""" + +import json +import shutil + +import dspy +import pytest +from dspy.adapters import XMLAdapter + +from plumb.programs.claude_code_lm import ClaudeCodeLM, find_claude_cli + +needs_claude_cli = pytest.mark.skipif( + shutil.which("claude") is None, + reason="claude CLI not installed", +) + + +@pytest.mark.slow +@needs_claude_cli +def test_claude_code_lm_parse_spec_single_file(): + """End-to-end: parse a tiny spec through ClaudeCodeLM → DSPy RequirementParser.""" + from plumb.programs.requirement_parser import RequirementParser + + lm = ClaudeCodeLM(model="sonnet", max_tokens=4000, timeout=60) + dspy.configure(lm=lm, adapter=XMLAdapter()) + + parser = RequirementParser() + + spec = """\ +# Widget API + +## Requirements + +The system must accept a widget name as a string. +The system must return a 400 error if the name is empty. +""" + + parsed = parser(markdown=spec) + assert len(parsed) >= 2, f"Expected at least 2 requirements, got {len(parsed)}" + + texts = [r.text.lower() for r in parsed] + assert any("name" in t for t in texts), f"No requirement mentions 'name': {texts}" + + +@pytest.mark.slow +@needs_claude_cli +def test_claude_code_lm_raw_call(): + """Smoke test: ClaudeCodeLM returns a non-empty response for a simple prompt.""" + lm = ClaudeCodeLM(model="sonnet", max_tokens=100, timeout=30) + + response = lm("Reply with only the word: hello") + assert response, "Got empty response from claude CLI" + assert isinstance(response, list) + assert len(response) > 0 + assert "hello" in response[0].lower() diff --git a/tests/test_claude_code_lm.py b/tests/test_claude_code_lm.py new file mode 100644 index 0000000..2ef67f7 --- /dev/null +++ b/tests/test_claude_code_lm.py @@ -0,0 +1,208 @@ +"""Tests for ClaudeCodeLM — DSPy BaseLM subclass that routes through claude CLI.""" + +import subprocess +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from plumb.programs.claude_code_lm import ( + ClaudeCodeLM, + _call_claude, + _make_response, + _serialize_messages, + find_claude_cli, +) + + +class TestFindClaudeCli: + def test_returns_path_when_found(self): + with patch("shutil.which", return_value="/usr/local/bin/claude"): + assert find_claude_cli() == "/usr/local/bin/claude" + + def test_returns_none_when_missing(self): + with patch("shutil.which", return_value=None): + assert find_claude_cli() is None + + +class TestSerializeMessages: + def test_prompt_only(self): + result = _serialize_messages(prompt="hello", messages=None) + assert result == "hello" + + def test_single_user_message(self): + msgs = [{"role": "user", "content": "hello"}] + result = _serialize_messages(prompt=None, messages=msgs) + assert result == "hello" + + def test_system_and_user_messages(self): + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ] + result = _serialize_messages(prompt=None, messages=msgs) + assert "\nYou are helpful.\n" in result + assert "hello" in result + + def test_multi_turn_with_assistant(self): + msgs = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hey"}, + {"role": "user", "content": "bye"}, + ] + result = _serialize_messages(prompt=None, messages=msgs) + assert "[user]\nhi" in result + assert "[assistant]\nhey" in result + assert "[user]\nbye" in result + + def test_empty_messages_falls_back_to_prompt(self): + result = _serialize_messages(prompt="fallback", messages=[]) + assert result == "fallback" + + +class TestMakeResponse: + def test_has_correct_structure(self): + resp = _make_response("hello world", "claude-sonnet") + assert resp.choices[0].message.content == "hello world" + assert resp.choices[0].message.role == "assistant" + assert resp.choices[0].finish_reason == "stop" + assert resp.model == "claude-sonnet" + + def test_usage_is_dictable(self): + resp = _make_response("text", "model") + usage = dict(resp.usage) + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + + +class TestCallClaude: + def test_success(self): + mock_result = subprocess.CompletedProcess( + args=["claude"], returncode=0, stdout="hello\n", stderr="" + ) + with patch("subprocess.run", return_value=mock_result) as mock_run: + result = _call_claude("say hello") + assert result == "hello\n" + args = mock_run.call_args + assert args[0][0][:2] == ["claude", "-p"] + assert "--output-format" in args[0][0] + assert "text" in args[0][0] + assert args[1]["input"] == "say hello" + + def test_strips_claudecode_env_var(self): + mock_result = subprocess.CompletedProcess( + args=["claude"], returncode=0, stdout="ok", stderr="" + ) + with patch("subprocess.run", return_value=mock_result) as mock_run, \ + patch.dict("os.environ", {"CLAUDECODE": "1", "PATH": "/usr/bin"}): + _call_claude("test") + env = mock_run.call_args[1]["env"] + assert "CLAUDECODE" not in env + assert "PATH" in env + + def test_passes_model_flag(self): + mock_result = subprocess.CompletedProcess( + args=["claude"], returncode=0, stdout="ok", stderr="" + ) + with patch("subprocess.run", return_value=mock_result) as mock_run: + _call_claude("test", model="opus") + cmd = mock_run.call_args[0][0] + assert "--model" in cmd + idx = cmd.index("--model") + assert cmd[idx + 1] == "opus" + + def test_raises_on_nonzero_exit(self): + mock_result = subprocess.CompletedProcess( + args=["claude"], returncode=1, stdout="", stderr="auth failed" + ) + with patch("subprocess.run", return_value=mock_result): + with pytest.raises(RuntimeError, match="auth failed"): + _call_claude("test") + + def test_raises_on_timeout(self): + with patch("subprocess.run", side_effect=subprocess.TimeoutExpired("claude", 300)): + with pytest.raises(subprocess.TimeoutExpired): + _call_claude("test") + + def test_strips_repo_local_git_env_vars(self): + """Repo-local GIT_* vars must be stripped to prevent claude -p + from corrupting a worktree's git index during pre-commit hooks. + Safe transport/config vars (GIT_SSH, GIT_CONFIG_*, etc.) are kept.""" + mock_result = subprocess.CompletedProcess( + args=["claude"], returncode=0, stdout="ok", stderr="" + ) + with patch("subprocess.run", return_value=mock_result) as mock_run, \ + patch.dict("os.environ", { + "GIT_INDEX_FILE": "/tmp/.git/worktrees/wt/index", + "GIT_DIR": "/tmp/.git/worktrees/wt", + "GIT_WORK_TREE": "/tmp/worktree", + "GIT_SSH_COMMAND": "ssh -i ~/.ssh/id_rsa", + "GIT_CONFIG_COUNT": "1", + "GIT_CONFIG_KEY_0": "user.name", + "GIT_CONFIG_VALUE_0": "Test", + "PATH": "/usr/bin", + }): + _call_claude("test") + env = mock_run.call_args[1]["env"] + # Repo-local vars stripped + assert "GIT_INDEX_FILE" not in env + assert "GIT_DIR" not in env + assert "GIT_WORK_TREE" not in env + # Transport/config vars kept + assert env["GIT_SSH_COMMAND"] == "ssh -i ~/.ssh/id_rsa" + assert env["GIT_CONFIG_COUNT"] == "1" + assert env["GIT_CONFIG_KEY_0"] == "user.name" + assert env["GIT_CONFIG_VALUE_0"] == "Test" + # Non-GIT vars kept + assert "PATH" in env + + +class TestClaudeCodeLM: + def test_is_base_lm_subclass(self): + import dspy + with patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + lm = ClaudeCodeLM() + assert isinstance(lm, dspy.BaseLM) + + def test_forward_calls_claude_cli(self): + mock_result = subprocess.CompletedProcess( + args=["claude"], returncode=0, stdout="response text", stderr="" + ) + with patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"), \ + patch("subprocess.run", return_value=mock_result): + lm = ClaudeCodeLM() + response = lm.forward(prompt="hello") + assert response.choices[0].message.content == "response text" + + def test_forward_serializes_messages(self): + mock_result = subprocess.CompletedProcess( + args=["claude"], returncode=0, stdout="answer", stderr="" + ) + messages = [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "What is 1+1?"}, + ] + with patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"), \ + patch("subprocess.run", return_value=mock_result) as mock_run: + lm = ClaudeCodeLM() + lm.forward(messages=messages) + stdin_input = mock_run.call_args[1]["input"] + assert "Be concise." in stdin_input + assert "What is 1+1?" in stdin_input + + def test_forward_raises_on_cli_error(self): + mock_result = subprocess.CompletedProcess( + args=["claude"], returncode=1, stdout="", stderr="error" + ) + with patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"), \ + patch("subprocess.run", return_value=mock_result): + lm = ClaudeCodeLM() + with pytest.raises(RuntimeError, match="error"): + lm.forward(prompt="hello") + + def test_model_name_stored(self): + with patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + lm = ClaudeCodeLM(model="opus") + assert lm.cli_model == "opus" + assert "claude-code/" in lm.model diff --git a/tests/test_cli.py b/tests/test_cli.py index a20b46e..3e3b72a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -50,6 +50,22 @@ def test_successful_init(self, runner, tmp_repo): hook = tmp_repo / ".git" / "hooks" / "pre-commit" assert os.access(str(hook), os.X_OK) + def test_pre_commit_hook_checks_plumb_skip(self, runner, tmp_repo): + """The pre-commit hook must exit 0 when PLUMB_SKIP=1 so users + can bypass Plumb in worktrees or automated scripts.""" + spec = tmp_repo / "spec.md" + spec.write_text("# Spec\n") + (tmp_repo / "tests").mkdir(exist_ok=True) + + with patch("plumb.cli.find_repo_root", return_value=tmp_repo), \ + patch("plumb.sync.parse_spec_files", return_value=[]): + runner.invoke(cli, ["init"], input="spec.md\ntests/\n") + + hook = tmp_repo / ".git" / "hooks" / "pre-commit" + content = hook.read_text() + assert 'PLUMB_SKIP' in content + assert 'exit 0' in content.split('PLUMB_SKIP')[1].split('\n')[0] + class TestInitPlumbignore: def test_init_creates_plumbignore(self, runner, tmp_repo): diff --git a/tests/test_program_model_overrides.py b/tests/test_program_model_overrides.py new file mode 100644 index 0000000..435c9dd --- /dev/null +++ b/tests/test_program_model_overrides.py @@ -0,0 +1,210 @@ +"""Tests that program_models config overrides actually reach the LLM call site. + +The contract: when a user puts an entry in program_models for a given program, +that LM — not the global default — must be the one that receives the prompt. + +These tests don't verify get_program_lm() in isolation (that's in test_programs.py). +They verify the end-to-end wiring: config → get_program_lm → dspy.context → program call. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import dspy +import pytest + +from plumb.config import PlumbConfig, save_config, ensure_plumb_dir + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_repo_with_override(tmp_repo: Path, program_name: str, model: str) -> Path: + """Set up a plumb repo with a single program_models override.""" + ensure_plumb_dir(tmp_repo) + cfg = PlumbConfig( + spec_paths=["spec.md"], + test_paths=["tests/"], + program_models={program_name: {"model": model}}, + ) + save_config(tmp_repo, cfg) + return tmp_repo + + +def _make_requirements_file(repo: Path, reqs: list[dict]) -> None: + """Write a requirements.json that check_spec_to_code_coverage expects.""" + req_path = repo / ".plumb" / "requirements.json" + req_path.write_text(json.dumps(reqs)) + + +def _make_source_file(repo: Path, name: str, content: str) -> None: + """Create a Python source file in the repo.""" + path = repo / name + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + + +# --------------------------------------------------------------------------- +# Core principle: the override LM must be the one that receives the call +# --------------------------------------------------------------------------- + + +class TestCoverageMapperUsesOverride: + """When program_models has 'code_coverage_mapper', coverage mapping + must use that LM, not the global default.""" + + def test_override_lm_receives_the_call(self, tmp_repo): + """The configured override LM should be invoked, not the default.""" + repo = _make_repo_with_override(tmp_repo, "code_coverage_mapper", "anthropic/claude-haiku-4-5-20251001") + _make_requirements_file(repo, [ + {"id": "req-1", "text": "The system must do X."}, + ]) + _make_source_file(repo, "app.py", "def do_x():\n pass\n") + + # Track which LM actually gets called + called_models = [] + + original_forward = dspy.Predict.forward + + def tracking_forward(self, **kwargs): + # Inside dspy.context, dspy.settings.lm reflects the active LM + active_lm = dspy.settings.lm + called_models.append(active_lm.model) + # Return a plausible result so the pipeline doesn't crash + from plumb.programs.code_coverage_mapper import RequirementCoverage + mock_result = MagicMock() + mock_result.coverage = [ + RequirementCoverage(requirement_id="req-1", implemented=False, evidence=""), + ] + return mock_result + + with patch.object(dspy.Predict, "forward", tracking_forward), \ + patch("plumb.programs.configure_dspy"), \ + patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + + from plumb.coverage_reporter import check_spec_to_code_coverage + check_spec_to_code_coverage(repo, use_llm=True) + + assert len(called_models) >= 1, "DSPy Predict was never called" + # The override model should have been used (ClaudeCodeLM strips 'anthropic/' prefix) + assert any("haiku" in m for m in called_models), ( + f"Expected haiku override to be active, but saw: {called_models}" + ) + +class TestTestMapperUsesOverride: + """When program_models has 'test_mapper', the test mapping command + must use that LM.""" + + def test_override_lm_receives_the_call(self, tmp_repo): + repo = _make_repo_with_override(tmp_repo, "test_mapper", "anthropic/claude-haiku-4-5-20251001") + + called_models = [] + + def tracking_forward(self, **kwargs): + active_lm = dspy.settings.lm + called_models.append(active_lm.model) + mock_result = MagicMock() + mock_result.mappings = [] + return mock_result + + with patch.object(dspy.Predict, "forward", tracking_forward), \ + patch("plumb.programs.configure_dspy"), \ + patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + + from plumb.programs import run_chunked_mapper, get_program_lm + from plumb.programs.test_mapper import TestMapper + + mapper = TestMapper() + override_lm = get_program_lm("test_mapper", repo) + + assert override_lm is not None, "Override should have been returned" + + req_json = json.dumps([{"id": "req-1", "text": "Must do X"}]) + items = [("test_foo", json.dumps({"name": "test_foo", "file": "tests/test_foo.py"}))] + + def _combine(chunk): + return json.dumps([json.loads(t) for _, t in chunk]) + + with dspy.context(lm=override_lm): + run_chunked_mapper(mapper, req_json, items, budget=60000, combine_fn=_combine) + + assert len(called_models) >= 1 + assert any("haiku" in m for m in called_models), ( + f"Expected haiku override, but saw: {called_models}" + ) + + +class TestRequirementParserUsesOverride: + """When program_models has 'requirement_parser', spec parsing + must use that LM.""" + + def test_override_lm_receives_the_call(self, tmp_repo): + repo = _make_repo_with_override(tmp_repo, "requirement_parser", "anthropic/claude-haiku-4-5-20251001") + + # Create a spec file the parser will read + spec = repo / "spec.md" + spec.write_text("# Spec\n\n## Features\n\nThe system must do X.\n") + + called_models = [] + + def tracking_forward(self, **kwargs): + active_lm = dspy.settings.lm + called_models.append(active_lm.model) + from plumb.programs.requirement_parser import ParsedRequirement + mock_result = MagicMock() + mock_result.requirements = [ + ParsedRequirement(text="The system must do X.", ambiguous=False), + ] + return mock_result + + with patch.object(dspy.Predict, "forward", tracking_forward), \ + patch("plumb.programs.configure_dspy"), \ + patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + + from plumb.sync import parse_spec_files + parse_spec_files(repo) + + assert len(called_models) >= 1 + assert any("haiku" in m for m in called_models), ( + f"Expected haiku override, but saw: {called_models}" + ) + + +# --------------------------------------------------------------------------- +# Negative case: override for one program must not leak to another +# --------------------------------------------------------------------------- + + +class TestOverrideIsolation: + """An override for program A must not affect program B.""" + + def test_coverage_mapper_override_does_not_affect_other_programs(self, tmp_repo): + """Configuring code_coverage_mapper should not change the LM for + requirement_parser.""" + repo = _make_repo_with_override(tmp_repo, "code_coverage_mapper", "anthropic/claude-haiku-4-5-20251001") + + from plumb.programs import get_program_lm + + with patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + + coverage_lm = get_program_lm("code_coverage_mapper", repo) + parser_lm = get_program_lm("requirement_parser", repo) + + assert coverage_lm is not None, "Coverage mapper override should exist" + assert parser_lm is None, "Requirement parser should have no override" diff --git a/tests/test_programs.py b/tests/test_programs.py index 55b9e3a..81dee87 100644 --- a/tests/test_programs.py +++ b/tests/test_programs.py @@ -7,9 +7,10 @@ import dspy import pytest -from plumb.programs import run_with_retries, configure_dspy, validate_api_access, get_program_lm +from plumb.programs import run_with_retries, configure_dspy, validate_api_access, get_lm, get_program_lm from plumb.config import PlumbConfig, save_config, ensure_plumb_dir from plumb import PlumbAuthError, PlumbInferenceError +from plumb.programs.claude_code_lm import ClaudeCodeLM from plumb.programs.diff_analyzer import ( ChangeSummary, DiffAnalyzerSignature, @@ -39,23 +40,36 @@ class TestValidateApiAccess: - def test_raises_when_key_missing(self): + def test_raises_when_key_missing_and_no_cli(self): # plumb:req-60f97012 # plumb:req-ab686eaa # plumb:req-222ddbbd with patch("dotenv.load_dotenv"), \ - patch.dict("os.environ", {}, clear=True): + patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value=None): import os os.environ.pop("ANTHROPIC_API_KEY", None) - with pytest.raises(PlumbAuthError, match="ANTHROPIC_API_KEY is not set"): + with pytest.raises(PlumbAuthError, match="No LLM backend available"): validate_api_access() - def test_raises_when_key_empty(self): + def test_raises_when_key_empty_and_no_cli(self): with patch("dotenv.load_dotenv"), \ - patch.dict("os.environ", {"ANTHROPIC_API_KEY": ""}): - with pytest.raises(PlumbAuthError, match="ANTHROPIC_API_KEY is not set"): + patch.dict("os.environ", {"ANTHROPIC_API_KEY": ""}), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value=None): + with pytest.raises(PlumbAuthError, match="No LLM backend available"): validate_api_access() + def test_passes_with_cli_when_no_key(self): + """CLI fallback works when ANTHROPIC_API_KEY is not set.""" + mock_lm = MagicMock(return_value=["hello"]) + with patch("dotenv.load_dotenv"), \ + patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"), \ + patch("plumb.programs.claude_code_lm.ClaudeCodeLM", return_value=mock_lm): + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + validate_api_access() # should not raise + def test_passes_when_key_set_and_api_works(self): mock_lm = MagicMock(return_value="hello") with patch("dotenv.load_dotenv"), \ @@ -324,6 +338,41 @@ def test_prompt_includes_all_inputs(self): assert "reason text" in prompt assert "spec text" in prompt + def test_uses_cli_when_no_api_key(self): + """CodeModifier falls back to claude CLI when no API key.""" + json_response = '```json\n{"src/a.py": "modified via cli"}\n```' + with patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.code_modifier.find_claude_cli", return_value="/usr/bin/claude"), \ + patch("plumb.programs.code_modifier._call_claude", return_value=json_response) as mock_call: + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + modifier = CodeModifier() + result = modifier.modify( + staged_diff="diff", + decision="Use async", + rejection_reason="Too complex", + spec_content="# Spec", + ) + assert result == {"src/a.py": "modified via cli"} + mock_call.assert_called_once() + prompt = mock_call.call_args[0][0] + assert "diff" in prompt + assert "Use async" in prompt + + def test_uses_api_when_key_set(self): + """CodeModifier uses Anthropic API when ANTHROPIC_API_KEY is set.""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{"a.py": "content"}')] + mock_client.messages.create.return_value = mock_response + + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "sk-ant-test"}), \ + patch("plumb.programs.code_modifier.anthropic") as mock_anthropic: + mock_anthropic.Anthropic.return_value = mock_client + modifier = CodeModifier() + result = modifier.modify("diff", "dec", "reason", "spec") + mock_client.messages.create.assert_called_once() + class TestGetProgramLm: def test_returns_none_when_no_config(self, tmp_path): @@ -339,8 +388,8 @@ def test_returns_none_when_program_not_listed(self, tmp_repo): result = get_program_lm("decision_deduplicator", repo_root=tmp_repo) assert result is None - def test_returns_lm_when_override_exists(self, tmp_repo): - """Config has an override → returns a dspy.LM.""" + def test_returns_lm_when_override_exists_with_api_key(self, tmp_repo): + """Config has an override + API key → returns a dspy.LM.""" ensure_plumb_dir(tmp_repo) cfg = PlumbConfig( spec_paths=["spec.md"], @@ -349,13 +398,66 @@ def test_returns_lm_when_override_exists(self, tmp_repo): }, ) save_config(tmp_repo, cfg) - lm = get_program_lm("decision_deduplicator", repo_root=tmp_repo) - assert isinstance(lm, dspy.LM) - assert lm.model == "openai/gpt-4o-mini" - assert lm.kwargs["max_tokens"] == 4096 + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "sk-ant-test"}): + lm = get_program_lm("decision_deduplicator", repo_root=tmp_repo) + assert isinstance(lm, dspy.LM) + assert lm.model == "openai/gpt-4o-mini" + assert lm.kwargs["max_tokens"] == 4096 + + def test_returns_claude_code_lm_when_override_exists_no_key(self, tmp_repo): + """Config has an override + no API key + CLI available → returns ClaudeCodeLM.""" + ensure_plumb_dir(tmp_repo) + cfg = PlumbConfig( + spec_paths=["spec.md"], + program_models={ + "decision_deduplicator": {"model": "anthropic/claude-sonnet-4-20250514", "max_tokens": 4096}, + }, + ) + save_config(tmp_repo, cfg) + with patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + lm = get_program_lm("decision_deduplicator", repo_root=tmp_repo) + assert isinstance(lm, ClaudeCodeLM) + assert lm.cli_model == "claude-sonnet-4-20250514" def test_returns_none_when_no_repo_root(self): """No repo root found → returns None.""" with patch("plumb.config.find_repo_root", return_value=None): result = get_program_lm("decision_deduplicator") assert result is None + + +class TestGetLm: + def test_returns_dspy_lm_with_api_key(self): + """ANTHROPIC_API_KEY set → returns dspy.LM.""" + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "sk-ant-test"}): + lm = get_lm() + assert isinstance(lm, dspy.LM) + + def test_returns_claude_code_lm_without_api_key(self): + """No API key + CLI available → returns ClaudeCodeLM.""" + with patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + lm = get_lm() + assert isinstance(lm, ClaudeCodeLM) + + def test_raises_when_neither_available(self): + """No API key + no CLI → raises PlumbAuthError.""" + with patch.dict("os.environ", {}, clear=True), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value=None): + import os + os.environ.pop("ANTHROPIC_API_KEY", None) + with pytest.raises(PlumbAuthError, match="No LLM backend available"): + get_lm() + + def test_api_key_takes_precedence_over_cli(self): + """When both API key and CLI exist, API key wins.""" + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "sk-ant-test"}), \ + patch("plumb.programs.claude_code_lm.find_claude_cli", return_value="/usr/bin/claude"): + lm = get_lm() + assert isinstance(lm, dspy.LM) + assert not isinstance(lm, ClaudeCodeLM) diff --git a/tests/test_worktree_index_corruption.py b/tests/test_worktree_index_corruption.py new file mode 100644 index 0000000..5026596 --- /dev/null +++ b/tests/test_worktree_index_corruption.py @@ -0,0 +1,221 @@ +"""E2E tests: git worktree index corruption caused by GIT_INDEX_FILE inheritance. + +During git commit in a worktree, git sets GIT_INDEX_FILE to the worktree's +index path. claude -p inherits this env var, and Claude Code's plugin init +runs git operations that overwrite the worktree's index with plugin cache +entries. Result: "error: Error building trees" and a destroyed index. + +Test A: Pre-commit hook calls claude -p directly. +Test B: Pre-commit hook calls plumb hook (the real code path). + +All tests: real git, real claude -p, real worktree, real commit. No mocks. +Marked slow — requires claude CLI installed and authenticated. + +See: https://github.com/ktinubu/plumb/issues/1 +""" + +import shutil +import subprocess +from datetime import datetime, timezone +from pathlib import Path + +import pytest +from git import Repo + +from plumb.config import PlumbConfig, save_config, ensure_plumb_dir + +needs_claude_cli = pytest.mark.skipif( + shutil.which("claude") is None, + reason="claude CLI not installed", +) + + +def _create_repo_with_worktree(tmp_path, num_files=20): + """Create a main repo with files and a worktree. + + Returns (main_repo_path, worktree_path). + """ + main_dir = tmp_path / "main-repo" + main_dir.mkdir() + repo = Repo.init(main_dir) + + for i in range(num_files): + (main_dir / f"file_{i}.txt").write_text(f"content {i}\n") + repo.index.add([f"file_{i}.txt" for i in range(num_files)]) + repo.index.commit("initial commit") + + wt_dir = tmp_path / "worktree" + repo.git.worktree("add", str(wt_dir), "-b", "wt-branch", "HEAD") + + return main_dir, wt_dir + + +def _count_index_entries(repo_path): + """Return the number of entries in the git index.""" + result = subprocess.run( + ["git", "ls-files"], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + lines = result.stdout.strip().splitlines() + return len(lines) if lines != [""] else 0 + + +def _install_hook(main_repo_path, hook_script): + """Install a pre-commit hook in the main repo (shared with worktrees).""" + hooks_dir = main_repo_path / ".git" / "hooks" + hooks_dir.mkdir(exist_ok=True) + hook_path = hooks_dir / "pre-commit" + hook_path.write_text(hook_script) + hook_path.chmod(0o755) + + +def _stage_and_commit(wt_dir): + """Stage a new file and attempt git commit. Returns CompletedProcess.""" + (wt_dir / "new_file.txt").write_text("trigger commit\n") + subprocess.run(["git", "add", "new_file.txt"], cwd=str(wt_dir)) + return subprocess.run( + ["git", "commit", "-m", "test commit"], + cwd=str(wt_dir), + capture_output=True, + text=True, + timeout=300, + ) + + +def _init_plumb(repo_path): + """Initialize plumb in a repo programmatically (same as plumb init).""" + ensure_plumb_dir(repo_path) + (repo_path / ".plumb" / "decisions").mkdir(exist_ok=True) + + spec = repo_path / "spec.md" + spec.write_text("# Spec\n\n## Features\n\nThe system must do X.\n") + + tests_dir = repo_path / "tests" + tests_dir.mkdir(exist_ok=True) + + cfg = PlumbConfig( + spec_paths=["spec.md"], + test_paths=["tests/"], + initialized_at=datetime.now(timezone.utc).isoformat(), + ) + save_config(repo_path, cfg) + + # Install the real plumb pre-commit hook (same string as cli.py) + hooks_dir = repo_path / ".git" / "hooks" + hooks_dir.mkdir(exist_ok=True) + hook_path = hooks_dir / "pre-commit" + hook_path.write_text( + '#!/bin/sh\n[ "$PLUMB_SKIP" = "1" ] && exit 0\nplumb hook\nexit $?\n' + ) + hook_path.chmod(0o755) + + +@pytest.mark.slow +@needs_claude_cli +class TestClaudePWorktreeIndex: + """Test A: claude -p called directly from a pre-commit hook. + + Documents the upstream Claude Code CLI bug: claude -p corrupts + worktree indexes when GIT_INDEX_FILE is inherited. This test + asserts the buggy behavior so it will break (become a passing + test) if/when Claude Code fixes the upstream issue. + """ + + def test_raw_claude_p_corrupts_worktree_index(self, tmp_path): + """claude -p called directly from a hook (no plumb) corrupts + the worktree index — this is an upstream Claude Code bug.""" + main_dir, wt_dir = _create_repo_with_worktree(tmp_path) + baseline = _count_index_entries(wt_dir) + assert baseline == 20 + + _install_hook( + main_dir, + '#!/bin/sh\necho "say hello" | claude -p --output-format text >/dev/null 2>&1\nexit 0\n', + ) + + result = _stage_and_commit(wt_dir) + after = _count_index_entries(wt_dir) + + # Upstream bug: claude -p corrupts the index + assert after != baseline, ( + "Expected corruption (upstream bug) but index stayed intact. " + "If this fails, Claude Code may have fixed the upstream issue!" + ) + assert result.returncode != 0, ( + "Expected commit failure (upstream bug) but it succeeded. " + "If this fails, Claude Code may have fixed the upstream issue!" + ) + + +@pytest.mark.slow +@needs_claude_cli +class TestShellLevelStrippingPreventsCorruption: + """Test B: stripping GIT_INDEX_FILE and GIT_DIR at the shell level + before calling claude -p prevents the corruption.""" + + def test_unset_git_env_vars_before_claude_p(self, tmp_path): + """Unsetting GIT_INDEX_FILE and GIT_DIR in the hook script + before calling claude -p keeps the index intact.""" + main_dir, wt_dir = _create_repo_with_worktree(tmp_path) + baseline = _count_index_entries(wt_dir) + assert baseline == 20 + + _install_hook( + main_dir, + '#!/bin/sh\n' + 'env -u GIT_INDEX_FILE -u GIT_DIR ' + 'sh -c \'echo "say hello" | claude -p --output-format text >/dev/null 2>&1\'\n' + 'exit 0\n', + ) + + result = _stage_and_commit(wt_dir) + after = _count_index_entries(wt_dir) + + assert result.returncode == 0, ( + f"Commit failed: {result.stderr[:300]}" + ) + assert after == baseline + 1, ( + f"Expected {baseline + 1} index entries, got {after}" + ) + + +@pytest.mark.slow +@needs_claude_cli +class TestPlumbHookWorktreeIndex: + """Test C: plumb hook called from a pre-commit hook (real code path). + + Verifies that plumb's fix (stripping GIT_INDEX_FILE/GIT_DIR) protects + the worktree index when plumb hook runs during git commit. + """ + + def test_commit_succeeds_with_index_intact(self, tmp_path): + """git commit in a worktree must succeed with index intact when + plumb hook calls _call_claude() during pre-commit.""" + main_dir, wt_dir = _create_repo_with_worktree(tmp_path) + + _init_plumb(main_dir) + + repo = Repo(main_dir) + repo.index.add([ + ".plumb/config.json", + "spec.md", + ]) + repo.index.commit("add plumb config") + + wt_repo = Repo(wt_dir) + wt_repo.git.merge("main", "--no-edit") + + baseline = _count_index_entries(wt_dir) + + result = _stage_and_commit(wt_dir) + after = _count_index_entries(wt_dir) + + assert after == baseline + 1, ( + f"Expected {baseline + 1} index entries (original + new file), got {after}. " + f"rc={result.returncode}, stderr={result.stderr[:1000]}" + ) + assert "Error building trees" not in result.stderr, ( + f"Index was corrupted: {result.stderr[:1000]}" + )