Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions src/autolean/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""SHA256-based LLM response cache.

Caches API responses by (prompt_hash, model) to avoid redundant calls.
Stored as JSON files in a .autolean_cache/ directory.
"""

from __future__ import annotations

import hashlib
import json
import os
import tempfile
import time
from pathlib import Path
from typing import Optional

from .util import CommandResult

_DEFAULT_CACHE_DIR = ".autolean_cache"


class ResponseCache:
"""Disk-backed cache for LLM API responses."""

def __init__(self, cache_dir: Optional[Path] = None, *, enabled: bool = True):
self.enabled = enabled
self.cache_dir = cache_dir or Path(_DEFAULT_CACHE_DIR)
self._hits = 0
self._misses = 0

@property
def hits(self) -> int:
return self._hits

@property
def misses(self) -> int:
return self._misses

@staticmethod
def _cache_key(prompt: str, model: str) -> str:
"""Generate a deterministic cache key from prompt + model."""
content = json.dumps({"prompt": prompt, "model": model}, sort_keys=True)
return hashlib.sha256(content.encode("utf-8")).hexdigest()

def _cache_path(self, key: str) -> Path:
# Use first 2 chars as subdirectory to avoid flat directory with thousands of files
subdir = self.cache_dir / key[:2]
subdir.mkdir(parents=True, exist_ok=True)
return subdir / f"{key}.json"
Comment on lines +45 to +49
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling mkdir on every cache lookup is inefficient and unnecessary for a read operation. This should be moved to the put method and only executed when a write is actually performed.

References
  1. Avoid unnecessary filesystem operations in hot paths or read-only lookups.


def get(self, prompt: str, model: str) -> Optional[CommandResult]:
"""Look up a cached response. Returns None on miss."""
if not self.enabled:
return None

key = self._cache_key(prompt, model)
path = self._cache_path(key)

if not path.exists():
self._misses += 1
return None

try:
data = json.loads(path.read_text(encoding="utf-8"))
self._hits += 1
return CommandResult(
argv=data.get("argv", []),
returncode=data.get("returncode", 0),
stdout=data.get("stdout", ""),
stderr=data.get("stderr", ""),
)
except (json.JSONDecodeError, OSError, KeyError):
self._misses += 1
return None

def put(self, prompt: str, model: str, result: CommandResult) -> None:
"""Store a successful response in the cache."""
if not self.enabled:
return
# Only cache successful responses
if result.returncode != 0:
return

key = self._cache_key(prompt, model)
path = self._cache_path(key)

data = {
"prompt_sha256": hashlib.sha256(prompt.encode("utf-8")).hexdigest(),
"model": model,
"cached_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"argv": result.argv,
"returncode": result.returncode,
"stdout": result.stdout,
"stderr": result.stderr,
}
try:
# Atomic write: write to temp file then rename to prevent corruption
path.parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=True)
os.replace(tmp_path, path)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
except OSError:
pass # Cache write failure is non-fatal
Comment on lines +76 to +111
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Writing directly to the cache file can lead to corrupted entries if the process is interrupted or if there is concurrent access. It is safer to write to a temporary file and then perform an atomic replace. Additionally, indent=2 is generally unnecessary for cache files and increases storage usage.

References
  1. Use atomic file operations (write to temp + rename) to prevent data corruption.


def clear(self) -> int:
"""Remove all cached entries. Returns count of files removed."""
if not self.cache_dir.exists():
return 0
count = 0
for f in self.cache_dir.rglob("*.json"):
try:
f.unlink()
count += 1
except OSError:
pass
return count

def stats(self) -> dict[str, int]:
"""Return cache hit/miss statistics."""
total = 0
if self.cache_dir.exists():
total = sum(1 for _ in self.cache_dir.rglob("*.json"))
return {
"hits": self._hits,
"misses": self._misses,
"total_entries": total,
}
14 changes: 14 additions & 0 deletions src/autolean/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,18 @@ def _build_parser() -> argparse.ArgumentParser:
help="Working directory in which to run the compiler (e.g., your Lean project root).",
)
run.add_argument("--force", action="store_true", help="Re-run even if output file exists.")
run.add_argument(
"--no-cache",
action="store_true",
default=False,
help="Disable LLM response caching (by default, identical prompts are cached to save API costs).",
)
run.add_argument(
"--cache-dir",
type=Path,
default=None,
help="Directory for LLM response cache (default: .autolean_cache/).",
)

return p

Expand Down Expand Up @@ -556,6 +568,8 @@ def main(argv: list[str] | None = None) -> int:
live_logs=live_logs,
compile_cmd=args.compile_cmd,
cwd=args.cwd,
cache_enabled=not args.no_cache,
cache_dir=args.cache_dir,
)

if cfg.formalization_only and cfg.require_no_sorry:
Expand Down
208 changes: 208 additions & 0 deletions src/autolean/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""Lean compiler interaction, error extraction, and error memory."""

from __future__ import annotations

import re
from collections import OrderedDict
from pathlib import Path
from typing import Optional

from .util import CommandResult
from .providers import run_subprocess

_LEAN_LOCATION_PREFIX_RE = re.compile(r"^(?:[A-Za-z]:)?[^:\s]*\.lean:\d+:\d+:\s*")
_WHITESPACE_RE = re.compile(r"\s+")
_LEAN_MODULE_PART_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_']*$")

REPAIR_ERROR_MEMORY_LIMIT = 6


# ---------------------------------------------------------------------------
# Compile
# ---------------------------------------------------------------------------

def compile_lean(
argv: list[str],
*,
cwd: Path,
live: bool = False,
stdout_sink=None,
stderr_sink=None,
) -> CommandResult:
return run_subprocess(
argv, cwd=cwd, live=live,
stdout_sink=stdout_sink, stderr_sink=stderr_sink,
)


# ---------------------------------------------------------------------------
# Error extraction and memory
# ---------------------------------------------------------------------------

def extract_compact_error_lines(compiler_res: CommandResult) -> list[str]:
combined = (compiler_res.stdout + "\n" + compiler_res.stderr).strip()
if not combined:
return []

lines: list[str] = []
for raw in combined.splitlines():
line = raw.strip()
if not line:
continue
lowered = line.lower()
if (
"error" in lowered
or "parse failure" in lowered
or "policy failure" in lowered
or "failed before producing lean output" in lowered
):
lines.append(line)

if lines:
return lines

for raw in combined.splitlines():
line = raw.strip()
if line:
return [line]
return []


def normalize_error_line(line: str) -> str:
normalized = _LEAN_LOCATION_PREFIX_RE.sub("", line.strip())
normalized = _WHITESPACE_RE.sub(" ", normalized).strip()
return normalized


def update_error_memory(
memory: OrderedDict[str, tuple[str, int, int]],
compiler_res: CommandResult,
*,
iter_no: int,
) -> None:
for line in extract_compact_error_lines(compiler_res):
key = normalize_error_line(line)
if not key:
continue
if key in memory:
_last_line, count, _last_iter = memory[key]
memory[key] = (key, count + 1, iter_no)
memory.move_to_end(key)
else:
memory[key] = (key, 1, iter_no)


def format_error_memory(memory: OrderedDict[str, tuple[str, int, int]], *, limit: int) -> str:
if limit <= 0 or not memory:
return ""
recent_items = list(memory.items())[-limit:]
lines: list[str] = []
for idx, (_key, (display, count, last_iter)) in enumerate(reversed(recent_items), start=1):
if count > 1:
lines.append(f"{idx}. [seen {count}x, last iter {last_iter}] {display}")
else:
lines.append(f"{idx}. [iter {last_iter}] {display}")
return "\n".join(lines)


# ---------------------------------------------------------------------------
# Lean code analysis
# ---------------------------------------------------------------------------

def extract_top_level_prop_from_theorem_header(header: str) -> Optional[str]:
depth = 0
last_colon = -1
for i, ch in enumerate(header):
if ch in "([{":
depth += 1
elif ch in ")]}":
depth = max(0, depth - 1)
elif ch == ":" and depth == 0:
if i + 1 < len(header) and header[i + 1] == "=":
continue
last_colon = i
if last_colon < 0:
return None
return header[last_colon + 1:].strip()
Comment on lines +112 to +126
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for extracting the proposition from the theorem header incorrectly uses the last colon at depth 0. In Lean 4, propositions often contain colons at the top level (e.g., in quantifiers like ∀ x : Nat, ...). Using the last colon will truncate the proposition. It should use the first colon at depth 0, which acts as the separator between the theorem name/arguments and the proposition.

References
  1. Ensure logic correctly handles language-specific syntax edge cases, such as colons in Lean 4 propositions.



def detect_trivialized_statement(lean_code: str, *, theorem_name: str) -> Optional[str]:
start_re = re.compile(rf"\b(?:theorem|lemma)\s+{re.escape(theorem_name)}\b")
m = start_re.search(lean_code)
if m is None:
return None
end = lean_code.find(":=", m.end())
if end < 0:
return None
header = lean_code[m.start():end]
prop = extract_top_level_prop_from_theorem_header(header)
if not prop:
return None
match = re.match(r"^\(?\s*(True|False)\b", prop)
if match is None:
return None
return match.group(1)


def module_name_from_lean_path(lean_path: Path, *, run_cwd: Path) -> Optional[str]:
try:
rel = lean_path.resolve().relative_to(run_cwd.resolve())
except ValueError:
return None
if rel.suffix != ".lean":
return None
parts = rel.with_suffix("").parts
if not parts:
return None
for part in parts:
if not _LEAN_MODULE_PART_RE.fullmatch(part):
return None
return ".".join(parts)


def inject_imports(lean_code: str, module_names: list[str]) -> str:
if not module_names:
return lean_code

ordered_modules: list[str] = []
seen: set[str] = set()
for module in module_names:
module = module.strip()
if not module or module in seen:
continue
seen.add(module)
ordered_modules.append(module)
if not ordered_modules:
return lean_code

lines = lean_code.splitlines()
existing_imports: set[str] = set()
insert_at = 0

for idx, line in enumerate(lines):
stripped = line.strip()
if not stripped:
insert_at = idx + 1
continue
if stripped.startswith("--"):
insert_at = idx + 1
continue
if stripped.startswith("import "):
module = stripped[len("import "):].strip()
if module:
existing_imports.add(module)
insert_at = idx + 1
continue
break

missing_import_lines = [
f"import {module}" for module in ordered_modules if module not in existing_imports
]
if not missing_import_lines:
return lean_code

merged_lines = lines[:insert_at] + missing_import_lines + lines[insert_at:]
merged = "\n".join(merged_lines)
if lean_code.endswith("\n"):
return merged + "\n"
return merged
Loading