diff --git a/.claude-plugin/marketplace.json b/.claude-plugin/marketplace.json new file mode 100644 index 00000000..dbb15b88 --- /dev/null +++ b/.claude-plugin/marketplace.json @@ -0,0 +1,22 @@ +{ + "name": "ml-intern", + "owner": { + "name": "Hugging Face", + "url": "https://github.com/huggingface" + }, + "metadata": { + "description": "ML engineering plugins for Claude Code, built on the Hugging Face ecosystem.", + "version": "0.1.0" + }, + "plugins": [ + { + "name": "ml-intern", + "source": "./plugin", + "description": "ML engineering assistant for Hugging Face — research-first methodology, dataset auditing, HF Jobs orchestration, sandbox-driven development. Slash commands, a research subagent, and content-aware approval policy ported from the standalone ml-intern CLI.", + "version": "0.1.0", + "category": "ml", + "tags": ["huggingface", "ml", "training", "fine-tuning", "datasets"], + "strict": false + } + ] +} diff --git a/.claude/agents/research.md b/.claude/agents/research.md new file mode 100644 index 00000000..d430b1fd --- /dev/null +++ b/.claude/agents/research.md @@ -0,0 +1,119 @@ +--- +name: research +description: Use proactively before writing any ML implementation code. Mines the literature to find the best training recipes backed by published results, then validates them with working code and current docs. The main agent uses these findings to implement the actual solution. Spawn with a specific brief — name anchor papers or arxiv IDs when you have them. +tools: Read, Bash, Grep, Glob, WebFetch, mcp__ml-intern-tools__explore_hf_docs, mcp__ml-intern-tools__fetch_hf_docs, mcp__ml-intern-tools__hf_papers, mcp__ml-intern-tools__hf_inspect_dataset, mcp__ml-intern-tools__github_find_examples, mcp__ml-intern-tools__github_list_repos, mcp__ml-intern-tools__github_read_file, mcp__ml-intern-tools__hf_repo_files +--- + +You are a research sub-agent for an ML engineering assistant. Your primary job: mine the literature to find the best training recipes — then back them up with working code and up-to-date documentation. The main agent will use your findings to implement the actual solution. + +# Start from the literature + +Your default approach is a deep literature crawl. Do not start from docs or example scripts — start from papers. Papers contain the results, and results tell you what actually works. + +## The crawl + +1. **Find anchor papers**: Search for the task/domain. Identify the landmark paper(s) — high citations, recent, or both. +2. **Crawl the citation graph**: Use `citation_graph` on the anchor paper(s). Look DOWNSTREAM (papers that cite it) — these are the ones that built on it, improved it, or applied it to new domains. Prioritize recent papers and papers with many citations. +3. **Read methodology sections**: For the most promising papers (strong results, recent, relevant), use `read_paper` with section parameter to read sections 3, 4, 5 (Methodology, Experiments, Results — not the abstract). Extract: + - The exact dataset(s) used (name, source, size, any filtering/preprocessing) + - The training method and configuration (optimizer, lr, schedule, epochs, batch size) + - The results those choices produced (benchmark scores, metrics, comparisons) +4. **Attribute results to recipes**: This is the critical step. Every finding must link a RESULT to the RECIPE that produced it. "Dataset X + method Y + lr Z → score W on benchmark V" is useful. "They used SFT" is not. +5. **Validate datasets**: For the most promising datasets, check if they exist on HF Hub with `hf_inspect_dataset`. Verify format matches the training method. Report if it doesn't. +6. **Find code**: Now find working implementation code via `github_find_examples` and `github_read_file`. Use docs (`explore_hf_docs`, `fetch_hf_docs`) to fill in API details. + +## When to go deeper + +- If the anchor paper is old (>1 year), its citation graph is your main source — the downstream papers will have better methods. +- If a downstream paper reports significantly better results, crawl ITS citation graph too. +- Use `snippet_search` to find specific claims across papers (e.g., "does dataset X consistently outperform Y for this task?"). +- Use `recommend` to find related papers the citation graph might miss. + +# How to use your tools + +## Papers & citations (USE FIRST) +- `hf_papers(operation="search", query=...)`: Search papers (HF-tuned for ML) +- `hf_papers(operation="search", query=..., min_citations=50, sort_by="citationCount")`: Find highly-cited papers via Semantic Scholar +- `hf_papers(operation="search", query=..., date_from="2024-01-01")`: Search with date filter +- `hf_papers(operation="paper_details", arxiv_id=...)`: Metadata, citations, TL;DR +- `hf_papers(operation="citation_graph", arxiv_id=...)`: References + citations with influence flags and intents +- `hf_papers(operation="read_paper", arxiv_id=..., section="3")`: Read a specific section's full text +- `hf_papers(operation="read_paper", arxiv_id=...)`: Get TOC (abstract + section list) — use this to find which section numbers contain methodology/experiments +- `hf_papers(operation="snippet_search", query=...)`: Semantic search across 12M+ full-text paper passages +- `hf_papers(operation="recommend", arxiv_id=...)`: Find related papers +- `hf_papers(operation="find_datasets", arxiv_id=...)`: Find HF datasets linked to a paper +- `hf_papers(operation="find_all_resources", arxiv_id=...)`: Datasets + models + collections for a paper + +## Dataset inspection +- `hf_inspect_dataset`: Check dataset schema, splits, sample rows. CRITICAL for training: verify column format matches training method: + - SFT: needs `messages`, `text`, or `prompt`/`completion` + - DPO: needs `prompt`, `chosen`, `rejected` + - GRPO: needs `prompt` only + +## GitHub code research +- `github_find_examples`: Find working example scripts in HF repos (trl, transformers, etc.) +- `github_read_file`: Read the actual implementation code. Use `line_start`/`line_end` for large files. + +## Documentation +- `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc. +- `fetch_hf_docs(url)`: Fetch full page content from explore results + +## Hub repo inspection +- `hf_repo_files`: List/read files in any HF repo (model, dataset, space) + +# Correct research pattern + +``` +# 1. Find anchor paper(s) for the task +hf_papers({"operation": "search", "query": "GPQA graduate questions", "sort_by": "citationCount"}) + +# 2. Crawl citation graph — look downstream +hf_papers({"operation": "citation_graph", "arxiv_id": "2311.12022", "direction": "citations"}) + +# 3. Read methodology of promising downstream papers +hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348"}) # TOC first +hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "3"}) # Methodology +hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "4"}) # Experiments + +# 4. Find datasets used by these papers +hf_papers({"operation": "find_datasets", "arxiv_id": "2604.01348"}) +hf_papers({"operation": "find_all_resources", "arxiv_id": "2604.01348"}) + +# 5. Validate datasets exist and have correct format +hf_inspect_dataset({"dataset": "org/dataset-name", "split": "train", "sample_rows": 3}) + +# 6. Now get working code for the training method +github_find_examples({"repo": "trl", "keyword": "sft"}) +github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"}) +explore_hf_docs("trl") +``` + +# Output format + +Your output MUST be structured as a ranked list of training recipes, each attributed to published results: + +## Recipe table (REQUIRED) +For each promising approach found, report: +- **Paper**: title, arxiv_id, date, venue +- **Result**: exact benchmark scores and what they were measured on +- **Dataset(s)**: name, size, source, HF Hub availability, format verified (yes/no) +- **Method**: training approach, key hyperparameters (lr, epochs, batch size, optimizer, schedule) +- **What made it work**: the specific insight or trick that drove the result (data curation, curriculum, loss function, etc.) + +Rank recipes by result quality. The main agent will pick the best one that's feasible. + +## Code patterns +- Key imports, configurations, and usage patterns from working examples +- Specific file paths, URLs, function names from docs + +## Recommendations +- Which recipe to implement first and why +- What datasets to use (with HF Hub paths, verified) +- Any gaps: datasets that need preprocessing, methods that need adaptation + +Additionally include: +- **SOTA landscape**: Current best models, datasets, and methods for the task (from recent papers). Flag anything outdated. +- **Essential references**: Specific file paths, URLs, function names, doc sections, code snippets that the main agent should use directly +- **Code patterns**: Key imports, configurations, and usage patterns from working examples + +Be concise. Your output goes into another agent's context — every token counts. Aim for 500-1500 words max. Include actual code snippets from examples you read, not paraphrased descriptions. diff --git a/.claude/commands/finetune.md b/.claude/commands/finetune.md new file mode 100644 index 00000000..67234db5 --- /dev/null +++ b/.claude/commands/finetune.md @@ -0,0 +1,47 @@ +--- +description: Fine-tune a model on a dataset, end-to-end (research → validate → train → push). +argument-hint: +--- + +Fine-tune the model described in: $ARGUMENTS + +Fine-tuning is never trivial. Follow this sequence in order. Do **not** skip steps even if the request looks simple — `CLAUDE.md` lists the specific failures that happen when you do. + +**1. Research first (mandatory).** Delegate to the `research` subagent via the Task tool with `subagent_type: "research"`. Brief it: + +> Find the best fine-tuning recipe for: $ARGUMENTS. +> Identify the model architecture and intended task. Crawl the citation graph for recent papers that fine-tuned this (or a comparable) model on this (or a comparable) dataset. Read methodology sections (3, 4, 5) of the top 3 candidates. Extract: training method (SFT/DPO/GRPO/...), exact hyperparameters (lr, schedule, epochs, batch size, optimizer, max_length), and any data preprocessing. Verify the dataset's HF Hub format with `hf_inspect_dataset`. Return a ranked recipe table per CLAUDE.md. + +Do not start writing code until the subagent returns. + +**2. Validate dataset and model.** Independently of the research output, run: +- `mcp__ml-intern-tools__hf_inspect_dataset` on the target dataset — confirm columns match the chosen training method (SFT: `messages`/`text`/`prompt`+`completion`; DPO: `prompt`+`chosen`+`rejected`; GRPO: `prompt`). +- `mcp__ml-intern-tools__hf_repo_files` on the target model — confirm it exists and note tokenizer/architecture. + +**3. Develop in a sandbox.** For non-trivial scripts, call `mcp__ml-intern-tools__sandbox_create` with a GPU flavor (`t4-small` minimum if the code touches CUDA/bf16/model loading). Write the script, install deps, run a tiny smoke test (1–2 steps), fix errors. Do not skip the smoke test. + +**4. Pre-flight check (mandatory output before `hf_jobs`).** Print this checklist and verify every line is filled: + +``` +Reference implementation: +Dataset format verified: +Training method: +Hyperparameters: +push_to_hub: True +hub_model_id: +hardware_flavor: +timeout: <≥ 2h for any training> +Trackio monitoring: +disable_tqdm=True, logging_strategy="steps", logging_first_step=True: yes +``` + +If any line is missing, **stop and complete it** before submitting. + +**5. Submit ONE job.** Call `mcp__ml-intern-tools__hf_jobs` (operation `run` or `uv`) with the verified config. Watch the first 60s of logs to confirm training started (loss values printing as plain text, not stuck on tokenizer/model load). Only then submit any sweep/ablation runs. + +**6. Report.** Provide: +- Direct Hub URL of the job (`https://huggingface.co/jobs/...`) +- Trackio dashboard URL +- Hub URL of the model that will appear on completion (`https://huggingface.co/`) + +If anything fails, do not silently switch training methods, reduce `max_length`, or substitute datasets. Diagnose, fix the minimal thing, or ask the user. diff --git a/.claude/commands/inspect-dataset.md b/.claude/commands/inspect-dataset.md new file mode 100644 index 00000000..be214c82 --- /dev/null +++ b/.claude/commands/inspect-dataset.md @@ -0,0 +1,18 @@ +--- +description: Audit a HF dataset — schema, splits, sample rows, and red flags. Direct port of `hf_inspect_dataset`. +argument-hint: +--- + +Inspect the dataset `$ARGUMENTS` using `mcp__ml-intern-tools__hf_inspect_dataset`. + +Report back with: +- schema and column types +- number of rows per split +- 3 sample rows +- red flags: class imbalance, missing values, unexpected formats, duplicates +- training-method compatibility: + - SFT-ready? (has `messages` / `text` / `prompt`+`completion`) + - DPO-ready? (has `prompt` + `chosen` + `rejected`) + - GRPO-ready? (has `prompt`) + +Include the direct Hub URL: `https://huggingface.co/datasets/$ARGUMENTS` diff --git a/.claude/commands/ml-intern.md b/.claude/commands/ml-intern.md new file mode 100644 index 00000000..614fab95 --- /dev/null +++ b/.claude/commands/ml-intern.md @@ -0,0 +1,12 @@ +--- +description: Default ML Intern entrypoint — equivalent to running `ml-intern ""` headlessly. +argument-hint: +--- + +You are running as ML Intern. Follow the workflow defined in `CLAUDE.md`: +research first (delegate to the `research` subagent for any non-trivial ML task), +validate datasets and models, then implement. + +User request: + +$ARGUMENTS diff --git a/.claude/commands/research.md b/.claude/commands/research.md new file mode 100644 index 00000000..9c27dac1 --- /dev/null +++ b/.claude/commands/research.md @@ -0,0 +1,22 @@ +--- +description: Force a literature-first research crawl — delegates immediately to the `research` subagent without doing anything else. +argument-hint: +--- + +Delegate this research task to the `research` subagent **immediately**. Do not +attempt the research yourself — the subagent has its own context window and +returns a structured recipe table. + +Use the Task tool with `subagent_type: "research"`. Brief: + +> Literature crawl for: $ARGUMENTS +> +> Start from anchor paper(s). Crawl citation graph for recent downstream +> papers. Read their methodology sections (3, 4, 5) — extract the exact +> datasets, training methods, and hyperparameters that produced their +> best results. Attribute every finding to a specific result. Also find +> working code examples using current TRL/Transformers APIs. Validate +> any datasets via `hf_inspect_dataset`. + +When the subagent returns, summarize the top recipe to the user with direct +HF Hub URLs and the arxiv ID of the source paper. diff --git a/.claude/commands/run-job.md b/.claude/commands/run-job.md new file mode 100644 index 00000000..d7745c36 --- /dev/null +++ b/.claude/commands/run-job.md @@ -0,0 +1,40 @@ +--- +description: Submit an HF Job (training, eval, batch inference) with the ml-intern pre-flight checklist. +argument-hint: +--- + +Submit an HF Job for: $ARGUMENTS + +Before calling `mcp__ml-intern-tools__hf_jobs`, produce the pre-flight check below. **Do not call `hf_jobs` until every line is filled in.** If you cannot fill a line, complete the missing step (research, dataset inspection, sandbox test) first. + +``` +Job purpose: +Reference implementation: +Dataset format verified: +Model verified: +push_to_hub: +hardware_flavor: +timeout: +Trackio monitoring: +Packages to install: +``` + +**Hardware sizing** (from `CLAUDE.md`): +- 1–3B params → `a10g-largex2` +- 7–13B params → `a100-large` +- 30B+ params → `l40sx4` or `a100x4` +- 70B+ params → `a100x8` +- CPU-only data prep → `cpu-basic` or `cpu-upgrade` + +Note: `a10g-small` and `a10g-large` have the SAME 24GB GPU memory — the difference is CPU/RAM only. + +**Timeout floor:** for any training job, set timeout ≥ `2h`. The default 30m kills training. If your timeout is < 2h and the job is training, **stop and revise** unless the user explicitly justified a shorter run (e.g. a smoke test). + +**Hooks will gate this call:** GPU jobs always prompt for confirmation. CPU jobs prompt by default (override with `ML_INTERN_CONFIRM_CPU_JOBS=0`). That is expected — present the pre-flight check clearly so the user can approve in one read. + +**For batch / ablation work:** submit ONE job first. Watch the first ~60 seconds of logs (look for plain-text loss lines — `disable_tqdm=True, logging_strategy="steps", logging_first_step=True` should be set). Only after that one starts training successfully, submit the rest. Never submit all at once. + +**After submission, report:** +- Job URL (`https://huggingface.co/jobs/...`) +- Trackio dashboard URL +- Expected output (model repo, dataset repo, eval scores file path) and where to find it after completion diff --git a/.claude/hooks/pre_tool_use_approval.py b/.claude/hooks/pre_tool_use_approval.py new file mode 100755 index 00000000..dabeabb4 --- /dev/null +++ b/.claude/hooks/pre_tool_use_approval.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +""" +PreToolUse hook — port of agent/core/agent_loop.py::_needs_approval. + +Claude Code's static permission lists can't express ml-intern's +content-aware approval rules (e.g. "auto-approve CPU jobs but require +confirmation for GPU jobs"). This hook reads the tool input from stdin +and either: + - exits 0 (allow without prompt) — equivalent to ml-intern auto-execute + - prints a JSON `ask` decision so Claude Code prompts the user + +Fail-safe: malformed payloads, non-dict tool_input, or empty tool_name +all result in `ask` (never silent allow). For an approval hook, falling +through to allow on error would defeat the policy. + +Env knobs (hook-layer equivalents of fields in `agent.config.Config` — +the standalone CLI reads these from configs/main_agent_config.json): + + ML_INTERN_YOLO=1 → skip ALL approvals (Config.yolo_mode) + ML_INTERN_CONFIRM_CPU_JOBS=0 → auto-approve CPU jobs (Config.confirm_cpu_jobs) +""" + +from __future__ import annotations + +import json +import os +import sys + +# Mirror agent/tools/jobs_tool.py::CPU_FLAVORS +CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"] + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name, "").strip().lower() + if not val: + return default + return val in ("1", "true", "yes", "on") + + +def _check_training_script_save_pattern(script: str) -> str | None: + """Inspired by agent/utils/reliability_checks.py::check_training_script_save_pattern. + + Returns a warning when an hf_jobs script appears to load a model but + not push it back to the Hub (job storage is ephemeral — the model is + lost when the job ends). Source also emits a green "will be pushed" + confirmation; we drop that — hook output is shown only when forcing + a prompt, and a positive note there would be noise. + """ + if not isinstance(script, str): + return None + has_from_pretrained = "from_pretrained" in script + has_push_to_hub = "push_to_hub" in script + if has_from_pretrained and not has_push_to_hub: + return "WARNING: training script loads a model with `from_pretrained` but has no `push_to_hub` call — the trained model will be lost when the job ends." + return None + + +def _hf_jobs_script_warning(tool_input: dict) -> str | None: + """Extract the script body from an hf_jobs invocation and run save-pattern check.""" + operation = tool_input.get("operation", "") + if operation not in ("run", "uv", "scheduled run", "scheduled uv"): + return None + script = ( + tool_input.get("script") + or tool_input.get("uv_script") + or tool_input.get("source") + or "" + ) + return _check_training_script_save_pattern(script) + + +def _needs_approval(tool_name: str, tool_input: dict) -> bool: + """Port of agent/core/agent_loop.py::_needs_approval (lines 51-118). + + Diverges from source in one place: source short-circuits to False on + malformed args via `_validate_tool_args` so a downstream validation error + surfaces. Here we don't have that path — Claude Code validates input + shape against the MCP schema upstream, so any payload reaching this hook + is already structurally valid. + """ + if _env_flag("ML_INTERN_YOLO", False): + return False + + # MCP tools surface in Claude Code as `mcp____`. Strip the prefix. + short_name = tool_name.split("__")[-1] if tool_name.startswith("mcp__") else tool_name + + if short_name == "sandbox_create": + return True + + if short_name == "hf_jobs": + operation = tool_input.get("operation", "") + if operation not in ("run", "uv", "scheduled run", "scheduled uv"): + return False + + hardware_flavor = ( + tool_input.get("hardware_flavor") + or tool_input.get("flavor") + or tool_input.get("hardware") + or "cpu-basic" + ) + is_cpu_job = hardware_flavor in CPU_FLAVORS + + if is_cpu_job: + return _env_flag("ML_INTERN_CONFIRM_CPU_JOBS", True) + + return True # GPU jobs always prompt + + # Note: hf_private_repos is intentionally not handled. agent/core/tools.py + # disables it ("replaced by hf_repo_files and hf_repo_git"). The two + # rules below cover the same destructive operations on the live tools. + + if short_name == "hf_repo_files": + operation = tool_input.get("operation", "") + if operation in ("upload", "delete"): + return True + + if short_name == "hf_repo_git": + operation = tool_input.get("operation", "") + if operation in ("delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo"): + return True + + return False + + +def _ask(reason: str) -> dict: + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "ask", + "permissionDecisionReason": reason, + } + } + + +def main() -> int: + try: + payload = json.load(sys.stdin) + except json.JSONDecodeError as e: + # Fail-safe: a malformed payload to an APPROVAL hook must not silently + # allow the tool. Log to stderr so the failure is inspectable. + print(f"[ml-intern] approval hook: malformed stdin ({e}); forcing prompt", file=sys.stderr) + print(json.dumps(_ask("ml-intern: approval hook received malformed input — confirm before proceeding"))) + return 0 + + if not isinstance(payload, dict): + print(f"[ml-intern] approval hook: stdin is {type(payload).__name__}, expected dict; forcing prompt", file=sys.stderr) + print(json.dumps(_ask("ml-intern: approval hook received unexpected input — confirm before proceeding"))) + return 0 + + tool_name = payload.get("tool_name") or "" + tool_input = payload.get("tool_input") or {} + if not isinstance(tool_input, dict): + print(f"[ml-intern] approval hook: tool_input is {type(tool_input).__name__}, expected dict; forcing prompt", file=sys.stderr) + print(json.dumps(_ask(f"ml-intern: {tool_name or 'tool'} received non-dict input — confirm before proceeding"))) + return 0 + + if not tool_name: + print("[ml-intern] approval hook: empty tool_name; forcing prompt", file=sys.stderr) + print(json.dumps(_ask("ml-intern: approval hook received empty tool_name — confirm before proceeding"))) + return 0 + + needs = _needs_approval(tool_name, tool_input) + + # Reliability warnings ride along — surface them by forcing a prompt + # even when the rule would otherwise auto-approve. + short_name = tool_name.split("__")[-1] if tool_name.startswith("mcp__") else tool_name + warning: str | None = None + if short_name == "hf_jobs": + warning = _hf_jobs_script_warning(tool_input) + if warning: + needs = True + + if needs: + reason_bits = [ + f"ml-intern policy: {tool_name} requires user confirmation " + f"(see .claude/hooks/pre_tool_use_approval.py)" + ] + if warning: + reason_bits.append(warning) + print(json.dumps(_ask(" | ".join(reason_bits)))) + return 0 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.claude/hooks/session_end_upload.py b/.claude/hooks/session_end_upload.py new file mode 100755 index 00000000..d1e560d2 --- /dev/null +++ b/.claude/hooks/session_end_upload.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +""" +SessionEnd hook — upload the Claude Code transcript to the HF Hub dataset +configured by `ML_INTERN_SESSION_REPO` (default: smolagents/ml-intern-sessions). + +Mirrors agent/core/session_uploader.py behavior: + - best-effort, write-only token preferred, never blocks the user + - applies agent/core/redact.py::scrub before upload to strip HF/Anthropic/ + OpenAI/GitHub/AWS tokens that users (or scripts) may have pasted into chat + - if redaction can't be loaded we skip upload entirely — losing a session + beats leaking a token + +Env knobs (hook-layer equivalents of fields in agent.config.Config): + + ML_INTERN_SAVE_SESSIONS=0 → disable session upload (Config.save_sessions) + ML_INTERN_SESSION_REPO=org/repo → override target dataset (Config.session_dataset_repo) + HF_SESSION_UPLOAD_TOKEN → preferred upload token (write-only, scoped) + HF_TOKEN → fallback + HF_ADMIN_TOKEN → last-resort fallback (parity with session_uploader.py) +""" + +from __future__ import annotations + +import json +import os +import sys +import tempfile +from pathlib import Path + +_PROJECT_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_PROJECT_ROOT)) + +DEFAULT_REPO = "smolagents/ml-intern-sessions" + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name, "").strip().lower() + if not val: + return default + return val in ("1", "true", "yes", "on") + + +def _resolve_token() -> str | None: + """Match the fallback chain in agent/core/session_uploader.py.""" + for name in ("HF_SESSION_UPLOAD_TOKEN", "HF_TOKEN", "HF_ADMIN_TOKEN"): + token = os.environ.get(name) + if token: + return token + return None + + +def _is_safe_transcript_path(p: Path) -> bool: + """Reject paths outside the directories Claude Code normally uses for + transcripts. Defense in depth against a malformed payload pointing at, + e.g., ~/.ssh/id_rsa — which the redact pipeline would happily upload + after only scrubbing token-shaped strings. + """ + try: + resolved = p.resolve() + except OSError: + return False + + allowed_roots: list[Path] = [] + home = Path.home() + allowed_roots.append((home / ".claude").resolve()) + project_dir = os.environ.get("CLAUDE_PROJECT_DIR") + if project_dir: + try: + allowed_roots.append(Path(project_dir).resolve()) + except OSError: + pass + + for root in allowed_roots: + try: + resolved.relative_to(root) + return True + except ValueError: + continue + return False + + +def _redact_jsonl(src: Path) -> Path: + """Return a NamedTemporaryFile path containing the redacted transcript. + + Each line is JSON-decoded, run through agent.core.redact.scrub, and + re-encoded. Lines that fail to parse fall back to a string-level scrub + (covers plain log lines or partial flushes). + """ + from agent.core.redact import scrub, scrub_string + + out = tempfile.NamedTemporaryFile( + prefix="ml-intern-session-", suffix=".jsonl", delete=False, mode="w", encoding="utf-8" + ) + fallback_lines = 0 + with src.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + line = line.rstrip("\n") + if not line: + out.write("\n") + continue + try: + obj = json.loads(line) + obj = scrub(obj) + out.write(json.dumps(obj, ensure_ascii=False)) + out.write("\n") + except json.JSONDecodeError: + fallback_lines += 1 + out.write(scrub_string(line)) + out.write("\n") + out.close() + if fallback_lines: + print( + f"[ml-intern] {fallback_lines} transcript line(s) fell back to string-scrub", + file=sys.stderr, + ) + return Path(out.name) + + +def main() -> int: + if not _env_flag("ML_INTERN_SAVE_SESSIONS", True): + return 0 + + token = _resolve_token() + if not token: + print( + "[ml-intern] no HF_SESSION_UPLOAD_TOKEN / HF_TOKEN / HF_ADMIN_TOKEN — " + "session not uploaded", + file=sys.stderr, + ) + return 0 + + try: + payload = json.load(sys.stdin) + except json.JSONDecodeError as e: + print(f"[ml-intern] session upload: malformed stdin ({e}); skipping", file=sys.stderr) + return 0 + if not isinstance(payload, dict): + print("[ml-intern] session upload: stdin is not a dict; skipping", file=sys.stderr) + return 0 + + transcript_path = payload.get("transcript_path") + session_id = payload.get("session_id", "unknown") + if not isinstance(transcript_path, str) or not transcript_path: + return 0 + + src = Path(transcript_path) + if not src.exists(): + return 0 + if not _is_safe_transcript_path(src): + print( + f"[ml-intern] refusing to upload transcript outside ~/.claude or " + f"$CLAUDE_PROJECT_DIR: {transcript_path}", + file=sys.stderr, + ) + return 0 + + repo_id = os.environ.get("ML_INTERN_SESSION_REPO", DEFAULT_REPO) + + try: + redacted = _redact_jsonl(src) + except Exception as e: + # Don't upload the raw transcript if redaction fails — better to lose + # the session than to leak a token. + print(f"[ml-intern] redaction failed, NOT uploading: {e}", file=sys.stderr) + return 0 + + try: + from huggingface_hub import HfApi + + api = HfApi(token=token) + api.upload_file( + path_or_fileobj=str(redacted), + path_in_repo=f"sessions/{session_id}.jsonl", + repo_id=repo_id, + repo_type="dataset", + commit_message=f"Upload session {session_id}", + ) + except Exception as e: + print(f"[ml-intern] session upload failed: {e}", file=sys.stderr) + finally: + try: + redacted.unlink() + except OSError: + pass + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.claude/hooks/session_start_context.py b/.claude/hooks/session_start_context.py new file mode 100755 index 00000000..29586a4b --- /dev/null +++ b/.claude/hooks/session_start_context.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +SessionStart hook — inject the dynamic session context that the standalone +CLI builds in agent/context_manager/manager.py: + + - HF username (so the agent uses the right namespace for hub_model_id) + - Local-mode banner (only when ML_INTERN_LOCAL_MODE=1, mirrors the + "CLI / Local mode" block injected into the system prompt) + +Output is JSON `additionalContext` per Claude Code's SessionStart hook +contract — Claude Code surfaces it to the model as a system reminder. +""" + +from __future__ import annotations + +import json +import os +import sys + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name, "").strip().lower() + if not val: + return default + return val in ("1", "true", "yes", "on") + + +def _hf_username(token: str | None) -> tuple[str | None, str | None]: + """Return (username, error_reason). Exactly one is non-None. + + The standalone CLI uses curl with `-4` to dodge IPv6 Happy-Eyeballs + hangs (see agent/context_manager/manager.py:27-30). `huggingface_hub` + is already a dep here and uses `requests`/`urllib3` which doesn't + have the same pathology in normal setups; we use it for KISS reasons + and accept that very-broken IPv6 environments will time out instead + of falling back instantly. + """ + if not token: + return None, "no HF_TOKEN in environment" + try: + from huggingface_hub import HfApi + from huggingface_hub.utils import HfHubHTTPError + + info = HfApi(token=token).whoami() + except HfHubHTTPError as e: + return None, f"whoami HTTP error: {e}" + except Exception as e: + return None, f"whoami failed: {type(e).__name__}: {e}" + + name = info.get("name") if isinstance(info, dict) else None + if isinstance(name, str) and name: + return name, None + return None, "whoami returned no name" + + +def main() -> int: + try: + sys.stdin.read() + except Exception: + pass + + parts: list[str] = [] + + user, err = _hf_username(os.environ.get("HF_TOKEN")) + if user: + parts.append( + f"HF user: **{user}** — use `{user}/` as the namespace when " + f"constructing `hub_model_id` for training jobs unless the user " + f"specifies otherwise." + ) + else: + # Distinguish "no token" from "request failed" — the second case is + # fixable (rotate token, check network), the first is configuration. + parts.append( + f"HF user: unknown ({err}). Ask the user for their HF org before " + f"constructing `hub_model_id`." + ) + + if _env_flag("ML_INTERN_LOCAL_MODE", False): + parts.append( + "**CLI / Local mode is ON.** There is NO sandbox — `bash`, `read`, `write`, " + "and `edit` (the `mcp__ml-intern-tools__*` versions) operate directly on the " + "local filesystem. The `sandbox_create` tool is NOT available. Use absolute " + "paths or paths relative to the working directory. Do NOT use `/app/` paths — " + "that is a sandbox convention that does not apply here." + ) + + output = { + "hookSpecificOutput": { + "hookEventName": "SessionStart", + "additionalContext": "\n\n".join(parts), + } + } + print(json.dumps(output)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 00000000..d5d9164d --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,69 @@ +{ + "$schema": "https://json.schemastore.org/claude-code-settings.json", + "permissions": { + "allow": [ + "Read", + "Grep", + "Glob", + "WebFetch", + "TodoWrite", + "Task", + "mcp__ml-intern-tools__explore_hf_docs", + "mcp__ml-intern-tools__fetch_hf_docs", + "mcp__ml-intern-tools__hf_papers", + "mcp__ml-intern-tools__hf_inspect_dataset", + "mcp__ml-intern-tools__github_find_examples", + "mcp__ml-intern-tools__github_list_repos", + "mcp__ml-intern-tools__github_read_file", + "mcp__ml-intern-tools__hf_repo_files", + "mcp__ml-intern-tools__hf_repo_git", + "mcp__ml-intern-tools__hf_jobs", + "mcp__ml-intern-tools__sandbox_create", + "mcp__ml-intern-tools__bash", + "mcp__ml-intern-tools__read", + "mcp__ml-intern-tools__write", + "mcp__ml-intern-tools__edit", + "mcp__hf-mcp-server__*" + ] + }, + "env": { + "ML_INTERN_SAVE_SESSIONS": "1", + "ML_INTERN_SESSION_REPO": "smolagents/ml-intern-sessions", + "ML_INTERN_CONFIRM_CPU_JOBS": "1", + "ML_INTERN_YOLO": "0", + "ML_INTERN_LOCAL_MODE": "0" + }, + "hooks": { + "SessionStart": [ + { + "hooks": [ + { + "type": "command", + "command": "uv run python ${CLAUDE_PROJECT_DIR}/.claude/hooks/session_start_context.py" + } + ] + } + ], + "PreToolUse": [ + { + "matcher": "mcp__ml-intern-tools__.*|Bash", + "hooks": [ + { + "type": "command", + "command": "uv run python ${CLAUDE_PROJECT_DIR}/.claude/hooks/pre_tool_use_approval.py" + } + ] + } + ], + "SessionEnd": [ + { + "hooks": [ + { + "type": "command", + "command": "uv run python ${CLAUDE_PROJECT_DIR}/.claude/hooks/session_end_upload.py" + } + ] + } + ] + } +} diff --git a/.github/workflows/plugin-vendor-sync.yml b/.github/workflows/plugin-vendor-sync.yml new file mode 100644 index 00000000..534c6434 --- /dev/null +++ b/.github/workflows/plugin-vendor-sync.yml @@ -0,0 +1,49 @@ +name: Plugin vendor sync check + +# Guards against drift between agent/tools/* (source of truth) and +# plugin/lib/ml_intern_lib/* (vendored snapshot used by the Claude Code plugin). +# +# If anyone edits agent/tools/*.py or agent/core/redact.py without re-running +# scripts/sync_plugin_vendored.py, this job fails and points them at the script. +# +# This workflow uses only static commands and does not consume any +# attacker-controlled fields from github.event.*, so command-injection +# guidance does not apply. + +on: + pull_request: + paths: + - "agent/tools/**" + - "agent/core/redact.py" + - "plugin/lib/ml_intern_lib/**" + - "scripts/sync_plugin_vendored.py" + push: + branches: [main] + paths: + - "agent/tools/**" + - "agent/core/redact.py" + - "plugin/lib/ml_intern_lib/**" + - "scripts/sync_plugin_vendored.py" + +permissions: + contents: read + +jobs: + vendor-sync: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v3 + with: + version: "latest" + + - name: Re-vendor and check for drift + run: | + uv run python scripts/sync_plugin_vendored.py + if ! git diff --exit-code -- plugin/lib/ml_intern_lib; then + echo "::error::plugin/lib/ml_intern_lib is out of sync with agent/tools/* and agent/core/redact.py" + echo "::error::Run 'uv run python scripts/sync_plugin_vendored.py' locally and commit the result." + exit 1 + fi + echo "Vendored plugin library is in sync with upstream agent/." diff --git a/.gitignore b/.gitignore index d758b077..fe299531 100644 --- a/.gitignore +++ b/.gitignore @@ -60,7 +60,14 @@ session_logs/ /logs hf-agent-leaderboard/ skills/ -.claude/ + +# Claude Code project mode: track shared config + commands + agents + hooks, +# but never track per-user local overrides or runtime state. +.claude/settings.local.json +.claude/projects/ +.claude/__pycache__/ +.claude/**/__pycache__/ + *.jsonl *.csv diff --git a/.mcp.json b/.mcp.json new file mode 100644 index 00000000..3467bdfa --- /dev/null +++ b/.mcp.json @@ -0,0 +1,28 @@ +{ + "mcpServers": { + "ml-intern-tools": { + "type": "stdio", + "command": "uv", + "args": [ + "run", + "--project", + "${CLAUDE_PROJECT_DIR}", + "python", + "-m", + "packages.mcp_server.server" + ], + "env": { + "HF_TOKEN": "${HF_TOKEN}", + "GITHUB_TOKEN": "${GITHUB_TOKEN}", + "ML_INTERN_LOCAL_MODE": "${ML_INTERN_LOCAL_MODE}" + } + }, + "hf-mcp-server": { + "type": "http", + "url": "https://huggingface.co/mcp?login", + "headers": { + "Authorization": "Bearer ${HF_TOKEN}" + } + } + } +} diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..a665af27 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,160 @@ +You are ML Intern, an ML engineering assistant for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face ecosystem. + +Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation. + +# Your knowledge of HF libraries is outdated + +You do not know current APIs for TRL, Transformers, PEFT, Trackio, or other HF libraries. Your internal knowledge WILL produce wrong imports, wrong argument names, and wrong trainer configurations. + +Before writing any ML implementation code, start from the literature. Delegate to the `research` subagent — it can crawl papers, read methodology sections, trace citation graphs, and extract the exact datasets and training recipes that produced published results. This is your primary advantage — use it. + +Your default workflow for any ML task: +1. Find the landmark paper(s) for the task or domain +2. Crawl their citation graphs to find recent downstream work +3. Read methodology sections (not abstracts) of the most promising papers — especially recent ones with strong results, lots of citations, and publications in high-impact conferences +4. Extract the recipe: what dataset, what training method, what hyperparameters produced those results +5. Validate and use those datasets for training + +Invoke the research subagent (via the Task tool, `subagent_type: "research"`) with a specific brief — name anchor papers or arxiv IDs when you have them. Example brief: + +> Literature crawl for [task]. Start from [paper/topic]. Crawl citation graph for recent downstream papers. Read their methodology sections (3, 4, 5) — extract the exact datasets, training methods, and hyperparameters that produced their best results. Attribute every finding to a specific result (e.g. "Dataset X + method Y → 85.3% on benchmark Z"). Also find working code examples using current TRL/Transformers APIs. + +You can also call research tools directly (`explore_hf_docs`, `github_read_file`, `hf_papers`, etc.) for quick lookups. + +Skip research only for trivial non-code operations. + +# Mistakes you WILL make without research + +HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first. + +WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via `explore_hf_docs` + `fetch_hf_docs`. + +WRONG DATASET FORMAT: You will assume column names without checking. Training fails with KeyError. Fix: call `hf_inspect_dataset` and verify columns match the training method. + +DEFAULT TIMEOUT KILLS JOBS: You will leave timeout at the default 30m for training jobs. Training takes hours. The job gets killed and all progress is lost. Fix: set timeout based on model size (minimum 2h for any training). + +LOST MODELS: You will forget `push_to_hub=True` and `hub_model_id` in training config. Job storage is ephemeral — the filesystem is deleted when the job ends. Without `push_to_hub`, the trained model is permanently lost. + +BATCH FAILURES: You will submit all ablation/batch jobs at once without testing that one works first. All will fail for the same bug. Fix: submit ONE job first, verify it completes successfully, then submit the rest. + +SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do. + +HARDCODED UNAVAILABLE PACKAGES: You will forget to install necessary packages like `flash-attn` for `flash_attention_2` or other packages that aren't automatically installed in the job environment. Fix: install necessary packages before running the job. + +SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself — switching full SFT to LoRA on OOM, reducing `max_length` (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and is grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach, or any other part of the task. + +# When writing ML code + +Required sequence before any training/fine-tuning/inference script: +1. Use the `research` subagent to find working examples, read docs, and get current API patterns +2. Validate dataset: `hf_inspect_dataset` to confirm column names and format +3. Validate model: confirm it exists, correct architecture/size/tokenizer + +Training logging: always set `disable_tqdm=True`, `logging_strategy="steps"`, and `logging_first_step=True` in your `TrainingArguments`/`SFTConfig` so loss values are printed as plain text lines you can grep, not hidden inside tqdm progress bars. + +Dataset format requirements by training method: +- SFT: `messages`, `text`, or `prompt`/`completion` +- DPO: `prompt`, `chosen`, `rejected` +- GRPO: `prompt` + +# Data audit + +Before working with any dataset, audit it first. Do not assume you know what the data looks like — inspect it. + +Use `hf_inspect_dataset` to check: schema/columns, number of rows per split, value distributions for key columns, sample rows. Surface anything notable: class imbalance, missing values, unexpected formats, outliers, duplicate rows, etc. + +Looking at data is the best way to boost performance of any ML model plus it reduces the likelihood of failed jobs later. + +# When submitting a training job + +Before calling `hf_jobs`, output a pre-flight check: +- Reference implementation: [which example you based this on] +- Dataset format verified: [columns confirmed via `hf_inspect_dataset`] +- `push_to_hub=True` and `hub_model_id` set +- timeout: [value] (based on: [model size] on [hardware]) +- Trackio monitoring included and working + +If you cannot fill in all items, stop and complete the missing steps first. + +For batch/ablation jobs: submit ONE job first. Check logs to confirm it starts training successfully. Only then submit the remaining jobs. Never submit all at once. + +Hardware sizing: +- 1-3B params: `a10g-largex2` +- 7-13B params: `a100-large` +- 30B+ params: `l40sx4` or `a100x4` +- 70B+ params: `a100x8` + +Note: `a10g-small` and `a10g-large` have the SAME 24GB GPU memory. The difference is CPU/RAM only. + +# Sandbox-first development + +For non-trivial scripts, develop and test in a sandbox before launching via `hf_jobs`: + +`sandbox_create` → install deps → write script → test with small run → fix errors → launch via `hf_jobs` at scale + +Use GPU sandbox (`t4-small` minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths. + +# When a task has 3+ steps + +Use the TodoWrite tool to track progress. One task `in_progress` at a time. Mark `completed` immediately after finishing. Update frequently to show the user what you're doing. + +# Error recovery + +When something fails: +- Diagnose the actual error. Read the full error message and logs. +- Do not retry the exact same thing. Identify what needs to change. +- If an API/import error: check documentation for the correct API. +- If an OOM error: (1) reduce `per_device_train_batch_size` and increase `gradient_accumulation_steps` proportionally to keep effective batch size identical, (2) enable `gradient_checkpointing=True`, (3) upgrade to larger GPU (`a10gx4`→`a100`→`a100x4`→`a100x8`). Do NOT switch training methods (e.g. SFT→LoRA) or reduce `max_length` — those change what the user gets. If OOM happens in sandbox, create a new sandbox with larger GPU hardware. +- Never change the user's requested approach (training method, dataset, model, sequence length) without explicit approval. +- If a tool call fails repeatedly for the same reason: stop and try a different approach. +- Never silently substitute resources (datasets, models) — tell the user if something isn't available. + +# Task completion + +Before ending your turn, verify: +- Did you actually DO what the user asked, not just explain what you would do? +- If something failed: did you diagnose and fix it, or at minimum explain what went wrong and ask for user input? +- For training jobs: did you include a working Trackio dashboard URL? + +Do not stop after describing what you plan to do. Continue calling tools until the task is verifiably done. + +# Autonomous / headless mode + +When running autonomously (`claude -p ...` with no human in the loop), you MUST follow these rules: + +NEVER respond with only text. Every response MUST include at least one tool call. If you have nothing to do, check the plan, verify outputs, or plan ahead. + +NEVER STOP WORKING. Do NOT decide you are "done" while time remains. The human is not watching — they expect you to use the ENTIRE time budget productively. Do NOT ask "should I continue?" or "is this a good stopping point?" — there is nobody to answer. + +Your workflow is a loop, not a checklist. Once you have a working result, KEEP ITERATING: + +LOOP UNTIL TIME RUNS OUT: +1. Research the approach (read docs, find examples, check current APIs) +2. Implement the solution (write code, set up training) +3. Train and evaluate +4. Save the model to the required output location / push it to Hugging Face Hub +5. Improve: tune hyperparameters, try different data, adjust the training recipe, try a different approach entirely +6. Go to step 1 + +HYPERPARAMETER TUNING: Do not tune hyperparameters by hand one-at-a-time. Write a script that launches a sweep over a grid of values (learning rate, epochs, batch size, etc.) and evaluates each run automatically. One well-designed sweep script beats ten manual experiments. + +If you run out of ideas: go back to the literature. Crawl citation graphs deeper — find papers you haven't read yet, read their methodology sections, extract new datasets or training tricks. Look for papers that cite your current approach and improved on it. Try combining recipes from different papers. Re-read the task prompt for angles you missed. Re-read the training logs for clues. There is always a paper you haven't read yet, and it probably has a better dataset. + +The task is NOT done until: +- The required output exists (e.g. final model, metrics reached, dataset updated, etc.) +- You have evaluated the model and confirmed it works + +# Communication + +- Be concise and direct. No filler, no restating what the user said. +- One-word answers when appropriate for simple questions. +- Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs. +- For errors: state what went wrong, why, and what you're doing to fix it. +- Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity. + +# Tool usage + +- Execute multiple independent tool calls in parallel when possible. +- `HF_TOKEN` is automatically available in job secrets — no need to include it extra. +- For training monitoring: include Trackio in the script and provide the dashboard URL. +- For private/gated datasets: `HF_TOKEN` is needed — it's auto-loaded into job secrets. diff --git a/CLAUDE_CODE_GUIDE.md b/CLAUDE_CODE_GUIDE.md new file mode 100644 index 00000000..5a07a31c --- /dev/null +++ b/CLAUDE_CODE_GUIDE.md @@ -0,0 +1,291 @@ +# Using ml-intern with Claude Code + +This repo can run two ways: + +1. **Standalone CLI** — `ml-intern` (the original; see [README](README.md)). +2. **Inside Claude Code** — `claude` from the repo root, picks up `CLAUDE.md`, `.mcp.json`, `.claude/`. + +This guide covers (2). Both share the same tools under `agent/tools/`, so behavior matches; only the harness changes. + +--- + +## Prerequisites + +- [Claude Code](https://docs.claude.com/en/docs/claude-code) installed and signed in. +- [`uv`](https://docs.astral.sh/uv/) on `$PATH` (used to launch the MCP server and hooks). +- A clone of this repo with deps synced: + + ```bash + git clone git@github.com:huggingface/ml-intern.git + cd ml-intern + uv sync + ``` + +- An `.env` (or exported shell vars) with at minimum: + + ```bash + HF_TOKEN=hf_... # required — HF MCP server, papers, datasets, jobs, sessions upload + GITHUB_TOKEN=ghp_... # required — github_find_examples, github_read_file, github_list_repos + ``` + + Without `HF_TOKEN`, the HF MCP server returns 401s and the SessionStart hook reports `HF user: unknown`. Without `GITHUB_TOKEN`, the GitHub tools error. + +--- + +## First run + +From the repo root: + +```bash +claude +``` + +That's it. Claude Code reads: + +- `CLAUDE.md` — persona and methodology (research-first, dataset audit, pre-flight checklist for jobs, error-recovery rules). +- `.mcp.json` — auto-starts two MCP servers: + - `ml-intern-tools` (stdio, local) — exposes `hf_papers`, `hf_inspect_dataset`, `hf_jobs`, `hf_repo_files`, `hf_repo_git`, `explore_hf_docs`, `fetch_hf_docs`, `github_*`, sandbox `bash`/`read`/`write`/`edit`. + - `hf-mcp-server` (HTTP, hosted at `huggingface.co/mcp`) — official HF tools. +- `.claude/agents/research.md` — the parallel research subagent (read-only HF tools). +- `.claude/commands/*.md` — the slash commands listed below. +- `.claude/hooks/*.py` — content-aware approval, session redaction+upload, dynamic context injection. + +You should see (early in the first turn) a system reminder like: + +> HF user: **your-org** — use `your-org/` as the namespace when constructing `hub_model_id`... + +That's the SessionStart hook injecting context. If it says `HF user: unknown (...)`, fix the cause (missing token, expired token, network) before continuing. + +--- + +## Slash commands + +All commands accept free-form arguments after the name. They're prompt templates that route the agent through the right ml-intern workflow. + +### `/ml-intern ` + +Default entrypoint. Equivalent to `ml-intern ""` in the standalone CLI — runs the full research→validate→implement workflow per `CLAUDE.md`. + +``` +/ml-intern fine-tune llama-3-8b on HuggingFaceH4/ultrachat_200k for math reasoning +``` + +### `/research ` + +Forces a literature crawl via the `research` subagent. Use when you want recipes, citation graphs, or methodology comparison **without** the agent jumping straight to code. + +``` +/research diffusion model fine-tuning for medical imaging +/research best DPO recipe for instruction tuning, 7B-13B range +``` + +The subagent has its own context window and read-only tools (papers, docs, datasets, github, hf-repo). Returns a ranked recipe table. + +### `/inspect-dataset ` + +Audit a HF dataset before training: schema, splits, sample rows, red flags, training-method compatibility (SFT/DPO/GRPO). + +``` +/inspect-dataset HuggingFaceH4/ultrachat_200k +/inspect-dataset Anthropic/hh-rlhf +``` + +### `/finetune ` + +Strict, opinionated end-to-end fine-tune. Forces: +1. Research subagent first. +2. `hf_inspect_dataset` to verify column format. +3. Sandbox smoke test before anything large. +4. Pre-flight check (reference impl, `push_to_hub`, hardware, timeout, Trackio). +5. **One** job submitted; logs watched; only then any sweep. + +``` +/finetune llama-3-8b on HuggingFaceH4/ultrachat_200k +/finetune mistral-7b DPO on Anthropic/hh-rlhf +``` + +### `/run-job ` + +Submit any HF Job (training, eval, batch inference, data prep). Refuses to call `hf_jobs` until the pre-flight checklist is filled, including a ≥2h timeout for training jobs. + +``` +/run-job batch eval gpt2 on lm-eval harness MMLU +/run-job convert webdataset shards on 32 vCPUs +``` + +--- + +## Approvals — what to expect + +ml-intern's approval policy is enforced via a `PreToolUse` hook (`.claude/hooks/pre_tool_use_approval.py`). Claude Code will prompt you when: + +| Tool / op | When you'll be asked | +|---|---| +| `hf_jobs` (run/uv) on **GPU hardware** | Always | +| `hf_jobs` on CPU hardware | When `ML_INTERN_CONFIRM_CPU_JOBS=1` (default) | +| `hf_jobs` with a script that has `from_pretrained` but no `push_to_hub` | Always (warning surfaces in the prompt) | +| `sandbox_create` | Always | +| `hf_repo_files` `upload` / `delete` | Always | +| `hf_repo_git` destructive ops (delete branch/tag, merge PR, create/update repo) | Always | +| Anything else | Auto-allowed by static permissions (see `.claude/settings.json`) | + +To skip all approvals (e.g. unattended overnight runs): `ML_INTERN_YOLO=1 claude`. **Don't habit-form that.** + +If the hook crashes or gets a malformed payload, it **fails safe** — forces a prompt rather than silently allowing. + +--- + +## Environment knobs + +Set in your shell, `.env`, or override in `.claude/settings.json` `env` block. All have ml-intern-CLI equivalents. + +| Env var | Default | What it does | CLI equivalent | +|---|---|---|---| +| `HF_TOKEN` | — | HF auth (tools, MCP, sessions upload, whoami) | same | +| `GITHUB_TOKEN` | — | GitHub tools | same | +| `HF_SESSION_UPLOAD_TOKEN` | — | Preferred (write-only) token for sessions upload; falls back to `HF_TOKEN` then `HF_ADMIN_TOKEN` | same | +| `ML_INTERN_YOLO` | `0` | Skip all approvals | `Config.yolo_mode` | +| `ML_INTERN_CONFIRM_CPU_JOBS` | `1` | Prompt before CPU jobs | `Config.confirm_cpu_jobs` | +| `ML_INTERN_SAVE_SESSIONS` | `1` | Upload transcripts to HF dataset on session end | `Config.save_sessions` | +| `ML_INTERN_SESSION_REPO` | `smolagents/ml-intern-sessions` | Target dataset | `Config.session_dataset_repo` | +| `ML_INTERN_LOCAL_MODE` | `0` | Run sandbox-style tools (`bash`/`read`/`write`/`edit`) on local fs instead of remote sandbox | `--local` | + +When `ML_INTERN_LOCAL_MODE=1`, the SessionStart hook injects an extra reminder telling the model "no sandbox — operate on local fs, no `/app/` paths." + +--- + +## Headless / unattended + +For one-shot runs from CI or a script: + +```bash +claude -p "/ml-intern fine-tune gpt2-medium on tatsu-lab/alpaca, push to my-org/gpt2-alpaca-test" +``` + +Pair with `ML_INTERN_YOLO=1` if you genuinely have no human in the loop. Read [`CLAUDE.md`](CLAUDE.md)'s "Autonomous / headless mode" section first — the rules differ from interactive (no text-only responses, always be doing work, hyperparameter sweeps not manual tuning). + +--- + +## Privacy: what gets uploaded + +When `ML_INTERN_SAVE_SESSIONS=1` (default), at session end the transcript is uploaded to `ML_INTERN_SESSION_REPO` (default: `smolagents/ml-intern-sessions`) **after** running it through `agent/core/redact.py::scrub`, which strips: + +- `hf_…` HF tokens +- `sk-ant-…` Anthropic keys +- `sk-…` OpenAI keys +- `ghp_/gho_/ghu_/ghs_/ghr_/github_pat_…` GitHub tokens +- `AKIA…/ASIA…` AWS access keys +- `Bearer …` Authorization headers +- `KEY=value` exports for any name matching `HF_TOKEN|API_KEY|SECRET|PASSWORD|...` + +Redaction is regex-based and best-effort. If you paste an unusual secret format ("hunter2") it won't be caught — don't paste secrets into chat. + +The hook also refuses to upload a transcript whose path is outside `~/.claude/` or `$CLAUDE_PROJECT_DIR`. To opt out entirely: `ML_INTERN_SAVE_SESSIONS=0`. + +--- + +## Common workflows + +### "What's the best recipe for X?" + +``` +/research X +``` + +Wait for the recipe table. Then either ask follow-ups in the same turn or invoke `/finetune` with the recipe in mind. + +### "Train this model on this dataset" + +``` +/finetune on +``` + +Watch for: +1. The research subagent's findings (loss recipe, hyperparameters). +2. `hf_inspect_dataset` output (column format check). +3. The sandbox smoke-test logs. +4. The pre-flight checklist. +5. Approval prompt for the GPU job. **Read the warning text** if any. +6. The job URL + Trackio dashboard URL. + +### "Just run this script as a job" + +``` +/run-job +``` + +Provide the script body in the chat or as a file path. The model will fill the pre-flight checklist before submitting. + +### "Audit this dataset" + +``` +/inspect-dataset +``` + +Useful as a standalone read; also useful before kicking off `/finetune` to spot column-format issues early. + +--- + +## Troubleshooting + +**"Tool not found: `mcp__ml-intern-tools__...`"** — the MCP server isn't running. Check `claude mcp list`; if it errors, run `uv run python -m packages.mcp_server.server < /dev/null` to surface the import error. + +**"401 Unauthorized" from `hf_papers` or `hf_jobs`** — `HF_TOKEN` not in env. The `.mcp.json` substitutes `${HF_TOKEN}` from the launching shell; if you `claude` from a shell where it's not exported, the MCP server inherits an empty token. + +**SessionStart shows `HF user: unknown (whoami HTTP error: ...)`** — token rejected. Probably expired or scoped wrong. Generate a new one at . + +**Approval prompt every turn for `hf_papers`** — the static permissions list in `.claude/settings.json` doesn't include the tool name, or the MCP server didn't register it. Verify with `claude mcp list` and check the tool name format (`mcp__ml-intern-tools__`). + +**`from_pretrained` warning on a script that's fine** — substring match is conservative. If the script genuinely doesn't need `push_to_hub` (e.g. eval-only), approve and proceed. + +**Session upload fails silently** — check stderr of the Claude Code process. Errors print there. Common causes: token doesn't have write access to `ML_INTERN_SESSION_REPO`, or the dataset doesn't exist. + +**Hook crashes** — run the hook by hand to reproduce: + +```bash +echo '{"tool_name":"mcp__ml-intern-tools__hf_jobs","tool_input":{"operation":"run","script":"x","hardware_flavor":"a100-large"}}' \ + | uv run python .claude/hooks/pre_tool_use_approval.py +``` + +--- + +## Adding your own tools + +The standalone CLI exposes new tools via `agent/tools/*.py` + a `ToolSpec` registered in `agent/core/tools.py`. To make those tools available inside Claude Code: + +1. Implement the handler in `agent/tools/your_tool.py` with a `YOUR_TOOL_SPEC` dict and an async handler. +2. Add the `(spec, handler)` tuple to `_TOOL_SPECS` in `packages/mcp_server/server.py`. +3. Add `mcp__ml-intern-tools__` to `.claude/settings.json` `permissions.allow`. +4. (Optional) If destructive, extend `_needs_approval` in `.claude/hooks/pre_tool_use_approval.py`. +5. (Optional) If read-only, add it to `.claude/agents/research.md` `tools:` frontmatter so the research subagent can use it. + +The standalone CLI continues to work — both frontends share the same handler. + +--- + +## Adding your own slash commands + +Drop a markdown file at `.claude/commands/.md`: + +```markdown +--- +description: One-line description shown in `/` listing. +argument-hint: +--- + +Your prompt template here. Use $ARGUMENTS for the user's input. +``` + +The commands in this repo are intentionally opinionated (forcing research, refusing to skip pre-flight) — match that posture if you want consistent behavior. + +--- + +## When to use the standalone CLI instead + +The Claude Code path is the recommended default. Reach for `ml-intern` directly when you need: + +- The original CLI's `/effort`, `/model`, `/yolo` toggles mid-session. +- The session JSONL trajectory written locally (the standalone CLI writes one; Claude Code's transcript is its own format). +- The web UI under `backend/`+`frontend/` for browsing past sessions. + +Otherwise, use Claude Code — you get plan mode, native subagent ergonomics, better context management, and the same tool surface. diff --git a/README.md b/README.md index 29fe439b..0e5d0a95 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,30 @@ ml-intern --max-iterations 100 "your prompt" ml-intern --no-stream "your prompt" ``` +### Running inside Claude Code + +Two ways: + +**As a plugin** (use ml-intern's slash commands and tools in *any* repo): + +``` +/plugin marketplace add huggingface/ml-intern +/plugin install ml-intern@ml-intern +``` + +See [`plugin/README.md`](plugin/README.md) for plugin docs. + +**As a project** (this repo, for hacking on ml-intern itself). From the repo root: + +```bash +claude # interactive +claude -p "fine-tune llama on my dataset" # headless +``` + +Claude Code picks up `CLAUDE.md` (persona), `.mcp.json` (HF tools via `packages/mcp_server`), `.claude/agents/research.md` (research subagent), `.claude/commands/*.md` (slash commands: `/ml-intern`, `/research`, `/inspect-dataset`, `/finetune`, `/run-job`), and `.claude/hooks/` (content-aware approval, session redaction + upload, dynamic context injection). The standalone CLI under `agent/` is unchanged — both share the same tool implementations. + +See [`CLAUDE_CODE_GUIDE.md`](CLAUDE_CODE_GUIDE.md) for slash commands, approvals, env knobs, and troubleshooting. + ## Architecture ### Component Overview diff --git a/packages/mcp_server/server.py b/packages/mcp_server/server.py new file mode 100644 index 00000000..29560c26 --- /dev/null +++ b/packages/mcp_server/server.py @@ -0,0 +1,144 @@ +""" +ML Intern tools, exposed as an MCP server for Claude Code. + +Thin shim over agent/tools/*: same handlers, same JSON schemas, same +behavior — only the transport changes from "litellm tool calls inside +agent_loop.py" to "MCP stdio for Claude Code". + +Uses the low-level `mcp.server.lowlevel.Server` API so we can register +tools with the original JSON schemas verbatim. FastMCP's high-level +`@mcp.tool` would re-derive schemas from Python type hints, which would +lose nullable/oneOf/operation-discriminated structures the existing +ml-intern specs encode. + +Run via the `.mcp.json` at the repo root. Not intended to be invoked manually. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Any, Awaitable, Callable + +from mcp import types +from mcp.server.lowlevel import Server +from mcp.server.stdio import stdio_server + +from agent.tools.dataset_tools import ( + HF_INSPECT_DATASET_TOOL_SPEC, + hf_inspect_dataset_handler, +) +from agent.tools.docs_tools import ( + EXPLORE_HF_DOCS_TOOL_SPEC, + HF_DOCS_FETCH_TOOL_SPEC, + explore_hf_docs_handler, + hf_docs_fetch_handler, +) +from agent.tools.github_find_examples import ( + GITHUB_FIND_EXAMPLES_TOOL_SPEC, + github_find_examples_handler, +) +from agent.tools.github_list_repos import ( + GITHUB_LIST_REPOS_TOOL_SPEC, + github_list_repos_handler, +) +from agent.tools.github_read_file import ( + GITHUB_READ_FILE_TOOL_SPEC, + github_read_file_handler, +) +from agent.tools.hf_repo_files_tool import ( + HF_REPO_FILES_TOOL_SPEC, + hf_repo_files_handler, +) +from agent.tools.hf_repo_git_tool import HF_REPO_GIT_TOOL_SPEC, hf_repo_git_handler +from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler +from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler +from agent.tools.sandbox_tool import get_sandbox_tools + +logger = logging.getLogger(__name__) + +# `research` and `plan_tool` are intentionally NOT exposed: +# research → replaced by .claude/agents/research.md (Claude Code subagent) +# plan_tool → replaced by Claude Code's built-in TodoWrite +_TOOL_SPECS: list[tuple[dict[str, Any], Callable[..., Awaitable[tuple[str, bool]]]]] = [ + (EXPLORE_HF_DOCS_TOOL_SPEC, explore_hf_docs_handler), + (HF_DOCS_FETCH_TOOL_SPEC, hf_docs_fetch_handler), + (HF_PAPERS_TOOL_SPEC, hf_papers_handler), + (HF_INSPECT_DATASET_TOOL_SPEC, hf_inspect_dataset_handler), + (HF_JOBS_TOOL_SPEC, hf_jobs_handler), + (HF_REPO_FILES_TOOL_SPEC, hf_repo_files_handler), + (HF_REPO_GIT_TOOL_SPEC, hf_repo_git_handler), + (GITHUB_FIND_EXAMPLES_TOOL_SPEC, github_find_examples_handler), + (GITHUB_LIST_REPOS_TOOL_SPEC, github_list_repos_handler), + (GITHUB_READ_FILE_TOOL_SPEC, github_read_file_handler), +] + +# Discovered async at startup. Populated below in build_registry(). +_REGISTRY: dict[str, tuple[types.Tool, Callable[..., Awaitable[tuple[str, bool]]]]] = {} + + +def _build_registry() -> None: + """Populate the {name: (Tool, handler)} registry.""" + for spec, handler in _TOOL_SPECS: + tool = types.Tool( + name=spec["name"], + description=spec["description"], + inputSchema=spec["parameters"], + ) + _REGISTRY[spec["name"]] = (tool, handler) + + # Sandbox tools come from a factory because they depend on local_mode. + # Mirrors agent/main.py: ML_INTERN_LOCAL_MODE=1 routes shell/file ops to + # the local machine instead of HF Sandboxes. + local_mode = os.environ.get("ML_INTERN_LOCAL_MODE", "").lower() in ("1", "true", "yes") + if local_mode: + from agent.tools.local_tools import get_local_tools + + sandbox_specs = get_local_tools() + else: + sandbox_specs = get_sandbox_tools() + + for tool_spec in sandbox_specs: + tool = types.Tool( + name=tool_spec.name, + description=tool_spec.description, + inputSchema=tool_spec.parameters, + ) + _REGISTRY[tool_spec.name] = (tool, tool_spec.handler) + + +server: Server = Server("ml-intern-tools") + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + return [tool for tool, _ in _REGISTRY.values()] + + +@server.call_tool() +async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: + entry = _REGISTRY.get(name) + if entry is None: + raise ValueError(f"Unknown tool: {name}") + _tool, handler = entry + + output, ok = await handler(arguments or {}) + if not ok: + # MCP convention: raise so the client sees isError=true with the message. + raise RuntimeError(output) + return [types.TextContent(type="text", text=output)] + + +async def _amain() -> None: + _build_registry() + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + + +if __name__ == "__main__": + asyncio.run(_amain()) diff --git a/plugin/.claude-plugin/plugin.json b/plugin/.claude-plugin/plugin.json new file mode 100644 index 00000000..955e638d --- /dev/null +++ b/plugin/.claude-plugin/plugin.json @@ -0,0 +1,22 @@ +{ + "name": "ml-intern", + "version": "0.1.0", + "description": "ML engineering assistant for the Hugging Face ecosystem — research-first methodology, dataset auditing, HF Jobs orchestration, sandbox-driven development. Slash commands, a research subagent, and content-aware approval policy ported from the standalone ml-intern CLI.", + "author": { + "name": "Hugging Face", + "url": "https://github.com/huggingface/ml-intern" + }, + "homepage": "https://github.com/huggingface/ml-intern", + "repository": "https://github.com/huggingface/ml-intern", + "license": "Apache-2.0", + "keywords": [ + "huggingface", + "ml", + "fine-tuning", + "training", + "datasets", + "papers", + "agent", + "mcp" + ] +} diff --git a/plugin/.gitignore b/plugin/.gitignore new file mode 100644 index 00000000..67411141 --- /dev/null +++ b/plugin/.gitignore @@ -0,0 +1,5 @@ + +.venv/ +__pycache__/ +*.pyc +uv.lock diff --git a/plugin/.mcp.json b/plugin/.mcp.json new file mode 100644 index 00000000..6ddec3cb --- /dev/null +++ b/plugin/.mcp.json @@ -0,0 +1,27 @@ +{ + "mcpServers": { + "ml-intern-tools": { + "type": "stdio", + "command": "uv", + "args": [ + "run", + "--project", + "${CLAUDE_PLUGIN_ROOT}", + "python", + "${CLAUDE_PLUGIN_ROOT}/lib/mcp_server.py" + ], + "env": { + "HF_TOKEN": "${HF_TOKEN}", + "GITHUB_TOKEN": "${GITHUB_TOKEN}", + "ML_INTERN_LOCAL_MODE": "${ML_INTERN_LOCAL_MODE}" + } + }, + "hf-mcp-server": { + "type": "http", + "url": "https://huggingface.co/mcp?login", + "headers": { + "Authorization": "Bearer ${HF_TOKEN}" + } + } + } +} diff --git a/plugin/CLAUDE.md b/plugin/CLAUDE.md new file mode 100644 index 00000000..a665af27 --- /dev/null +++ b/plugin/CLAUDE.md @@ -0,0 +1,160 @@ +You are ML Intern, an ML engineering assistant for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face ecosystem. + +Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation. + +# Your knowledge of HF libraries is outdated + +You do not know current APIs for TRL, Transformers, PEFT, Trackio, or other HF libraries. Your internal knowledge WILL produce wrong imports, wrong argument names, and wrong trainer configurations. + +Before writing any ML implementation code, start from the literature. Delegate to the `research` subagent — it can crawl papers, read methodology sections, trace citation graphs, and extract the exact datasets and training recipes that produced published results. This is your primary advantage — use it. + +Your default workflow for any ML task: +1. Find the landmark paper(s) for the task or domain +2. Crawl their citation graphs to find recent downstream work +3. Read methodology sections (not abstracts) of the most promising papers — especially recent ones with strong results, lots of citations, and publications in high-impact conferences +4. Extract the recipe: what dataset, what training method, what hyperparameters produced those results +5. Validate and use those datasets for training + +Invoke the research subagent (via the Task tool, `subagent_type: "research"`) with a specific brief — name anchor papers or arxiv IDs when you have them. Example brief: + +> Literature crawl for [task]. Start from [paper/topic]. Crawl citation graph for recent downstream papers. Read their methodology sections (3, 4, 5) — extract the exact datasets, training methods, and hyperparameters that produced their best results. Attribute every finding to a specific result (e.g. "Dataset X + method Y → 85.3% on benchmark Z"). Also find working code examples using current TRL/Transformers APIs. + +You can also call research tools directly (`explore_hf_docs`, `github_read_file`, `hf_papers`, etc.) for quick lookups. + +Skip research only for trivial non-code operations. + +# Mistakes you WILL make without research + +HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first. + +WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via `explore_hf_docs` + `fetch_hf_docs`. + +WRONG DATASET FORMAT: You will assume column names without checking. Training fails with KeyError. Fix: call `hf_inspect_dataset` and verify columns match the training method. + +DEFAULT TIMEOUT KILLS JOBS: You will leave timeout at the default 30m for training jobs. Training takes hours. The job gets killed and all progress is lost. Fix: set timeout based on model size (minimum 2h for any training). + +LOST MODELS: You will forget `push_to_hub=True` and `hub_model_id` in training config. Job storage is ephemeral — the filesystem is deleted when the job ends. Without `push_to_hub`, the trained model is permanently lost. + +BATCH FAILURES: You will submit all ablation/batch jobs at once without testing that one works first. All will fail for the same bug. Fix: submit ONE job first, verify it completes successfully, then submit the rest. + +SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do. + +HARDCODED UNAVAILABLE PACKAGES: You will forget to install necessary packages like `flash-attn` for `flash_attention_2` or other packages that aren't automatically installed in the job environment. Fix: install necessary packages before running the job. + +SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself — switching full SFT to LoRA on OOM, reducing `max_length` (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and is grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach, or any other part of the task. + +# When writing ML code + +Required sequence before any training/fine-tuning/inference script: +1. Use the `research` subagent to find working examples, read docs, and get current API patterns +2. Validate dataset: `hf_inspect_dataset` to confirm column names and format +3. Validate model: confirm it exists, correct architecture/size/tokenizer + +Training logging: always set `disable_tqdm=True`, `logging_strategy="steps"`, and `logging_first_step=True` in your `TrainingArguments`/`SFTConfig` so loss values are printed as plain text lines you can grep, not hidden inside tqdm progress bars. + +Dataset format requirements by training method: +- SFT: `messages`, `text`, or `prompt`/`completion` +- DPO: `prompt`, `chosen`, `rejected` +- GRPO: `prompt` + +# Data audit + +Before working with any dataset, audit it first. Do not assume you know what the data looks like — inspect it. + +Use `hf_inspect_dataset` to check: schema/columns, number of rows per split, value distributions for key columns, sample rows. Surface anything notable: class imbalance, missing values, unexpected formats, outliers, duplicate rows, etc. + +Looking at data is the best way to boost performance of any ML model plus it reduces the likelihood of failed jobs later. + +# When submitting a training job + +Before calling `hf_jobs`, output a pre-flight check: +- Reference implementation: [which example you based this on] +- Dataset format verified: [columns confirmed via `hf_inspect_dataset`] +- `push_to_hub=True` and `hub_model_id` set +- timeout: [value] (based on: [model size] on [hardware]) +- Trackio monitoring included and working + +If you cannot fill in all items, stop and complete the missing steps first. + +For batch/ablation jobs: submit ONE job first. Check logs to confirm it starts training successfully. Only then submit the remaining jobs. Never submit all at once. + +Hardware sizing: +- 1-3B params: `a10g-largex2` +- 7-13B params: `a100-large` +- 30B+ params: `l40sx4` or `a100x4` +- 70B+ params: `a100x8` + +Note: `a10g-small` and `a10g-large` have the SAME 24GB GPU memory. The difference is CPU/RAM only. + +# Sandbox-first development + +For non-trivial scripts, develop and test in a sandbox before launching via `hf_jobs`: + +`sandbox_create` → install deps → write script → test with small run → fix errors → launch via `hf_jobs` at scale + +Use GPU sandbox (`t4-small` minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths. + +# When a task has 3+ steps + +Use the TodoWrite tool to track progress. One task `in_progress` at a time. Mark `completed` immediately after finishing. Update frequently to show the user what you're doing. + +# Error recovery + +When something fails: +- Diagnose the actual error. Read the full error message and logs. +- Do not retry the exact same thing. Identify what needs to change. +- If an API/import error: check documentation for the correct API. +- If an OOM error: (1) reduce `per_device_train_batch_size` and increase `gradient_accumulation_steps` proportionally to keep effective batch size identical, (2) enable `gradient_checkpointing=True`, (3) upgrade to larger GPU (`a10gx4`→`a100`→`a100x4`→`a100x8`). Do NOT switch training methods (e.g. SFT→LoRA) or reduce `max_length` — those change what the user gets. If OOM happens in sandbox, create a new sandbox with larger GPU hardware. +- Never change the user's requested approach (training method, dataset, model, sequence length) without explicit approval. +- If a tool call fails repeatedly for the same reason: stop and try a different approach. +- Never silently substitute resources (datasets, models) — tell the user if something isn't available. + +# Task completion + +Before ending your turn, verify: +- Did you actually DO what the user asked, not just explain what you would do? +- If something failed: did you diagnose and fix it, or at minimum explain what went wrong and ask for user input? +- For training jobs: did you include a working Trackio dashboard URL? + +Do not stop after describing what you plan to do. Continue calling tools until the task is verifiably done. + +# Autonomous / headless mode + +When running autonomously (`claude -p ...` with no human in the loop), you MUST follow these rules: + +NEVER respond with only text. Every response MUST include at least one tool call. If you have nothing to do, check the plan, verify outputs, or plan ahead. + +NEVER STOP WORKING. Do NOT decide you are "done" while time remains. The human is not watching — they expect you to use the ENTIRE time budget productively. Do NOT ask "should I continue?" or "is this a good stopping point?" — there is nobody to answer. + +Your workflow is a loop, not a checklist. Once you have a working result, KEEP ITERATING: + +LOOP UNTIL TIME RUNS OUT: +1. Research the approach (read docs, find examples, check current APIs) +2. Implement the solution (write code, set up training) +3. Train and evaluate +4. Save the model to the required output location / push it to Hugging Face Hub +5. Improve: tune hyperparameters, try different data, adjust the training recipe, try a different approach entirely +6. Go to step 1 + +HYPERPARAMETER TUNING: Do not tune hyperparameters by hand one-at-a-time. Write a script that launches a sweep over a grid of values (learning rate, epochs, batch size, etc.) and evaluates each run automatically. One well-designed sweep script beats ten manual experiments. + +If you run out of ideas: go back to the literature. Crawl citation graphs deeper — find papers you haven't read yet, read their methodology sections, extract new datasets or training tricks. Look for papers that cite your current approach and improved on it. Try combining recipes from different papers. Re-read the task prompt for angles you missed. Re-read the training logs for clues. There is always a paper you haven't read yet, and it probably has a better dataset. + +The task is NOT done until: +- The required output exists (e.g. final model, metrics reached, dataset updated, etc.) +- You have evaluated the model and confirmed it works + +# Communication + +- Be concise and direct. No filler, no restating what the user said. +- One-word answers when appropriate for simple questions. +- Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs. +- For errors: state what went wrong, why, and what you're doing to fix it. +- Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity. + +# Tool usage + +- Execute multiple independent tool calls in parallel when possible. +- `HF_TOKEN` is automatically available in job secrets — no need to include it extra. +- For training monitoring: include Trackio in the script and provide the dashboard URL. +- For private/gated datasets: `HF_TOKEN` is needed — it's auto-loaded into job secrets. diff --git a/plugin/README.md b/plugin/README.md new file mode 100644 index 00000000..355b15b8 --- /dev/null +++ b/plugin/README.md @@ -0,0 +1,137 @@ +# ml-intern (Claude Code plugin) + +Brings the [ml-intern](https://github.com/huggingface/ml-intern) ML engineering experience to Claude Code in any repository. Research-first methodology, HF dataset/paper/jobs/sandbox tools, content-aware approval policy. + +## Install + +### Via marketplace + +``` +/plugin marketplace add huggingface/ml-intern +/plugin install ml-intern@ml-intern +``` + +### Or directly from this repo + +``` +/plugin install /plugin +``` + +After install, restart Claude Code. The plugin will: + +- Bootstrap a stdio MCP server (`ml-intern-tools`) the first time you use one of its tools — this triggers `uv sync` against the bundled `pyproject.toml` (~30s the first time, instant after). +- Register five slash commands: `/ml-intern`, `/research`, `/inspect-dataset`, `/finetune`, `/run-job`. +- Register a `research` subagent invoked via the Task tool. +- Wire three lifecycle hooks (SessionStart, PreToolUse, SessionEnd). + +## Required environment + +Set in your shell or `.env`: + +```bash +HF_TOKEN=hf_... # required — papers, datasets, jobs, repo tools, sessions upload +GITHUB_TOKEN=ghp_... # required — github_find_examples, github_read_file, github_list_repos +``` + +The plugin's MCP server needs these in the Claude Code launching shell. See [Troubleshooting](#troubleshooting) if you see 401s. + +## What it gives you + +### Slash commands + +| Command | Purpose | +|---|---| +| `/ml-intern ` | Default entrypoint — runs the full research → validate → implement workflow. | +| `/research ` | Force a literature crawl via the research subagent (returns a ranked recipe table). | +| `/inspect-dataset ` | Audit a HF dataset (schema, splits, sample rows, training-method compatibility). | +| `/finetune ` | Strict end-to-end fine-tune (research → dataset audit → sandbox → pre-flight → submit ONE job). | +| `/run-job ` | Submit any HF Job with the pre-flight checklist (≥2h timeout, push_to_hub, Trackio). | + +### MCP tools (10) + +`hf_papers`, `hf_inspect_dataset`, `hf_jobs`, `hf_repo_files`, `hf_repo_git`, `explore_hf_docs`, `fetch_hf_docs`, `github_find_examples`, `github_list_repos`, `github_read_file`, plus sandbox `bash`/`read`/`write`/`edit`/`sandbox_create`. + +These appear in Claude Code as `mcp__ml-intern__ml-intern-tools__`. + +### Approval policy (PreToolUse hook) + +| Tool / op | Behavior | +|---|---| +| `hf_jobs` (run/uv) on **GPU hardware** | Always prompts | +| `hf_jobs` on CPU hardware | Prompts when `ML_INTERN_CONFIRM_CPU_JOBS=1` (default) | +| `hf_jobs` script with `from_pretrained` but no `push_to_hub` | Always prompts (warning surfaces) | +| `sandbox_create` | Always prompts | +| `hf_repo_files` `upload`/`delete` | Always prompts | +| `hf_repo_git` destructive ops | Always prompts | + +The hook **fails safe** — malformed payloads force a prompt, never silent allow. + +### Session redaction + upload (SessionEnd hook) + +When `ML_INTERN_SAVE_SESSIONS=1` (default), transcripts upload to `smolagents/ml-intern-sessions` (override with `ML_INTERN_SESSION_REPO`). The bundled `redact.py` strips HF/Anthropic/OpenAI/GitHub/AWS tokens before upload. Refuses to upload paths outside `~/.claude/` or `$CLAUDE_PROJECT_DIR`. + +### Dynamic context (SessionStart hook) + +Injects: +- HF username from `HF_TOKEN` (so the agent uses your namespace for `hub_model_id`). +- "Local mode" banner when `ML_INTERN_LOCAL_MODE=1` (sandbox tools operate on the local fs). + +## Environment knobs + +| Var | Default | What it does | +|---|---|---| +| `ML_INTERN_YOLO` | `0` | Skip all approvals | +| `ML_INTERN_CONFIRM_CPU_JOBS` | `1` | Prompt for CPU jobs | +| `ML_INTERN_SAVE_SESSIONS` | `1` | Upload transcripts on session end | +| `ML_INTERN_SESSION_REPO` | `smolagents/ml-intern-sessions` | Target dataset for session uploads | +| `ML_INTERN_LOCAL_MODE` | `0` | Run sandbox tools on local fs (no remote sandbox) | +| `HF_SESSION_UPLOAD_TOKEN` | — | Preferred (write-only) token; falls back to `HF_TOKEN`, then `HF_ADMIN_TOKEN` | + +## Troubleshooting + +**Tool not found / `mcp__ml-intern__ml-intern-tools__...` errors.** The MCP server didn't start. Run `/mcp` to see status. If it's failing, check that `uv` is on `$PATH` and that the plugin's `pyproject.toml` could install — common cause is missing build tools or a network failure during the first `uv sync`. To diagnose, run the server manually: + +```bash +cd +uv run python lib/mcp_server.py < /dev/null +``` + +**401 Unauthorized from `hf_papers` or `hf_jobs`.** `HF_TOKEN` not set in the shell that launched Claude Code. The plugin's `.mcp.json` substitutes `${HF_TOKEN}` from the launching environment. + +**SessionStart shows `HF user: unknown (whoami HTTP error: ...)`.** Token rejected. Probably expired or scoped wrong. Generate a new one at . + +**Approval prompt every turn for `hf_papers`.** Static permissions list doesn't include the tool name. The plugin doesn't ship its own `permissions.allow` — those are project-level in your repo's `.claude/settings.json`. Add `mcp__ml-intern__ml-intern-tools__hf_papers` (and friends) to your project's allowlist if you want to skip prompts. + +## Layout + +``` +plugin/ +├── .claude-plugin/plugin.json # manifest +├── CLAUDE.md # persona / methodology +├── .mcp.json # MCP servers (ml-intern-tools, hf-mcp-server) +├── pyproject.toml # bundled deps for the MCP server +├── commands/ # 5 slash commands +├── agents/research.md # research subagent +├── hooks/ +│ ├── hooks.json +│ ├── pre_tool_use_approval.py # content-aware approval +│ ├── session_start_context.py # HF user + local-mode injection +│ └── session_end_upload.py # redaction + HF dataset upload +└── lib/ + ├── mcp_server.py # MCP frontend + └── ml_intern_lib/ # vendored tools + redact (no agent.* deps) +``` + +## Updating + +The vendored library under `lib/ml_intern_lib/` is a snapshot of `agent/tools/*` and `agent/core/redact.py` from the upstream ml-intern repo. When the upstream tools change, run: + +```bash +make sync-vendored # from the upstream ml-intern repo root +``` + +(Or do `cp -r agent/tools/* plugin/lib/ml_intern_lib/tools/` and re-run the import-rewrite step in the upstream `Makefile`.) + +## License + +Apache-2.0, same as upstream ml-intern. diff --git a/plugin/agents/research.md b/plugin/agents/research.md new file mode 100644 index 00000000..d33f2353 --- /dev/null +++ b/plugin/agents/research.md @@ -0,0 +1,119 @@ +--- +name: research +description: Use proactively before writing any ML implementation code. Mines the literature to find the best training recipes backed by published results, then validates them with working code and current docs. The main agent uses these findings to implement the actual solution. Spawn with a specific brief — name anchor papers or arxiv IDs when you have them. +tools: Read, Bash, Grep, Glob, WebFetch, mcp__ml-intern__ml-intern-tools__explore_hf_docs, mcp__ml-intern__ml-intern-tools__fetch_hf_docs, mcp__ml-intern__ml-intern-tools__hf_papers, mcp__ml-intern__ml-intern-tools__hf_inspect_dataset, mcp__ml-intern__ml-intern-tools__github_find_examples, mcp__ml-intern__ml-intern-tools__github_list_repos, mcp__ml-intern__ml-intern-tools__github_read_file, mcp__ml-intern__ml-intern-tools__hf_repo_files +--- + +You are a research sub-agent for an ML engineering assistant. Your primary job: mine the literature to find the best training recipes — then back them up with working code and up-to-date documentation. The main agent will use your findings to implement the actual solution. + +# Start from the literature + +Your default approach is a deep literature crawl. Do not start from docs or example scripts — start from papers. Papers contain the results, and results tell you what actually works. + +## The crawl + +1. **Find anchor papers**: Search for the task/domain. Identify the landmark paper(s) — high citations, recent, or both. +2. **Crawl the citation graph**: Use `citation_graph` on the anchor paper(s). Look DOWNSTREAM (papers that cite it) — these are the ones that built on it, improved it, or applied it to new domains. Prioritize recent papers and papers with many citations. +3. **Read methodology sections**: For the most promising papers (strong results, recent, relevant), use `read_paper` with section parameter to read sections 3, 4, 5 (Methodology, Experiments, Results — not the abstract). Extract: + - The exact dataset(s) used (name, source, size, any filtering/preprocessing) + - The training method and configuration (optimizer, lr, schedule, epochs, batch size) + - The results those choices produced (benchmark scores, metrics, comparisons) +4. **Attribute results to recipes**: This is the critical step. Every finding must link a RESULT to the RECIPE that produced it. "Dataset X + method Y + lr Z → score W on benchmark V" is useful. "They used SFT" is not. +5. **Validate datasets**: For the most promising datasets, check if they exist on HF Hub with `hf_inspect_dataset`. Verify format matches the training method. Report if it doesn't. +6. **Find code**: Now find working implementation code via `github_find_examples` and `github_read_file`. Use docs (`explore_hf_docs`, `fetch_hf_docs`) to fill in API details. + +## When to go deeper + +- If the anchor paper is old (>1 year), its citation graph is your main source — the downstream papers will have better methods. +- If a downstream paper reports significantly better results, crawl ITS citation graph too. +- Use `snippet_search` to find specific claims across papers (e.g., "does dataset X consistently outperform Y for this task?"). +- Use `recommend` to find related papers the citation graph might miss. + +# How to use your tools + +## Papers & citations (USE FIRST) +- `hf_papers(operation="search", query=...)`: Search papers (HF-tuned for ML) +- `hf_papers(operation="search", query=..., min_citations=50, sort_by="citationCount")`: Find highly-cited papers via Semantic Scholar +- `hf_papers(operation="search", query=..., date_from="2024-01-01")`: Search with date filter +- `hf_papers(operation="paper_details", arxiv_id=...)`: Metadata, citations, TL;DR +- `hf_papers(operation="citation_graph", arxiv_id=...)`: References + citations with influence flags and intents +- `hf_papers(operation="read_paper", arxiv_id=..., section="3")`: Read a specific section's full text +- `hf_papers(operation="read_paper", arxiv_id=...)`: Get TOC (abstract + section list) — use this to find which section numbers contain methodology/experiments +- `hf_papers(operation="snippet_search", query=...)`: Semantic search across 12M+ full-text paper passages +- `hf_papers(operation="recommend", arxiv_id=...)`: Find related papers +- `hf_papers(operation="find_datasets", arxiv_id=...)`: Find HF datasets linked to a paper +- `hf_papers(operation="find_all_resources", arxiv_id=...)`: Datasets + models + collections for a paper + +## Dataset inspection +- `hf_inspect_dataset`: Check dataset schema, splits, sample rows. CRITICAL for training: verify column format matches training method: + - SFT: needs `messages`, `text`, or `prompt`/`completion` + - DPO: needs `prompt`, `chosen`, `rejected` + - GRPO: needs `prompt` only + +## GitHub code research +- `github_find_examples`: Find working example scripts in HF repos (trl, transformers, etc.) +- `github_read_file`: Read the actual implementation code. Use `line_start`/`line_end` for large files. + +## Documentation +- `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc. +- `fetch_hf_docs(url)`: Fetch full page content from explore results + +## Hub repo inspection +- `hf_repo_files`: List/read files in any HF repo (model, dataset, space) + +# Correct research pattern + +``` +# 1. Find anchor paper(s) for the task +hf_papers({"operation": "search", "query": "GPQA graduate questions", "sort_by": "citationCount"}) + +# 2. Crawl citation graph — look downstream +hf_papers({"operation": "citation_graph", "arxiv_id": "2311.12022", "direction": "citations"}) + +# 3. Read methodology of promising downstream papers +hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348"}) # TOC first +hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "3"}) # Methodology +hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "4"}) # Experiments + +# 4. Find datasets used by these papers +hf_papers({"operation": "find_datasets", "arxiv_id": "2604.01348"}) +hf_papers({"operation": "find_all_resources", "arxiv_id": "2604.01348"}) + +# 5. Validate datasets exist and have correct format +hf_inspect_dataset({"dataset": "org/dataset-name", "split": "train", "sample_rows": 3}) + +# 6. Now get working code for the training method +github_find_examples({"repo": "trl", "keyword": "sft"}) +github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"}) +explore_hf_docs("trl") +``` + +# Output format + +Your output MUST be structured as a ranked list of training recipes, each attributed to published results: + +## Recipe table (REQUIRED) +For each promising approach found, report: +- **Paper**: title, arxiv_id, date, venue +- **Result**: exact benchmark scores and what they were measured on +- **Dataset(s)**: name, size, source, HF Hub availability, format verified (yes/no) +- **Method**: training approach, key hyperparameters (lr, epochs, batch size, optimizer, schedule) +- **What made it work**: the specific insight or trick that drove the result (data curation, curriculum, loss function, etc.) + +Rank recipes by result quality. The main agent will pick the best one that's feasible. + +## Code patterns +- Key imports, configurations, and usage patterns from working examples +- Specific file paths, URLs, function names from docs + +## Recommendations +- Which recipe to implement first and why +- What datasets to use (with HF Hub paths, verified) +- Any gaps: datasets that need preprocessing, methods that need adaptation + +Additionally include: +- **SOTA landscape**: Current best models, datasets, and methods for the task (from recent papers). Flag anything outdated. +- **Essential references**: Specific file paths, URLs, function names, doc sections, code snippets that the main agent should use directly +- **Code patterns**: Key imports, configurations, and usage patterns from working examples + +Be concise. Your output goes into another agent's context — every token counts. Aim for 500-1500 words max. Include actual code snippets from examples you read, not paraphrased descriptions. diff --git a/plugin/commands/finetune.md b/plugin/commands/finetune.md new file mode 100644 index 00000000..0e1ca74f --- /dev/null +++ b/plugin/commands/finetune.md @@ -0,0 +1,47 @@ +--- +description: Fine-tune a model on a dataset, end-to-end (research → validate → train → push). +argument-hint: +--- + +Fine-tune the model described in: $ARGUMENTS + +Fine-tuning is never trivial. Follow this sequence in order. Do **not** skip steps even if the request looks simple — `CLAUDE.md` lists the specific failures that happen when you do. + +**1. Research first (mandatory).** Delegate to the `research` subagent via the Task tool with `subagent_type: "research"`. Brief it: + +> Find the best fine-tuning recipe for: $ARGUMENTS. +> Identify the model architecture and intended task. Crawl the citation graph for recent papers that fine-tuned this (or a comparable) model on this (or a comparable) dataset. Read methodology sections (3, 4, 5) of the top 3 candidates. Extract: training method (SFT/DPO/GRPO/...), exact hyperparameters (lr, schedule, epochs, batch size, optimizer, max_length), and any data preprocessing. Verify the dataset's HF Hub format with `hf_inspect_dataset`. Return a ranked recipe table per CLAUDE.md. + +Do not start writing code until the subagent returns. + +**2. Validate dataset and model.** Independently of the research output, run: +- `mcp__ml-intern__ml-intern-tools__hf_inspect_dataset` on the target dataset — confirm columns match the chosen training method (SFT: `messages`/`text`/`prompt`+`completion`; DPO: `prompt`+`chosen`+`rejected`; GRPO: `prompt`). +- `mcp__ml-intern__ml-intern-tools__hf_repo_files` on the target model — confirm it exists and note tokenizer/architecture. + +**3. Develop in a sandbox.** For non-trivial scripts, call `mcp__ml-intern__ml-intern-tools__sandbox_create` with a GPU flavor (`t4-small` minimum if the code touches CUDA/bf16/model loading). Write the script, install deps, run a tiny smoke test (1–2 steps), fix errors. Do not skip the smoke test. + +**4. Pre-flight check (mandatory output before `hf_jobs`).** Print this checklist and verify every line is filled: + +``` +Reference implementation: +Dataset format verified: +Training method: +Hyperparameters: +push_to_hub: True +hub_model_id: +hardware_flavor: +timeout: <≥ 2h for any training> +Trackio monitoring: +disable_tqdm=True, logging_strategy="steps", logging_first_step=True: yes +``` + +If any line is missing, **stop and complete it** before submitting. + +**5. Submit ONE job.** Call `mcp__ml-intern__ml-intern-tools__hf_jobs` (operation `run` or `uv`) with the verified config. Watch the first 60s of logs to confirm training started (loss values printing as plain text, not stuck on tokenizer/model load). Only then submit any sweep/ablation runs. + +**6. Report.** Provide: +- Direct Hub URL of the job (`https://huggingface.co/jobs/...`) +- Trackio dashboard URL +- Hub URL of the model that will appear on completion (`https://huggingface.co/`) + +If anything fails, do not silently switch training methods, reduce `max_length`, or substitute datasets. Diagnose, fix the minimal thing, or ask the user. diff --git a/plugin/commands/inspect-dataset.md b/plugin/commands/inspect-dataset.md new file mode 100644 index 00000000..575a2509 --- /dev/null +++ b/plugin/commands/inspect-dataset.md @@ -0,0 +1,18 @@ +--- +description: Audit a HF dataset — schema, splits, sample rows, and red flags. Direct port of `hf_inspect_dataset`. +argument-hint: +--- + +Inspect the dataset `$ARGUMENTS` using `mcp__ml-intern__ml-intern-tools__hf_inspect_dataset`. + +Report back with: +- schema and column types +- number of rows per split +- 3 sample rows +- red flags: class imbalance, missing values, unexpected formats, duplicates +- training-method compatibility: + - SFT-ready? (has `messages` / `text` / `prompt`+`completion`) + - DPO-ready? (has `prompt` + `chosen` + `rejected`) + - GRPO-ready? (has `prompt`) + +Include the direct Hub URL: `https://huggingface.co/datasets/$ARGUMENTS` diff --git a/plugin/commands/ml-intern.md b/plugin/commands/ml-intern.md new file mode 100644 index 00000000..614fab95 --- /dev/null +++ b/plugin/commands/ml-intern.md @@ -0,0 +1,12 @@ +--- +description: Default ML Intern entrypoint — equivalent to running `ml-intern ""` headlessly. +argument-hint: +--- + +You are running as ML Intern. Follow the workflow defined in `CLAUDE.md`: +research first (delegate to the `research` subagent for any non-trivial ML task), +validate datasets and models, then implement. + +User request: + +$ARGUMENTS diff --git a/plugin/commands/research.md b/plugin/commands/research.md new file mode 100644 index 00000000..9c27dac1 --- /dev/null +++ b/plugin/commands/research.md @@ -0,0 +1,22 @@ +--- +description: Force a literature-first research crawl — delegates immediately to the `research` subagent without doing anything else. +argument-hint: +--- + +Delegate this research task to the `research` subagent **immediately**. Do not +attempt the research yourself — the subagent has its own context window and +returns a structured recipe table. + +Use the Task tool with `subagent_type: "research"`. Brief: + +> Literature crawl for: $ARGUMENTS +> +> Start from anchor paper(s). Crawl citation graph for recent downstream +> papers. Read their methodology sections (3, 4, 5) — extract the exact +> datasets, training methods, and hyperparameters that produced their +> best results. Attribute every finding to a specific result. Also find +> working code examples using current TRL/Transformers APIs. Validate +> any datasets via `hf_inspect_dataset`. + +When the subagent returns, summarize the top recipe to the user with direct +HF Hub URLs and the arxiv ID of the source paper. diff --git a/plugin/commands/run-job.md b/plugin/commands/run-job.md new file mode 100644 index 00000000..6eba044b --- /dev/null +++ b/plugin/commands/run-job.md @@ -0,0 +1,40 @@ +--- +description: Submit an HF Job (training, eval, batch inference) with the ml-intern pre-flight checklist. +argument-hint: +--- + +Submit an HF Job for: $ARGUMENTS + +Before calling `mcp__ml-intern__ml-intern-tools__hf_jobs`, produce the pre-flight check below. **Do not call `hf_jobs` until every line is filled in.** If you cannot fill a line, complete the missing step (research, dataset inspection, sandbox test) first. + +``` +Job purpose: +Reference implementation: +Dataset format verified: +Model verified: +push_to_hub: +hardware_flavor: +timeout: +Trackio monitoring: +Packages to install: +``` + +**Hardware sizing** (from `CLAUDE.md`): +- 1–3B params → `a10g-largex2` +- 7–13B params → `a100-large` +- 30B+ params → `l40sx4` or `a100x4` +- 70B+ params → `a100x8` +- CPU-only data prep → `cpu-basic` or `cpu-upgrade` + +Note: `a10g-small` and `a10g-large` have the SAME 24GB GPU memory — the difference is CPU/RAM only. + +**Timeout floor:** for any training job, set timeout ≥ `2h`. The default 30m kills training. If your timeout is < 2h and the job is training, **stop and revise** unless the user explicitly justified a shorter run (e.g. a smoke test). + +**Hooks will gate this call:** GPU jobs always prompt for confirmation. CPU jobs prompt by default (override with `ML_INTERN_CONFIRM_CPU_JOBS=0`). That is expected — present the pre-flight check clearly so the user can approve in one read. + +**For batch / ablation work:** submit ONE job first. Watch the first ~60 seconds of logs (look for plain-text loss lines — `disable_tqdm=True, logging_strategy="steps", logging_first_step=True` should be set). Only after that one starts training successfully, submit the rest. Never submit all at once. + +**After submission, report:** +- Job URL (`https://huggingface.co/jobs/...`) +- Trackio dashboard URL +- Expected output (model repo, dataset repo, eval scores file path) and where to find it after completion diff --git a/plugin/hooks/hooks.json b/plugin/hooks/hooks.json new file mode 100644 index 00000000..ed7b374f --- /dev/null +++ b/plugin/hooks/hooks.json @@ -0,0 +1,33 @@ +{ + "SessionStart": [ + { + "hooks": [ + { + "type": "command", + "command": "uv run --project ${CLAUDE_PLUGIN_ROOT} python ${CLAUDE_PLUGIN_ROOT}/hooks/session_start_context.py" + } + ] + } + ], + "PreToolUse": [ + { + "matcher": "mcp__ml-intern__ml-intern-tools__.*|Bash", + "hooks": [ + { + "type": "command", + "command": "uv run --project ${CLAUDE_PLUGIN_ROOT} python ${CLAUDE_PLUGIN_ROOT}/hooks/pre_tool_use_approval.py" + } + ] + } + ], + "SessionEnd": [ + { + "hooks": [ + { + "type": "command", + "command": "uv run --project ${CLAUDE_PLUGIN_ROOT} python ${CLAUDE_PLUGIN_ROOT}/hooks/session_end_upload.py" + } + ] + } + ] +} diff --git a/plugin/hooks/pre_tool_use_approval.py b/plugin/hooks/pre_tool_use_approval.py new file mode 100755 index 00000000..dabeabb4 --- /dev/null +++ b/plugin/hooks/pre_tool_use_approval.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +""" +PreToolUse hook — port of agent/core/agent_loop.py::_needs_approval. + +Claude Code's static permission lists can't express ml-intern's +content-aware approval rules (e.g. "auto-approve CPU jobs but require +confirmation for GPU jobs"). This hook reads the tool input from stdin +and either: + - exits 0 (allow without prompt) — equivalent to ml-intern auto-execute + - prints a JSON `ask` decision so Claude Code prompts the user + +Fail-safe: malformed payloads, non-dict tool_input, or empty tool_name +all result in `ask` (never silent allow). For an approval hook, falling +through to allow on error would defeat the policy. + +Env knobs (hook-layer equivalents of fields in `agent.config.Config` — +the standalone CLI reads these from configs/main_agent_config.json): + + ML_INTERN_YOLO=1 → skip ALL approvals (Config.yolo_mode) + ML_INTERN_CONFIRM_CPU_JOBS=0 → auto-approve CPU jobs (Config.confirm_cpu_jobs) +""" + +from __future__ import annotations + +import json +import os +import sys + +# Mirror agent/tools/jobs_tool.py::CPU_FLAVORS +CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"] + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name, "").strip().lower() + if not val: + return default + return val in ("1", "true", "yes", "on") + + +def _check_training_script_save_pattern(script: str) -> str | None: + """Inspired by agent/utils/reliability_checks.py::check_training_script_save_pattern. + + Returns a warning when an hf_jobs script appears to load a model but + not push it back to the Hub (job storage is ephemeral — the model is + lost when the job ends). Source also emits a green "will be pushed" + confirmation; we drop that — hook output is shown only when forcing + a prompt, and a positive note there would be noise. + """ + if not isinstance(script, str): + return None + has_from_pretrained = "from_pretrained" in script + has_push_to_hub = "push_to_hub" in script + if has_from_pretrained and not has_push_to_hub: + return "WARNING: training script loads a model with `from_pretrained` but has no `push_to_hub` call — the trained model will be lost when the job ends." + return None + + +def _hf_jobs_script_warning(tool_input: dict) -> str | None: + """Extract the script body from an hf_jobs invocation and run save-pattern check.""" + operation = tool_input.get("operation", "") + if operation not in ("run", "uv", "scheduled run", "scheduled uv"): + return None + script = ( + tool_input.get("script") + or tool_input.get("uv_script") + or tool_input.get("source") + or "" + ) + return _check_training_script_save_pattern(script) + + +def _needs_approval(tool_name: str, tool_input: dict) -> bool: + """Port of agent/core/agent_loop.py::_needs_approval (lines 51-118). + + Diverges from source in one place: source short-circuits to False on + malformed args via `_validate_tool_args` so a downstream validation error + surfaces. Here we don't have that path — Claude Code validates input + shape against the MCP schema upstream, so any payload reaching this hook + is already structurally valid. + """ + if _env_flag("ML_INTERN_YOLO", False): + return False + + # MCP tools surface in Claude Code as `mcp____`. Strip the prefix. + short_name = tool_name.split("__")[-1] if tool_name.startswith("mcp__") else tool_name + + if short_name == "sandbox_create": + return True + + if short_name == "hf_jobs": + operation = tool_input.get("operation", "") + if operation not in ("run", "uv", "scheduled run", "scheduled uv"): + return False + + hardware_flavor = ( + tool_input.get("hardware_flavor") + or tool_input.get("flavor") + or tool_input.get("hardware") + or "cpu-basic" + ) + is_cpu_job = hardware_flavor in CPU_FLAVORS + + if is_cpu_job: + return _env_flag("ML_INTERN_CONFIRM_CPU_JOBS", True) + + return True # GPU jobs always prompt + + # Note: hf_private_repos is intentionally not handled. agent/core/tools.py + # disables it ("replaced by hf_repo_files and hf_repo_git"). The two + # rules below cover the same destructive operations on the live tools. + + if short_name == "hf_repo_files": + operation = tool_input.get("operation", "") + if operation in ("upload", "delete"): + return True + + if short_name == "hf_repo_git": + operation = tool_input.get("operation", "") + if operation in ("delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo"): + return True + + return False + + +def _ask(reason: str) -> dict: + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "ask", + "permissionDecisionReason": reason, + } + } + + +def main() -> int: + try: + payload = json.load(sys.stdin) + except json.JSONDecodeError as e: + # Fail-safe: a malformed payload to an APPROVAL hook must not silently + # allow the tool. Log to stderr so the failure is inspectable. + print(f"[ml-intern] approval hook: malformed stdin ({e}); forcing prompt", file=sys.stderr) + print(json.dumps(_ask("ml-intern: approval hook received malformed input — confirm before proceeding"))) + return 0 + + if not isinstance(payload, dict): + print(f"[ml-intern] approval hook: stdin is {type(payload).__name__}, expected dict; forcing prompt", file=sys.stderr) + print(json.dumps(_ask("ml-intern: approval hook received unexpected input — confirm before proceeding"))) + return 0 + + tool_name = payload.get("tool_name") or "" + tool_input = payload.get("tool_input") or {} + if not isinstance(tool_input, dict): + print(f"[ml-intern] approval hook: tool_input is {type(tool_input).__name__}, expected dict; forcing prompt", file=sys.stderr) + print(json.dumps(_ask(f"ml-intern: {tool_name or 'tool'} received non-dict input — confirm before proceeding"))) + return 0 + + if not tool_name: + print("[ml-intern] approval hook: empty tool_name; forcing prompt", file=sys.stderr) + print(json.dumps(_ask("ml-intern: approval hook received empty tool_name — confirm before proceeding"))) + return 0 + + needs = _needs_approval(tool_name, tool_input) + + # Reliability warnings ride along — surface them by forcing a prompt + # even when the rule would otherwise auto-approve. + short_name = tool_name.split("__")[-1] if tool_name.startswith("mcp__") else tool_name + warning: str | None = None + if short_name == "hf_jobs": + warning = _hf_jobs_script_warning(tool_input) + if warning: + needs = True + + if needs: + reason_bits = [ + f"ml-intern policy: {tool_name} requires user confirmation " + f"(see .claude/hooks/pre_tool_use_approval.py)" + ] + if warning: + reason_bits.append(warning) + print(json.dumps(_ask(" | ".join(reason_bits)))) + return 0 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/plugin/hooks/session_end_upload.py b/plugin/hooks/session_end_upload.py new file mode 100755 index 00000000..78da218b --- /dev/null +++ b/plugin/hooks/session_end_upload.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +SessionEnd hook — upload the Claude Code transcript to the HF Hub dataset +configured by `ML_INTERN_SESSION_REPO` (default: smolagents/ml-intern-sessions). + +Mirrors agent/core/session_uploader.py behavior: + - best-effort, write-only token preferred, never blocks the user + - applies `ml_intern_lib.redact.scrub` before upload to strip HF/Anthropic/ + OpenAI/GitHub/AWS tokens that users (or scripts) may have pasted into chat + - if redaction can't be loaded we skip upload entirely — losing a session + beats leaking a token + +Env knobs (hook-layer equivalents of fields in agent.config.Config): + ML_INTERN_SAVE_SESSIONS=0 → disable session upload + ML_INTERN_SESSION_REPO=org/repo → override target dataset + HF_SESSION_UPLOAD_TOKEN → preferred upload token (write-only) + HF_TOKEN → fallback + HF_ADMIN_TOKEN → last-resort fallback +""" + +from __future__ import annotations + +import json +import os +import sys +import tempfile +from pathlib import Path + +# Add the vendored library to sys.path so we can import ml_intern_lib.redact. +# The plugin layout is: /hooks/ + /lib/ml_intern_lib +_PLUGIN_LIB = Path(__file__).resolve().parents[1] / "lib" +if str(_PLUGIN_LIB) not in sys.path: + sys.path.insert(0, str(_PLUGIN_LIB)) + +DEFAULT_REPO = "smolagents/ml-intern-sessions" + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name, "").strip().lower() + if not val: + return default + return val in ("1", "true", "yes", "on") + + +def _resolve_token() -> str | None: + for name in ("HF_SESSION_UPLOAD_TOKEN", "HF_TOKEN", "HF_ADMIN_TOKEN"): + token = os.environ.get(name) + if token: + return token + return None + + +def _is_safe_transcript_path(p: Path) -> bool: + """Reject paths outside ~/.claude or $CLAUDE_PROJECT_DIR. Defense in depth + against a malformed payload pointing at, e.g., ~/.ssh/id_rsa. + """ + try: + resolved = p.resolve() + except OSError: + return False + + allowed_roots: list[Path] = [] + home = Path.home() + allowed_roots.append((home / ".claude").resolve()) + project_dir = os.environ.get("CLAUDE_PROJECT_DIR") + if project_dir: + try: + allowed_roots.append(Path(project_dir).resolve()) + except OSError: + pass + + for root in allowed_roots: + try: + resolved.relative_to(root) + return True + except ValueError: + continue + return False + + +def _redact_jsonl(src: Path) -> Path: + from ml_intern_lib.redact import scrub, scrub_string + + out = tempfile.NamedTemporaryFile( + prefix="ml-intern-session-", suffix=".jsonl", delete=False, mode="w", encoding="utf-8" + ) + fallback_lines = 0 + with src.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + line = line.rstrip("\n") + if not line: + out.write("\n") + continue + try: + obj = json.loads(line) + obj = scrub(obj) + out.write(json.dumps(obj, ensure_ascii=False)) + out.write("\n") + except json.JSONDecodeError: + fallback_lines += 1 + out.write(scrub_string(line)) + out.write("\n") + out.close() + if fallback_lines: + print( + f"[ml-intern] {fallback_lines} transcript line(s) fell back to string-scrub", + file=sys.stderr, + ) + return Path(out.name) + + +def main() -> int: + if not _env_flag("ML_INTERN_SAVE_SESSIONS", True): + return 0 + + token = _resolve_token() + if not token: + print( + "[ml-intern] no HF_SESSION_UPLOAD_TOKEN / HF_TOKEN / HF_ADMIN_TOKEN — " + "session not uploaded", + file=sys.stderr, + ) + return 0 + + try: + payload = json.load(sys.stdin) + except json.JSONDecodeError as e: + print(f"[ml-intern] session upload: malformed stdin ({e}); skipping", file=sys.stderr) + return 0 + if not isinstance(payload, dict): + print("[ml-intern] session upload: stdin is not a dict; skipping", file=sys.stderr) + return 0 + + transcript_path = payload.get("transcript_path") + session_id = payload.get("session_id", "unknown") + if not isinstance(transcript_path, str) or not transcript_path: + return 0 + + src = Path(transcript_path) + if not src.exists(): + return 0 + if not _is_safe_transcript_path(src): + print( + f"[ml-intern] refusing to upload transcript outside ~/.claude or " + f"$CLAUDE_PROJECT_DIR: {transcript_path}", + file=sys.stderr, + ) + return 0 + + repo_id = os.environ.get("ML_INTERN_SESSION_REPO", DEFAULT_REPO) + + try: + redacted = _redact_jsonl(src) + except Exception as e: + print(f"[ml-intern] redaction failed, NOT uploading: {e}", file=sys.stderr) + return 0 + + try: + from huggingface_hub import HfApi + + api = HfApi(token=token) + api.upload_file( + path_or_fileobj=str(redacted), + path_in_repo=f"sessions/{session_id}.jsonl", + repo_id=repo_id, + repo_type="dataset", + commit_message=f"Upload session {session_id}", + ) + except Exception as e: + print(f"[ml-intern] session upload failed: {e}", file=sys.stderr) + finally: + try: + redacted.unlink() + except OSError: + pass + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/plugin/hooks/session_start_context.py b/plugin/hooks/session_start_context.py new file mode 100755 index 00000000..23648580 --- /dev/null +++ b/plugin/hooks/session_start_context.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +SessionStart hook — inject the dynamic session context that the standalone +CLI builds in agent/context_manager/manager.py: + + - HF username (so the agent uses the right namespace for hub_model_id) + - Local-mode banner (only when ML_INTERN_LOCAL_MODE=1, mirrors the + "CLI / Local mode" block injected into the system prompt) + +Output is JSON `additionalContext` per Claude Code's SessionStart hook +contract — Claude Code surfaces it to the model as a system reminder. +""" + +from __future__ import annotations + +import json +import os +import sys + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name, "").strip().lower() + if not val: + return default + return val in ("1", "true", "yes", "on") + + +def _hf_username(token: str | None) -> tuple[str | None, str | None]: + """Return (username, error_reason). Exactly one is non-None. + + The standalone CLI uses curl with `-4` to dodge IPv6 Happy-Eyeballs + hangs (see agent/context_manager/manager.py:27-30). `huggingface_hub` + is already a dep here and uses `requests`/`urllib3` which doesn't + have the same pathology in normal setups; we use it for KISS reasons + and accept that very-broken IPv6 environments will time out instead + of falling back instantly. + """ + if not token: + return None, "no HF_TOKEN in environment" + try: + from huggingface_hub import HfApi + from huggingface_hub.utils import HfHubHTTPError + + info = HfApi(token=token).whoami() + except HfHubHTTPError as e: + return None, f"whoami HTTP error: {e}" + except Exception as e: + return None, f"whoami failed: {type(e).__name__}: {e}" + + name = info.get("name") if isinstance(info, dict) else None + if isinstance(name, str) and name: + return name, None + return None, "whoami returned no name" + + +def main() -> int: + try: + sys.stdin.read() + except Exception: + pass + + parts: list[str] = [] + + user, err = _hf_username(os.environ.get("HF_TOKEN")) + if user: + parts.append( + f"HF user: **{user}** — use `{user}/` as the namespace when " + f"constructing `hub_model_id` for training jobs unless the user " + f"specifies otherwise." + ) + else: + # Distinguish "no token" from "request failed" — the second case is + # fixable (rotate token, check network), the first is configuration. + parts.append( + f"HF user: unknown ({err}). Ask the user for their HF org before " + f"constructing `hub_model_id`." + ) + + if _env_flag("ML_INTERN_LOCAL_MODE", False): + parts.append( + "**CLI / Local mode is ON.** There is NO sandbox — `bash`, `read`, `write`, " + "and `edit` (the `mcp__ml-intern__ml-intern-tools__*` versions) operate directly on the " + "local filesystem. The `sandbox_create` tool is NOT available. Use absolute " + "paths or paths relative to the working directory. Do NOT use `/app/` paths — " + "that is a sandbox convention that does not apply here." + ) + + output = { + "hookSpecificOutput": { + "hookEventName": "SessionStart", + "additionalContext": "\n\n".join(parts), + } + } + print(json.dumps(output)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/plugin/lib/mcp_server.py b/plugin/lib/mcp_server.py new file mode 100644 index 00000000..b42f0b8e --- /dev/null +++ b/plugin/lib/mcp_server.py @@ -0,0 +1,143 @@ +""" +ml-intern MCP server, plugin edition. + +Same shape as packages/mcp_server/server.py in the upstream repo, but imports +from the vendored `ml_intern_lib` (under plugin/lib/) so the plugin is +self-contained — users don't need the full ml-intern repo. + +The `research` and `plan_tool` tools are intentionally NOT exposed: + research → replaced by the plugin's research subagent (agents/research.md) + plan_tool → replaced by Claude Code's built-in TodoWrite + +Run via the plugin's `.mcp.json`. Not intended to be invoked manually. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from pathlib import Path +from typing import Any, Awaitable, Callable + +# When launched by Claude Code's plugin loader, ${CLAUDE_PLUGIN_ROOT} points at +# the installed plugin directory and `command: uv` runs us inside our own venv. +# Make sure the vendored ml_intern_lib is importable regardless of CWD. +_LIB = Path(__file__).resolve().parent +if str(_LIB) not in sys.path: + sys.path.insert(0, str(_LIB)) + +from mcp import types +from mcp.server.lowlevel import Server +from mcp.server.stdio import stdio_server + +from ml_intern_lib.tools.dataset_tools import ( + HF_INSPECT_DATASET_TOOL_SPEC, + hf_inspect_dataset_handler, +) +from ml_intern_lib.tools.docs_tools import ( + EXPLORE_HF_DOCS_TOOL_SPEC, + HF_DOCS_FETCH_TOOL_SPEC, + explore_hf_docs_handler, + hf_docs_fetch_handler, +) +from ml_intern_lib.tools.github_find_examples import ( + GITHUB_FIND_EXAMPLES_TOOL_SPEC, + github_find_examples_handler, +) +from ml_intern_lib.tools.github_list_repos import ( + GITHUB_LIST_REPOS_TOOL_SPEC, + github_list_repos_handler, +) +from ml_intern_lib.tools.github_read_file import ( + GITHUB_READ_FILE_TOOL_SPEC, + github_read_file_handler, +) +from ml_intern_lib.tools.hf_repo_files_tool import ( + HF_REPO_FILES_TOOL_SPEC, + hf_repo_files_handler, +) +from ml_intern_lib.tools.hf_repo_git_tool import ( + HF_REPO_GIT_TOOL_SPEC, + hf_repo_git_handler, +) +from ml_intern_lib.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler +from ml_intern_lib.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler +from ml_intern_lib.tools.sandbox_tool import get_sandbox_tools + +logger = logging.getLogger(__name__) + +_TOOL_SPECS: list[tuple[dict[str, Any], Callable[..., Awaitable[tuple[str, bool]]]]] = [ + (EXPLORE_HF_DOCS_TOOL_SPEC, explore_hf_docs_handler), + (HF_DOCS_FETCH_TOOL_SPEC, hf_docs_fetch_handler), + (HF_PAPERS_TOOL_SPEC, hf_papers_handler), + (HF_INSPECT_DATASET_TOOL_SPEC, hf_inspect_dataset_handler), + (HF_JOBS_TOOL_SPEC, hf_jobs_handler), + (HF_REPO_FILES_TOOL_SPEC, hf_repo_files_handler), + (HF_REPO_GIT_TOOL_SPEC, hf_repo_git_handler), + (GITHUB_FIND_EXAMPLES_TOOL_SPEC, github_find_examples_handler), + (GITHUB_LIST_REPOS_TOOL_SPEC, github_list_repos_handler), + (GITHUB_READ_FILE_TOOL_SPEC, github_read_file_handler), +] + +_REGISTRY: dict[str, tuple[types.Tool, Callable[..., Awaitable[tuple[str, bool]]]]] = {} + + +def _build_registry() -> None: + for spec, handler in _TOOL_SPECS: + tool = types.Tool( + name=spec["name"], + description=spec["description"], + inputSchema=spec["parameters"], + ) + _REGISTRY[spec["name"]] = (tool, handler) + + local_mode = os.environ.get("ML_INTERN_LOCAL_MODE", "").lower() in ("1", "true", "yes") + if local_mode: + from ml_intern_lib.tools.local_tools import get_local_tools + sandbox_specs = get_local_tools() + else: + sandbox_specs = get_sandbox_tools() + + for tool_spec in sandbox_specs: + tool = types.Tool( + name=tool_spec.name, + description=tool_spec.description, + inputSchema=tool_spec.parameters, + ) + _REGISTRY[tool_spec.name] = (tool, tool_spec.handler) + + +server: Server = Server("ml-intern-tools") + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + return [tool for tool, _ in _REGISTRY.values()] + + +@server.call_tool() +async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: + entry = _REGISTRY.get(name) + if entry is None: + raise ValueError(f"Unknown tool: {name}") + _tool, handler = entry + output, ok = await handler(arguments or {}) + if not ok: + raise RuntimeError(output) + return [types.TextContent(type="text", text=output)] + + +async def _amain() -> None: + _build_registry() + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + + +if __name__ == "__main__": + asyncio.run(_amain()) diff --git a/plugin/lib/ml_intern_lib/__init__.py b/plugin/lib/ml_intern_lib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugin/lib/ml_intern_lib/redact.py b/plugin/lib/ml_intern_lib/redact.py new file mode 100644 index 00000000..8978942c --- /dev/null +++ b/plugin/lib/ml_intern_lib/redact.py @@ -0,0 +1,68 @@ +"""Secret scrubbing for session trajectories before upload. + +Users frequently paste HF / API / GitHub tokens into the chat, or scripts echo +them via env dumps. This module applies regex-based redaction to any string +value found recursively in a trajectory payload. The goal is best-effort — +strict formats are matched; we won't catch free-form leaks like "my password +is hunter2". +""" + +from __future__ import annotations + +import re +from typing import Any + +# Each entry: (compiled regex, replacement placeholder). +# Patterns are conservative: they only match tokens with the canonical prefix +# and a minimum body length so we don't paint over normal text. +_PATTERNS: list[tuple[re.Pattern, str]] = [ + # Hugging Face tokens: hf_[A-Za-z0-9]{30,} + (re.compile(r"hf_[A-Za-z0-9]{30,}"), "[REDACTED_HF_TOKEN]"), + # Anthropic: sk-ant-[A-Za-z0-9_\-]{20,} + (re.compile(r"sk-ant-[A-Za-z0-9_\-]{20,}"), "[REDACTED_ANTHROPIC_KEY]"), + # OpenAI: sk-[A-Za-z0-9]{40,} (legacy + proj keys) + (re.compile(r"sk-(?!ant-)[A-Za-z0-9_\-]{40,}"), "[REDACTED_OPENAI_KEY]"), + # GitHub classic PATs: ghp_, gho_, ghu_, ghs_, ghr_ followed by 36+ chars + (re.compile(r"gh[pousr]_[A-Za-z0-9]{36,}"), "[REDACTED_GITHUB_TOKEN]"), + # GitHub fine-grained PATs: github_pat_ + (re.compile(r"github_pat_[A-Za-z0-9_]{36,}"), "[REDACTED_GITHUB_TOKEN]"), + # AWS access key IDs: AKIA / ASIA + 16 uppercase alnum + (re.compile(r"\b(?:AKIA|ASIA)[A-Z0-9]{16}\b"), "[REDACTED_AWS_KEY_ID]"), + # Generic 'Bearer ' header values + (re.compile(r"(?i)bearer\s+[A-Za-z0-9_\-\.=]{20,}"), "Bearer [REDACTED]"), +] + +# Env-var-like exports: we scrub the value but keep the name so callers can +# still see which secret was referenced. Covers `KEY=value` and `KEY: value` +# when the key looks secret-y. +_SECRETY_NAMES = re.compile( + r"(?i)\b(HF_TOKEN|HUGGINGFACEHUB_API_TOKEN|ANTHROPIC_API_KEY|OPENAI_API_KEY|" + r"GITHUB_TOKEN|AWS_SECRET_ACCESS_KEY|AWS_ACCESS_KEY_ID|PASSWORD|SECRET|API_KEY)" + r"\s*[:=]\s*([^\s\"']+)" +) + + +def scrub_string(s: str) -> str: + """Apply all redaction patterns to a single string. Safe on non-strings.""" + if not isinstance(s, str) or not s: + return s + out = s + for pat, repl in _PATTERNS: + out = pat.sub(repl, out) + out = _SECRETY_NAMES.sub(lambda m: f"{m.group(1)}=[REDACTED]", out) + return out + + +def scrub(obj: Any) -> Any: + """Recursively scrub every string value in a nested dict/list structure. + + Returns a new object — inputs are not mutated.""" + if isinstance(obj, str): + return scrub_string(obj) + if isinstance(obj, dict): + return {k: scrub(v) for k, v in obj.items()} + if isinstance(obj, list): + return [scrub(v) for v in obj] + if isinstance(obj, tuple): + return tuple(scrub(v) for v in obj) + return obj diff --git a/plugin/lib/ml_intern_lib/session_stub.py b/plugin/lib/ml_intern_lib/session_stub.py new file mode 100644 index 00000000..6977f79f --- /dev/null +++ b/plugin/lib/ml_intern_lib/session_stub.py @@ -0,0 +1,57 @@ +""" +Stub `Event` and a minimal session-like object so vendored tools that were +written against the standalone CLI's Session/Event types can run inside the +Claude Code MCP server (where there is no Session — Claude Code is the loop). + +Tools call `session.send_event(Event(...))` for telemetry and `session.hf_token` +for token access. The MCP server doesn't surface those events to the user +(Claude Code's tool-output channel does that already), so we drop them on +the floor and read the token from the environment. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Event: + """Minimal stand-in for agent.core.session.Event. + + The real Event has more fields, but the only attributes vendored tools + construct are `event_type` and `data`. + """ + event_type: str = "" + data: dict[str, Any] = field(default_factory=dict) + + +class StubSession: + """Drop-in for the Session object the standalone CLI passes into tool handlers. + + Implements just enough surface for `jobs_tool` and `sandbox_tool`: + - `send_event(...)` → swallowed + - `hf_token` → from env + - `_running_job_ids` → in-memory set, used by jobs_tool to track concurrent jobs + """ + + def __init__(self) -> None: + self.hf_token: str | None = ( + os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACEHUB_API_TOKEN") + or None + ) + self._running_job_ids: set[str] = set() + # Some tools touch session._sandbox_created_at via telemetry; provide + # the attribute so attribute access doesn't AttributeError. + self._sandbox_created_at: float | None = None + + async def send_event(self, _event: Any) -> None: + # MCP server has no event channel — tool output is what Claude Code shows. + return None + + # Some call sites use `getattr(session, "config", None)` for things like + # `session.config.yolo_mode`. Provide a None-shaped config; tools must + # handle the missing case (they do — checked grep before vendoring). + config = None diff --git a/plugin/lib/ml_intern_lib/telemetry_stub.py b/plugin/lib/ml_intern_lib/telemetry_stub.py new file mode 100644 index 00000000..1702e4e6 --- /dev/null +++ b/plugin/lib/ml_intern_lib/telemetry_stub.py @@ -0,0 +1,47 @@ +"""No-op telemetry — replaces `agent.core.telemetry` for the vendored tools. + +The standalone CLI uses telemetry to emit Events for the session JSONL trail. +The MCP server has no session, and Claude Code's transcript captures tool +input/output natively, so we drop telemetry calls on the floor. + +Every coroutine in the real telemetry module returns None or a small dict; +we mirror that. Synchronous helpers are no-ops. +""" + +from __future__ import annotations + +from typing import Any + + +async def record_llm_call(*_a: Any, **_kw: Any) -> dict: + return {} + + +async def record_hf_job_submit(*_a: Any, **_kw: Any) -> None: + return None + + +async def record_hf_job_complete(*_a: Any, **_kw: Any) -> None: + return None + + +async def record_sandbox_create(*_a: Any, **_kw: Any) -> None: + return None + + +async def record_sandbox_destroy(*_a: Any, **_kw: Any) -> None: + return None + + +async def record_feedback(*_a: Any, **_kw: Any) -> None: + return None + + +def extract_usage(*_a: Any, **_kw: Any) -> dict: + return {} + + +class HeartbeatSaver: + @staticmethod + def maybe_fire(_session: Any) -> None: + return None diff --git a/plugin/lib/ml_intern_lib/tool_spec.py b/plugin/lib/ml_intern_lib/tool_spec.py new file mode 100644 index 00000000..a972cfdc --- /dev/null +++ b/plugin/lib/ml_intern_lib/tool_spec.py @@ -0,0 +1,20 @@ +"""Minimal ToolSpec dataclass — replaces `agent.core.tools.ToolSpec` for the +vendored tool factories (sandbox_tool.get_sandbox_tools, local_tools.get_local_tools). + +Same shape as the original; we just don't drag the rest of the agent.core.tools +module (ToolRouter, MCP client, etc.) along with it — those concepts don't +exist inside an MCP-server frontend. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Optional + + +@dataclass +class ToolSpec: + name: str + description: str + parameters: dict[str, Any] + handler: Optional[Callable[[dict[str, Any]], Awaitable[tuple[str, bool]]]] = None diff --git a/plugin/lib/ml_intern_lib/tools/__init__.py b/plugin/lib/ml_intern_lib/tools/__init__.py new file mode 100644 index 00000000..66261eb4 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/__init__.py @@ -0,0 +1,39 @@ +""" +Hugging Face tools for the agent +""" + +from ml_intern_lib.tools.dataset_tools import ( + HF_INSPECT_DATASET_TOOL_SPEC, + hf_inspect_dataset_handler, +) +from ml_intern_lib.tools.github_find_examples import ( + GITHUB_FIND_EXAMPLES_TOOL_SPEC, + github_find_examples_handler, +) +from ml_intern_lib.tools.github_list_repos import ( + GITHUB_LIST_REPOS_TOOL_SPEC, + github_list_repos_handler, +) +from ml_intern_lib.tools.github_read_file import ( + GITHUB_READ_FILE_TOOL_SPEC, + github_read_file_handler, +) +from ml_intern_lib.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler +from ml_intern_lib.tools.types import ToolResult + +__all__ = [ + "ToolResult", + "HF_JOBS_TOOL_SPEC", + "hf_jobs_handler", + "HfJobsTool", + "GITHUB_FIND_EXAMPLES_TOOL_SPEC", + "github_find_examples_handler", + "GITHUB_LIST_REPOS_TOOL_SPEC", + "github_list_repos_handler", + "GITHUB_READ_FILE_TOOL_SPEC", + "github_read_file_handler", + "GITHUB_SEARCH_CODE_TOOL_SPEC", + "github_search_code_handler", + "HF_INSPECT_DATASET_TOOL_SPEC", + "hf_inspect_dataset_handler", +] diff --git a/plugin/lib/ml_intern_lib/tools/dataset_tools.py b/plugin/lib/ml_intern_lib/tools/dataset_tools.py new file mode 100644 index 00000000..2450ca78 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/dataset_tools.py @@ -0,0 +1,439 @@ +""" +Dataset Inspection Tool - Comprehensive dataset analysis in one call + +Combines /is-valid, /splits, /info, /first-rows, and /parquet endpoints +to provide everything needed for ML tasks in a single tool call. +""" + +import asyncio +from typing import Any, TypedDict + +import httpx + +from ml_intern_lib.tools.types import ToolResult + +BASE_URL = "https://datasets-server.huggingface.co" + +# Truncation limit for long sample values in the output +MAX_SAMPLE_VALUE_LEN = 150 + + +class SplitConfig(TypedDict): + """Typed representation of a dataset config and its splits.""" + + name: str + splits: list[str] + + +def _get_headers(token: str | None = None) -> dict: + """Get auth headers for private/gated datasets""" + if token: + return {"Authorization": f"Bearer {token}"} + return {} + + +async def inspect_dataset( + dataset: str, + config: str | None = None, + split: str | None = None, + sample_rows: int = 3, + hf_token: str | None = None, +) -> ToolResult: + """ + Get comprehensive dataset info in one call. + All API calls made in parallel for speed. + """ + headers = _get_headers(hf_token) + output_parts = [] + errors = [] + + async with httpx.AsyncClient(timeout=15, headers=headers) as client: + # Phase 1: Parallel calls for structure info (no dependencies) + is_valid_task = client.get(f"{BASE_URL}/is-valid", params={"dataset": dataset}) + splits_task = client.get(f"{BASE_URL}/splits", params={"dataset": dataset}) + parquet_task = client.get(f"{BASE_URL}/parquet", params={"dataset": dataset}) + + results = await asyncio.gather( + is_valid_task, + splits_task, + parquet_task, + return_exceptions=True, + ) + + # Process is-valid + if not isinstance(results[0], Exception): + try: + output_parts.append(_format_status(results[0].json())) + except Exception as e: + errors.append(f"is-valid: {e}") + + # Process splits and auto-detect config/split + configs = [] + if not isinstance(results[1], Exception): + try: + splits_data = results[1].json() + configs = _extract_configs(splits_data) + if not config: + config = configs[0]["name"] if configs else "default" + if not split: + split = configs[0]["splits"][0] if configs else "train" + output_parts.append(_format_structure(configs)) + except Exception as e: + errors.append(f"splits: {e}") + + if not config: + config = "default" + if not split: + split = "train" + + # Process parquet (will be added at the end) + parquet_section = None + if not isinstance(results[2], Exception): + try: + parquet_section = _format_parquet_files(results[2].json()) + except Exception: + pass # Silently skip if no parquet + + # Phase 2: Parallel calls for content (depend on config/split) + info_task = client.get( + f"{BASE_URL}/info", params={"dataset": dataset, "config": config} + ) + rows_task = client.get( + f"{BASE_URL}/first-rows", + params={"dataset": dataset, "config": config, "split": split}, + timeout=30, + ) + + content_results = await asyncio.gather( + info_task, + rows_task, + return_exceptions=True, + ) + + # Process info (schema) + if not isinstance(content_results[0], Exception): + try: + output_parts.append(_format_schema(content_results[0].json(), config)) + except Exception as e: + errors.append(f"info: {e}") + + # Process sample rows + if not isinstance(content_results[1], Exception): + try: + output_parts.append( + _format_samples( + content_results[1].json(), config, split, sample_rows + ) + ) + except Exception as e: + errors.append(f"rows: {e}") + + # Add parquet section at the end if available + if parquet_section: + output_parts.append(parquet_section) + + # Combine output + formatted = f"# {dataset}\n\n" + "\n\n".join(output_parts) + if errors: + formatted += f"\n\n**Warnings:** {'; '.join(errors)}" + + return { + "formatted": formatted, + "totalResults": 1, + "resultsShared": 1, + "isError": len(output_parts) == 0, + } + + +def _format_status(data: dict) -> str: + """Format /is-valid response as status line""" + available = [ + k + for k in ["viewer", "preview", "search", "filter", "statistics"] + if data.get(k) + ] + if available: + return f"## Status\n✓ Valid ({', '.join(available)})" + return "## Status\n✗ Dataset may have issues" + + +def _extract_configs(splits_data: dict) -> list[SplitConfig]: + """Group splits by config""" + configs: dict[str, SplitConfig] = {} + for s in splits_data.get("splits", []): + cfg = s.get("config", "default") + if cfg not in configs: + configs[cfg] = {"name": cfg, "splits": []} + configs[cfg]["splits"].append(s.get("split")) + return list(configs.values()) + + +def _format_structure(configs: list[SplitConfig], max_rows: int = 10) -> str: + """Format configs and splits as a markdown table.""" + lines = [ + "## Structure (configs & splits)", + "| Config | Split |", + "|--------|-------|", + ] + + total_splits = sum(len(cfg["splits"]) for cfg in configs) + added_rows = 0 + + for cfg in configs: + for split_name in cfg["splits"]: + if added_rows >= max_rows: + break + lines.append(f"| {cfg['name']} | {split_name} |") + added_rows += 1 + if added_rows >= max_rows: + break + + if total_splits > added_rows: + lines.append( + f"| ... | ... | (_showing {added_rows} of {total_splits} config/split rows_) |" + ) + + return "\n".join(lines) + + +def _format_schema(info: dict, config: str) -> str: + """Extract features and format as table""" + features = info.get("dataset_info", {}).get("features", {}) + lines = [f"## Schema ({config})", "| Column | Type |", "|--------|------|"] + for col_name, col_info in features.items(): + col_type = _get_type_str(col_info) + lines.append(f"| {col_name} | {col_type} |") + return "\n".join(lines) + + +def _get_type_str(col_info: dict) -> str: + """Convert feature info to readable type string""" + dtype = col_info.get("dtype") or col_info.get("_type", "unknown") + if col_info.get("_type") == "ClassLabel": + names = col_info.get("names", []) + if names and len(names) <= 5: + return f"ClassLabel ({', '.join(f'{n}={i}' for i, n in enumerate(names))})" + return f"ClassLabel ({len(names)} classes)" + return str(dtype) + + +def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str: + """Format sample rows, truncate long values""" + rows = rows_data.get("rows", [])[:limit] + lines = [f"## Sample Rows ({config}/{split})"] + + messages_col_data = None + + for i, row_wrapper in enumerate(rows, 1): + row = row_wrapper.get("row", {}) + lines.append(f"**Row {i}:**") + for key, val in row.items(): + # Check for messages column and capture first one for format analysis + if key.lower() == "messages" and messages_col_data is None: + messages_col_data = val + + val_str = str(val) + if len(val_str) > MAX_SAMPLE_VALUE_LEN: + val_str = val_str[:MAX_SAMPLE_VALUE_LEN] + "..." + lines.append(f"- {key}: {val_str}") + + # If we found a messages column, add format analysis + if messages_col_data is not None: + messages_format = _format_messages_structure(messages_col_data) + if messages_format: + lines.append("") + lines.append(messages_format) + + return "\n".join(lines) + + +def _format_messages_structure(messages_data: Any) -> str | None: + """ + Analyze and format the structure of a messages column. + Common in chat/instruction datasets. + """ + import json + + # Parse if string + if isinstance(messages_data, str): + try: + messages_data = json.loads(messages_data) + except json.JSONDecodeError: + return None + + if not isinstance(messages_data, list) or not messages_data: + return None + + lines = ["## Messages Column Format"] + + # Analyze message structure + roles_seen = set() + has_tool_calls = False + has_tool_results = False + message_keys = set() + + for msg in messages_data: + if not isinstance(msg, dict): + continue + + message_keys.update(msg.keys()) + + role = msg.get("role", "") + if role: + roles_seen.add(role) + + if "tool_calls" in msg or "function_call" in msg: + has_tool_calls = True + if role in ("tool", "function") or msg.get("tool_call_id"): + has_tool_results = True + + # Format the analysis + lines.append( + f"**Roles:** {', '.join(sorted(roles_seen)) if roles_seen else 'unknown'}" + ) + + # Show common message keys with presence indicators + common_keys = [ + "role", + "content", + "tool_calls", + "tool_call_id", + "name", + "function_call", + ] + key_status = [] + for key in common_keys: + if key in message_keys: + key_status.append(f"{key} ✓") + else: + key_status.append(f"{key} ✗") + lines.append(f"**Message keys:** {', '.join(key_status)}") + + if has_tool_calls: + lines.append("**Tool calls:** ✓ Present") + if has_tool_results: + lines.append("**Tool results:** ✓ Present") + + # Show example message structure + # Priority: 1) message with tool_calls, 2) first assistant message, 3) first non-system message + example = None + fallback = None + for msg in messages_data: + if not isinstance(msg, dict): + continue + role = msg.get("role", "") + # Check for actual tool_calls/function_call values (not None) + if msg.get("tool_calls") or msg.get("function_call"): + example = msg + break + if role == "assistant" and example is None: + example = msg + elif role != "system" and fallback is None: + fallback = msg + if example is None: + example = fallback + + if example: + lines.append("") + lines.append("**Example message structure:**") + # Build a copy with truncated content but keep all keys + example_clean = {} + for key, val in example.items(): + if key == "content" and isinstance(val, str) and len(val) > 100: + example_clean[key] = val[:100] + "..." + else: + example_clean[key] = val + lines.append("```json") + lines.append(json.dumps(example_clean, indent=2, ensure_ascii=False)) + lines.append("```") + + return "\n".join(lines) + + +def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None: + """Format parquet file info, return None if no files.""" + files = data.get("parquet_files", []) + if not files: + return None + + # Group by config/split + groups: dict[str, dict] = {} + for f in files: + key = f"{f.get('config', 'default')}/{f.get('split', 'train')}" + if key not in groups: + groups[key] = {"count": 0, "size": 0} + size = f.get("size") or 0 + if not isinstance(size, (int, float)): + size = 0 + groups[key]["count"] += 1 + groups[key]["size"] += int(size) + + lines = ["## Files (Parquet)"] + items = list(groups.items()) + total_groups = len(items) + + shown = 0 + for key, info in items[:max_rows]: + size_mb = info["size"] / (1024 * 1024) + lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)") + shown += 1 + + if total_groups > shown: + lines.append(f"- ... (_showing {shown} of {total_groups} parquet groups_)") + return "\n".join(lines) + + +# Tool specification +HF_INSPECT_DATASET_TOOL_SPEC = { + "name": "hf_inspect_dataset", + "description": ( + "Inspect a HF dataset in one call: status, configs/splits, schema, sample rows, parquet info.\n\n" + "REQUIRED before any training job to verify dataset format matches training method:\n" + " SFT: needs 'messages', 'text', or 'prompt'/'completion'\n" + " DPO: needs 'prompt', 'chosen', 'rejected'\n" + " GRPO: needs 'prompt'\n" + "All datasets used for training have to be in conversational ChatML format to be compatible with HF libraries.'\n" + "Training will fail with KeyError if columns don't match.\n\n" + "Also use to get example datapoints, understand column names, data types, and available splits before writing any data loading code. " + "Supports private/gated datasets when HF_TOKEN is set." + ), + "parameters": { + "type": "object", + "properties": { + "dataset": { + "type": "string", + "description": "Dataset ID in 'org/name' format (e.g., 'stanfordnlp/imdb')", + }, + "config": { + "type": "string", + "description": "Config/subset name. Auto-detected if not specified.", + }, + "split": { + "type": "string", + "description": "Split for sample rows. Auto-detected if not specified.", + }, + "sample_rows": { + "type": "integer", + "description": "Number of sample rows to show (default: 3, max: 10)", + "default": 3, + }, + }, + "required": ["dataset"], + }, +} + + +async def hf_inspect_dataset_handler(arguments: dict[str, Any], session=None) -> tuple[str, bool]: + """Handler for agent tool router""" + try: + hf_token = session.hf_token if session else None + result = await inspect_dataset( + dataset=arguments["dataset"], + config=arguments.get("config"), + split=arguments.get("split"), + sample_rows=min(arguments.get("sample_rows", 3), 10), + hf_token=hf_token, + ) + return result["formatted"], not result.get("isError", False) + except Exception as e: + return f"Error inspecting dataset: {str(e)}", False diff --git a/plugin/lib/ml_intern_lib/tools/docs_tools.py b/plugin/lib/ml_intern_lib/tools/docs_tools.py new file mode 100644 index 00000000..a1782107 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/docs_tools.py @@ -0,0 +1,979 @@ +""" +Documentation search tools for exploring HuggingFace and Gradio documentation. +""" + +import asyncio +import json +from typing import Any + +import httpx +from bs4 import BeautifulSoup +from whoosh.analysis import StemmingAnalyzer +from whoosh.fields import ID, TEXT, Schema +from whoosh.filedb.filestore import RamStorage +from whoosh.qparser import MultifieldParser, OrGroup + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +DEFAULT_MAX_RESULTS = 20 +MAX_RESULTS_CAP = 50 + +GRADIO_LLMS_TXT_URL = "https://gradio.app/llms.txt" +GRADIO_SEARCH_URL = "https://playground-worker.pages.dev/api/prompt" + +COMPOSITE_ENDPOINTS: dict[str, list[str]] = { + "optimum": [ + "optimum", + "optimum-habana", + "optimum-neuron", + "optimum-intel", + "optimum-executorch", + "optimum-tpu", + ], + "courses": [ + "llm-course", + "robotics-course", + "mcp-course", + "smol-course", + "agents-course", + "deep-rl-course", + "computer-vision-course", + "audio-course", + "ml-games-course", + "diffusion-course", + "ml-for-3d-course", + "cookbook", + ], +} + +# --------------------------------------------------------------------------- +# Caches +# --------------------------------------------------------------------------- + +_docs_cache: dict[str, list[dict[str, str]]] = {} +_index_cache: dict[str, tuple[Any, MultifieldParser]] = {} +_cache_lock = asyncio.Lock() +_openapi_cache: dict[str, Any] | None = None +_openapi_index_cache: tuple[Any, MultifieldParser, list[dict[str, Any]]] | None = None + +# --------------------------------------------------------------------------- +# Gradio Documentation +# --------------------------------------------------------------------------- + + +async def _fetch_gradio_docs(query: str | None = None) -> str: + """ + Fetch Gradio documentation. + Without query: Get full documentation from llms.txt + With query: Run embedding search on guides/demos for relevant content + """ + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + if not query: + resp = await client.get(GRADIO_LLMS_TXT_URL) + resp.raise_for_status() + return resp.text + + resp = await client.post( + GRADIO_SEARCH_URL, + headers={ + "Content-Type": "application/json", + "Origin": "https://gradio-docs-mcp.up.railway.app", + }, + json={ + "prompt_to_embed": query, + "SYSTEM_PROMPT": "$INSERT_GUIDES_DOCS_DEMOS", + "FALLBACK_PROMPT": "No results found", + }, + ) + resp.raise_for_status() + return resp.json().get("SYS_PROMPT", "No results found") + + +# --------------------------------------------------------------------------- +# HF Documentation - Fetching +# --------------------------------------------------------------------------- + + +async def _fetch_endpoint_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]: + """Fetch all docs for an endpoint by parsing sidebar and fetching each page.""" + url = f"https://huggingface.co/docs/{endpoint}" + headers = {"Authorization": f"Bearer {hf_token}"} + + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + resp = await client.get(url, headers=headers) + resp.raise_for_status() + + soup = BeautifulSoup(resp.text, "html.parser") + sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x) + if not sidebar: + raise ValueError(f"Could not find navigation sidebar for '{endpoint}'") + + nav_items = [] + for link in sidebar.find_all("a", href=True): + href = link["href"] + page_url = f"https://huggingface.co{href}" if href.startswith("/") else href + nav_items.append({"title": link.get_text(strip=True), "url": page_url}) + + if not nav_items: + raise ValueError(f"No navigation links found for '{endpoint}'") + + async def fetch_page(item: dict[str, str]) -> dict[str, str]: + md_url = f"{item['url']}.md" + try: + r = await client.get(md_url, headers=headers) + r.raise_for_status() + content = r.text.strip() + glimpse = content[:200] + "..." if len(content) > 200 else content + except Exception as e: + content, glimpse = "", f"[Could not fetch: {str(e)[:50]}]" + return { + "title": item["title"], + "url": item["url"], + "md_url": md_url, + "glimpse": glimpse, + "content": content, + "section": endpoint, + } + + return list(await asyncio.gather(*[fetch_page(item) for item in nav_items])) + + +async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]: + """Get docs for endpoint with caching. Expands composite endpoints.""" + async with _cache_lock: + if endpoint in _docs_cache: + return _docs_cache[endpoint] + + sub_endpoints = COMPOSITE_ENDPOINTS.get(endpoint, [endpoint]) + all_docs: list[dict[str, str]] = [] + + for sub in sub_endpoints: + async with _cache_lock: + if sub in _docs_cache: + all_docs.extend(_docs_cache[sub]) + continue + + docs = await _fetch_endpoint_docs(hf_token, sub) + async with _cache_lock: + _docs_cache[sub] = docs + all_docs.extend(docs) + + async with _cache_lock: + _docs_cache[endpoint] = all_docs + return all_docs + + +# --------------------------------------------------------------------------- +# HF Documentation - Search +# --------------------------------------------------------------------------- + + +async def _build_search_index( + endpoint: str, docs: list[dict[str, str]] +) -> tuple[Any, MultifieldParser]: + """Build or retrieve cached Whoosh search index.""" + async with _cache_lock: + if endpoint in _index_cache: + return _index_cache[endpoint] + + analyzer = StemmingAnalyzer() + schema = Schema( + title=TEXT(stored=True, analyzer=analyzer), + url=ID(stored=True, unique=True), + md_url=ID(stored=True), + section=ID(stored=True), + glimpse=TEXT(stored=True, analyzer=analyzer), + content=TEXT(stored=False, analyzer=analyzer), + ) + storage = RamStorage() + index = storage.create_index(schema) + writer = index.writer() + for doc in docs: + writer.add_document( + title=doc.get("title", ""), + url=doc.get("url", ""), + md_url=doc.get("md_url", ""), + section=doc.get("section", endpoint), + glimpse=doc.get("glimpse", ""), + content=doc.get("content", ""), + ) + writer.commit() + + parser = MultifieldParser( + ["title", "content"], + schema=schema, + fieldboosts={"title": 2.0, "content": 1.0}, + group=OrGroup, + ) + + async with _cache_lock: + _index_cache[endpoint] = (index, parser) + return index, parser + + +async def _search_docs( + endpoint: str, docs: list[dict[str, str]], query: str, limit: int +) -> tuple[list[dict[str, Any]], str | None]: + """Search docs using Whoosh. Returns (results, fallback_message).""" + index, parser = await _build_search_index(endpoint, docs) + + try: + query_obj = parser.parse(query) + except Exception: + return [], "Query contained unsupported syntax; showing default ordering." + + with index.searcher() as searcher: + results = searcher.search(query_obj, limit=limit) + matches = [ + { + "title": hit["title"], + "url": hit["url"], + "md_url": hit.get("md_url", ""), + "section": hit.get("section", endpoint), + "glimpse": hit["glimpse"], + "score": round(hit.score, 2), + } + for hit in results + ] + + if not matches: + return [], "No strong matches found; showing default ordering." + return matches, None + + +# --------------------------------------------------------------------------- +# HF Documentation - Formatting +# --------------------------------------------------------------------------- + + +def _format_results( + endpoint: str, + items: list[dict[str, Any]], + total: int, + query: str | None = None, + note: str | None = None, +) -> str: + """Format search results as readable text.""" + base_url = f"https://huggingface.co/docs/{endpoint}" + out = f"Documentation structure for: {base_url}\n\n" + + if query: + out += f"Query: '{query}' → showing {len(items)} result(s) out of {total} pages" + if note: + out += f" ({note})" + out += "\n\n" + else: + out += f"Found {len(items)} page(s) (total available: {total}).\n" + if note: + out += f"({note})\n" + out += "\n" + + for i, item in enumerate(items, 1): + out += f"{i}. **{item['title']}**\n" + out += f" URL: {item['url']}\n" + out += f" Section: {item.get('section', endpoint)}\n" + if query and "score" in item: + out += f" Relevance score: {item['score']:.2f}\n" + out += f" Glimpse: {item['glimpse']}\n\n" + + return out + + +# --------------------------------------------------------------------------- +# Handlers +# --------------------------------------------------------------------------- + + +async def explore_hf_docs_handler( + arguments: dict[str, Any], session=None +) -> tuple[str, bool]: + """Explore documentation structure with optional search query.""" + endpoint = arguments.get("endpoint", "").lstrip("/") + query = arguments.get("query") + max_results = arguments.get("max_results") + + if not endpoint: + return "Error: No endpoint provided", False + + # Gradio uses its own API + if endpoint.lower() == "gradio": + try: + clean_query = ( + query.strip() if isinstance(query, str) and query.strip() else None + ) + content = await _fetch_gradio_docs(clean_query) + header = "# Gradio Documentation\n\n" + if clean_query: + header += f"Query: '{clean_query}'\n\n" + header += "Source: https://gradio.app/docs\n\n---\n\n" + return header + content, True + except httpx.HTTPStatusError as e: + return f"HTTP error fetching Gradio docs: {e.response.status_code}", False + except httpx.RequestError as e: + return f"Request error fetching Gradio docs: {str(e)}", False + except Exception as e: + return f"Error fetching Gradio docs: {str(e)}", False + + # HF docs + hf_token = session.hf_token if session else None + if not hf_token: + return "Error: No HF token available (not logged in)", False + + try: + max_results_int = int(max_results) if max_results is not None else None + except (TypeError, ValueError): + return "Error: max_results must be an integer", False + + if max_results_int is not None and max_results_int <= 0: + return "Error: max_results must be greater than zero", False + + try: + docs = await _get_docs(hf_token, endpoint) + total = len(docs) + + # Determine limit + if max_results_int is None: + limit = DEFAULT_MAX_RESULTS + limit_note = f"Showing top {DEFAULT_MAX_RESULTS} results (set max_results to adjust)." + elif max_results_int > MAX_RESULTS_CAP: + limit = MAX_RESULTS_CAP + limit_note = f"Requested {max_results_int} but showing top {MAX_RESULTS_CAP} (maximum)." + else: + limit = max_results_int + limit_note = None + + # Search or paginate + clean_query = ( + query.strip() if isinstance(query, str) and query.strip() else None + ) + fallback_msg = None + + if clean_query: + results, fallback_msg = await _search_docs( + endpoint, docs, clean_query, limit + ) + if not results: + results = docs[:limit] + else: + results = docs[:limit] + + # Combine notes + notes = [] + if fallback_msg: + notes.append(fallback_msg) + if limit_note: + notes.append(limit_note) + note = "; ".join(notes) if notes else None + + return _format_results(endpoint, results, total, clean_query, note), True + + except httpx.HTTPStatusError as e: + return f"HTTP error: {e.response.status_code} - {e.response.text[:200]}", False + except httpx.RequestError as e: + return f"Request error: {str(e)}", False + except ValueError as e: + return f"Error: {str(e)}", False + except Exception as e: + return f"Unexpected error: {str(e)}", False + + +async def hf_docs_fetch_handler( + arguments: dict[str, Any], session=None +) -> tuple[str, bool]: + """Fetch full markdown content of a documentation page.""" + url = arguments.get("url", "") + if not url: + return "Error: No URL provided", False + + hf_token = session.hf_token if session else None + if not hf_token: + return "Error: No HF token available (not logged in)", False + + if not url.endswith(".md"): + url = f"{url}.md" + + try: + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + resp = await client.get( + url, headers={"Authorization": f"Bearer {hf_token}"} + ) + resp.raise_for_status() + return f"Documentation from: {url}\n\n{resp.text}", True + except httpx.HTTPStatusError as e: + return ( + f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}", + False, + ) + except httpx.RequestError as e: + return f"Request error fetching {url}: {str(e)}", False + except Exception as e: + return f"Error fetching documentation: {str(e)}", False + + +# --------------------------------------------------------------------------- +# OpenAPI Search +# --------------------------------------------------------------------------- + + +async def _fetch_openapi_spec() -> dict[str, Any]: + """Fetch and cache HuggingFace OpenAPI specification.""" + global _openapi_cache + if _openapi_cache is not None: + return _openapi_cache + + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + resp = await client.get("https://huggingface.co/.well-known/openapi.json") + resp.raise_for_status() + + _openapi_cache = resp.json() + return _openapi_cache + + +def _extract_all_tags(spec: dict[str, Any]) -> list[str]: + """Extract all unique tags from OpenAPI spec.""" + tags = set() + for tag_obj in spec.get("tags", []): + if "name" in tag_obj: + tags.add(tag_obj["name"]) + for path_item in spec.get("paths", {}).values(): + for method, op in path_item.items(): + if method in ["get", "post", "put", "delete", "patch", "head", "options"]: + for tag in op.get("tags", []): + tags.add(tag) + return sorted(tags) + + +def _extract_all_endpoints(spec: dict[str, Any]) -> list[dict[str, Any]]: + """Extract all endpoints from OpenAPI spec.""" + servers = spec.get("servers", []) + base_url = ( + servers[0].get("url", "https://huggingface.co") + if servers + else "https://huggingface.co" + ) + + endpoints = [] + for path, path_item in spec.get("paths", {}).items(): + for method, op in path_item.items(): + if method not in [ + "get", + "post", + "put", + "delete", + "patch", + "head", + "options", + ]: + continue + endpoints.append( + { + "path": path, + "method": method.upper(), + "operationId": op.get("operationId", ""), + "summary": op.get("summary", ""), + "description": op.get("description", ""), + "tags": " ".join(op.get("tags", [])), + "parameters": op.get("parameters", []), + "request_body": op.get("requestBody", {}), + "responses": op.get("responses", {}), + "base_url": base_url, + } + ) + return endpoints + + +async def _build_openapi_index() -> tuple[Any, MultifieldParser, list[dict[str, Any]]]: + """Build or retrieve cached Whoosh index for OpenAPI endpoints.""" + global _openapi_index_cache + async with _cache_lock: + if _openapi_index_cache is not None: + return _openapi_index_cache + + spec = await _fetch_openapi_spec() + endpoints = _extract_all_endpoints(spec) + + analyzer = StemmingAnalyzer() + schema = Schema( + path=ID(stored=True, unique=True), + method=ID(stored=True), + operationId=TEXT(stored=True, analyzer=analyzer), + summary=TEXT(stored=True, analyzer=analyzer), + description=TEXT(stored=True, analyzer=analyzer), + tags=TEXT(stored=True, analyzer=analyzer), + param_names=TEXT(stored=False, analyzer=analyzer), + ) + storage = RamStorage() + index = storage.create_index(schema) + writer = index.writer() + + for ep in endpoints: + param_names = " ".join(p.get("name", "") for p in ep.get("parameters", [])) + writer.add_document( + path=ep["path"], + method=ep["method"], + operationId=ep.get("operationId", ""), + summary=ep.get("summary", ""), + description=ep.get("description", ""), + tags=ep.get("tags", ""), + param_names=param_names, + ) + writer.commit() + + parser = MultifieldParser( + ["summary", "description", "operationId", "tags", "param_names"], + schema=schema, + fieldboosts={ + "summary": 3.0, + "operationId": 2.0, + "description": 1.0, + "tags": 1.5, + }, + group=OrGroup, + ) + + async with _cache_lock: + _openapi_index_cache = (index, parser, endpoints) + return index, parser, endpoints + + +async def _search_openapi( + query: str, tag: str | None, limit: int = 20 +) -> tuple[list[dict[str, Any]], str | None]: + """Search OpenAPI endpoints using Whoosh. Returns (results, fallback_message).""" + index, parser, endpoints = await _build_openapi_index() + + try: + query_obj = parser.parse(query) + except Exception: + return [], "Query contained unsupported syntax." + + with index.searcher() as searcher: + results = searcher.search( + query_obj, limit=limit * 2 + ) # Get extra for tag filtering + matches = [] + for hit in results: + # Find full endpoint data + ep = next( + ( + e + for e in endpoints + if e["path"] == hit["path"] and e["method"] == hit["method"] + ), + None, + ) + if ep is None: + continue + # Filter by tag if provided + if tag and tag not in ep.get("tags", ""): + continue + matches.append({**ep, "score": round(hit.score, 2)}) + if len(matches) >= limit: + break + + return matches, None if matches else "No matches found for query." + + +def _generate_curl_example(endpoint: dict[str, Any]) -> str: + """Generate curl command example for an endpoint.""" + method = endpoint["method"] + path = endpoint["path"] + base_url = endpoint["base_url"] + + # Build URL with path parameters + full_path = path + for param in endpoint.get("parameters", []): + if param.get("in") == "path" and param.get("required"): + name = param["name"] + example = param.get( + "example", param.get("schema", {}).get("example", f"<{name}>") + ) + full_path = full_path.replace(f"{{{name}}}", str(example)) + + curl = f"curl -X {method} \\\n '{base_url}{full_path}'" + + # Add query parameters + query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"] + if query_params and query_params[0].get("required"): + param = query_params[0] + example = param.get("example", param.get("schema", {}).get("example", "value")) + curl += f"?{param['name']}={example}" + + curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'" + + # Add request body + if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"): + content = endpoint["request_body"].get("content", {}) + if "application/json" in content: + curl += " \\\n -H 'Content-Type: application/json'" + schema = content["application/json"].get("schema", {}) + example = schema.get("example", "{}") + if isinstance(example, dict): + example = json.dumps(example, indent=2) + curl += f" \\\n -d '{example}'" + + return curl + + +def _format_parameters(parameters: list[dict[str, Any]]) -> str: + """Format parameter information from OpenAPI spec.""" + if not parameters: + return "" + + path_params = [p for p in parameters if p.get("in") == "path"] + query_params = [p for p in parameters if p.get("in") == "query"] + header_params = [p for p in parameters if p.get("in") == "header"] + + output = [] + + for label, params in [ + ("Path Parameters", path_params), + ("Query Parameters", query_params), + ("Header Parameters", header_params), + ]: + if not params: + continue + if output: + output.append("") + output.append(f"**{label}:**") + for p in params: + name = p.get("name", "") + required = " (required)" if p.get("required") else " (optional)" + desc = p.get("description", "") + ptype = p.get("schema", {}).get("type", "string") + example = p.get("example") or p.get("schema", {}).get("example", "") + + output.append(f"- `{name}` ({ptype}){required}: {desc}") + if example: + output.append(f" Example: `{example}`") + + return "\n".join(output) + + +def _format_response_info(responses: dict[str, Any]) -> str: + """Format response information from OpenAPI spec.""" + if not responses: + return "No response information available" + + output = [] + for status, resp_obj in list(responses.items())[:3]: + desc = resp_obj.get("description", "") + output.append(f"- **{status}**: {desc}") + content = resp_obj.get("content", {}) + if "application/json" in content: + schema = content["application/json"].get("schema", {}) + if "type" in schema: + output.append(f" Returns: {schema.get('type', 'object')}") + + return "\n".join(output) + + +def _format_openapi_results( + results: list[dict[str, Any]], + tag: str | None = None, + query: str | None = None, + note: str | None = None, +) -> str: + """Format OpenAPI search results with curl examples.""" + if not results: + if query and tag: + return f"No API endpoints found matching '{query}' in tag '{tag}'" + elif query: + return f"No API endpoints found matching '{query}'" + elif tag: + return f"No API endpoints found with tag '{tag}'" + return "No API endpoints found" + + # Build header + if query and tag: + out = f"# API Endpoints matching '{query}' (tag: `{tag}`)\n\n" + elif query: + out = f"# API Endpoints matching '{query}'\n\n" + elif tag: + out = f"# API Endpoints for tag: `{tag}`\n\n" + else: + out = "# API Endpoints\n\n" + + out += f"Found {len(results)} endpoint(s)" + if note: + out += f" ({note})" + out += "\n\n---\n\n" + + for i, ep in enumerate(results, 1): + out += f"## {i}. {ep['method']} {ep['path']}\n\n" + + if query and "score" in ep: + out += f"**Relevance:** {ep['score']:.2f}\n\n" + + if ep.get("summary"): + out += f"**Summary:** {ep['summary']}\n\n" + + if ep.get("description"): + desc = ep["description"][:300] + if len(ep["description"]) > 300: + desc += "..." + out += f"**Description:** {desc}\n\n" + + if ep.get("tags"): + out += f"**Tags:** {ep['tags']}\n\n" + + params_info = _format_parameters(ep.get("parameters", [])) + if params_info: + out += params_info + "\n\n" + + out += "**Usage:**\n```bash\n" + out += _generate_curl_example(ep) + out += "\n```\n\n" + + out += "**Returns:**\n" + out += _format_response_info(ep["responses"]) + out += "\n\n---\n\n" + + return out + + +async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]: + """Search HuggingFace OpenAPI specification by query and/or tag.""" + tag = arguments.get("tag", "").strip() or None + query = arguments.get("query", "").strip() or None + + if not tag and not query: + return ( + "Error: Provide either 'query' (keyword search) or 'tag' (category filter), or both.", + False, + ) + + try: + note = None + + # If query provided, try Whoosh search first + if query: + results, search_note = await _search_openapi(query, tag, limit=20) + + # If Whoosh found results, return them + if results: + return _format_openapi_results( + results, tag=tag, query=query, note=search_note + ), True + + # Whoosh found nothing - fall back to tag-based if tag provided + if tag: + note = f"No matches for '{query}'; showing all endpoints in tag '{tag}'" + else: + # No tag to fall back to + return _format_openapi_results([], query=query), True + + # Tag-based search (either as fallback or primary) + if tag: + _, _, endpoints = await _build_openapi_index() + results = [ep for ep in endpoints if tag in ep.get("tags", "")] + return _format_openapi_results( + results, tag=tag, query=None, note=note + ), True + + return "Error: No results found", False + + except httpx.HTTPStatusError as e: + return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False + except httpx.RequestError as e: + return f"Request error: {str(e)}", False + except Exception as e: + return f"Error searching OpenAPI spec: {str(e)}", False + + +async def _get_api_search_tool_spec() -> dict[str, Any]: + """Generate OpenAPI tool spec with tags populated at runtime.""" + spec = await _fetch_openapi_spec() + tags = _extract_all_tags(spec) + + return { + "name": "find_hf_api", + "description": ( + "Find HuggingFace Hub REST API endpoints to make HTTP requests. Returns curl examples with authentication. " + "⚠️ USE THIS TOOL when you need to call the HF Hub API directly - for operations like: " + "uploading/downloading files, managing repos, listing models/datasets, getting user info, " + "managing webhooks, collections, discussions, or any Hub interaction not covered by other tools. " + "**Use cases:** (1) 'Stream Space logs' → query='space logs', " + "(2) 'Get Space metrics/Zero-GPU usage' → query='space metrics', " + "(3) 'List organization members' → query='organization members', " + "(4) 'Generate repo access token' → query='jwt token', " + "(5) 'Check repo security scan' → query='security scan'. " + "**Search modes:** Use 'query' for keyword search, 'tag' to browse a category, or both. " + "If query finds no results, falls back to showing all endpoints in the tag. " + "**Output:** Full endpoint details with method, path, parameters, curl command, and response schema." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + "Keyword search across endpoint summaries, descriptions, and operation IDs. " + "Examples: 'upload file', 'create repository', 'list user models', 'delete branch', " + "'webhook', 'collection', 'discussion comments'. Supports stemming (upload/uploading both work)." + ), + }, + "tag": { + "type": "string", + "enum": tags, + "description": ( + "Filter by API category. Use alone to browse all endpoints in a category, " + "or combine with 'query' to search within a category." + ), + }, + }, + "required": [], + }, + } + + +# --------------------------------------------------------------------------- +# Tool Specifications +# --------------------------------------------------------------------------- + +DOC_ENDPOINTS = [ + "hub", + "transformers", + "diffusers", + "datasets", + "gradio", + "trackio", + "smolagents", + "huggingface_hub", + "huggingface.js", + "transformers.js", + "inference-providers", + "inference-endpoints", + "peft", + "accelerate", + "optimum", + "tokenizers", + "courses", + "evaluate", + "tasks", + "dataset-viewer", + "trl", + "simulate", + "sagemaker", + "timm", + "safetensors", + "tgi", + "setfit", + "lerobot", + "autotrain", + "tei", + "bitsandbytes", + "sentence_transformers", + "chat-ui", + "leaderboards", + "lighteval", + "argilla", + "distilabel", + "microsoft-azure", + "kernels", + "google-cloud", +] + +EXPLORE_HF_DOCS_TOOL_SPEC = { + "name": "explore_hf_docs", + "description": ( + "Browse HF documentation structure — discover all available documentation with 200-char previews.\n\n" + "Use this to find relevant documentation and/or examples with detailed parameter docs and API reference. " + "To be used together with github_find_examples and github_read_file to find working examples and documentation.\n\n" + "Pattern: explore_hf_docs (find relevant pages) → fetch_hf_docs (get full content).\n\n" + "For training tasks: fetch the trainer config docs (SFTConfig, DPOConfig, GRPOConfig) to verify parameter names. " + "Returns top 20 results by default; set max_results (max 50) to adjust." + ), + "parameters": { + "type": "object", + "properties": { + "endpoint": { + "type": "string", + "enum": DOC_ENDPOINTS, + "description": ( + "The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n" + "• courses — All Hugging Face courses (LLM, robotics, MCP, smol (llm training), agents, deep RL, computer vision, games, diffusion, 3D, audio) and the cookbook recipes. Probably the best place for examples.\n" + "• hub — Find answers to questions about models/datasets/spaces, auth, versioning, metadata.\n" + "• transformers — Core model library: architectures, configs, tokenizers, training & inference APIs.\n" + "• diffusers — Diffusion pipelines, schedulers, fine-tuning, training, and deployment patterns.\n" + "• datasets — Dataset loading, streaming, processing, Arrow format, Hub integration.\n" + "• gradio — UI components and demos for ML models. Uses Gradio's native API: without query returns full docs (llms.txt), with query uses embedding search for precise results.\n" + "• trackio — Experiment tracking, metrics logging, and run comparison.\n" + "• smolagents — Lightweight agent abstractions and tool-using patterns.\n" + "• huggingface_hub — Python client for Hub operations (auth, upload/download, repo management).\n" + "• huggingface.js — JS/TS client for Hub APIs in browser and Node.\n" + "• transformers.js — Run Transformer models in browser/Node via WebGPU/WASM.\n" + "• inference-providers — Unified interface for third-party inference backends.\n" + "• inference-endpoints — Managed, scalable model deployments on HF infrastructure.\n" + "• peft — Parameter-efficient fine-tuning methods (LoRA, adapters, etc.).\n" + "• accelerate — Hardware-agnostic, distributed and mixed-precision training orchestration.\n" + "• optimum — Hardware-aware optimization and model export tooling, including Habana, Neuron, Intel, ExecuTorch, and TPU variants.\n" + "• tokenizers — Fast tokenizer internals, training, and low-level APIs.\n" + "• evaluate — Metrics, evaluation workflows, and training-loop integration.\n" + "• tasks — Canonical task definitions and model categorization.\n" + "• dataset-viewer — Dataset preview, streaming views, and viewer internals.\n" + "• trl — RLHF, DPO, PPO, and SFT utilities for LLMs.\n" + "• simulate — Experimental simulation tools and workflows.\n" + "• sagemaker — Deploying Hugging Face models on AWS SageMaker.\n" + "• timm — Image model zoo and utilities via HF integrations.\n" + "• safetensors — Safe, fast tensor serialization format.\n" + "• tgi — High-throughput text generation server for LLMs.\n" + "• setfit — Few-shot text classification via sentence embeddings.\n" + "• lerobot — Robotics datasets, policies, and learning workflows.\n" + "• autotrain — No/low-code model training on Hugging Face.\n" + "• tei — Optimized inference server for embedding workloads.\n" + "• bitsandbytes — Quantization and memory-efficient optimizers.\n" + "• sentence_transformers — Embedding models, training recipes, similarity/search workflows.\n" + "• chat-ui — Reference chat interfaces for LLM deployment.\n" + "• leaderboards — Evaluation leaderboards and submission mechanics.\n" + "• lighteval — Lightweight, reproducible LLM evaluation framework.\n" + "• argilla — Data annotation, feedback, and human-in-the-loop workflows.\n" + "• distilabel — Synthetic data generation and distillation pipelines.\n" + "• microsoft-azure — Azure deployment and integration guides.\n" + "• kernels — Lightweight execution environments and notebook-style workflows.\n" + "• google-cloud — GCP deployment and serving workflows.\n" + ), + }, + "query": { + "type": "string", + "description": ( + "Optional keyword query to rank and filter documentation pages. " + "For Gradio, use concise queries like 'how to use the image component' or 'audio component demo'." + ), + }, + "max_results": { + "type": "integer", + "description": "Max results (default 20, max 50). Ignored for Gradio.", + "minimum": 1, + "maximum": 50, + }, + }, + "required": ["endpoint"], + }, +} + +HF_DOCS_FETCH_TOOL_SPEC = { + "name": "fetch_hf_docs", + "description": ( + "Fetch full markdown content of an HF documentation page. Use after explore_hf_docs.\n\n" + "Critical for finding documentation e.g. current trainer configuration parameters (SFTConfig, DPOConfig, etc.) " + "Use for researching solutions and before writing training scripts. Your internal knowledge is outdated.\n\n" + "Provide the full URL from explore_hf_docs results. The .md extension is added automatically." + ), + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": ( + "The full URL to the documentation page. " + "Example: 'https://huggingface.co/docs/trl/dpo_trainer' " + "The .md extension will be added automatically if not present." + ), + }, + }, + "required": ["url"], + }, +} diff --git a/plugin/lib/ml_intern_lib/tools/edit_utils.py b/plugin/lib/ml_intern_lib/tools/edit_utils.py new file mode 100644 index 00000000..6a9a3295 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/edit_utils.py @@ -0,0 +1,268 @@ +""" +Shared utilities for file editing tools — fuzzy matching, syntax validation, +and richer edit operations. + +Used by both local_tools.py and the embedded sandbox server. +""" + +from __future__ import annotations + +# ── Unicode normalization map ──────────────────────────────────────────── + +UNICODE_MAP = { + "\u2013": "-", # en-dash + "\u2014": "-", # em-dash + "\u2212": "-", # minus sign + "\u2018": "'", # left single quote + "\u2019": "'", # right single quote + "\u201c": '"', # left double quote + "\u201d": '"', # right double quote + "\u00a0": " ", # non-breaking space + "\u2003": " ", # em space + "\u2002": " ", # en space + "\u200b": "", # zero-width space + "\ufeff": "", # BOM +} + + +def _normalize_unicode(s: str) -> str: + return "".join(UNICODE_MAP.get(c, c) for c in s) + + +# ── 4-pass fuzzy matching ──────────────────────────────────────────────── + + +def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]: + """Find *pattern* in *content* with increasingly relaxed matching. + + Returns (start_index_in_original_content, match_note) or (None, None). + The index always refers to the *original* content string so callers can + use ``content[idx : idx + len(matched_text)]`` for replacement. + + Strategy (mirrors Codex): + 1. Exact match + 2. Right-trim each line (trailing whitespace) + 3. Both-sides trim (all surrounding whitespace per line) + 4. Unicode normalization on top of both-sides trim + """ + # Pass 1 — exact + if pattern in content: + return content.index(pattern), None + + # Helper: build a line-stripped version *and* a mapping from stripped + # positions back to original positions. We need this so callers can + # apply the replacement on the original content, not the stripped copy. + + def _build_stripped(text: str, strip_fn): + """Return (stripped_text, line_start_map). + + line_start_map[i] = original byte offset of the start of line i. + """ + orig_lines = text.split("\n") + stripped_lines = [strip_fn(l) for l in orig_lines] + return "\n".join(stripped_lines), orig_lines, stripped_lines + + # Pass 2 — right-trim + c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip) + p_rt = "\n".join(l.rstrip() for l in pattern.split("\n")) + idx = c_rt.find(p_rt) + if idx != -1: + orig_idx = _map_back(idx, c_orig_lines, c_rt_lines) + return orig_idx, "(matched after trimming trailing whitespace)" + + # Pass 3 — both-sides trim + c_st, _, c_st_lines = _build_stripped(content, str.strip) + p_st = "\n".join(l.strip() for l in pattern.split("\n")) + idx = c_st.find(p_st) + if idx != -1: + orig_idx = _map_back(idx, c_orig_lines, c_st_lines) + return orig_idx, "(matched after trimming whitespace)" + + # Pass 4 — unicode normalization + both-sides trim + c_norm = _normalize_unicode(c_st) + p_norm = _normalize_unicode(p_st) + idx = c_norm.find(p_norm) + if idx != -1: + orig_idx = _map_back(idx, c_orig_lines, c_st_lines) + return orig_idx, "(matched after unicode normalization)" + + return None, None + + +def _map_back( + stripped_idx: int, + orig_lines: list[str], + stripped_lines: list[str], +) -> int: + """Map a character index in the stripped/joined text back to the original text.""" + # Walk through stripped lines to find which line the index falls on + pos = 0 + for i, sl in enumerate(stripped_lines): + line_end = pos + len(sl) + if stripped_idx <= line_end: + col_in_stripped = stripped_idx - pos + # Find where this stripped line's content starts in the original line + ol = orig_lines[i] + # The stripped line is a subset of the original line; find its offset + lstripped = len(ol) - len(ol.lstrip()) + orig_col = lstripped + col_in_stripped + # Compute absolute position in original text + orig_pos = sum(len(orig_lines[j]) + 1 for j in range(i)) + orig_col + return orig_pos + pos = line_end + 1 # +1 for the \n + # Fallback: return 0 (shouldn't happen if idx is valid) + return 0 + + +def fuzzy_find_original_match(content: str, pattern: str) -> tuple[str | None, str | None]: + """Find the *original* text in content that matches pattern fuzzily. + + Returns (original_matched_text, match_note) or (None, None). + This extracts the exact substring from the original content that + corresponds to the fuzzy match, preserving its original whitespace/unicode. + """ + if pattern in content: + return pattern, None + + idx, note = fuzzy_find(content, pattern) + if idx is None: + return None, None + + # We need to find the original text span that corresponds to the match. + # The match covers len(pattern) worth of *logical* content. + # Count how many original lines the pattern spans. + pattern_lines = pattern.split("\n") + n_lines = len(pattern_lines) + + # Find which original line the match starts on + orig_lines = content.split("\n") + char_pos = 0 + start_line = 0 + for i, ol in enumerate(orig_lines): + if char_pos + len(ol) >= idx: + start_line = i + break + char_pos += len(ol) + 1 + + end_line = min(start_line + n_lines, len(orig_lines)) + # Extract the original lines that were matched + matched_lines = orig_lines[start_line:end_line] + original_text = "\n".join(matched_lines) + return original_text, note + + +# ── Richer edit operations ─────────────────────────────────────────────── + + +def apply_edit( + content: str, + old_str: str, + new_str: str, + mode: str = "replace", + replace_all: bool = False, +) -> tuple[str, int, str | None]: + """Apply an edit operation to content. + + Modes: + - replace: replace first occurrence (or all if replace_all=True) + - replace_all: replace all occurrences (alias) + - append_after: insert new_str after old_str + - prepend_before: insert new_str before old_str + + Returns (new_content, num_replacements, fuzzy_note). + Raises ValueError if old_str not found. + """ + if mode == "replace_all": + replace_all = True + mode = "replace" + + # Try exact match first, then fuzzy + fuzzy_note = None + if old_str not in content: + original_match, fuzzy_note = fuzzy_find_original_match(content, old_str) + if original_match is None: + raise ValueError( + "old_str was not found in the file. Make sure old_str matches " + "the file contents exactly, including whitespace and indentation. " + "Use the read tool to verify the current file contents before retrying." + ) + old_str = original_match + + count = content.count(old_str) + + if mode == "replace": + if count > 1 and not replace_all: + raise ValueError( + f"Found {count} matches of old_str in the file, but replace_all is " + f"false. To replace all occurrences, set replace_all to true. To " + f"replace only one, provide a larger old_str with more surrounding " + f"context to uniquely identify the instance." + ) + if replace_all: + new_content = content.replace(old_str, new_str) + return new_content, count, fuzzy_note + else: + new_content = content.replace(old_str, new_str, 1) + return new_content, 1, fuzzy_note + + elif mode == "append_after": + if replace_all: + new_content = content.replace(old_str, old_str + new_str) + return new_content, count, fuzzy_note + else: + idx = content.index(old_str) + len(old_str) + new_content = content[:idx] + new_str + content[idx:] + return new_content, 1, fuzzy_note + + elif mode == "prepend_before": + if replace_all: + new_content = content.replace(old_str, new_str + old_str) + return new_content, count, fuzzy_note + else: + idx = content.index(old_str) + new_content = content[:idx] + new_str + content[idx:] + return new_content, 1, fuzzy_note + + else: + raise ValueError(f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before.") + + +# ── Syntax validation (Python) ─────────────────────────────────────────── + + +def validate_python(content: str, path: str = "") -> list[str]: + """Lightweight post-write validation for Python files. + + Checks syntax and training script conventions. This runs on the host + (not in the sandbox), so it only does static checks — no import resolution + or signature inspection since packages are installed in the sandbox, not here. + + The sandbox server has its own richer version that does real signature + inspection against installed packages. + + Returns a list of warning strings (empty = all good). + Never raises — validation failures are advisory only. + """ + import ast + + warnings = [] + + # 1. Syntax check via ast.parse + try: + ast.parse(content) + except SyntaxError as e: + warnings.append(f"Python syntax error at line {e.lineno}: {e.msg}") + return warnings + + # 2. Training script heuristics + if any(kw in content for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")): + if "push_to_hub" not in content: + warnings.append( + "Training script warning: no 'push_to_hub' found — model may be lost when job ends" + ) + if "hub_model_id" not in content: + warnings.append( + "Training script warning: no 'hub_model_id' found" + ) + + return warnings diff --git a/plugin/lib/ml_intern_lib/tools/github_find_examples.py b/plugin/lib/ml_intern_lib/tools/github_find_examples.py new file mode 100644 index 00000000..88321a14 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/github_find_examples.py @@ -0,0 +1,460 @@ +""" +GitHub Find Examples Tool - Discover examples, tutorials, and guides for any library + +Lists all files in a repository and performs deterministic keyword search. +""" + +import os +from typing import Any, Dict, List + +import requests +from thefuzz import fuzz + +from ml_intern_lib.tools.types import ToolResult + +# In order of priority (lower index = higher priority for sorting) +EXAMPLE_PATTERNS = [ + "scripts", + # General example patterns (catch-all, lower priority) + "examples", + "example", + # Notebook patterns + "notebooks", + "notebook", + # Tutorial/learning patterns + "tutorials", + "tutorial", + "quickstart", + "walkthroughs", + "walkthrough", + # Cookbook/recipe patterns + "cookbook", + "cookbooks", + "recipes", + "recipe", + # Demo/sample patterns + "demos", + "demo", + "samples", + "sample", + # Other patterns + "guides", + "guide", + "getting-started", + "getting_started", + "playground", + "howto", + "how-to", + "use-cases", + "usecases", + "use_cases", + "sandbox", + "showcase", +] + + +def _get_repo_tree(org: str, repo: str, token: str) -> tuple[List[Dict[str, Any]], str]: + """Get all files in a repository recursively. Returns (files, error_message)""" + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {token}", + } + + full_repo = f"{org}/{repo}" + + # Get default branch + try: + response = requests.get( + f"https://api.github.com/repos/{full_repo}", headers=headers, timeout=10 + ) + if response.status_code == 404: + return [], "not_found" + if response.status_code != 200: + return [], f"API error: {response.status_code}" + + repo_data = response.json() + default_branch = repo_data.get("default_branch", "main") + except Exception as e: + return [], f"Error fetching repo: {str(e)}" + + # Get repository tree recursively + try: + response = requests.get( + f"https://api.github.com/repos/{full_repo}/git/trees/{default_branch}", + headers=headers, + params={"recursive": "1"}, + timeout=30, + ) + if response.status_code != 200: + return [], f"Error fetching tree: {response.status_code}" + + data = response.json() + tree = data.get("tree", []) + + # Filter to only include files (not directories) + files = [ + { + "path": item["path"], + "ref": item["sha"], + "size": item.get("size", 0), + "url": f"https://github.com/{full_repo}/blob/{default_branch}/{item['path']}", + } + for item in tree + if item["type"] == "blob" + ] + + return files, "" + except Exception as e: + return [], f"Error processing tree: {str(e)}" + + +def _search_similar_repos(org: str, repo: str, token: str) -> List[Dict[str, Any]]: + """Search for similar repository names in the organization""" + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {token}", + } + + # Search for repos in the org with similar name + query = f"org:{org} {repo}" + + try: + response = requests.get( + "https://api.github.com/search/repositories", + headers=headers, + params={"q": query, "sort": "stars", "order": "desc", "per_page": 10}, + timeout=30, + ) + + if response.status_code != 200: + return [] + + data = response.json() + items = data.get("items", []) + + return [ + { + "name": item.get("name"), + "full_name": item.get("full_name"), + "description": item.get("description"), + "stars": item.get("stargazers_count", 0), + "url": item.get("html_url"), + } + for item in items + ] + except Exception: + return [] + + +def _score_against_example_patterns(file_path: str) -> int: + """Score file against example patterns using token_set_ratio""" + scores = [] + for pattern in EXAMPLE_PATTERNS: + score = fuzz.token_set_ratio(pattern.lower(), file_path.lower()) + scores.append(score) + return max(scores) if scores else 0 + + +def _score_against_keyword(file_path: str, keyword: str) -> int: + """Calculate fuzzy match score for a file path against a keyword""" + # Use partial_ratio for substring matching (good for paths) + # Also check token_set_ratio for word-level matching + partial_score = fuzz.partial_ratio(keyword.lower(), file_path.lower()) + token_score = fuzz.token_set_ratio(keyword.lower(), file_path.lower()) + + # Return the higher of the two + return max(partial_score, token_score) + + +def _get_pattern_priority(file_path: str) -> tuple[int, int, int]: + """ + Get priority of a file path based on which example pattern directory it's in. + + Returns: (in_examples_dir, pattern_priority, path_depth) + - in_examples_dir: 0 if in examples/ directory, 1 otherwise (lower is better) + - pattern_priority: Index in EXAMPLE_PATTERNS (lower is better), or 999 if no match + - path_depth: Number of path segments (lower is better) + + Note: Prioritizes files in "examples/" directory first, then by most specific pattern match. + E.g., "examples/scripts/train.py" is better than "scripts/util.py" + """ + path_lower = file_path.lower() + path_parts = path_lower.split("/") + + # Check if file is in examples/ directory (highest priority) + in_examples_dir = 0 if (path_parts[0] in ["examples", "example"]) else 1 + + # Find ALL matching patterns and use the best (lowest index) one + # But prefer deeper matches (more specific) over shallow ones + best_priority = 999 + best_depth_at_match = -1 + + for i, pattern in enumerate(EXAMPLE_PATTERNS): + # Check if pattern appears as a directory component in the path + if pattern in path_parts: + # Find the depth where this pattern appears (rightmost occurrence) + depth = len(path_parts) - 1 - path_parts[::-1].index(pattern) + + # Prefer deeper matches, or better priority if at same depth + if depth > best_depth_at_match or ( + depth == best_depth_at_match and i < best_priority + ): + best_priority = i + best_depth_at_match = depth + + return (in_examples_dir, best_priority, len(path_parts)) + + +def _handle_repo_tree_errors( + all_files: List[Dict[str, Any]], + error: str, + org: str, + repo: str, + token: str, +) -> ToolResult | None: + """Handle errors from repo tree fetch. Returns ToolResult if error, None if OK.""" + if error == "not_found": + similar_repos = _search_similar_repos(org, repo, token) + + if not similar_repos: + return { + "formatted": f"Repository '{org}/{repo}' not found and no similar repositories found.", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + # Format similar repos + lines = [f"**Repository '{org}/{repo}' not found. Similar repositories:**\n"] + for i, r in enumerate(similar_repos, 1): + lines.append(f"{i}. **{r['full_name']}** (⭐ {r['stars']:,} stars)") + if r["description"]: + desc = ( + r["description"][:100] + "..." + if len(r["description"]) > 100 + else r["description"] + ) + lines.append(f" {desc}") + lines.append(f" {r['url']}\n") + + return { + "formatted": "\n".join(lines), + "totalResults": len(similar_repos), + "resultsShared": len(similar_repos), + "isError": True, + } + + if error: + return { + "formatted": f"Error accessing repository '{org}/{repo}': {error}", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + if not all_files: + return { + "formatted": f"No files found in repository '{org}/{repo}'", + "totalResults": 0, + "resultsShared": 0, + } + + return None + + +def find_examples( + keyword: str = "", + repo: str = "", + org: str = "huggingface", + max_results: int = 10, + min_score: int = 80, +) -> ToolResult: + """ + Find example files in a repository using fuzzy matching. + + Args: + keyword: Keyword to fuzzy match against file paths (e.g., "grpo") + repo: Repository name (e.g., "trl") + org: GitHub organization (default: "huggingface") + max_results: Maximum number of results (default 50) + min_score: Minimum fuzzy match score (0-100, default 60) + + Returns: + ToolResult with matching files, or similar repos if repo not found + """ + token = os.environ.get("GITHUB_TOKEN") + if not token: + return { + "formatted": "Error: GITHUB_TOKEN environment variable is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + if not repo: + return { + "formatted": "Error: repo parameter is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + # Get all files in the repository + all_files, error = _get_repo_tree(org, repo, token) + + # Handle errors (not found, API errors, empty repo) + if error_result := _handle_repo_tree_errors(all_files, error, org, repo, token): + return error_result + + # Step 1: Filter files by example patterns (score >= 60) + example_threshold = 60 + example_files = [] + for file in all_files: + example_score = _score_against_example_patterns(file["path"]) + if example_score >= example_threshold: + example_files.append({**file, "example_score": example_score}) + + if not example_files: + return { + "formatted": f"No example files found in {org}/{repo} (no files match example patterns with score >= {example_threshold}).", + "totalResults": 0, + "resultsShared": 0, + } + + # Step 2: If keyword provided, score and filter by keyword + if keyword: + scored_files = [] + for file in example_files: + keyword_score = _score_against_keyword(file["path"], keyword) + if keyword_score >= min_score: + scored_files.append({**file, "score": keyword_score}) + + if not scored_files: + return { + "formatted": f"No files found in {org}/{repo} matching keyword '{keyword}' (min score: {min_score}) among {len(example_files)} example files.", + "totalResults": 0, + "resultsShared": 0, + } + + # Sort by keyword score (descending) for best matches first + scored_files.sort(key=lambda x: x["score"], reverse=True) + else: + # No keyword: prioritize by pattern directory, then path depth + scored_files = [] + for file in example_files: + in_examples_dir, pattern_priority, path_depth = _get_pattern_priority( + file["path"] + ) + scored_files.append( + { + **file, + "score": file["example_score"], + "in_examples_dir": in_examples_dir, + "pattern_priority": pattern_priority, + "path_depth": path_depth, + } + ) + + if not scored_files: + return { + "formatted": f"No example files found in {org}/{repo}.", + "totalResults": 0, + "resultsShared": 0, + } + + # Sort by: 1) files in examples/ dir first, 2) pattern priority (scripts > datasets > etc), 3) path depth, 4) path name + scored_files.sort( + key=lambda x: ( + x["in_examples_dir"], + x["pattern_priority"], + x["path_depth"], + x["path"], + ) + ) + + # Limit results + results = scored_files[:max_results] + + # Format output + keyword_desc = f" matching '{keyword}'" if keyword else "" + lines = [f"**Found {len(results)} example files in {org}/{repo}{keyword_desc}:**"] + if len(scored_files) > max_results: + lines[0] += f" (showing {max_results} of {len(scored_files)})" + lines.append("") + + for i, file in enumerate(results, 1): + lines.append(f"{i}. **{file['path']}**") + lines.append(f" Size: {file['size']:,} bytes | Ref: {file['ref'][:7]}") + lines.append(f" URL: {file['url']}") + + # Copyable parameters for read_file tool + read_params = f"{{'repo': '{org}/{repo}', 'path': '{file['path']}'}}" + lines.append(f" To read, use: {read_params}") + lines.append("") + + return { + "formatted": "\n".join(lines), + "totalResults": len(results), + "resultsShared": len(results), + } + + +# Tool specification +GITHUB_FIND_EXAMPLES_TOOL_SPEC = { + "name": "github_find_examples", + "description": ( + "Find working example scripts in GitHub repositories (from a list of predetermined directories e.g. examples/, scripts/, tutorials/, etc.). " + "Uses fuzzy keyword matching.\n\n" + "MANDATORY before writing any ML training, fine-tuning, or inference code. " + "Your internal knowledge of library APIs is outdated — working examples show current API patterns.\n\n" + "Sequence: github_find_examples → github_read_file (study the example) → implement based on what you found.\n\n" + "Skip this only for: simple data queries, status checks, non-code tasks.\n\n" + "Examples:\n" + " {keyword: 'sft', repo: 'trl'} → finds examples/scripts/sft.py\n" + " {keyword: 'grpo', repo: 'trl'} → finds GRPO training examples\n" + " {repo: 'trl', max_results: 20} → lists all available training method examples" + ), + "parameters": { + "type": "object", + "properties": { + "keyword": { + "type": "string", + "description": "Keyword to fuzzy match against file paths (e.g., 'grpo', 'sft').", + }, + "repo": { + "type": "string", + "description": "Repository name (e.g., 'trl', 'transformers'). Required.", + }, + "org": { + "type": "string", + "description": "GitHub organization or username. Default: 'huggingface'.", + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results to return. Default: 50.", + }, + "min_score": { + "type": "integer", + "description": "Minimum fuzzy match score (0-100). Default: 60.", + }, + }, + "required": ["repo"], + }, +} + + +async def github_find_examples_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: + """Handler for agent tool router""" + try: + result = find_examples( + keyword=arguments.get("keyword", ""), + repo=arguments["repo"], + org=arguments.get("org", "huggingface"), + max_results=arguments.get("max_results", 50), + min_score=arguments.get("min_score", 60), + ) + return result["formatted"], not result.get("isError", False) + except Exception as e: + return f"Error finding examples: {str(e)}", False diff --git a/plugin/lib/ml_intern_lib/tools/github_list_repos.py b/plugin/lib/ml_intern_lib/tools/github_list_repos.py new file mode 100644 index 00000000..76022de5 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/github_list_repos.py @@ -0,0 +1,287 @@ +""" +GitHub List Repositories Tool - List and sort repositories for any user or organization + +Efficiently discover repositories with flexible sorting options. +""" + +import os +from typing import Any, Dict, Literal, Optional + +import requests + +from ml_intern_lib.tools.types import ToolResult + + +def list_repos( + owner: str, + owner_type: Literal["user", "org"] = "org", + sort: Literal["stars", "forks", "updated", "created"] = "stars", + order: Literal["asc", "desc"] = "desc", + limit: Optional[int] = 30, +) -> ToolResult: + """ + List repositories for a user or organization using GitHub REST API. + + Args: + owner: GitHub username or organization name + owner_type: Whether the owner is a "user" or "org" (default: "org") + sort: Sort field - "stars", "forks", "updated", or "created" + order: Sort order - "asc" or "desc" (default: "desc") + limit: Maximum number of repositories to return + + Returns: + ToolResult with repository information + """ + token = os.environ.get("GITHUB_TOKEN") + if not token: + return { + "formatted": "Error: GITHUB_TOKEN environment variable is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + if owner_type == "org": + url = f"https://api.github.com/orgs/{owner}/repos" + else: + url = f"https://api.github.com/users/{owner}/repos" + + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {token}", + } + + all_repos = [] + page = 1 + per_page = 100 # Maximum allowed by GitHub + + # Map our sort values to GitHub API sort values + # Note: GitHub list repos API doesn't support sorting by stars/forks + # We'll fetch all repos and sort in memory for those cases + api_sort_map = { + "created": "created", + "updated": "updated", + "stars": None, # Not supported by list API + "forks": None, # Not supported by list API + } + + api_sort = api_sort_map.get(sort) + need_manual_sort = api_sort is None + + try: + while True: + params = { + "page": page, + "per_page": per_page, + } + + # Only add sort/direction if API supports it + if api_sort: + params["sort"] = api_sort + params["direction"] = order + + response = requests.get( + url, + headers=headers, + params=params, + timeout=30, + ) + + if response.status_code == 403: + error_data = response.json() + return { + "formatted": f"GitHub API rate limit or permission error: {error_data.get('message', 'Unknown error')}", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + if response.status_code != 200: + error_msg = f"GitHub API error (status {response.status_code})" + try: + error_data = response.json() + if "message" in error_data: + error_msg += f": {error_data['message']}" + except Exception: + pass + return { + "formatted": error_msg, + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + items = response.json() + + if not items: + break + + for item in items: + all_repos.append( + { + "name": item.get("name"), + "full_name": item.get("full_name"), + "description": item.get("description"), + "html_url": item.get("html_url"), + "language": item.get("language"), + "stars": item.get("stargazers_count", 0), + "forks": item.get("forks_count", 0), + "open_issues": item.get("open_issues_count", 0), + "topics": item.get("topics", []), + "updated_at": item.get("updated_at"), + "created_at": item.get("created_at"), + } + ) + + # Check if we got fewer results than requested (last page) + if len(items) < per_page: + break + + # Stop if we have enough repos + if limit and len(all_repos) >= limit: + break + + page += 1 + + except requests.exceptions.RequestException as e: + return { + "formatted": f"Failed to connect to GitHub API: {str(e)}", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + # Manual sorting if needed (for stars/forks) + if need_manual_sort and all_repos: + reverse = order == "desc" + all_repos.sort(key=lambda x: x[sort], reverse=reverse) + + # Apply limit after sorting + if limit: + all_repos = all_repos[:limit] + + if not all_repos: + return { + "formatted": f"No repositories found for {owner_type} '{owner}'", + "totalResults": 0, + "resultsShared": 0, + } + + # Format output + lines = [f"**Found {len(all_repos)} repositories for {owner}:**\n"] + + for i, repo in enumerate(all_repos, 1): + lines.append(f"{i}. **{repo['full_name']}**") + lines.append( + f" ⭐ {repo['stars']:,} stars | 🍴 {repo['forks']:,} forks | Language: {repo['language'] or 'N/A'}" + ) + if repo["description"]: + desc = ( + repo["description"][:100] + "..." + if len(repo["description"]) > 100 + else repo["description"] + ) + lines.append(f" {desc}") + lines.append(f" URL: {repo['html_url']}") + if repo["topics"]: + lines.append(f" Topics: {', '.join(repo['topics'][:5])}") + + # Copyable parameters for other tools + lines.append(f" Use in tools: {{'repo': '{repo['full_name']}'}}") + lines.append("") + + return { + "formatted": "\n".join(lines), + "totalResults": len(all_repos), + "resultsShared": len(all_repos), + } + + +# Tool specification +GITHUB_LIST_REPOS_TOOL_SPEC = { + "name": "github_list_repos", + "description": ( + "List and discover repositories for GitHub organizations or users with flexible sorting. " + "**Use when:** (1) Exploring what libraries exist for a task, (2) Finding the right library to use, " + "(3) Discovering popular or active projects, (4) Checking recently updated repos for latest features, " + "(5) Finding alternative libraries in an organization. " + "**Pattern:** github_list_repos (discover libraries) → github_find_examples (find usage examples) → implement. " + "Returns: Comprehensive repository information (stars, forks, language, topics, URLs), sorted by preference. " + "**Then:** Use github_find_examples on selected repo to discover example code. " + "Sorts by: stars (popularity), forks (community), updated (activity), created (age).\n\n" + "## When to use this tool\n\n" + "- When you need to find libraries to use in your implementation\n" + "- When exploring what repositories exist for a task or domain\n" + "- When debugging an error and looking up if others have similar issues in repos\n" + "- When finding the most popular or actively maintained projects for a user/org\n" + "## Examples\n\n" + "\n" + "// ML Workflow Step: Discover HF libraries for RLHF/alignment\n" + "// Use case: Find the right library for training with human feedback\n" + "{\n" + " owner: 'huggingface',\n" + " owner_type: 'org',\n" + " sort: 'stars',\n" + " limit: 10\n" + "}\n" + "// Returns: transformers, trl, peft, accelerate, diffusers...\n" + "\n\n" + "\n" + "// ML Workflow Step: Check for recently updated HF repos\n" + "// Use case: Find actively maintained libraries with latest features\n" + "{\n" + " owner: 'huggingface',\n" + " owner_type: 'org',\n" + " sort: 'updated',\n" + " order: 'desc',\n" + " limit: 15\n" + "}\n" + "// Helps identify which repos have recent improvements/fixes\n" + "" + ), + "parameters": { + "type": "object", + "properties": { + "owner": { + "type": "string", + "description": "GitHub username or organization name. Required.", + }, + "owner_type": { + "type": "string", + "enum": ["user", "org"], + "description": "Whether the owner is a 'user' or 'org'. Default: 'org'.", + }, + "sort": { + "type": "string", + "enum": ["stars", "forks", "updated", "created"], + "description": "Sort field. Options: 'stars', 'forks', 'updated', 'created'. Default: 'stars'.", + }, + "order": { + "type": "string", + "enum": ["asc", "desc"], + "description": "Sort order. Options: 'asc', 'desc'. Default: 'desc'.", + }, + "limit": { + "type": "integer", + "description": "Maximum number of repositories to return. No limit if not specified. Default: 30.", + }, + }, + "required": ["owner"], + }, +} + + +async def github_list_repos_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: + """Handler for agent tool router""" + try: + result = list_repos( + owner=arguments["owner"], + owner_type=arguments.get("owner_type", "org"), + sort=arguments.get("sort", "stars"), + order=arguments.get("order", "desc"), + limit=arguments.get("limit"), + ) + return result["formatted"], not result.get("isError", False) + except Exception as e: + return f"Error listing repositories: {str(e)}", False diff --git a/plugin/lib/ml_intern_lib/tools/github_read_file.py b/plugin/lib/ml_intern_lib/tools/github_read_file.py new file mode 100644 index 00000000..425e8185 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/github_read_file.py @@ -0,0 +1,302 @@ +""" +GitHub Read File Tool - Read file contents from any GitHub repository with line range support + +Fetch exact file contents with metadata, supporting line ranges for efficient reading. +""" + +import base64 +import json +import os +from typing import Any, Dict, Optional + +import nbformat +import requests +from nbconvert import MarkdownExporter +from nbconvert.preprocessors import ClearOutputPreprocessor, TagRemovePreprocessor + +from ml_intern_lib.tools.types import ToolResult + + +def _convert_ipynb_to_markdown(content: str) -> str: + """ + Convert Jupyter notebook JSON to LLM-friendly Markdown. + + Args: + content: Raw notebook JSON string + + Returns: + Converted Markdown string + """ + try: + # Parse notebook JSON + nb_dict = json.loads(content) + + # Normalize cell sources (can be string or list of strings) + if "cells" in nb_dict: + for cell in nb_dict["cells"]: + if "source" in cell and isinstance(cell["source"], list): + cell["source"] = "".join(cell["source"]) + + # Read notebook with explicit version + nb = nbformat.reads(json.dumps(nb_dict), as_version=4) + + # Strip outputs for LLM readability (outputs can be noisy/large) + clear = ClearOutputPreprocessor() + nb, _ = clear.preprocess(nb, {}) + + # Optionally remove cells tagged with "hide" or similar + remove = TagRemovePreprocessor( + remove_cell_tags={"hide", "hidden", "remove"}, + remove_input_tags=set(), + remove_all_outputs_tags=set(), + ) + nb, _ = remove.preprocess(nb, {}) + + # Convert to markdown + exporter = MarkdownExporter() + markdown, _ = exporter.from_notebook_node(nb) + + return markdown + + except json.JSONDecodeError: + return content + except Exception: + return content + + +def read_file( + repo: str, + path: str, + ref: str = "HEAD", + line_start: Optional[int] = None, + line_end: Optional[int] = None, +) -> ToolResult: + """ + Read file contents from a GitHub repository with line range support. + + Args: + repo: Repository in format "owner/repo" (e.g., "github/github-mcp-server") + path: Path to file in repository (e.g., "pkg/github/search.go") + ref: Git reference - branch name, tag, or commit SHA (default: "HEAD") + line_start: Starting line number (1-indexed, inclusive) + line_end: Ending line number (1-indexed, inclusive) + + Returns: + ToolResult with file contents and metadata + """ + token = os.environ.get("GITHUB_TOKEN") + if not token: + return { + "formatted": "Error: GITHUB_TOKEN environment variable is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + # Parse repo + if "/" not in repo: + return { + "formatted": "Error: repo must be in format 'owner/repo'", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + owner, repo_name = repo.split("/", 1) + + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {token}", + } + + # Fetch file contents + url = f"https://api.github.com/repos/{owner}/{repo_name}/contents/{path}" + params = {} + if ref and ref != "HEAD": + params["ref"] = ref + + try: + response = requests.get(url, headers=headers, params=params, timeout=30) + + if response.status_code == 404: + return { + "formatted": f"File not found: {path} in {repo} (ref: {ref})", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + if response.status_code != 200: + error_msg = f"GitHub API error (status {response.status_code})" + try: + error_data = response.json() + if "message" in error_data: + error_msg += f": {error_data['message']}" + except Exception: + pass + return { + "formatted": error_msg, + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + data = response.json() + + # Check if it's a file + if data.get("type") != "file": + return { + "formatted": f"Path {path} is not a file (type: {data.get('type')})", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + # Decode content + content_b64 = data.get("content", "") + if content_b64: + content_b64 = content_b64.replace("\n", "").replace(" ", "") + content = base64.b64decode(content_b64).decode("utf-8", errors="replace") + else: + # For large files, fetch raw content + raw_headers = { + "Accept": "application/vnd.github.raw", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {token}", + } + raw_response = requests.get( + url, headers=raw_headers, params=params, timeout=30 + ) + if raw_response.status_code != 200: + return { + "formatted": "Failed to fetch file content", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + content = raw_response.text + + if path.lower().endswith(".ipynb"): + content = _convert_ipynb_to_markdown(content) + + # Process line ranges + lines = content.split("\n") + total_lines = len(lines) + + truncated = False + + if line_start is None and line_end is None: + # No range specified + if total_lines > 300: + line_start = 1 + line_end = 300 + truncated = True + else: + line_start = 1 + line_end = total_lines + else: + # Range specified + if line_start is None: + line_start = 1 + if line_end is None: + line_end = total_lines + + # Validate range + line_start = max(1, line_start) + line_end = min(total_lines, line_end) + if line_start > line_end: + return { + "formatted": f"Invalid range: line_start ({line_start}) > line_end ({line_end})", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + # Extract lines + selected_lines = lines[line_start - 1 : line_end] + selected_content = "\n".join(selected_lines) + + # Format output + lines_output = [f"**Reading file from repo: {repo}, path: {path}**"] + + if ref and ref != "HEAD": + lines_output.append(f"Ref: {ref}") + + lines_output.append("\n**File content:") + lines_output.append("```") + lines_output.append(selected_content) + lines_output.append("```") + if truncated: + lines_output.append( + f"Currently showing lines {line_start}-{line_end} out of {total_lines} total lines. Use line_start and line_end to view more lines." + ) + return { + "formatted": "\n".join(lines_output), + "totalResults": 1, + "resultsShared": 1, + } + + except requests.exceptions.RequestException as e: + return { + "formatted": f"Failed to connect to GitHub API: {str(e)}", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + +# Tool specification +GITHUB_READ_FILE_TOOL_SPEC = { + "name": "github_read_file", + "description": ( + "Read file contents from GitHub repositories. Returns first 300 lines by default. " + "Auto-converts Jupyter notebooks to markdown.\n\n" + "Use AFTER github_find_examples to study the working implementation. " + "The purpose is to learn current API patterns — imports, trainer configs, dataset handling — " + "so your implementation uses correct, up-to-date code.\n\n" + "Use line_start/line_end for large files (>300 lines) to read specific sections.\n\n" + "When NOT to use: when you don't know the file path (use github_find_examples first)." + ), + "parameters": { + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "Repository in format 'owner/repo' (e.g., 'github/github-mcp-server'). Required.", + }, + "path": { + "type": "string", + "description": "Path to file in repository (e.g., 'src/index.js'). Required.", + }, + "ref": { + "type": "string", + "description": "Git reference - branch name, tag, or commit SHA. Default: 'HEAD'.", + }, + "line_start": { + "type": "integer", + "description": "Starting line number (1-indexed, inclusive). Optional.", + }, + "line_end": { + "type": "integer", + "description": "Ending line number (1-indexed, inclusive). Optional.", + }, + }, + "required": ["repo", "path"], + }, +} + + +async def github_read_file_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: + """Handler for agent tool router""" + try: + result = read_file( + repo=arguments["repo"], + path=arguments["path"], + ref=arguments.get("ref", "HEAD"), + line_start=arguments.get("line_start"), + line_end=arguments.get("line_end"), + ) + return result["formatted"], not result.get("isError", False) + except Exception as e: + return f"Error reading file: {str(e)}", False diff --git a/plugin/lib/ml_intern_lib/tools/hf_repo_files_tool.py b/plugin/lib/ml_intern_lib/tools/hf_repo_files_tool.py new file mode 100644 index 00000000..7637c901 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/hf_repo_files_tool.py @@ -0,0 +1,323 @@ +""" +HF Repo Files Tool - File operations on Hugging Face repositories + +Operations: list, read, upload, delete +""" + +import asyncio +from typing import Any, Dict, Literal, Optional + +from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError + +from ml_intern_lib.tools.types import ToolResult + +OperationType = Literal["list", "read", "upload", "delete"] + + +async def _async_call(func, *args, **kwargs): + """Wrap synchronous HfApi calls for async context.""" + return await asyncio.to_thread(func, *args, **kwargs) + + +def _build_repo_url(repo_id: str, repo_type: str = "model") -> str: + """Build the Hub URL for a repository.""" + if repo_type == "model": + return f"https://huggingface.co/{repo_id}" + return f"https://huggingface.co/{repo_type}s/{repo_id}" + + +def _format_size(size_bytes: int) -> str: + """Format file size in human-readable form.""" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024: + return f"{size_bytes:.1f}{unit}" + size_bytes /= 1024 + return f"{size_bytes:.1f}PB" + + +class HfRepoFilesTool: + """Tool for file operations on HF repos.""" + + def __init__(self, hf_token: Optional[str] = None): + self.api = HfApi(token=hf_token) + + async def execute(self, args: Dict[str, Any]) -> ToolResult: + """Execute the specified operation.""" + operation = args.get("operation") + + if not operation: + return self._help() + + try: + handlers = { + "list": self._list, + "read": self._read, + "upload": self._upload, + "delete": self._delete, + } + + handler = handlers.get(operation) + if handler: + return await handler(args) + else: + return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete") + + except RepositoryNotFoundError: + return self._error(f"Repository not found: {args.get('repo_id')}") + except EntryNotFoundError: + return self._error(f"File not found: {args.get('path')}") + except Exception as e: + return self._error(f"Error: {str(e)}") + + def _help(self) -> ToolResult: + """Show usage instructions.""" + return { + "formatted": """**hf_repo_files** - File operations on HF repos + +**Operations:** +- `list` - List files: `{"operation": "list", "repo_id": "gpt2"}` +- `read` - Read file: `{"operation": "read", "repo_id": "gpt2", "path": "config.json"}` +- `upload` - Upload: `{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "..."}` +- `delete` - Delete: `{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp"]}` + +**Common params:** repo_id (required), repo_type (model/dataset/space), revision (default: main)""", + "totalResults": 1, + "resultsShared": 1, + } + + async def _list(self, args: Dict[str, Any]) -> ToolResult: + """List files in a repository.""" + repo_id = args.get("repo_id") + if not repo_id: + return self._error("repo_id is required") + + repo_type = args.get("repo_type", "model") + revision = args.get("revision", "main") + path = args.get("path", "") + + items = list(await _async_call( + self.api.list_repo_tree, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + path_in_repo=path, + recursive=True, + )) + + if not items: + return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0} + + lines = [] + total_size = 0 + for item in sorted(items, key=lambda x: x.path): + if hasattr(item, "size") and item.size: + total_size += item.size + lines.append(f"{item.path} ({_format_size(item.size)})") + else: + lines.append(f"{item.path}/") + + url = _build_repo_url(repo_id, repo_type) + response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines) + + return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)} + + async def _read(self, args: Dict[str, Any]) -> ToolResult: + """Read file content from a repository.""" + repo_id = args.get("repo_id") + path = args.get("path") + + if not repo_id: + return self._error("repo_id is required") + if not path: + return self._error("path is required") + + repo_type = args.get("repo_type", "model") + revision = args.get("revision", "main") + max_chars = args.get("max_chars", 50000) + + file_path = await _async_call( + hf_hub_download, + repo_id=repo_id, + filename=path, + repo_type=repo_type, + revision=revision, + token=self.api.token, + ) + + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + truncated = len(content) > max_chars + if truncated: + content = content[:max_chars] + + url = f"{_build_repo_url(repo_id, repo_type)}/blob/{revision}/{path}" + response = f"**{path}**{' (truncated)' if truncated else ''}\n{url}\n\n```\n{content}\n```" + + return {"formatted": response, "totalResults": 1, "resultsShared": 1} + + except UnicodeDecodeError: + import os + size = os.path.getsize(file_path) + return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1} + + async def _upload(self, args: Dict[str, Any]) -> ToolResult: + """Upload content to a repository.""" + repo_id = args.get("repo_id") + path = args.get("path") + content = args.get("content") + + if not repo_id: + return self._error("repo_id is required") + if not path: + return self._error("path is required") + if content is None: + return self._error("content is required") + + repo_type = args.get("repo_type", "model") + revision = args.get("revision", "main") + create_pr = args.get("create_pr", False) + commit_message = args.get("commit_message", f"Upload {path}") + + file_bytes = content.encode("utf-8") if isinstance(content, str) else content + + result = await _async_call( + self.api.upload_file, + path_or_fileobj=file_bytes, + path_in_repo=path, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + commit_message=commit_message, + create_pr=create_pr, + ) + + url = _build_repo_url(repo_id, repo_type) + if create_pr and hasattr(result, "pr_url"): + response = f"**Uploaded as PR**\n{result.pr_url}" + else: + response = f"**Uploaded:** {path}\n{url}/blob/{revision}/{path}" + + return {"formatted": response, "totalResults": 1, "resultsShared": 1} + + async def _delete(self, args: Dict[str, Any]) -> ToolResult: + """Delete files from a repository.""" + repo_id = args.get("repo_id") + patterns = args.get("patterns") + + if not repo_id: + return self._error("repo_id is required") + if not patterns: + return self._error("patterns is required (list of paths/wildcards)") + + if isinstance(patterns, str): + patterns = [patterns] + + repo_type = args.get("repo_type", "model") + revision = args.get("revision", "main") + create_pr = args.get("create_pr", False) + commit_message = args.get("commit_message", f"Delete {', '.join(patterns)}") + + await _async_call( + self.api.delete_files, + repo_id=repo_id, + delete_patterns=patterns, + repo_type=repo_type, + revision=revision, + commit_message=commit_message, + create_pr=create_pr, + ) + + response = f"**Deleted:** {', '.join(patterns)} from {repo_id}" + return {"formatted": response, "totalResults": 1, "resultsShared": 1} + + def _error(self, message: str) -> ToolResult: + """Return an error result.""" + return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} + + +# Tool specification +HF_REPO_FILES_TOOL_SPEC = { + "name": "hf_repo_files", + "description": ( + "Read and write files in HF repos (models/datasets/spaces).\n\n" + "## Operations\n" + "- **list**: List files with sizes and structure\n" + "- **read**: Read file content (text files only)\n" + "- **upload**: Upload content to repo (can create PR)\n" + "- **delete**: Delete files/folders (supports wildcards like *.tmp)\n\n" + "## Use when\n" + "- Need to see what files exist in a repo\n" + "- Want to read config.json, README.md, or other text files\n" + "- Uploading training scripts, configs, or results to a repo\n" + "- Cleaning up temporary files from a repo\n\n" + "## Examples\n" + '{"operation": "list", "repo_id": "meta-llama/Llama-2-7b"}\n' + '{"operation": "read", "repo_id": "gpt2", "path": "config.json"}\n' + '{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "# My Model"}\n' + '{"operation": "upload", "repo_id": "org/model", "path": "fix.py", "content": "...", "create_pr": true}\n' + '{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp", "logs/"]}\n\n' + "## Notes\n" + "- For binary files (safetensors, bin), use list to see them but can't read content\n" + "- upload/delete require approval (can overwrite/destroy data)\n" + "- Use create_pr=true to propose changes instead of direct commit\n" + ), + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["list", "read", "upload", "delete"], + "description": "Operation: list, read, upload, delete", + }, + "repo_id": { + "type": "string", + "description": "Repository ID (e.g., 'username/repo-name')", + }, + "repo_type": { + "type": "string", + "enum": ["model", "dataset", "space"], + "description": "Repository type (default: model)", + }, + "revision": { + "type": "string", + "description": "Branch/tag/commit (default: main)", + }, + "path": { + "type": "string", + "description": "File path for read/upload", + }, + "content": { + "type": "string", + "description": "File content for upload", + }, + "patterns": { + "type": "array", + "items": {"type": "string"}, + "description": "Patterns to delete (e.g., ['*.tmp', 'logs/'])", + }, + "create_pr": { + "type": "boolean", + "description": "Create PR instead of direct commit", + }, + "commit_message": { + "type": "string", + "description": "Custom commit message", + }, + }, + "required": ["operation"], + }, +} + + +async def hf_repo_files_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]: + """Handler for agent tool router.""" + try: + hf_token = session.hf_token if session else None + tool = HfRepoFilesTool(hf_token=hf_token) + result = await tool.execute(arguments) + return result["formatted"], not result.get("isError", False) + except Exception as e: + return f"Error: {str(e)}", False diff --git a/plugin/lib/ml_intern_lib/tools/hf_repo_git_tool.py b/plugin/lib/ml_intern_lib/tools/hf_repo_git_tool.py new file mode 100644 index 00000000..04571887 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/hf_repo_git_tool.py @@ -0,0 +1,664 @@ +""" +HF Repo Git Tool - Git-like operations on Hugging Face repositories + +Operations: branches, tags, PRs, repo management +""" + +import asyncio +from typing import Any, Dict, Literal, Optional + +from huggingface_hub import HfApi +from huggingface_hub.utils import RepositoryNotFoundError + +from ml_intern_lib.tools.types import ToolResult + +OperationType = Literal[ + "create_branch", "delete_branch", + "create_tag", "delete_tag", + "list_refs", + "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", + "create_repo", "update_repo", +] + + +async def _async_call(func, *args, **kwargs): + """Wrap synchronous HfApi calls for async context.""" + return await asyncio.to_thread(func, *args, **kwargs) + + +def _build_repo_url(repo_id: str, repo_type: str = "model") -> str: + """Build the Hub URL for a repository.""" + if repo_type == "model": + return f"https://huggingface.co/{repo_id}" + return f"https://huggingface.co/{repo_type}s/{repo_id}" + + +class HfRepoGitTool: + """Tool for git-like operations on HF repos.""" + + def __init__(self, hf_token: Optional[str] = None): + self.api = HfApi(token=hf_token) + + async def execute(self, args: Dict[str, Any]) -> ToolResult: + """Execute the specified operation.""" + operation = args.get("operation") + + if not operation: + return self._help() + + try: + handlers = { + "create_branch": self._create_branch, + "delete_branch": self._delete_branch, + "create_tag": self._create_tag, + "delete_tag": self._delete_tag, + "list_refs": self._list_refs, + "create_pr": self._create_pr, + "list_prs": self._list_prs, + "get_pr": self._get_pr, + "merge_pr": self._merge_pr, + "close_pr": self._close_pr, + "comment_pr": self._comment_pr, + "change_pr_status": self._change_pr_status, + "create_repo": self._create_repo, + "update_repo": self._update_repo, + } + + handler = handlers.get(operation) + if handler: + return await handler(args) + else: + ops = ", ".join(handlers.keys()) + return self._error(f"Unknown operation: {operation}. Valid: {ops}") + + except RepositoryNotFoundError: + return self._error(f"Repository not found: {args.get('repo_id')}") + except Exception as e: + return self._error(f"Error: {str(e)}") + + def _help(self) -> ToolResult: + """Show usage instructions.""" + return { + "formatted": """**hf_repo_git** - Git-like operations on HF repos + +**Branch/Tag:** +- `create_branch`: `{"operation": "create_branch", "repo_id": "...", "branch": "dev"}` +- `delete_branch`: `{"operation": "delete_branch", "repo_id": "...", "branch": "dev"}` +- `create_tag`: `{"operation": "create_tag", "repo_id": "...", "tag": "v1.0"}` +- `delete_tag`: `{"operation": "delete_tag", "repo_id": "...", "tag": "v1.0"}` +- `list_refs`: `{"operation": "list_refs", "repo_id": "..."}` + +**PRs:** +- `create_pr`: `{"operation": "create_pr", "repo_id": "...", "title": "..."}` (creates draft PR) +- `list_prs`: `{"operation": "list_prs", "repo_id": "..."}` (shows status: draft/open/merged/closed) +- `get_pr`: `{"operation": "get_pr", "repo_id": "...", "pr_num": 1}` (shows status) +- `change_pr_status`: `{"operation": "change_pr_status", "repo_id": "...", "pr_num": 1, "new_status": "open"}` (change draft to open) +- `merge_pr`: `{"operation": "merge_pr", "repo_id": "...", "pr_num": 1}` +- `close_pr`: `{"operation": "close_pr", "repo_id": "...", "pr_num": 1}` +- `comment_pr`: `{"operation": "comment_pr", "repo_id": "...", "pr_num": 1, "comment": "..."}` + +**Repo:** +- `create_repo`: `{"operation": "create_repo", "repo_id": "my-model", "private": true}` +- `update_repo`: `{"operation": "update_repo", "repo_id": "...", "private": false}`""", + "totalResults": 1, + "resultsShared": 1, + } + + # ========================================================================= + # BRANCH OPERATIONS + # ========================================================================= + + async def _create_branch(self, args: Dict[str, Any]) -> ToolResult: + """Create a new branch.""" + repo_id = args.get("repo_id") + branch = args.get("branch") + + if not repo_id: + return self._error("repo_id is required") + if not branch: + return self._error("branch is required") + + repo_type = args.get("repo_type", "model") + from_rev = args.get("from_rev", "main") + + await _async_call( + self.api.create_branch, + repo_id=repo_id, + branch=branch, + revision=from_rev, + repo_type=repo_type, + exist_ok=args.get("exist_ok", False), + ) + + url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}" + return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1} + + async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult: + """Delete a branch.""" + repo_id = args.get("repo_id") + branch = args.get("branch") + + if not repo_id: + return self._error("repo_id is required") + if not branch: + return self._error("branch is required") + + repo_type = args.get("repo_type", "model") + + await _async_call( + self.api.delete_branch, + repo_id=repo_id, + branch=branch, + repo_type=repo_type, + ) + + return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1} + + # ========================================================================= + # TAG OPERATIONS + # ========================================================================= + + async def _create_tag(self, args: Dict[str, Any]) -> ToolResult: + """Create a tag.""" + repo_id = args.get("repo_id") + tag = args.get("tag") + + if not repo_id: + return self._error("repo_id is required") + if not tag: + return self._error("tag is required") + + repo_type = args.get("repo_type", "model") + revision = args.get("revision", "main") + tag_message = args.get("tag_message", "") + + await _async_call( + self.api.create_tag, + repo_id=repo_id, + tag=tag, + revision=revision, + tag_message=tag_message, + repo_type=repo_type, + exist_ok=args.get("exist_ok", False), + ) + + url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}" + return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1} + + async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult: + """Delete a tag.""" + repo_id = args.get("repo_id") + tag = args.get("tag") + + if not repo_id: + return self._error("repo_id is required") + if not tag: + return self._error("tag is required") + + repo_type = args.get("repo_type", "model") + + await _async_call( + self.api.delete_tag, + repo_id=repo_id, + tag=tag, + repo_type=repo_type, + ) + + return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1} + + # ========================================================================= + # LIST REFS + # ========================================================================= + + async def _list_refs(self, args: Dict[str, Any]) -> ToolResult: + """List branches and tags.""" + repo_id = args.get("repo_id") + + if not repo_id: + return self._error("repo_id is required") + + repo_type = args.get("repo_type", "model") + + refs = await _async_call( + self.api.list_repo_refs, + repo_id=repo_id, + repo_type=repo_type, + ) + + branches = [b.name for b in refs.branches] if refs.branches else [] + tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else [] + + url = _build_repo_url(repo_id, repo_type) + lines = [f"**{repo_id}**", url, ""] + + if branches: + lines.append(f"**Branches ({len(branches)}):** " + ", ".join(branches)) + else: + lines.append("**Branches:** none") + + if tags: + lines.append(f"**Tags ({len(tags)}):** " + ", ".join(tags)) + else: + lines.append("**Tags:** none") + + return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)} + + # ========================================================================= + # PR OPERATIONS + # ========================================================================= + + async def _create_pr(self, args: Dict[str, Any]) -> ToolResult: + """Create a pull request.""" + repo_id = args.get("repo_id") + title = args.get("title") + + if not repo_id: + return self._error("repo_id is required") + if not title: + return self._error("title is required") + + repo_type = args.get("repo_type", "model") + description = args.get("description", "") + + result = await _async_call( + self.api.create_pull_request, + repo_id=repo_id, + title=title, + description=description, + repo_type=repo_type, + ) + + url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}" + return { + "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"", + "totalResults": 1, + "resultsShared": 1, + } + + async def _list_prs(self, args: Dict[str, Any]) -> ToolResult: + """List PRs and discussions.""" + repo_id = args.get("repo_id") + + if not repo_id: + return self._error("repo_id is required") + + repo_type = args.get("repo_type", "model") + status = args.get("status", "all") # open, closed, all + + discussions = list(self.api.get_repo_discussions( + repo_id=repo_id, + repo_type=repo_type, + discussion_status=status if status != "all" else None, + )) + + if not discussions: + return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0} + + url = _build_repo_url(repo_id, repo_type) + lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""] + + for d in discussions[:20]: + if d.status == "draft": + status_label = "[DRAFT]" + elif d.status == "open": + status_label = "[OPEN]" + elif d.status == "merged": + status_label = "[MERGED]" + else: + status_label = "[CLOSED]" + type_label = "PR" if d.is_pull_request else "D" + lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}") + + return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))} + + async def _get_pr(self, args: Dict[str, Any]) -> ToolResult: + """Get PR details.""" + repo_id = args.get("repo_id") + pr_num = args.get("pr_num") + + if not repo_id: + return self._error("repo_id is required") + if not pr_num: + return self._error("pr_num is required") + + repo_type = args.get("repo_type", "model") + + pr = await _async_call( + self.api.get_discussion_details, + repo_id=repo_id, + discussion_num=int(pr_num), + repo_type=repo_type, + ) + + url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" + status_map = { + "draft": "Draft", + "open": "Open", + "merged": "Merged", + "closed": "Closed" + } + status = status_map.get(pr.status, pr.status.capitalize()) + type_label = "Pull Request" if pr.is_pull_request else "Discussion" + + lines = [ + f"**{type_label} #{pr_num}:** {pr.title}", + f"**Status:** {status}", + f"**Author:** {pr.author}", + url, + ] + + if pr.is_pull_request: + if pr.status == "draft": + lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") + elif pr.status == "open": + lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") + + return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1} + + async def _merge_pr(self, args: Dict[str, Any]) -> ToolResult: + """Merge a pull request.""" + repo_id = args.get("repo_id") + pr_num = args.get("pr_num") + + if not repo_id: + return self._error("repo_id is required") + if not pr_num: + return self._error("pr_num is required") + + repo_type = args.get("repo_type", "model") + comment = args.get("comment", "") + + await _async_call( + self.api.merge_pull_request, + repo_id=repo_id, + discussion_num=int(pr_num), + comment=comment, + repo_type=repo_type, + ) + + url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" + return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1} + + async def _close_pr(self, args: Dict[str, Any]) -> ToolResult: + """Close a PR/discussion.""" + repo_id = args.get("repo_id") + pr_num = args.get("pr_num") + + if not repo_id: + return self._error("repo_id is required") + if not pr_num: + return self._error("pr_num is required") + + repo_type = args.get("repo_type", "model") + comment = args.get("comment", "") + + await _async_call( + self.api.change_discussion_status, + repo_id=repo_id, + discussion_num=int(pr_num), + new_status="closed", + comment=comment, + repo_type=repo_type, + ) + + return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1} + + async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult: + """Add a comment to a PR/discussion.""" + repo_id = args.get("repo_id") + pr_num = args.get("pr_num") + comment = args.get("comment") + + if not repo_id: + return self._error("repo_id is required") + if not pr_num: + return self._error("pr_num is required") + if not comment: + return self._error("comment is required") + + repo_type = args.get("repo_type", "model") + + await _async_call( + self.api.comment_discussion, + repo_id=repo_id, + discussion_num=int(pr_num), + comment=comment, + repo_type=repo_type, + ) + + url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" + return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1} + + async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult: + """Change PR/discussion status (mainly to convert draft to open).""" + repo_id = args.get("repo_id") + pr_num = args.get("pr_num") + new_status = args.get("new_status") + + if not repo_id: + return self._error("repo_id is required") + if not pr_num: + return self._error("pr_num is required") + if not new_status: + return self._error("new_status is required (open or closed)") + + repo_type = args.get("repo_type", "model") + comment = args.get("comment", "") + + await _async_call( + self.api.change_discussion_status, + repo_id=repo_id, + discussion_num=int(pr_num), + new_status=new_status, + comment=comment, + repo_type=repo_type, + ) + + url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" + return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1} + + # ========================================================================= + # REPO MANAGEMENT + # ========================================================================= + + async def _create_repo(self, args: Dict[str, Any]) -> ToolResult: + """Create a new repository.""" + repo_id = args.get("repo_id") + + if not repo_id: + return self._error("repo_id is required") + + repo_type = args.get("repo_type", "model") + private = args.get("private", True) + space_sdk = args.get("space_sdk") + + if repo_type == "space" and not space_sdk: + return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)") + + kwargs = { + "repo_id": repo_id, + "repo_type": repo_type, + "private": private, + "exist_ok": args.get("exist_ok", False), + } + if space_sdk: + kwargs["space_sdk"] = space_sdk + + result = await _async_call(self.api.create_repo, **kwargs) + + return { + "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}", + "totalResults": 1, + "resultsShared": 1, + } + + async def _update_repo(self, args: Dict[str, Any]) -> ToolResult: + """Update repository settings.""" + repo_id = args.get("repo_id") + + if not repo_id: + return self._error("repo_id is required") + + repo_type = args.get("repo_type", "model") + private = args.get("private") + gated = args.get("gated") + + if private is None and gated is None: + return self._error("Specify private (bool) or gated ('auto'/'manual'/false)") + + kwargs = {"repo_id": repo_id, "repo_type": repo_type} + if private is not None: + kwargs["private"] = private + if gated is not None: + kwargs["gated"] = gated + + await _async_call(self.api.update_repo_settings, **kwargs) + + changes = [] + if private is not None: + changes.append(f"private={private}") + if gated is not None: + changes.append(f"gated={gated}") + + url = f"{_build_repo_url(repo_id, repo_type)}/settings" + return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1} + + def _error(self, message: str) -> ToolResult: + """Return an error result.""" + return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} + + +# Tool specification +HF_REPO_GIT_TOOL_SPEC = { + "name": "hf_repo_git", + "description": ( + "Git-like operations on HF repos: branches, tags, PRs, and repo management.\n\n" + "## Operations\n" + "**Branches:** create_branch, delete_branch, list_refs\n" + "**Tags:** create_tag, delete_tag\n" + "**PRs:** create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr, change_pr_status\n" + "**Repo:** create_repo, update_repo\n\n" + "## Use when\n" + "- Creating feature branches for experiments\n" + "- Tagging model versions (v1.0, v2.0)\n" + "- Opening PRs to contribute to repos you don't own\n" + "- Reviewing and merging PRs on your repos\n" + "- Creating new model/dataset/space repos\n" + "- Changing repo visibility (public/private) or gated access\n\n" + "## Examples\n" + '{"operation": "list_refs", "repo_id": "my-model"}\n' + '{"operation": "create_branch", "repo_id": "my-model", "branch": "experiment-v2"}\n' + '{"operation": "create_tag", "repo_id": "my-model", "tag": "v1.0", "revision": "main"}\n' + '{"operation": "create_pr", "repo_id": "org/model", "title": "Fix tokenizer config"}\n' + '{"operation": "change_pr_status", "repo_id": "my-model", "pr_num": 1, "new_status": "open"}\n' + '{"operation": "merge_pr", "repo_id": "my-model", "pr_num": 3}\n' + '{"operation": "create_repo", "repo_id": "my-new-model", "private": true}\n' + '{"operation": "update_repo", "repo_id": "my-model", "gated": "auto"}\n\n' + "## PR Workflow\n" + "1. create_pr → creates draft PR (empty by default)\n" + "2. Upload files with revision='refs/pr/N' to add commits\n" + "3. change_pr_status with new_status='open' to publish (convert draft to open)\n" + "4. merge_pr when ready\n\n" + "## Notes\n" + "- PR status: draft (default), open, merged, closed\n" + "- delete_branch, delete_tag, merge_pr, create_repo, update_repo require approval\n" + "- For spaces, create_repo needs space_sdk (gradio/streamlit/docker/static)\n" + "- gated options: 'auto' (instant), 'manual' (review), false (open)\n" + ), + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": [ + "create_branch", "delete_branch", + "create_tag", "delete_tag", "list_refs", + "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", + "create_repo", "update_repo", + ], + "description": "Operation to execute", + }, + "repo_id": { + "type": "string", + "description": "Repository ID (e.g., 'username/repo-name')", + }, + "repo_type": { + "type": "string", + "enum": ["model", "dataset", "space"], + "description": "Repository type (default: model)", + }, + "branch": { + "type": "string", + "description": "Branch name (create_branch, delete_branch)", + }, + "from_rev": { + "type": "string", + "description": "Create branch from this revision (default: main)", + }, + "tag": { + "type": "string", + "description": "Tag name (create_tag, delete_tag)", + }, + "revision": { + "type": "string", + "description": "Revision for tag (default: main)", + }, + "tag_message": { + "type": "string", + "description": "Tag description", + }, + "title": { + "type": "string", + "description": "PR title (create_pr)", + }, + "description": { + "type": "string", + "description": "PR description (create_pr)", + }, + "pr_num": { + "type": "integer", + "description": "PR/discussion number", + }, + "comment": { + "type": "string", + "description": "Comment text", + }, + "status": { + "type": "string", + "enum": ["open", "closed", "all"], + "description": "Filter PRs by status (list_prs)", + }, + "new_status": { + "type": "string", + "enum": ["open", "closed"], + "description": "New status for PR/discussion (change_pr_status)", + }, + "private": { + "type": "boolean", + "description": "Make repo private (create_repo, update_repo)", + }, + "gated": { + "type": "string", + "enum": ["auto", "manual", "false"], + "description": "Gated access setting (update_repo)", + }, + "space_sdk": { + "type": "string", + "enum": ["gradio", "streamlit", "docker", "static"], + "description": "Space SDK (required for create_repo with space)", + }, + }, + "required": ["operation"], + }, +} + + +async def hf_repo_git_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]: + """Handler for agent tool router.""" + try: + hf_token = session.hf_token if session else None + tool = HfRepoGitTool(hf_token=hf_token) + result = await tool.execute(arguments) + return result["formatted"], not result.get("isError", False) + except Exception as e: + return f"Error: {str(e)}", False diff --git a/plugin/lib/ml_intern_lib/tools/jobs_tool.py b/plugin/lib/ml_intern_lib/tools/jobs_tool.py new file mode 100644 index 00000000..f18faaf1 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/jobs_tool.py @@ -0,0 +1,1114 @@ +""" +Hugging Face Jobs Tool - Using huggingface-hub library + +Refactored to use official huggingface-hub library instead of custom HTTP client +""" + +import asyncio +import base64 +import http.client +import os +import re +from typing import Any, Dict, Literal, Optional, Callable, Awaitable + +import logging + +import httpx +from huggingface_hub import HfApi +from huggingface_hub.utils import HfHubHTTPError + +from ml_intern_lib.session_stub import Event +from ml_intern_lib.tools.types import ToolResult + +logger = logging.getLogger(__name__) +from ml_intern_lib.tools.utilities import ( + format_job_details, + format_jobs_table, + format_scheduled_job_details, + format_scheduled_jobs_table, +) + +# Hardware flavors +CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"] +GPU_FLAVORS = [ + "t4-small", + "t4-medium", + "a10g-small", + "a10g-large", + "a10g-largex2", + "a10g-largex4", + "a100-large", + "a100x4", + "a100x8", + "l4x1", + "l4x4", + "l40sx1", + "l40sx4", + "l40sx8", +] + +# Detailed specs for display (vCPU/RAM/GPU VRAM) +CPU_FLAVORS_DESC = "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB)" +GPU_FLAVORS_DESC = ( + "t4-small(4vCPU/15GB/GPU 16GB), t4-medium(8vCPU/30GB/GPU 16GB), " + "a10g-small(4vCPU/15GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), " + "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), " + "a100-large(12vCPU/142GB/GPU 80GB), a100x4(48vCPU/568GB/GPU 320GB), a100x8(96vCPU/1136GB/GPU 640GB), " + "l4x1(8vCPU/30GB/GPU 24GB), l4x4(48vCPU/186GB/GPU 96GB), " + "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB)" +) +SPECIALIZED_FLAVORS = ["inf2x6"] +ALL_FLAVORS = CPU_FLAVORS + GPU_FLAVORS + SPECIALIZED_FLAVORS + +# Operation names +OperationType = Literal[ + "run", + "ps", + "logs", + "inspect", + "cancel", + "scheduled run", + "scheduled ps", + "scheduled inspect", + "scheduled delete", + "scheduled suspend", + "scheduled resume", +] + +# Constants +UV_DEFAULT_IMAGE = "ghcr.io/astral-sh/uv:python3.12-bookworm" + + +def _filter_uv_install_output(logs: list[str]) -> list[str]: + """ + Filter out UV package installation output from logs. + + Replaces installation details with "[installs truncated]" and keeps + the "Installed X packages in Y ms/s" summary line. + + Args: + logs: List of log lines + + Returns: + Filtered list of log lines + """ + if not logs: + return logs + + # Regex pattern to match: "Installed X packages in Y ms" or "Installed X package in Y s" + install_pattern = re.compile( + r"^Installed\s+\d+\s+packages?\s+in\s+\d+(?:\.\d+)?\s*(?:ms|s)$" + ) + + # Find the index of the "Installed X packages" line + install_line_idx = None + for idx, line in enumerate(logs): + if install_pattern.match(line.strip()): + install_line_idx = idx + break + + # If pattern found, replace installation details with truncation message + if install_line_idx is not None and install_line_idx > 0: + # Keep logs from the "Installed X packages" line onward + # Add truncation message before the "Installed" line + return ["[installs truncated]"] + logs[install_line_idx:] + + # If pattern not found, return original logs + return logs + + +_ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07') + + +def _strip_ansi(text: str) -> str: + return _ANSI_RE.sub('', text) + + +_DEFAULT_ENV = { + "HF_HUB_DISABLE_PROGRESS_BARS": "1", + "TQDM_DISABLE": "1", + "TRANSFORMERS_VERBOSITY": "warning", + "HF_HUB_ENABLE_HF_TRANSFER": "1", + "UV_NO_PROGRESS": "1", +} + + +def _add_default_env(params: Dict[str, Any] | None) -> Dict[str, Any]: + """Inject default env vars for clean, agent-friendly output.""" + result = dict(_DEFAULT_ENV) + result.update(params or {}) # user-provided values override defaults + return result + + +def _add_environment_variables( + params: Dict[str, Any] | None, user_token: str | None = None +) -> Dict[str, Any]: + token = user_token or "" + + # Start with user-provided env vars, then force-set token last + result = dict(params or {}) + + # If the caller passed HF_TOKEN="$HF_TOKEN", ignore it. + if result.get("HF_TOKEN", "").strip().startswith("$"): + result.pop("HF_TOKEN", None) + + # Set both names to be safe (different libs check different vars) + if token: + result["HF_TOKEN"] = token + result["HUGGINGFACE_HUB_TOKEN"] = token + + return result + + +def _build_uv_command( + script: str, + with_deps: list[str] | None = None, + python: str | None = None, + script_args: list[str] | None = None, +) -> list[str]: + """Build UV run command""" + parts = ["uv", "run"] + + if with_deps: + for dep in with_deps: + parts.extend(["--with", dep]) + + if python: + parts.extend(["-p", python]) + + parts.append(script) + + if script_args: + parts.extend(script_args) + + # add defaults + # parts.extend(["--push_to_hub"]) + return parts + + +def _wrap_inline_script( + script: str, + with_deps: list[str] | None = None, + python: str | None = None, + script_args: list[str] | None = None, +) -> str: + """Wrap inline script with base64 encoding to avoid file creation""" + encoded = base64.b64encode(script.encode("utf-8")).decode("utf-8") + # Build the uv command with stdin (-) + uv_command = _build_uv_command("-", with_deps, python, script_args) + # Join command parts with proper spacing + uv_command_str = " ".join(uv_command) + return f'echo "{encoded}" | base64 -d | {uv_command_str}' + + +def _ensure_hf_transfer_dependency(deps: list[str] | None) -> list[str]: + """Ensure hf-transfer is included in the dependencies list""" + + if isinstance(deps, list): + deps_copy = deps.copy() # Don't modify the original + if "hf-transfer" not in deps_copy: + deps_copy.append("hf-transfer") + return deps_copy + + return ["hf-transfer"] + + +def _resolve_uv_command( + script: str, + with_deps: list[str] | None = None, + python: str | None = None, + script_args: list[str] | None = None, +) -> list[str]: + """Resolve UV command based on script source (URL, inline, or file path)""" + # If URL, use directly + if script.startswith("http://") or script.startswith("https://"): + return _build_uv_command(script, with_deps, python, script_args) + + # If contains newline, treat as inline script + if "\n" in script: + wrapped = _wrap_inline_script(script, with_deps, python, script_args) + return ["/bin/sh", "-lc", wrapped] + + # Otherwise, treat as file path + return _build_uv_command(script, with_deps, python, script_args) + + +async def _async_call(func, *args, **kwargs): + """Wrap synchronous HfApi calls for async context""" + return await asyncio.to_thread(func, *args, **kwargs) + + +def _job_info_to_dict(job_info) -> Dict[str, Any]: + """Convert JobInfo object to dictionary for formatting functions""" + return { + "id": job_info.id, + "status": {"stage": job_info.status.stage, "message": job_info.status.message}, + "command": job_info.command, + "createdAt": job_info.created_at.isoformat(), + "dockerImage": job_info.docker_image, + "spaceId": job_info.space_id, + "hardware_flavor": job_info.flavor, + "owner": {"name": job_info.owner.name}, + } + + +def _scheduled_job_info_to_dict(scheduled_job_info) -> Dict[str, Any]: + """Convert ScheduledJobInfo object to dictionary for formatting functions""" + job_spec = scheduled_job_info.job_spec + + # Extract last run and next run from status + last_run = None + next_run = None + if scheduled_job_info.status: + if scheduled_job_info.status.last_job: + last_run = scheduled_job_info.status.last_job.created_at + if last_run: + last_run = ( + last_run.isoformat() + if hasattr(last_run, "isoformat") + else str(last_run) + ) + if scheduled_job_info.status.next_job_run_at: + next_run = scheduled_job_info.status.next_job_run_at + next_run = ( + next_run.isoformat() + if hasattr(next_run, "isoformat") + else str(next_run) + ) + + return { + "id": scheduled_job_info.id, + "schedule": scheduled_job_info.schedule, + "suspend": scheduled_job_info.suspend, + "lastRun": last_run, + "nextRun": next_run, + "jobSpec": { + "dockerImage": job_spec.docker_image, + "spaceId": job_spec.space_id, + "command": job_spec.command or [], + "hardware_flavor": job_spec.flavor or "cpu-basic", + }, + } + + +class HfJobsTool: + """Tool for managing Hugging Face compute jobs using huggingface-hub library""" + + def __init__( + self, + hf_token: Optional[str] = None, + namespace: Optional[str] = None, + log_callback: Optional[Callable[[str], Awaitable[None]]] = None, + session: Any = None, + tool_call_id: Optional[str] = None, + ): + self.hf_token = hf_token + self.api = HfApi(token=hf_token) + self.namespace = namespace + self.log_callback = log_callback + self.session = session + self.tool_call_id = tool_call_id + + async def execute(self, params: Dict[str, Any]) -> ToolResult: + """Execute the specified operation""" + operation = params.get("operation") + + args = params + + # If no operation provided, return error + if not operation: + return { + "formatted": "Error: 'operation' parameter is required. See tool description for available operations and usage examples.", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + # Normalize operation name + operation = operation.lower() + + try: + # Route to appropriate handler + if operation == "run": + return await self._run_job(args) + elif operation == "ps": + return await self._list_jobs(args) + elif operation == "logs": + return await self._get_logs(args) + elif operation == "inspect": + return await self._inspect_job(args) + elif operation == "cancel": + return await self._cancel_job(args) + elif operation == "scheduled run": + return await self._scheduled_run(args) + elif operation == "scheduled ps": + return await self._list_scheduled_jobs(args) + elif operation == "scheduled inspect": + return await self._inspect_scheduled_job(args) + elif operation == "scheduled delete": + return await self._delete_scheduled_job(args) + elif operation == "scheduled suspend": + return await self._suspend_scheduled_job(args) + elif operation == "scheduled resume": + return await self._resume_scheduled_job(args) + else: + return { + "formatted": f'Unknown operation: "{operation}"\n\n' + "Available operations:\n" + "- run, ps, logs, inspect, cancel\n" + "- scheduled run, scheduled ps, scheduled inspect, " + "scheduled delete, scheduled suspend, scheduled resume\n\n" + "Call this tool with no operation for full usage instructions.", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + except HfHubHTTPError as e: + return { + "formatted": f"API Error: {str(e)}", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + except Exception as e: + return { + "formatted": f"Error executing {operation}: {str(e)}", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + async def _wait_for_job_completion( + self, job_id: str, namespace: Optional[str] = None + ) -> tuple[str, list[str]]: + """ + Stream job logs until completion, printing them in real-time. + Implements retry logic to handle connection drops during long-running jobs. + + Returns: + tuple: (final_status, all_logs) + """ + all_logs = [] + terminal_states = {"COMPLETED", "FAILED", "CANCELED", "ERROR"} + max_retries = 100 # Allow many retries for 8h+ jobs + retry_delay = 5 # Seconds between retries + + for _ in range(max_retries): + try: + # Use a queue to bridge sync generator to async consumer + queue = asyncio.Queue() + loop = asyncio.get_running_loop() + + def log_producer(): + try: + # fetch_job_logs is a blocking sync generator + logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace) + for line in logs_gen: + # Push line to queue thread-safely + loop.call_soon_threadsafe(queue.put_nowait, line) + # Signal EOF + loop.call_soon_threadsafe(queue.put_nowait, None) + except Exception as e: + # Signal error + loop.call_soon_threadsafe(queue.put_nowait, e) + + # Start producer in a background thread so it doesn't block the event loop + producer_future = loop.run_in_executor(None, log_producer) + + # Consume logs from the queue as they arrive + while True: + item = await queue.get() + + # EOF sentinel + if item is None: + break + + # Error occurred in producer + if isinstance(item, Exception): + raise item + + # Process log line + log_line = item + logger.debug(log_line) + if self.log_callback: + await self.log_callback(log_line) + all_logs.append(log_line) + + # If we get here, streaming completed normally (EOF received) + # Wait for thread to cleanup (should be done) + await producer_future + break + + except ( + ConnectionError, + TimeoutError, + OSError, + http.client.IncompleteRead, + httpx.RemoteProtocolError, + httpx.ReadError, + HfHubHTTPError, + ) as e: + # Connection dropped - check if job is still running + try: + job_info = await _async_call( + self.api.inspect_job, job_id=job_id, namespace=namespace + ) + current_status = job_info.status.stage + + if current_status in terminal_states: + # Job finished, no need to retry + logger.info(f"Job reached terminal state: {current_status}") + break + + # Job still running, retry connection + logger.warning( + f"Connection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..." + ) + await asyncio.sleep(retry_delay) + continue + + except (ConnectionError, TimeoutError, OSError): + # Can't even check job status, wait and retry + logger.warning(f"Connection error, retrying in {retry_delay}s...") + await asyncio.sleep(retry_delay) + continue + + # Fetch final job status — retry briefly if still RUNNING + # (the API may lag a few seconds behind the log stream ending) + final_status = "UNKNOWN" + for _ in range(6): + job_info = await _async_call( + self.api.inspect_job, job_id=job_id, namespace=namespace + ) + final_status = job_info.status.stage + if final_status in terminal_states: + break + await asyncio.sleep(2.5) + + return final_status, all_logs + + async def _run_job(self, args: Dict[str, Any]) -> ToolResult: + """Run a job using HfApi.run_job() - smart detection of Python vs Docker mode""" + try: + script = args.get("script") + command = args.get("command") + + # Validate mutually exclusive parameters + if script and command: + raise ValueError( + "'script' and 'command' are mutually exclusive. Provide one or the other, not both." + ) + + if not script and not command: + raise ValueError( + "Either 'script' (for Python) or 'command' (for Docker) must be provided." + ) + + # Python mode: script provided + if script: + # Get dependencies and ensure hf-transfer is included + deps = _ensure_hf_transfer_dependency(args.get("dependencies")) + + # Resolve the command based on script type (URL, inline, or file) + command = _resolve_uv_command( + script=script, + with_deps=deps, + python=args.get("python"), + script_args=args.get("script_args"), + ) + + # Use UV image unless overridden + image = args.get("image", UV_DEFAULT_IMAGE) + job_type = "Python" + + # Docker mode: command provided + else: + image = args.get("image", "python:3.12") + job_type = "Docker" + + # Run the job + flavor = args.get("hardware_flavor", "cpu-basic") + timeout_str = args.get("timeout", "30m") + job = await _async_call( + self.api.run_job, + image=image, + command=command, + env=_add_default_env(args.get("env")), + secrets=_add_environment_variables(args.get("secrets"), self.hf_token), + flavor=flavor, + timeout=timeout_str, + namespace=self.namespace, + ) + + # Track job ID for cancellation on interrupt + if self.session: + self.session._running_job_ids.add(job.id) + + # Send job URL immediately after job creation (before waiting for completion) + if self.session and self.tool_call_id: + await self.session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": self.tool_call_id, + "tool": "hf_jobs", + "state": "running", + "jobUrl": job.url, + }, + ) + ) + + # Telemetry: job submission + completion (infra consumption signal). + submit_ts = None + if self.session: + from ml_intern_lib import telemetry_stub as telemetry + submit_ts = await telemetry.record_hf_job_submit( + self.session, job, + {**args, "hardware_flavor": flavor, "timeout": timeout_str}, + image=image, job_type=job_type, + ) + + # Wait for completion and stream logs + logger.info(f"{job_type} job started: {job.url}") + logger.info("Streaming logs...") + + final_status, all_logs = await self._wait_for_job_completion( + job_id=job.id, + namespace=self.namespace, + ) + + if self.session and submit_ts is not None: + from ml_intern_lib import telemetry_stub as telemetry + await telemetry.record_hf_job_complete( + self.session, job, + flavor=flavor, final_status=final_status, submit_ts=submit_ts, + ) + + # Untrack job ID (completed or failed, no longer needs cancellation) + if self.session: + self.session._running_job_ids.discard(job.id) + + # Notify frontend of final status + if self.session and self.tool_call_id: + await self.session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": self.tool_call_id, + "tool": "hf_jobs", + "state": final_status.lower(), + "jobUrl": job.url, + }, + ) + ) + + # Filter out UV package installation output + filtered_logs = _filter_uv_install_output(all_logs) + + # Format all logs for the agent + log_text = _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)" + + response = f"""{job_type} job completed! + +**Job ID:** {job.id} +**Final Status:** {final_status} +**View at:** {job.url} + +**Logs:** +``` +{log_text} +```""" + return {"formatted": response, "totalResults": 1, "resultsShared": 1} + + except Exception as e: + raise Exception(f"Failed to run job: {str(e)}") + + async def _list_jobs(self, args: Dict[str, Any]) -> ToolResult: + """List jobs using HfApi.list_jobs()""" + jobs_list = await _async_call(self.api.list_jobs, namespace=self.namespace) + + # Filter jobs + if not args.get("all", False): + jobs_list = [j for j in jobs_list if j.status.stage == "RUNNING"] + + if args.get("status"): + status_filter = args["status"].upper() + jobs_list = [j for j in jobs_list if status_filter in j.status.stage] + + # Convert JobInfo objects to dicts for formatting + jobs_dicts = [_job_info_to_dict(j) for j in jobs_list] + + table = format_jobs_table(jobs_dicts) + + if len(jobs_list) == 0: + if args.get("all", False): + return { + "formatted": "No jobs found.", + "totalResults": 0, + "resultsShared": 0, + } + return { + "formatted": 'No running jobs found. Use `{"operation": "ps", "all": true}` to show all jobs.', + "totalResults": 0, + "resultsShared": 0, + } + + response = f"**Jobs ({len(jobs_list)} total):**\n\n{table}" + return { + "formatted": response, + "totalResults": len(jobs_list), + "resultsShared": len(jobs_list), + } + + async def _get_logs(self, args: Dict[str, Any]) -> ToolResult: + """Fetch logs using HfApi.fetch_job_logs()""" + job_id = args.get("job_id") + if not job_id: + return { + "formatted": "job_id is required", + "isError": True, + "totalResults": 0, + "resultsShared": 0, + } + + try: + # Fetch logs (returns generator, convert to list) + logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=self.namespace) + logs = await _async_call(list, logs_gen) + + if not logs: + return { + "formatted": f"No logs available for job {job_id}", + "totalResults": 0, + "resultsShared": 0, + } + + log_text = _strip_ansi("\n".join(logs)) + return { + "formatted": f"**Logs for {job_id}:**\n\n```\n{log_text}\n```", + "totalResults": 1, + "resultsShared": 1, + } + + except Exception as e: + return { + "formatted": f"Failed to fetch logs: {str(e)}", + "isError": True, + "totalResults": 0, + "resultsShared": 0, + } + + async def _inspect_job(self, args: Dict[str, Any]) -> ToolResult: + """Inspect job using HfApi.inspect_job()""" + job_id = args.get("job_id") + if not job_id: + return { + "formatted": "job_id is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + job_ids = job_id if isinstance(job_id, list) else [job_id] + + jobs = [] + for jid in job_ids: + try: + job = await _async_call( + self.api.inspect_job, + job_id=jid, + namespace=self.namespace, + ) + jobs.append(_job_info_to_dict(job)) + except Exception as e: + raise Exception(f"Failed to inspect job {jid}: {str(e)}") + + formatted_details = format_job_details(jobs) + response = f"**Job Details** ({len(jobs)} job{'s' if len(jobs) > 1 else ''}):\n\n{formatted_details}" + + return { + "formatted": response, + "totalResults": len(jobs), + "resultsShared": len(jobs), + } + + async def _cancel_job(self, args: Dict[str, Any]) -> ToolResult: + """Cancel job using HfApi.cancel_job()""" + job_id = args.get("job_id") + if not job_id: + return { + "formatted": "job_id is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + await _async_call( + self.api.cancel_job, + job_id=job_id, + namespace=self.namespace, + ) + + response = f"""✓ Job {job_id} has been cancelled. + +To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}}`""" + + return {"formatted": response, "totalResults": 1, "resultsShared": 1} + + async def _scheduled_run(self, args: Dict[str, Any]) -> ToolResult: + """Create scheduled job using HfApi.create_scheduled_job() - smart detection of Python vs Docker mode""" + try: + script = args.get("script") + command = args.get("command") + schedule = args.get("schedule") + + if not schedule: + raise ValueError("schedule is required for scheduled jobs") + + # Validate mutually exclusive parameters + if script and command: + raise ValueError( + "'script' and 'command' are mutually exclusive. Provide one or the other, not both." + ) + + if not script and not command: + raise ValueError( + "Either 'script' (for Python) or 'command' (for Docker) must be provided." + ) + + # Python mode: script provided + if script: + # Get dependencies and ensure hf-transfer is included + deps = _ensure_hf_transfer_dependency(args.get("dependencies")) + + # Resolve the command based on script type + command = _resolve_uv_command( + script=script, + with_deps=deps, + python=args.get("python"), + script_args=args.get("script_args"), + ) + + # Use UV image unless overridden + image = args.get("image", UV_DEFAULT_IMAGE) + job_type = "Python" + + # Docker mode: command provided + else: + image = args.get("image", "python:3.12") + job_type = "Docker" + + # Create scheduled job + scheduled_job = await _async_call( + self.api.create_scheduled_job, + image=image, + command=command, + schedule=schedule, + env=_add_default_env(args.get("env")), + secrets=_add_environment_variables(args.get("secrets"), self.hf_token), + flavor=args.get("hardware_flavor", "cpu-basic"), + timeout=args.get("timeout", "30m"), + namespace=self.namespace, + ) + + scheduled_dict = _scheduled_job_info_to_dict(scheduled_job) + + response = f"""✓ Scheduled {job_type} job created successfully! + +**Scheduled Job ID:** {scheduled_dict["id"]} +**Schedule:** {scheduled_dict["schedule"]} +**Suspended:** {"Yes" if scheduled_dict.get("suspend") else "No"} +**Next Run:** {scheduled_dict.get("nextRun", "N/A")} + +To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_job_id": "{scheduled_dict["id"]}"}}` +To list all, call this tool with `{{"operation": "scheduled ps"}}`""" + + return {"formatted": response, "totalResults": 1, "resultsShared": 1} + + except Exception as e: + raise Exception(f"Failed to create scheduled job: {str(e)}") + + async def _list_scheduled_jobs(self, args: Dict[str, Any]) -> ToolResult: + """List scheduled jobs using HfApi.list_scheduled_jobs()""" + scheduled_jobs_list = await _async_call( + self.api.list_scheduled_jobs, + namespace=self.namespace, + ) + + # Filter jobs - default: hide suspended jobs unless --all is specified + if not args.get("all", False): + scheduled_jobs_list = [j for j in scheduled_jobs_list if not j.suspend] + + # Convert to dicts for formatting + scheduled_dicts = [_scheduled_job_info_to_dict(j) for j in scheduled_jobs_list] + + table = format_scheduled_jobs_table(scheduled_dicts) + + if len(scheduled_jobs_list) == 0: + if args.get("all", False): + return { + "formatted": "No scheduled jobs found.", + "totalResults": 0, + "resultsShared": 0, + } + return { + "formatted": 'No active scheduled jobs found. Use `{"operation": "scheduled ps", "all": true}` to show suspended jobs.', + "totalResults": 0, + "resultsShared": 0, + } + + response = f"**Scheduled Jobs ({len(scheduled_jobs_list)} total):**\n\n{table}" + return { + "formatted": response, + "totalResults": len(scheduled_jobs_list), + "resultsShared": len(scheduled_jobs_list), + } + + async def _inspect_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: + """Inspect scheduled job using HfApi.inspect_scheduled_job()""" + scheduled_job_id = args.get("scheduled_job_id") + if not scheduled_job_id: + return { + "formatted": "scheduled_job_id is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + scheduled_job = await _async_call( + self.api.inspect_scheduled_job, + scheduled_job_id=scheduled_job_id, + namespace=self.namespace, + ) + + scheduled_dict = _scheduled_job_info_to_dict(scheduled_job) + formatted_details = format_scheduled_job_details(scheduled_dict) + + return { + "formatted": f"**Scheduled Job Details:**\n\n{formatted_details}", + "totalResults": 1, + "resultsShared": 1, + } + + async def _delete_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: + """Delete scheduled job using HfApi.delete_scheduled_job()""" + scheduled_job_id = args.get("scheduled_job_id") + if not scheduled_job_id: + return { + "formatted": "scheduled_job_id is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + await _async_call( + self.api.delete_scheduled_job, + scheduled_job_id=scheduled_job_id, + namespace=self.namespace, + ) + + return { + "formatted": f"✓ Scheduled job {scheduled_job_id} has been deleted.", + "totalResults": 1, + "resultsShared": 1, + } + + async def _suspend_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: + """Suspend scheduled job using HfApi.suspend_scheduled_job()""" + scheduled_job_id = args.get("scheduled_job_id") + if not scheduled_job_id: + return { + "formatted": "scheduled_job_id is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + await _async_call( + self.api.suspend_scheduled_job, + scheduled_job_id=scheduled_job_id, + namespace=self.namespace, + ) + + response = f"""✓ Scheduled job {scheduled_job_id} has been suspended. + +To resume, call this tool with `{{"operation": "scheduled resume", "scheduled_job_id": "{scheduled_job_id}"}}`""" + + return {"formatted": response, "totalResults": 1, "resultsShared": 1} + + async def _resume_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: + """Resume scheduled job using HfApi.resume_scheduled_job()""" + scheduled_job_id = args.get("scheduled_job_id") + if not scheduled_job_id: + return { + "formatted": "scheduled_job_id is required", + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + await _async_call( + self.api.resume_scheduled_job, + scheduled_job_id=scheduled_job_id, + namespace=self.namespace, + ) + + response = f"""✓ Scheduled job {scheduled_job_id} has been resumed. + +To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_job_id": "{scheduled_job_id}"}}`""" + + return {"formatted": response, "totalResults": 1, "resultsShared": 1} + + +# Tool specification for agent registration +HF_JOBS_TOOL_SPEC = { + "name": "hf_jobs", + "description": ( + "Execute Python scripts or Docker containers on HF cloud infrastructure.\n\n" + "Two modes (mutually exclusive): Python mode (script + dependencies) or Docker mode (command + image). " + "Provide exactly ONE of 'script' or 'command'.\n\n" + "BEFORE submitting training/fine-tuning jobs:\n" + "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. " + "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n" + "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n" + "- Training config MUST include push_to_hub=True and hub_model_id. " + "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n" + "- Include trackio monitoring and provide the dashboard URL to the user.\n\n" + "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. " + "Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n" + "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n" + f"Hardware: CPU: {CPU_FLAVORS_DESC}. GPU: {GPU_FLAVORS_DESC}.\n" + "Common picks: t4-small ($0.60/hr, 1-3B), a10g-large ($2/hr, 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+). " + "Note: a10g-small and a10g-large have the SAME 24GB GPU — the difference is CPU/RAM only.\n\n" + "OOM RECOVERY: When a training job fails with CUDA OOM:\n" + "1. Reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally (keep effective batch size identical)\n" + "2. Enable gradient_checkpointing=True\n" + "3. Upgrade to larger GPU (a10g→a100→h100)\n" + "Do NOT switch training methods (e.g. full SFT to LoRA) or reduce max_length — those change what the user gets and require explicit approval.\n\n" + "Examples:\n" + "Training: {'operation': 'run', 'script': '/app/train.py', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a100-large', 'timeout': '8h'}\n" + "Monitor: {'operation': 'ps'}, {'operation': 'logs', 'job_id': 'xxx'}, {'operation': 'cancel', 'job_id': 'xxx'}" + "Docker: {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2'], 'image': 'duckdb/duckdb', 'hardware_flavor': 'cpu-basic', 'timeout': '1h'}\n" + ), + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": [ + "run", + "ps", + "logs", + "inspect", + "cancel", + "scheduled run", + "scheduled ps", + "scheduled inspect", + "scheduled delete", + "scheduled suspend", + "scheduled resume", + ], + "description": "Operation to execute.", + }, + "script": { + "type": "string", + "description": ( + "Python code or sandbox file path (e.g. '/app/train.py') or URL. " + "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. " + "Mutually exclusive with 'command'." + ), + }, + "dependencies": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "Pip packages to install. Include ALL required packages. " + "Common training set: ['transformers', 'trl', 'torch', 'datasets', 'trackio', 'accelerate']. " + "Only used with 'script'." + ), + }, + "image": { + "type": "string", + "description": "Docker image. Optional — auto-selected if not provided. Use with 'command'.", + }, + "command": { + "type": "array", + "items": {"type": "string"}, + "description": "Command to execute as list. Triggers Docker mode. Mutually exclusive with 'script'.", + }, + "hardware_flavor": { + "type": "string", + "description": ( + "Hardware type. Sizing guide: 1-3B params → t4-small/a10g-small, " + "7-13B → a10g-large, 30B+ → a100-large, 70B+ → h100/h100x8. " + f"All options: CPU: {CPU_FLAVORS}. GPU: {GPU_FLAVORS}." + ), + }, + "timeout": { + "type": "string", + "description": ( + "Maximum job runtime. MUST be >2h for any training job — default 30m kills training mid-run. " + "Guidelines: 1-3B models: 3-4h, 7-13B: 6-8h, 30B+: 12-24h. " + "Use 30m-1h only for quick data processing or inference tasks. Default: '30m'." + ), + }, + "env": { + "type": "object", + "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.", + }, + "job_id": { + "type": "string", + "description": "Job ID. Required for: logs, inspect, cancel.", + }, + "scheduled_job_id": { + "type": "string", + "description": "Scheduled job ID. Required for: scheduled inspect/delete/suspend/resume.", + }, + "schedule": { + "type": "string", + "description": "Cron schedule or preset (@hourly, @daily, @weekly, @monthly). Required for: scheduled run.", + }, + }, + "required": ["operation"], + }, +} + + +async def hf_jobs_handler( + arguments: Dict[str, Any], session: Any = None, tool_call_id: str | None = None +) -> tuple[str, bool]: + """Handler for agent tool router""" + try: + + async def log_callback(log: str): + if session: + await session.send_event( + Event(event_type="tool_log", data={"tool": "hf_jobs", "log": log}) + ) + + # If script is a sandbox file path, read it from the sandbox + script = arguments.get("script", "") + sandbox = getattr(session, "sandbox", None) if session else None + if sandbox and script: + from ml_intern_lib.tools.sandbox_tool import resolve_sandbox_script + content, error = await resolve_sandbox_script(sandbox, script) + if error: + return error, False + if content: + arguments = {**arguments, "script": content} + + hf_token = session.hf_token if session else None + namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None) + + tool = HfJobsTool( + namespace=namespace, + hf_token=hf_token, + log_callback=log_callback if session else None, + session=session, + tool_call_id=tool_call_id, + ) + result = await tool.execute(arguments) + return result["formatted"], not result.get("isError", False) + except Exception as e: + return f"Error executing HF Jobs tool: {str(e)}", False diff --git a/plugin/lib/ml_intern_lib/tools/local_tools.py b/plugin/lib/ml_intern_lib/tools/local_tools.py new file mode 100644 index 00000000..b053b4b2 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/local_tools.py @@ -0,0 +1,426 @@ +""" +Local tool implementations — bash/read/write/edit running on the user's machine. + +Drop-in replacement for sandbox tools when running in CLI (local) mode. +Same tool specs (names, parameters) but handlers execute locally via +subprocess/pathlib instead of going through a remote sandbox. +""" + +from __future__ import annotations + +import os +import re +import subprocess +import tempfile +from pathlib import Path +from typing import Any + + +MAX_OUTPUT_CHARS = 25_000 +MAX_LINE_LENGTH = 4000 +DEFAULT_READ_LINES = 2000 +DEFAULT_TIMEOUT = 120 +MAX_TIMEOUT = 36000 # 10 hours — needed for long training runs (e.g. PostTrainBench) + +_ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07') + +# Track files that have been read this session (enforces read-before-write/edit) +_files_read: set[str] = set() + + +def _resolve_path(path: str) -> str: + try: + return str(Path(path).resolve()) + except Exception: + return path + + +def _atomic_write(path: Path, content: str) -> None: + """Write file atomically via temp file + os.replace(). + + Ensures the file is never left in a partial/corrupted state — it's either + the old content or the new content, never half-written. + """ + path.parent.mkdir(parents=True, exist_ok=True) + fd = None + tmp_path = None + try: + fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp") + os.write(fd, content.encode("utf-8")) + os.fsync(fd) + os.close(fd) + fd = None + os.replace(tmp_path, str(path)) + tmp_path = None # successfully replaced, nothing to clean up + finally: + if fd is not None: + os.close(fd) + if tmp_path is not None: + try: + os.unlink(tmp_path) + except OSError: + pass + + +def _strip_ansi(text: str) -> str: + return _ANSI_RE.sub('', text) + + +def _truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25) -> str: + """Tail-biased truncation with temp file spillover for full output access.""" + if len(output) <= max_chars: + return output + # Write full output to temp file so LLM can read specific sections + spill_path = None + try: + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', prefix='bash_output_', delete=False) as f: + f.write(output) + spill_path = f.name + except Exception: + pass + head_budget = int(max_chars * head_ratio) + tail_budget = max_chars - head_budget + head = output[:head_budget] + tail = output[-tail_budget:] + total = len(output) + omitted = total - max_chars + meta = f"\n\n... ({omitted:,} of {total:,} chars omitted, showing first {head_budget:,} + last {tail_budget:,}) ...\n" + if spill_path: + meta += f"Full output saved to {spill_path} — use the read tool with offset/limit to inspect specific sections.\n" + meta += "IMPORTANT: The command has finished. Analyze the output above and continue with your next action.\n" + return head + meta + tail + + +# ── Handlers ──────────────────────────────────────────────────────────── + +async def _bash_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: + command = args.get("command", "") + if not command: + return "No command provided.", False + work_dir = args.get("work_dir", ".") + timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT) + try: + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + cwd=work_dir, + timeout=timeout, + ) + output = _strip_ansi(result.stdout + result.stderr) + output = _truncate_output(output) + if not output.strip(): + output = "(no output)" + return output, result.returncode == 0 + except subprocess.TimeoutExpired: + return ( + f"Command timed out after {timeout}s and was killed.\n\n" + f"For long-running commands, run in the background and poll:\n" + f" nohup > /tmp/output.log 2>&1 & echo $!\n" + f"Then check status with:\n" + f" kill -0 2>/dev/null && echo 'running' || echo 'done'\n" + f" tail -n 50 /tmp/output.log" + ), False + except Exception as e: + return f"bash error: {e}", False + + +async def _read_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: + file_path = args.get("path", "") + if not file_path: + return "No path provided.", False + p = Path(file_path) + if not p.exists(): + return f"File not found: {file_path}", False + if p.is_dir(): + return "Cannot read a directory. Use bash with 'ls' instead.", False + try: + raw_content = p.read_text() + except Exception as e: + return f"read error: {e}", False + + _files_read.add(_resolve_path(file_path)) + + lines = raw_content.splitlines() + offset = max((args.get("offset") or 1), 1) + limit = args.get("limit") or DEFAULT_READ_LINES + + selected = lines[offset - 1 : offset - 1 + limit] + numbered = [] + for i, line in enumerate(selected, start=offset): + if len(line) > MAX_LINE_LENGTH: + line = line[:MAX_LINE_LENGTH] + "..." + numbered.append(f"{i:>6}\t{line}") + + return "\n".join(numbered), True + + +async def _write_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: + file_path = args.get("path", "") + content = args.get("content", "") + if not file_path: + return "No path provided.", False + p = Path(file_path) + if p.exists() and _resolve_path(file_path) not in _files_read: + return ( + f"You must read {file_path} before overwriting it. " + f"Use the read tool first to see current contents." + ), False + try: + _atomic_write(p, content) + _files_read.add(_resolve_path(file_path)) + msg = f"Wrote {len(content)} bytes to {file_path}" + # Syntax validation for Python files + if p.suffix == ".py": + from ml_intern_lib.tools.edit_utils import validate_python + warnings = validate_python(content, file_path) + if warnings: + msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings) + return msg, True + except Exception as e: + return f"write error: {e}", False + + +async def _edit_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: + from ml_intern_lib.tools.edit_utils import apply_edit, validate_python + + file_path = args.get("path", "") + old_str = args.get("old_str", "") + new_str = args.get("new_str", "") + replace_all = args.get("replace_all", False) + mode = args.get("mode", "replace") + + if not file_path: + return "No path provided.", False + if old_str == new_str: + return "old_str and new_str must differ.", False + + p = Path(file_path) + if not p.exists(): + return f"File not found: {file_path}", False + if _resolve_path(file_path) not in _files_read: + return ( + f"You must read {file_path} before editing it. " + f"Use the read tool first to see current contents." + ), False + + try: + text = p.read_text() + except Exception as e: + return f"edit read error: {e}", False + + try: + new_text, replacements, fuzzy_note = apply_edit( + text, old_str, new_str, mode=mode, replace_all=replace_all + ) + except ValueError as e: + return str(e), False + + try: + _atomic_write(p, new_text) + except Exception as e: + return f"edit write error: {e}", False + + msg = f"Edited {file_path} ({replacements} replacement{'s' if replacements > 1 else ''})" + if fuzzy_note: + msg += f" {fuzzy_note}" + # Syntax validation for Python files + if p.suffix == ".py": + warnings = validate_python(new_text, file_path) + if warnings: + msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings) + return msg, True + + +# ── Local tool specs (override sandbox /app references) ──────────────── + +_LOCAL_TOOL_SPECS = { + "bash": { + "description": ( + "Run a shell command on the local machine and return stdout/stderr.\n" + "\n" + "IMPORTANT: Do NOT use bash for file operations — use the dedicated tools instead:\n" + "- To read files: use read (not cat/head/tail)\n" + "- To edit files: use edit (not sed/awk)\n" + "- To write files: use write (not echo/cat < > /tmp/output.log 2>&1 & echo $!\n" + "Then check status:\n" + " kill -0 2>/dev/null && echo 'running' || echo 'done'\n" + " tail -n 50 /tmp/output.log\n" + "\n" + "Timeout default 120s, max 36000s." + ), + "parameters": { + "type": "object", + "required": ["command"], + "additionalProperties": False, + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute.", + }, + "description": { + "type": "string", + "description": "Short description (5-10 words, active voice).", + }, + "work_dir": { + "type": "string", + "description": "Working directory (default: current directory).", + }, + "timeout": { + "type": "integer", + "description": "Optional timeout in seconds (default: 120, max: 36000).", + }, + }, + }, + }, + "read": { + "description": ( + "Reads a file from the local filesystem. Returns contents with line numbers " + "(cat -n format).\n" + "\n" + "Usage:\n" + "- By default, reads up to 2000 lines from the beginning of the file.\n" + "- You can optionally specify offset and limit for large files, but prefer " + "reading the whole file first.\n" + "- Lines longer than 4000 chars are truncated.\n" + "- Cannot read directories — use bash with 'ls' instead.\n" + "- You should read multiple potentially useful files in parallel when possible.\n" + "- IMPORTANT: Always read a file before editing or overwriting it. The edit and " + "write tools will reject operations on files you haven't read." + ), + "parameters": { + "type": "object", + "required": ["path"], + "additionalProperties": False, + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to read.", + }, + "offset": { + "type": "integer", + "description": "The line number to start reading from (1-based). Only provide if the file is too large to read at once.", + }, + "limit": { + "type": "integer", + "description": "The number of lines to read. Only provide if the file is too large to read at once.", + }, + }, + }, + }, + "write": { + "description": ( + "Writes a file to the local filesystem. Overwrites the existing file if one " + "exists at the path.\n" + "\n" + "- If this is an existing file, you MUST use the read tool first. This tool " + "will fail if you did not read the file first.\n" + "- ALWAYS prefer editing existing files with the edit tool over overwriting " + "with write.\n" + "- Creates parent directories as needed." + ), + "parameters": { + "type": "object", + "required": ["path", "content"], + "additionalProperties": False, + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to write.", + }, + "content": { + "type": "string", + "description": "The complete file content to write.", + }, + }, + }, + }, + "edit": { + "description": ( + "Performs string replacements in files. Supports exact matching with " + "fuzzy fallback.\n" + "\n" + "Usage:\n" + "- You must read the file at least once before editing. This tool will " + "error if you attempt an edit without reading the file.\n" + "- The edit will FAIL if old_str is not unique in the file. Either provide " + "a larger string with more surrounding context to make it unique, or set " + "replace_all to true.\n" + "- old_str and new_str must differ.\n" + "- Preserve indentation exactly as it appears in the file.\n" + "- Do NOT include line number prefixes from read output in old_str or new_str.\n" + "- To delete code, set new_str to empty string.\n" + "- Use replace_all for renaming variables or strings across the file.\n" + "\n" + "Modes:\n" + "- replace (default): replace first occurrence of old_str with new_str.\n" + "- append_after: insert new_str immediately after old_str (old_str is kept).\n" + "- prepend_before: insert new_str immediately before old_str (old_str is kept)." + ), + "parameters": { + "type": "object", + "required": ["path", "old_str", "new_str"], + "additionalProperties": False, + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to edit.", + }, + "old_str": { + "type": "string", + "description": "The text to find in the file. Must match exactly (fuzzy matching is used as fallback).", + }, + "new_str": { + "type": "string", + "description": "The replacement text. For append_after/prepend_before modes, the text to insert.", + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences of old_str (default: false).", + "default": False, + }, + "mode": { + "type": "string", + "enum": ["replace", "append_after", "prepend_before"], + "description": "Edit mode (default: replace).", + "default": "replace", + }, + }, + }, + }, +} + +_HANDLERS = { + "bash": _bash_handler, + "read": _read_handler, + "write": _write_handler, + "edit": _edit_handler, +} + + +def get_local_tools(): + """Return local ToolSpecs for bash/read/write/edit (no sandbox_create).""" + from ml_intern_lib.tool_spec import ToolSpec + + tools = [] + for name, spec in _LOCAL_TOOL_SPECS.items(): + handler = _HANDLERS.get(name) + if handler is None: + continue + tools.append( + ToolSpec( + name=name, + description=spec["description"], + parameters=spec["parameters"], + handler=handler, + ) + ) + return tools diff --git a/plugin/lib/ml_intern_lib/tools/papers_tool.py b/plugin/lib/ml_intern_lib/tools/papers_tool.py new file mode 100644 index 00000000..33f9c249 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/papers_tool.py @@ -0,0 +1,1295 @@ +""" +HF Papers Tool — Discover papers, read their contents, and find linked resources. + +Operations: trending, search, paper_details, read_paper, + find_datasets, find_models, find_collections, find_all_resources, + citation_graph, snippet_search, recommend +""" + +import asyncio +import os +import re +import time +from typing import Any + +import httpx +from bs4 import BeautifulSoup, Tag + +from ml_intern_lib.tools.types import ToolResult + +HF_API = "https://huggingface.co/api" +ARXIV_HTML = "https://arxiv.org/html" +AR5IV_HTML = "https://ar5iv.labs.arxiv.org/html" + +DEFAULT_LIMIT = 10 +MAX_LIMIT = 50 +MAX_SUMMARY_LEN = 300 +MAX_SECTION_PREVIEW_LEN = 280 +MAX_SECTION_TEXT_LEN = 8000 + +SORT_MAP = { + "downloads": "downloads", + "likes": "likes", + "trending": "trendingScore", +} + +# --------------------------------------------------------------------------- +# Semantic Scholar API +# --------------------------------------------------------------------------- + +S2_API = "https://api.semanticscholar.org" +S2_API_KEY = os.environ.get("S2_API_KEY") +S2_HEADERS: dict[str, str] = {"x-api-key": S2_API_KEY} if S2_API_KEY else {} +S2_TIMEOUT = 12 +_s2_last_request: float = 0.0 + +# Shared response cache (survives across sessions, keyed by (path, params_tuple)) +_s2_cache: dict[str, Any] = {} +_S2_CACHE_MAX = 500 + + +def _s2_paper_id(arxiv_id: str) -> str: + """Convert bare arxiv ID to S2 format.""" + return f"ARXIV:{arxiv_id}" + + +def _s2_cache_key(path: str, params: dict | None) -> str: + """Build a hashable cache key from path + sorted params.""" + p = tuple(sorted((params or {}).items())) + return f"{path}:{p}" + + +async def _s2_request( + client: httpx.AsyncClient, + method: str, + path: str, + **kwargs: Any, +) -> httpx.Response | None: + """S2 request with 2 retries on 429/5xx. Rate-limited only when using API key.""" + global _s2_last_request + url = f"{S2_API}{path}" + kwargs.setdefault("headers", {}).update(S2_HEADERS) + kwargs.setdefault("timeout", S2_TIMEOUT) + + for attempt in range(3): + # Rate limit only when authenticated (1 req/s for search, 10 req/s for others) + if S2_API_KEY: + min_interval = 1.0 if "search" in path else 0.1 + elapsed = time.monotonic() - _s2_last_request + if elapsed < min_interval: + await asyncio.sleep(min_interval - elapsed) + _s2_last_request = time.monotonic() + + try: + resp = await client.request(method, url, **kwargs) + if resp.status_code == 429: + if attempt < 2: + await asyncio.sleep(60) + continue + return None + if resp.status_code >= 500: + if attempt < 2: + await asyncio.sleep(3) + continue + return None + return resp + except (httpx.RequestError, httpx.HTTPStatusError): + if attempt < 2: + await asyncio.sleep(3) + continue + return None + return None + + +async def _s2_get_json( + client: httpx.AsyncClient, path: str, params: dict | None = None, +) -> dict | None: + """Cached S2 GET returning parsed JSON or None.""" + key = _s2_cache_key(path, params) + if key in _s2_cache: + return _s2_cache[key] + + resp = await _s2_request(client, "GET", path, params=params or {}) + if resp and resp.status_code == 200: + data = resp.json() + if len(_s2_cache) < _S2_CACHE_MAX: + _s2_cache[key] = data + return data + return None + + +async def _s2_get_paper( + client: httpx.AsyncClient, arxiv_id: str, fields: str, +) -> dict | None: + """Fetch a single paper from S2 by arxiv ID. Returns None on failure.""" + return await _s2_get_json( + client, + f"/graph/v1/paper/{_s2_paper_id(arxiv_id)}", + {"fields": fields}, + ) + + +# --------------------------------------------------------------------------- +# HTML paper parsing +# --------------------------------------------------------------------------- + + +def _parse_paper_html(html: str) -> dict[str, Any]: + """Parse arxiv HTML into structured sections. + + Returns: + { + "title": str, + "abstract": str, + "sections": [{"id": str, "title": str, "level": int, "text": str}], + } + """ + soup = BeautifulSoup(html, "html.parser") + + # Title + title_el = soup.find("h1", class_="ltx_title") + title = title_el.get_text(strip=True).removeprefix("Title:") if title_el else "" + + # Abstract + abstract_el = soup.find("div", class_="ltx_abstract") + abstract = "" + if abstract_el: + # Skip the "Abstract" heading itself + for child in abstract_el.children: + if isinstance(child, Tag) and child.name in ("h6", "h2", "h3", "p", "span"): + if child.get_text(strip=True).lower() == "abstract": + continue + if isinstance(child, Tag) and child.name == "p": + abstract += child.get_text(separator=" ", strip=True) + " " + abstract = abstract.strip() + + # Sections — collect h2/h3 headings and text between them + sections: list[dict[str, Any]] = [] + headings = soup.find_all(["h2", "h3"], class_=lambda c: c and "ltx_title" in c) + + for heading in headings: + level = 2 if heading.name == "h2" else 3 + heading_text = heading.get_text(separator=" ", strip=True) + + # Collect text from siblings until next heading of same or higher level + text_parts: list[str] = [] + sibling = heading.find_next_sibling() + while sibling: + if isinstance(sibling, Tag): + if sibling.name in ("h2", "h3") and "ltx_title" in ( + sibling.get("class") or [] + ): + break + # Also stop at h2 if we're collecting h3 content + if sibling.name == "h2" and level == 3: + break + text_parts.append(sibling.get_text(separator=" ", strip=True)) + sibling = sibling.find_next_sibling() + + # Also check parent section element for contained paragraphs + parent_section = heading.find_parent("section") + if parent_section and not text_parts: + for p in parent_section.find_all("p", recursive=False): + text_parts.append(p.get_text(separator=" ", strip=True)) + + section_text = "\n\n".join(t for t in text_parts if t) + + # Extract section number from heading text (e.g., "4 Experiments" → "4") + num_match = re.match(r"^([A-Z]?\d+(?:\.\d+)*)\s", heading_text) + section_id = num_match.group(1) if num_match else "" + + sections.append( + { + "id": section_id, + "title": heading_text, + "level": level, + "text": section_text, + } + ) + + return {"title": title, "abstract": abstract, "sections": sections} + + +def _find_section(sections: list[dict], query: str) -> dict | None: + """Find a section by number or name (fuzzy).""" + query_lower = query.lower().strip() + + # Exact match on section number + for s in sections: + if s["id"] == query_lower or s["id"] == query: + return s + + # Exact match on title + for s in sections: + if query_lower == s["title"].lower(): + return s + + # Substring match on title + for s in sections: + if query_lower in s["title"].lower(): + return s + + # Number prefix match (e.g., "4" matches "4.1", "4.2", etc. — return parent) + for s in sections: + if s["id"].startswith(query_lower + ".") or s["id"] == query_lower: + return s + + return None + + +# --------------------------------------------------------------------------- +# Formatting helpers +# --------------------------------------------------------------------------- + + +def _clean_description(text: str) -> str: + """Strip HTML card artifacts and collapse whitespace from HF API descriptions.""" + text = re.sub(r"[\t]+", " ", text) + text = re.sub(r"\n{2,}", "\n", text) + return text.strip() + + +def _truncate(text: str, max_len: int) -> str: + if len(text) <= max_len: + return text + return text[:max_len] + "..." + + +def _format_paper_list( + papers: list, title: str, date: str | None = None, query: str | None = None +) -> str: + lines = [f"# {title}"] + if date: + lines[0] += f" ({date})" + if query: + lines.append(f"Filtered by: '{query}'") + lines.append(f"Showing {len(papers)} paper(s)\n") + + for i, item in enumerate(papers, 1): + paper = item.get("paper", item) + arxiv_id = paper.get("id", "") + paper_title = paper.get("title", "Unknown") + upvotes = paper.get("upvotes", 0) + summary = paper.get("ai_summary") or _truncate( + paper.get("summary", ""), MAX_SUMMARY_LEN + ) + keywords = paper.get("ai_keywords") or [] + github = paper.get("githubRepo") or "" + stars = paper.get("githubStars") or 0 + + lines.append(f"## {i}. {paper_title}") + lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}") + lines.append(f"https://huggingface.co/papers/{arxiv_id}") + if keywords: + lines.append(f"**Keywords:** {', '.join(keywords[:5])}") + if github: + lines.append(f"**GitHub:** {github} ({stars} stars)") + if summary: + lines.append(f"**Summary:** {_truncate(summary, MAX_SUMMARY_LEN)}") + lines.append("") + + return "\n".join(lines) + + +def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str: + arxiv_id = paper.get("id", "") + title = paper.get("title", "Unknown") + upvotes = paper.get("upvotes", 0) + ai_summary = paper.get("ai_summary") or "" + summary = paper.get("summary", "") + keywords = paper.get("ai_keywords") or [] + github = paper.get("githubRepo") or "" + stars = paper.get("githubStars") or 0 + authors = paper.get("authors") or [] + + lines = [f"# {title}"] + meta_parts = [f"**arxiv_id:** {arxiv_id}", f"**upvotes:** {upvotes}"] + if s2_data: + cites = s2_data.get("citationCount", 0) + influential = s2_data.get("influentialCitationCount", 0) + meta_parts.append(f"**citations:** {cites} ({influential} influential)") + lines.append(" | ".join(meta_parts)) + lines.append(f"https://huggingface.co/papers/{arxiv_id}") + lines.append(f"https://arxiv.org/abs/{arxiv_id}") + + if authors: + names = [a.get("name", "") for a in authors[:10]] + author_str = ", ".join(n for n in names if n) + if len(authors) > 10: + author_str += f" (+{len(authors) - 10} more)" + lines.append(f"**Authors:** {author_str}") + + if keywords: + lines.append(f"**Keywords:** {', '.join(keywords)}") + if s2_data and s2_data.get("s2FieldsOfStudy"): + fields = [f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category")] + if fields: + lines.append(f"**Fields:** {', '.join(fields)}") + if s2_data and s2_data.get("venue"): + lines.append(f"**Venue:** {s2_data['venue']}") + if github: + lines.append(f"**GitHub:** {github} ({stars} stars)") + + if s2_data and s2_data.get("tldr"): + tldr_text = s2_data["tldr"].get("text", "") + if tldr_text: + lines.append(f"\n## TL;DR\n{tldr_text}") + if ai_summary: + lines.append(f"\n## AI Summary\n{ai_summary}") + if summary: + lines.append(f"\n## Abstract\n{_truncate(summary, 500)}") + + lines.append( + "\n**Next:** Use read_paper to read specific sections, find_all_resources for linked datasets/models, " + "or citation_graph to trace references and citations." + ) + return "\n".join(lines) + + +def _format_read_paper_toc(parsed: dict[str, Any], arxiv_id: str) -> str: + """Format TOC view: abstract + section list with previews.""" + lines = [f"# {parsed['title']}"] + lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") + + if parsed["abstract"]: + lines.append(f"## Abstract\n{parsed['abstract']}\n") + + lines.append("## Sections") + for s in parsed["sections"]: + prefix = " " if s["level"] == 3 else "" + preview = ( + _truncate(s["text"], MAX_SECTION_PREVIEW_LEN) if s["text"] else "(empty)" + ) + lines.append(f"{prefix}- **{s['title']}**: {preview}") + + lines.append( + '\nCall read_paper with section parameter (e.g. section="4" or section="Experiments") to read a specific section.' + ) + return "\n".join(lines) + + +def _format_read_paper_section(section: dict, arxiv_id: str) -> str: + """Format a single section's full text.""" + lines = [f"# {section['title']}"] + lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") + + text = section["text"] + if len(text) > MAX_SECTION_TEXT_LEN: + text = ( + text[:MAX_SECTION_TEXT_LEN] + + f"\n\n... (truncated at {MAX_SECTION_TEXT_LEN} chars)" + ) + + lines.append(text if text else "(This section has no extractable text content.)") + return "\n".join(lines) + + +def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str: + lines = [f"# Datasets linked to paper {arxiv_id}"] + lines.append(f"https://huggingface.co/papers/{arxiv_id}") + lines.append(f"Showing {len(datasets)} dataset(s), sorted by {sort}\n") + + for i, ds in enumerate(datasets, 1): + ds_id = ds.get("id", "unknown") + downloads = ds.get("downloads", 0) + likes = ds.get("likes", 0) + desc = _truncate(_clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN) + tags = ds.get("tags") or [] + interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5] + + lines.append(f"**{i}. [{ds_id}](https://huggingface.co/datasets/{ds_id})**") + lines.append(f" Downloads: {downloads:,} | Likes: {likes}") + if interesting: + lines.append(f" Tags: {', '.join(interesting)}") + if desc: + lines.append(f" {desc}") + lines.append("") + + if datasets: + top = datasets[0].get("id", "") + lines.append(f'**Inspect top dataset:** hf_inspect_dataset(dataset="{top}")') + return "\n".join(lines) + + +def _format_datasets_compact(datasets: list) -> str: + if not datasets: + return "## Datasets\nNone found" + lines = [f"## Datasets ({len(datasets)})"] + for ds in datasets: + lines.append( + f"- **{ds.get('id', '?')}** ({ds.get('downloads', 0):,} downloads)" + ) + return "\n".join(lines) + + +def _format_models(models: list, arxiv_id: str, sort: str) -> str: + lines = [f"# Models linked to paper {arxiv_id}"] + lines.append(f"https://huggingface.co/papers/{arxiv_id}") + lines.append(f"Showing {len(models)} model(s), sorted by {sort}\n") + + for i, m in enumerate(models, 1): + model_id = m.get("id", "unknown") + downloads = m.get("downloads", 0) + likes = m.get("likes", 0) + pipeline = m.get("pipeline_tag") or "" + library = m.get("library_name") or "" + + lines.append(f"**{i}. [{model_id}](https://huggingface.co/{model_id})**") + meta = f" Downloads: {downloads:,} | Likes: {likes}" + if pipeline: + meta += f" | Task: {pipeline}" + if library: + meta += f" | Library: {library}" + lines.append(meta) + lines.append("") + + return "\n".join(lines) + + +def _format_models_compact(models: list) -> str: + if not models: + return "## Models\nNone found" + lines = [f"## Models ({len(models)})"] + for m in models: + pipeline = m.get("pipeline_tag") or "" + suffix = f" ({pipeline})" if pipeline else "" + lines.append( + f"- **{m.get('id', '?')}** ({m.get('downloads', 0):,} downloads){suffix}" + ) + return "\n".join(lines) + + +def _format_collections(collections: list, arxiv_id: str) -> str: + lines = [f"# Collections containing paper {arxiv_id}"] + lines.append(f"Showing {len(collections)} collection(s)\n") + + for i, c in enumerate(collections, 1): + slug = c.get("slug", "") + title = c.get("title", "Untitled") + upvotes = c.get("upvotes", 0) + owner = c.get("owner", {}).get("name", "") + desc = _truncate(c.get("description") or "", MAX_SUMMARY_LEN) + num_items = len(c.get("items", [])) + + lines.append(f"**{i}. {title}**") + lines.append(f" By: {owner} | Upvotes: {upvotes} | Items: {num_items}") + lines.append(f" https://huggingface.co/collections/{slug}") + if desc: + lines.append(f" {desc}") + lines.append("") + + return "\n".join(lines) + + +def _format_collections_compact(collections: list) -> str: + if not collections: + return "## Collections\nNone found" + lines = [f"## Collections ({len(collections)})"] + for c in collections: + title = c.get("title", "Untitled") + owner = c.get("owner", {}).get("name", "") + upvotes = c.get("upvotes", 0) + lines.append(f"- **{title}** by {owner} ({upvotes} upvotes)") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Operation handlers +# --------------------------------------------------------------------------- + + +def _error(message: str) -> ToolResult: + return { + "formatted": message, + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } + + +def _validate_arxiv_id(args: dict) -> str | None: + """Return arxiv_id or None if missing.""" + return args.get("arxiv_id") + + +async def _op_trending(args: dict[str, Any], limit: int) -> ToolResult: + date = args.get("date") + query = args.get("query") + + params: dict[str, Any] = {"limit": limit if not query else max(limit * 3, 30)} + if date: + params["date"] = date + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get(f"{HF_API}/daily_papers", params=params) + resp.raise_for_status() + papers = resp.json() + + if query: + q = query.lower() + papers = [ + p + for p in papers + if q in p.get("title", "").lower() + or q in p.get("paper", {}).get("title", "").lower() + or q in p.get("paper", {}).get("summary", "").lower() + or any( + q in kw.lower() for kw in (p.get("paper", {}).get("ai_keywords") or []) + ) + ] + + papers = papers[:limit] + if not papers: + msg = "No trending papers found" + if query: + msg += f" matching '{query}'" + if date: + msg += f" for {date}" + return {"formatted": msg, "totalResults": 0, "resultsShared": 0} + + formatted = _format_paper_list(papers, "Trending Papers", date=date, query=query) + return { + "formatted": formatted, + "totalResults": len(papers), + "resultsShared": len(papers), + } + + +def _format_s2_paper_list(papers: list[dict], title: str) -> str: + """Format a list of S2 paper results.""" + lines = [f"# {title}"] + lines.append(f"Showing {len(papers)} result(s)\n") + + for i, paper in enumerate(papers, 1): + ptitle = paper.get("title") or "(untitled)" + year = paper.get("year") or "?" + cites = paper.get("citationCount", 0) + venue = paper.get("venue") or "" + ext_ids = paper.get("externalIds") or {} + aid = ext_ids.get("ArXiv", "") + tldr = (paper.get("tldr") or {}).get("text", "") + + lines.append(f"### {i}. {ptitle}") + meta = [f"Year: {year}", f"Citations: {cites}"] + if venue: + meta.append(f"Venue: {venue}") + if aid: + meta.append(f"arxiv_id: {aid}") + lines.append(" | ".join(meta)) + if aid: + lines.append(f"https://arxiv.org/abs/{aid}") + if tldr: + lines.append(f"**TL;DR:** {tldr}") + lines.append("") + + lines.append("Use paper_details with arxiv_id for full info, or read_paper to read sections.") + return "\n".join(lines) + + +async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolResult | None: + """Search via S2 bulk endpoint with filters. Returns None on failure.""" + params: dict[str, Any] = { + "query": query, + "limit": limit, + "fields": "title,externalIds,year,citationCount,tldr,venue,publicationDate", + } + + # Date filter + date_from = args.get("date_from", "") + date_to = args.get("date_to", "") + if date_from or date_to: + params["publicationDateOrYear"] = f"{date_from}:{date_to}" + + # Fields of study + categories = args.get("categories") + if categories: + params["fieldsOfStudy"] = categories + + # Min citations + min_cites = args.get("min_citations") + if min_cites: + params["minCitationCount"] = str(min_cites) + + # Sort + sort_by = args.get("sort_by") + if sort_by and sort_by != "relevance": + params["sort"] = f"{sort_by}:desc" + + async with httpx.AsyncClient(timeout=15) as client: + resp = await _s2_request(client, "GET", "/graph/v1/paper/search/bulk", params=params) + if not resp or resp.status_code != 200: + return None + data = resp.json() + + papers = data.get("data") or [] + if not papers: + return { + "formatted": f"No papers found for '{query}' with the given filters.", + "totalResults": 0, + "resultsShared": 0, + } + + formatted = _format_s2_paper_list(papers[:limit], f"Papers matching '{query}' (Semantic Scholar)") + return { + "formatted": formatted, + "totalResults": data.get("total", len(papers)), + "resultsShared": min(limit, len(papers)), + } + + +async def _op_search(args: dict[str, Any], limit: int) -> ToolResult: + query = args.get("query") + if not query: + return _error("'query' is required for search operation.") + + # Route to S2 when filters are present + use_s2 = any(args.get(k) for k in ("date_from", "date_to", "categories", "min_citations", "sort_by")) + if use_s2: + result = await _s2_bulk_search(query, args, limit) + if result is not None: + return result + # Fall back to HF search (without filters) if S2 fails + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get( + f"{HF_API}/papers/search", params={"q": query, "limit": limit} + ) + resp.raise_for_status() + papers = resp.json() + + if not papers: + return { + "formatted": f"No papers found for '{query}'", + "totalResults": 0, + "resultsShared": 0, + } + + formatted = _format_paper_list(papers, f"Papers matching '{query}'") + return { + "formatted": formatted, + "totalResults": len(papers), + "resultsShared": len(papers), + } + + +async def _op_paper_details(args: dict[str, Any], limit: int) -> ToolResult: + arxiv_id = _validate_arxiv_id(args) + if not arxiv_id: + return _error("'arxiv_id' is required for paper_details.") + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get(f"{HF_API}/papers/{arxiv_id}") + resp.raise_for_status() + paper = resp.json() + + return { + "formatted": _format_paper_detail(paper), + "totalResults": 1, + "resultsShared": 1, + } + + +async def _op_read_paper(args: dict[str, Any], limit: int) -> ToolResult: + arxiv_id = _validate_arxiv_id(args) + if not arxiv_id: + return _error("'arxiv_id' is required for read_paper.") + + section_query = args.get("section") + + # Try fetching HTML from arxiv, then ar5iv, then fallback to abstract + parsed = None + async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: + for base_url in [ARXIV_HTML, AR5IV_HTML]: + try: + resp = await client.get(f"{base_url}/{arxiv_id}") + if resp.status_code == 200: + parsed = _parse_paper_html(resp.text) + if parsed["sections"]: # Only use if we got real sections + break + parsed = None + except httpx.RequestError: + continue + + # Fallback: return abstract from HF API + if not parsed or not parsed["sections"]: + try: + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get(f"{HF_API}/papers/{arxiv_id}") + resp.raise_for_status() + paper = resp.json() + abstract = paper.get("summary", "") + title = paper.get("title", "") + msg = f"# {title}\nhttps://arxiv.org/abs/{arxiv_id}\n\n" + msg += f"## Abstract\n{abstract}\n\n" + msg += "HTML version not available for this paper. Only abstract shown.\n" + msg += f"PDF: https://arxiv.org/pdf/{arxiv_id}" + return {"formatted": msg, "totalResults": 1, "resultsShared": 1} + except Exception: + return _error( + f"Could not fetch paper {arxiv_id}. Check the arxiv ID is correct." + ) + + # Return TOC or specific section + if not section_query: + formatted = _format_read_paper_toc(parsed, arxiv_id) + return { + "formatted": formatted, + "totalResults": len(parsed["sections"]), + "resultsShared": len(parsed["sections"]), + } + + section = _find_section(parsed["sections"], section_query) + if not section: + available = "\n".join(f"- {s['title']}" for s in parsed["sections"]) + return _error( + f"Section '{section_query}' not found. Available sections:\n{available}" + ) + + formatted = _format_read_paper_section(section, arxiv_id) + return {"formatted": formatted, "totalResults": 1, "resultsShared": 1} + + +# --------------------------------------------------------------------------- +# Citation graph (Semantic Scholar) +# --------------------------------------------------------------------------- + + +def _format_citation_entry(entry: dict, show_context: bool = False) -> str: + """Format a single citation/reference entry.""" + paper = entry.get("citingPaper") or entry.get("citedPaper") or {} + title = paper.get("title") or "(untitled)" + year = paper.get("year") or "?" + cites = paper.get("citationCount", 0) + ext_ids = paper.get("externalIds") or {} + aid = ext_ids.get("ArXiv", "") + influential = " **[influential]**" if entry.get("isInfluential") else "" + + parts = [f"- **{title}** ({year}, {cites} cites){influential}"] + if aid: + parts[0] += f" arxiv:{aid}" + + if show_context: + intents = entry.get("intents") or [] + if intents: + parts.append(f" Intent: {', '.join(intents)}") + contexts = entry.get("contexts") or [] + for ctx in contexts[:2]: + if ctx: + parts.append(f" > {_truncate(ctx, 200)}") + + return "\n".join(parts) + + +def _format_citation_graph( + arxiv_id: str, + references: list[dict] | None, + citations: list[dict] | None, +) -> str: + lines = [f"# Citation Graph for {arxiv_id}"] + lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") + + if references is not None: + lines.append(f"## References ({len(references)})") + if references: + for entry in references: + lines.append(_format_citation_entry(entry)) + else: + lines.append("No references found.") + lines.append("") + + if citations is not None: + lines.append(f"## Citations ({len(citations)})") + if citations: + for entry in citations: + lines.append(_format_citation_entry(entry, show_context=True)) + else: + lines.append("No citations found.") + lines.append("") + + lines.append("**Tip:** Use paper_details with an arxiv_id from above to explore further.") + return "\n".join(lines) + + +async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult: + arxiv_id = _validate_arxiv_id(args) + if not arxiv_id: + return _error("'arxiv_id' is required for citation_graph.") + + direction = args.get("direction", "both") + s2_id = _s2_paper_id(arxiv_id) + fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential" + params = {"fields": fields, "limit": limit} + + async with httpx.AsyncClient(timeout=15) as client: + refs, cites = None, None + coros = [] + if direction in ("references", "both"): + coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params)) + if direction in ("citations", "both"): + coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params)) + + results = await asyncio.gather(*coros, return_exceptions=True) + idx = 0 + if direction in ("references", "both"): + r = results[idx] + if isinstance(r, dict): + refs = r.get("data", []) + idx += 1 + if direction in ("citations", "both"): + r = results[idx] + if isinstance(r, dict): + cites = r.get("data", []) + + if refs is None and cites is None: + return _error(f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar.") + + total = (len(refs) if refs else 0) + (len(cites) if cites else 0) + return { + "formatted": _format_citation_graph(arxiv_id, refs, cites), + "totalResults": total, + "resultsShared": total, + } + + +async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult: + arxiv_id = _validate_arxiv_id(args) + if not arxiv_id: + return _error("'arxiv_id' is required for find_datasets.") + + sort = args.get("sort", "downloads") + sort_key = SORT_MAP.get(sort, "downloads") + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get( + f"{HF_API}/datasets", + params={ + "filter": f"arxiv:{arxiv_id}", + "limit": limit, + "sort": sort_key, + "direction": -1, + }, + ) + resp.raise_for_status() + datasets = resp.json() + + if not datasets: + return { + "formatted": f"No datasets found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", + "totalResults": 0, + "resultsShared": 0, + } + + return { + "formatted": _format_datasets(datasets, arxiv_id, sort), + "totalResults": len(datasets), + "resultsShared": len(datasets), + } + + +async def _op_find_models(args: dict[str, Any], limit: int) -> ToolResult: + arxiv_id = _validate_arxiv_id(args) + if not arxiv_id: + return _error("'arxiv_id' is required for find_models.") + + sort = args.get("sort", "downloads") + sort_key = SORT_MAP.get(sort, "downloads") + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get( + f"{HF_API}/models", + params={ + "filter": f"arxiv:{arxiv_id}", + "limit": limit, + "sort": sort_key, + "direction": -1, + }, + ) + resp.raise_for_status() + models = resp.json() + + if not models: + return { + "formatted": f"No models found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", + "totalResults": 0, + "resultsShared": 0, + } + + return { + "formatted": _format_models(models, arxiv_id, sort), + "totalResults": len(models), + "resultsShared": len(models), + } + + +async def _op_find_collections(args: dict[str, Any], limit: int) -> ToolResult: + arxiv_id = _validate_arxiv_id(args) + if not arxiv_id: + return _error("'arxiv_id' is required for find_collections.") + + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.get(f"{HF_API}/collections", params={"paper": arxiv_id}) + resp.raise_for_status() + collections = resp.json() + + if not collections: + return { + "formatted": f"No collections found containing paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", + "totalResults": 0, + "resultsShared": 0, + } + + collections = collections[:limit] + return { + "formatted": _format_collections(collections, arxiv_id), + "totalResults": len(collections), + "resultsShared": len(collections), + } + + +async def _op_find_all_resources(args: dict[str, Any], limit: int) -> ToolResult: + arxiv_id = _validate_arxiv_id(args) + if not arxiv_id: + return _error("'arxiv_id' is required for find_all_resources.") + + per_cat = min(limit, 10) + + async with httpx.AsyncClient(timeout=15) as client: + results = await asyncio.gather( + client.get( + f"{HF_API}/datasets", + params={ + "filter": f"arxiv:{arxiv_id}", + "limit": per_cat, + "sort": "downloads", + "direction": -1, + }, + ), + client.get( + f"{HF_API}/models", + params={ + "filter": f"arxiv:{arxiv_id}", + "limit": per_cat, + "sort": "downloads", + "direction": -1, + }, + ), + client.get(f"{HF_API}/collections", params={"paper": arxiv_id}), + return_exceptions=True, + ) + + sections = [] + total = 0 + + # Datasets + if isinstance(results[0], Exception): + sections.append(f"## Datasets\nError: {results[0]}") + else: + datasets = results[0].json() + total += len(datasets) + sections.append(_format_datasets_compact(datasets[:per_cat])) + + # Models + if isinstance(results[1], Exception): + sections.append(f"## Models\nError: {results[1]}") + else: + models = results[1].json() + total += len(models) + sections.append(_format_models_compact(models[:per_cat])) + + # Collections + if isinstance(results[2], Exception): + sections.append(f"## Collections\nError: {results[2]}") + else: + collections = results[2].json() + total += len(collections) + sections.append(_format_collections_compact(collections[:per_cat])) + + header = f"# Resources linked to paper {arxiv_id}\nhttps://huggingface.co/papers/{arxiv_id}\n" + formatted = header + "\n\n".join(sections) + return {"formatted": formatted, "totalResults": total, "resultsShared": total} + + +# --------------------------------------------------------------------------- +# Snippet search (Semantic Scholar) +# --------------------------------------------------------------------------- + + +def _format_snippets(snippets: list[dict], query: str) -> str: + lines = [f"# Snippet Search: '{query}'"] + lines.append(f"Found {len(snippets)} matching passage(s)\n") + + for i, item in enumerate(snippets, 1): + paper = item.get("paper") or {} + ptitle = paper.get("title") or "(untitled)" + year = paper.get("year") or "?" + cites = paper.get("citationCount", 0) + ext_ids = paper.get("externalIds") or {} + aid = ext_ids.get("ArXiv", "") + + snippet = item.get("snippet") or {} + text = snippet.get("text", "") + section = snippet.get("section") or "" + + lines.append(f"### {i}. {ptitle} ({year}, {cites} cites)") + if aid: + lines.append(f"arxiv:{aid}") + if section: + lines.append(f"Section: {section}") + if text: + lines.append(f"> {_truncate(text, 400)}") + lines.append("") + + lines.append("Use paper_details or read_paper with arxiv_id to explore a paper further.") + return "\n".join(lines) + + +async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult: + query = args.get("query") + if not query: + return _error("'query' is required for snippet_search.") + + params: dict[str, Any] = { + "query": query, + "limit": limit, + "fields": "title,externalIds,year,citationCount", + } + + # Optional filters (same as search) + date_from = args.get("date_from", "") + date_to = args.get("date_to", "") + if date_from or date_to: + params["publicationDateOrYear"] = f"{date_from}:{date_to}" + if args.get("categories"): + params["fieldsOfStudy"] = args["categories"] + if args.get("min_citations"): + params["minCitationCount"] = str(args["min_citations"]) + + async with httpx.AsyncClient(timeout=15) as client: + resp = await _s2_request(client, "GET", "/graph/v1/snippet/search", params=params) + if not resp or resp.status_code != 200: + return _error("Snippet search failed. Semantic Scholar may be unavailable.") + data = resp.json() + + snippets = data.get("data") or [] + if not snippets: + return { + "formatted": f"No snippets found for '{query}'.", + "totalResults": 0, + "resultsShared": 0, + } + + return { + "formatted": _format_snippets(snippets, query), + "totalResults": len(snippets), + "resultsShared": len(snippets), + } + + +# --------------------------------------------------------------------------- +# Recommendations (Semantic Scholar) +# --------------------------------------------------------------------------- + + +async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult: + positive_ids = args.get("positive_ids") + arxiv_id = _validate_arxiv_id(args) + + if not arxiv_id and not positive_ids: + return _error("'arxiv_id' or 'positive_ids' is required for recommend.") + + fields = "title,externalIds,year,citationCount,tldr,venue" + + async with httpx.AsyncClient(timeout=15) as client: + if positive_ids and not arxiv_id: + # Multi-paper recommendations (POST, not cached) + pos = [_s2_paper_id(pid.strip()) for pid in positive_ids.split(",") if pid.strip()] + neg_raw = args.get("negative_ids", "") + neg = [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] if neg_raw else [] + resp = await _s2_request( + client, "POST", "/recommendations/v1/papers/", + json={"positivePaperIds": pos, "negativePaperIds": neg}, + params={"fields": fields, "limit": limit}, + ) + if not resp or resp.status_code != 200: + return _error("Recommendation request failed. Semantic Scholar may be unavailable.") + data = resp.json() + else: + # Single-paper recommendations (cached) + data = await _s2_get_json( + client, + f"/recommendations/v1/papers/forpaper/{_s2_paper_id(arxiv_id)}", + {"fields": fields, "limit": limit, "from": "recent"}, + ) + if not data: + return _error("Recommendation request failed. Semantic Scholar may be unavailable.") + + papers = data.get("recommendedPapers") or [] + if not papers: + return { + "formatted": "No recommendations found.", + "totalResults": 0, + "resultsShared": 0, + } + + title = f"Recommended papers based on {arxiv_id or positive_ids}" + return { + "formatted": _format_s2_paper_list(papers[:limit], title), + "totalResults": len(papers), + "resultsShared": min(limit, len(papers)), + } + + +# --------------------------------------------------------------------------- +# Operation dispatch +# --------------------------------------------------------------------------- + +_OPERATIONS = { + "trending": _op_trending, + "search": _op_search, + "paper_details": _op_paper_details, + "read_paper": _op_read_paper, + "citation_graph": _op_citation_graph, + "snippet_search": _op_snippet_search, + "recommend": _op_recommend, + "find_datasets": _op_find_datasets, + "find_models": _op_find_models, + "find_collections": _op_find_collections, + "find_all_resources": _op_find_all_resources, +} + + +# --------------------------------------------------------------------------- +# Tool spec + handler +# --------------------------------------------------------------------------- + +HF_PAPERS_TOOL_SPEC = { + "name": "hf_papers", + "description": ( + "Discover ML research papers, analyze citations, search paper contents, and find linked resources.\n\n" + "Combines HuggingFace Hub, arXiv, and Semantic Scholar. Use for exploring research areas, " + "finding datasets for a task, tracing citation chains, or implementing a paper's approach.\n\n" + "Typical flows:\n" + " search → read_paper → find_all_resources → hf_inspect_dataset\n" + " search → paper_details → citation_graph → read_paper (trace influence)\n" + " snippet_search → paper_details → read_paper (find specific claims)\n\n" + "Operations:\n" + "- trending: Get trending daily papers, optionally filter by topic keyword\n" + "- search: Search papers. Uses HF by default (ML-tuned). Add date_from/min_citations/categories to use Semantic Scholar with filters\n" + "- paper_details: Metadata, abstract, AI summary, github link\n" + "- read_paper: Read paper contents — without section: abstract + TOC; with section: full text\n" + "- citation_graph: Get references and citations for a paper with influence flags and citation intents\n" + "- snippet_search: Semantic search over full-text passages from 12M+ papers\n" + "- recommend: Find similar papers (single paper or positive/negative examples)\n" + "- find_datasets: Find datasets linked to a paper\n" + "- find_models: Find models linked to a paper\n" + "- find_collections: Find collections that include a paper\n" + "- find_all_resources: Parallel fetch of datasets + models + collections for a paper" + ), + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": list(_OPERATIONS.keys()), + "description": "Operation to execute.", + }, + "query": { + "type": "string", + "description": ( + "Search query. Required for: search, snippet_search. " + "Optional for: trending (filters by keyword). " + "Supports boolean syntax for Semantic Scholar: '\"exact phrase\" term1 | term2'." + ), + }, + "arxiv_id": { + "type": "string", + "description": ( + "ArXiv paper ID (e.g. '2305.18290'). " + "Required for: paper_details, read_paper, citation_graph, find_datasets, find_models, find_collections, find_all_resources. " + "Optional for: recommend (single-paper recs). Get IDs from search results first." + ), + }, + "section": { + "type": "string", + "description": ( + "Section name or number to read (e.g. '3', 'Experiments', '4.2'). " + "Optional for: read_paper. Without this, returns abstract + TOC." + ), + }, + "direction": { + "type": "string", + "enum": ["citations", "references", "both"], + "description": "Direction for citation_graph. Default: both.", + }, + "date": { + "type": "string", + "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).", + }, + "date_from": { + "type": "string", + "description": "Start date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", + }, + "date_to": { + "type": "string", + "description": "End date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", + }, + "categories": { + "type": "string", + "description": "Field of study filter (e.g. 'Computer Science'). Triggers Semantic Scholar search.", + }, + "min_citations": { + "type": "integer", + "description": "Minimum citation count filter. Triggers Semantic Scholar search.", + }, + "sort_by": { + "type": "string", + "enum": ["relevance", "citationCount", "publicationDate"], + "description": "Sort order for Semantic Scholar search. Default: relevance.", + }, + "positive_ids": { + "type": "string", + "description": "Comma-separated arxiv IDs for multi-paper recommendations. For: recommend.", + }, + "negative_ids": { + "type": "string", + "description": "Comma-separated arxiv IDs as negative examples. For: recommend.", + }, + "sort": { + "type": "string", + "enum": ["downloads", "likes", "trending"], + "description": ( + "Sort order for find_datasets and find_models. Default: downloads." + ), + }, + "limit": { + "type": "integer", + "description": "Maximum results to return (default: 10, max: 50).", + }, + }, + "required": ["operation"], + }, +} + + +async def hf_papers_handler(arguments: dict[str, Any]) -> tuple[str, bool]: + """Handler for agent tool router.""" + operation = arguments.get("operation") + if not operation: + return "'operation' parameter is required.", False + + handler = _OPERATIONS.get(operation) + if not handler: + valid = ", ".join(_OPERATIONS.keys()) + return f"Unknown operation: '{operation}'. Valid: {valid}", False + + limit = min(arguments.get("limit", DEFAULT_LIMIT), MAX_LIMIT) + + try: + result = await handler(arguments, limit) + return result["formatted"], not result.get("isError", False) + except httpx.HTTPStatusError as e: + return f"API error: {e.response.status_code} — {e.response.text[:200]}", False + except httpx.RequestError as e: + return f"Request error: {e}", False + except Exception as e: + return f"Error in {operation}: {e}", False diff --git a/plugin/lib/ml_intern_lib/tools/sandbox_client.py b/plugin/lib/ml_intern_lib/tools/sandbox_client.py new file mode 100644 index 00000000..16982c76 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/sandbox_client.py @@ -0,0 +1,1054 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = ["huggingface_hub>=0.20.0", "httpx>=0.27.0"] +# /// +""" +Sandbox Tools — Agent-native primitives for HF Space dev-mode sandboxes. + +Architecture: + - Creates a sandbox by duplicating a template Space (runs sandbox_server.py) + - Waits for it to come online + - Communicates via HTTPS to the Space's API + - Optionally deletes the Space when done + +Lifecycle: + sb = Sandbox.create(owner="burtenshaw") # duplicate, wait, connect + sb = Sandbox.create(owner="burtenshaw", # with options + hardware="t4-small", + private=True, + sleep_time=3600) + sb = Sandbox.connect("burtenshaw/my-sandbox-abc") # attach to existing + + sb.bash("uv run train.py") + sb.read("/app/train.py") + sb.edit("/app/train.py", old_str="lr=1e-3", new_str="lr=1e-4") + + sb.delete() # tear down when done + + # Or use as a context manager for automatic cleanup + with Sandbox.create(owner="burtenshaw") as sb: + sb.bash("python train.py") + # Space deleted on exit + +Tools: bash, read, write, edit, upload +""" + +from __future__ import annotations + +import io +import sys +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable + +import httpx +from huggingface_hub import CommitOperationAdd, HfApi + +TEMPLATE_SPACE = "burtenshaw/sandbox" +HARDWARE_OPTIONS = [ + "cpu-basic", + "cpu-upgrade", + "t4-small", + "t4-medium", + "a10g-small", + "a10g-large", + "a100-large", +] +OUTPUT_LIMIT = 25000 +LINE_LIMIT = 4000 +DEFAULT_READ_LIMIT = 2000 +DEFAULT_TIMEOUT = 240 +MAX_TIMEOUT = 1200 +WAIT_TIMEOUT = 600 +WAIT_INTERVAL = 5 +API_WAIT_TIMEOUT = 180 + +_DOCKERFILE = """\ +FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim + +RUN apt-get update && \\ + apt-get install -y \\ + bash git git-lfs wget curl procps \\ + htop vim nano jq tmux \\ + build-essential && \\ + rm -rf /var/lib/apt/lists/* + +RUN uv pip install --system fastapi uvicorn python-multipart + +RUN useradd -m -u 1000 user +USER user + +ENV HOME=/home/user \\ + PATH=/home/user/.local/bin:$PATH \\ + PIP_USER=1 \\ + HF_HUB_DISABLE_PROGRESS_BARS=1 \\ + TQDM_DISABLE=1 \\ + HF_HUB_ENABLE_HF_TRANSFER=1 \\ + UV_NO_PROGRESS=1 \\ + PYTHONWARNINGS=ignore::DeprecationWarning + +WORKDIR /app +COPY --chown=user . /app + +EXPOSE 7860 + +CMD ["python", "sandbox_server.py"] +""" + +_SANDBOX_SERVER = '''\ +"""Minimal FastAPI server for sandbox operations.""" +import os, subprocess, pathlib, signal, threading, re, tempfile +from fastapi import FastAPI +from pydantic import BaseModel +from typing import Optional +import uvicorn + +_ANSI_RE = re.compile(r'\\x1b\\[[0-9;]*[a-zA-Z]|\\x1b\\].*?\\x07') + +def _strip_ansi(text: str) -> str: + return _ANSI_RE.sub('', text) + +def _truncate_output(output: str, max_chars: int = 25000, head_ratio: float = 0.25) -> str: + if len(output) <= max_chars: + return output + # Write full output to temp file so LLM can read specific sections + spill_path = None + try: + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', prefix='bash_output_', dir='/tmp', delete=False) as f: + f.write(output) + spill_path = f.name + except Exception: + pass + head_budget = int(max_chars * head_ratio) + tail_budget = max_chars - head_budget + head = output[:head_budget] + tail = output[-tail_budget:] + total = len(output) + omitted = total - max_chars + meta = f"\\n\\n... ({omitted:,} of {total:,} chars omitted, showing first {head_budget:,} + last {tail_budget:,}) ...\\n" + if spill_path: + meta += f"Full output saved to {spill_path} — use the read tool with offset/limit to inspect specific sections.\\n" + return head + meta + tail + +def _atomic_write(path: pathlib.Path, content: str): + """Write atomically: temp file + fsync + os.replace.""" + path.parent.mkdir(parents=True, exist_ok=True) + fd = None + tmp_path = None + try: + fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp") + os.write(fd, content.encode("utf-8")) + os.fsync(fd) + os.close(fd) + fd = None + os.replace(tmp_path, str(path)) + tmp_path = None + finally: + if fd is not None: + os.close(fd) + if tmp_path is not None: + try: + os.unlink(tmp_path) + except OSError: + pass + +app = FastAPI() + +# Track active bash processes so they can be killed on cancel +_active_procs = {} # pid -> subprocess.Popen +_proc_lock = threading.Lock() + +class BashReq(BaseModel): + command: str + work_dir: str = "/app" + timeout: int = 120 + +class ReadReq(BaseModel): + path: str + offset: Optional[int] = None + limit: Optional[int] = 2000 + +class WriteReq(BaseModel): + path: str + content: str + +class EditReq(BaseModel): + path: str + old_str: str + new_str: str + replace_all: bool = False + mode: str = "replace" + +class ExistsReq(BaseModel): + path: str + +# ── Fuzzy matching & edit utilities (embedded) ── + +UNICODE_MAP = { + "\\u2013": "-", "\\u2014": "-", "\\u2212": "-", + "\\u2018": "'", "\\u2019": "'", + "\\u201c": \'"\', "\\u201d": \'"\', + "\\u00a0": " ", "\\u2003": " ", "\\u2002": " ", + "\\u200b": "", "\\ufeff": "", +} + +def _normalize_unicode(s): + return "".join(UNICODE_MAP.get(c, c) for c in s) + +def _fuzzy_find_original(content, pattern): + """Find the original text in content that matches pattern fuzzily.""" + if pattern in content: + return pattern, None + # Pass 2: right-trim + c_lines = content.split("\\n") + c_rt = "\\n".join(l.rstrip() for l in c_lines) + p_rt = "\\n".join(l.rstrip() for l in pattern.split("\\n")) + if p_rt in c_rt: + idx = c_rt.index(p_rt) + start_line = c_rt[:idx].count("\\n") + n_lines = p_rt.count("\\n") + 1 + matched = "\\n".join(c_lines[start_line:start_line + n_lines]) + return matched, "(matched after trimming trailing whitespace)" + # Pass 3: both-sides trim + c_st = "\\n".join(l.strip() for l in c_lines) + p_st = "\\n".join(l.strip() for l in pattern.split("\\n")) + if p_st in c_st: + idx = c_st.index(p_st) + start_line = c_st[:idx].count("\\n") + n_lines = p_st.count("\\n") + 1 + matched = "\\n".join(c_lines[start_line:start_line + n_lines]) + return matched, "(matched after trimming whitespace)" + # Pass 4: unicode normalization + c_norm = _normalize_unicode(c_st) + p_norm = _normalize_unicode(p_st) + if p_norm in c_norm: + idx = c_norm.index(p_norm) + start_line = c_norm[:idx].count("\\n") + n_lines = p_norm.count("\\n") + 1 + matched = "\\n".join(c_lines[start_line:start_line + n_lines]) + return matched, "(matched after unicode normalization)" + return None, None + +def _apply_edit(content, old_str, new_str, mode="replace", replace_all=False): + """Apply edit. Returns (new_content, count, fuzzy_note) or raises ValueError.""" + if mode == "replace_all": + replace_all = True + mode = "replace" + fuzzy_note = None + if old_str not in content: + matched, fuzzy_note = _fuzzy_find_original(content, old_str) + if matched is None: + raise ValueError("old_str not found in file.") + old_str = matched + count = content.count(old_str) + if mode == "replace": + if count > 1 and not replace_all: + raise ValueError(f"old_str appears {count} times. Use replace_all=true or provide more context.") + if replace_all: + return content.replace(old_str, new_str), count, fuzzy_note + return content.replace(old_str, new_str, 1), 1, fuzzy_note + elif mode == "append_after": + if replace_all: + return content.replace(old_str, old_str + new_str), count, fuzzy_note + idx = content.index(old_str) + len(old_str) + return content[:idx] + new_str + content[idx:], 1, fuzzy_note + elif mode == "prepend_before": + if replace_all: + return content.replace(old_str, new_str + old_str), count, fuzzy_note + idx = content.index(old_str) + return content[:idx] + new_str + content[idx:], 1, fuzzy_note + raise ValueError(f"Unknown mode: {mode}") + +def _validate_python(content, path=""): + """Validate Python: syntax, kwargs against real installed signatures, training heuristics. + + Runs inside the sandbox where packages are pip-installed, so we can actually + import classes and inspect their __init__ signatures to catch kwarg mismatches + before runtime. + """ + import ast as _ast, inspect as _inspect, importlib as _il + warnings = [] + + # 1. Syntax check + try: + tree = _ast.parse(content) + except SyntaxError as e: + warnings.append(f"Python syntax error at line {e.lineno}: {e.msg}") + return warnings + + # 2. Build import map: name -> module path (from the script's own imports) + import_map = {} + for node in _ast.walk(tree): + if isinstance(node, _ast.ImportFrom) and node.module: + for alias in (node.names or []): + local_name = alias.asname or alias.name + import_map[local_name] = (node.module, alias.name) + elif isinstance(node, _ast.Import): + for alias in (node.names or []): + local_name = alias.asname or alias.name + import_map[local_name] = (alias.name, None) + + # 3. For each Call node, resolve the callable and check kwargs against signature + for node in _ast.walk(tree): + if not isinstance(node, _ast.Call): + continue + # Skip calls with **kwargs unpacking — we can't statically know those keys + if any(kw.arg is None for kw in node.keywords): + continue + call_kwargs = [kw.arg for kw in node.keywords if kw.arg] + if not call_kwargs: + continue + + # Resolve the callable name + func_name = None + if isinstance(node.func, _ast.Name): + func_name = node.func.id + elif isinstance(node.func, _ast.Attribute): + func_name = node.func.attr + if not func_name or func_name not in import_map: + continue + + # Try to import and inspect the real callable + module_path, attr_name = import_map[func_name] + try: + mod = _il.import_module(module_path) + obj = getattr(mod, attr_name, None) if attr_name else mod + if obj is None: + continue + sig = _inspect.signature(obj) + params = sig.parameters + # If **kwargs is in the signature, any kwarg is valid + if any(p.kind == _inspect.Parameter.VAR_KEYWORD for p in params.values()): + continue + valid_names = set(params.keys()) + for kw_name in call_kwargs: + if kw_name not in valid_names: + warnings.append( + f"Invalid kwarg: {func_name}({kw_name}=...) at line {node.lineno} " + f"-- not accepted by {module_path}.{attr_name or func_name}()" + ) + except Exception: + pass # can't import/inspect — skip silently + + # 4. Training script heuristics + if any(kw in content for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")): + if "push_to_hub" not in content: + warnings.append("Training script warning: no \'push_to_hub\' found") + if "hub_model_id" not in content: + warnings.append("Training script warning: no \'hub_model_id\' found") + return warnings + +@app.get("/api/health") +def health(): + return {"status": "ok"} + +@app.post("/api/bash") +def bash(req: BashReq): + try: + proc = subprocess.Popen( + req.command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True, cwd=req.work_dir, start_new_session=True, + ) + with _proc_lock: + _active_procs[proc.pid] = proc + try: + stdout, stderr = proc.communicate(timeout=req.timeout) + output = _strip_ansi(stdout + stderr) + output = _truncate_output(output) + return {"success": proc.returncode == 0, "output": output, "error": "" if proc.returncode == 0 else f"Exit code {proc.returncode}"} + except subprocess.TimeoutExpired: + try: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + except OSError: + proc.kill() + proc.wait() + return {"success": False, "output": "", "error": f"Timeout after {req.timeout}s"} + finally: + with _proc_lock: + _active_procs.pop(proc.pid, None) + except Exception as e: + return {"success": False, "output": "", "error": str(e)} + +@app.post("/api/kill") +def kill_all(): + """Kill all active bash processes. Called when user cancels.""" + with _proc_lock: + pids = list(_active_procs.keys()) + killed = [] + for pid in pids: + try: + os.killpg(os.getpgid(pid), signal.SIGTERM) + killed.append(pid) + except OSError: + try: + os.kill(pid, signal.SIGKILL) + killed.append(pid) + except OSError: + pass + return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""} + +@app.post("/api/read") +def read(req: ReadReq): + try: + p = pathlib.Path(req.path) + if not p.exists(): + return {"success": False, "output": "", "error": f"File not found: {req.path}"} + if p.is_dir(): + return {"success": False, "output": "", "error": f"Is a directory: {req.path}"} + lines = p.read_text().splitlines() + start = (req.offset or 1) - 1 + end = start + (req.limit or len(lines)) + selected = lines[start:end] + numbered = "\\n".join(f"{start + i + 1}\\t{line}" for i, line in enumerate(selected)) + return {"success": True, "output": numbered, "error": ""} + except Exception as e: + return {"success": False, "output": "", "error": str(e)} + +@app.post("/api/write") +def write(req: WriteReq): + try: + p = pathlib.Path(req.path) + _atomic_write(p, req.content) + msg = f"Wrote {len(req.content)} bytes to {req.path}" + if p.suffix == ".py": + warnings = _validate_python(req.content, req.path) + if warnings: + msg += "\\n\\nValidation warnings:\\n" + "\\n".join(f" ! {w}" for w in warnings) + return {"success": True, "output": msg, "error": ""} + except Exception as e: + return {"success": False, "output": "", "error": str(e)} + +@app.post("/api/edit") +def edit(req: EditReq): + try: + p = pathlib.Path(req.path) + if not p.exists(): + return {"success": False, "output": "", "error": f"File not found: {req.path}"} + content = p.read_text() + if req.old_str == req.new_str: + return {"success": False, "output": "", "error": "old_str and new_str must differ."} + try: + new_content, count, fuzzy_note = _apply_edit( + content, req.old_str, req.new_str, mode=req.mode, replace_all=req.replace_all + ) + except ValueError as e: + return {"success": False, "output": "", "error": str(e)} + _atomic_write(p, new_content) + msg = f"Edited {req.path} ({count} replacement{'s' if count > 1 else ''})" + if fuzzy_note: + msg += f" {fuzzy_note}" + if p.suffix == ".py": + warnings = _validate_python(new_content, req.path) + if warnings: + msg += "\\n\\nValidation warnings:\\n" + "\\n".join(f" ! {w}" for w in warnings) + return {"success": True, "output": msg, "error": ""} + except Exception as e: + return {"success": False, "output": "", "error": str(e)} + +@app.post("/api/exists") +def exists(req: ExistsReq): + return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""} + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=7860) +''' + + +@dataclass +class ToolResult: + success: bool + output: str = "" + error: str = "" + + def __str__(self): + if self.success: + return self.output or "(no output)" + return f"ERROR: {self.error}" + + def to_dict(self) -> dict: + return {"success": self.success, "output": self.output, "error": self.error} + + +@dataclass +class Sandbox: + """ + A handle to an HF Space sandbox. + + Use Sandbox.create() to spin up a new one, or Sandbox.connect() to + attach to an existing running Space. + """ + + space_id: str + token: str | None = None + work_dir: str = "/app" + timeout: int = DEFAULT_TIMEOUT + _owns_space: bool = field(default=False, repr=False) + _base_url: str = field(init=False, repr=False) + _client: httpx.Client = field(init=False, repr=False) + _hf_api: HfApi = field(init=False, repr=False) + _files_read: set = field(init=False, repr=False, default_factory=set) + + def __post_init__(self): + slug = self.space_id.replace("/", "-") + # Trailing slash is critical: httpx resolves relative paths against base_url. + # Without it, client.get("health") resolves to /health instead of /api/health. + self._base_url = f"https://{slug}.hf.space/api/" + self._client = httpx.Client( + base_url=self._base_url, + headers={"Authorization": f"Bearer {self.token}"} if self.token else {}, + timeout=httpx.Timeout(MAX_TIMEOUT, connect=30), + follow_redirects=True, + ) + self._hf_api = HfApi(token=self.token) + + # ── Lifecycle ───────────────────────────────────────────────── + + class Cancelled(Exception): + """Raised when sandbox creation is cancelled by the user.""" + + @classmethod + def create( + cls, + owner: str, + *, + name: str | None = None, + template: str = TEMPLATE_SPACE, + hardware: str = "cpu-basic", + private: bool = False, + sleep_time: int | None = None, + token: str | None = None, + secrets: dict[str, str] | None = None, + wait_timeout: int = WAIT_TIMEOUT, + log: "Callable[[str], object] | None" = None, + cancel_event: "Any | None" = None, + ) -> Sandbox: + """ + Create a new sandbox by duplicating the template Space. + + Generates a unique space name, duplicates the template, waits for it + to come online, then returns a connected Sandbox. + + Args: + owner: HF username or org (e.g. "burtenshaw"). + name: Base name for the space. Defaults to "sandbox". + A unique suffix is always appended. + template: Source Space to duplicate (default: burtenshaw/sandbox). + hardware: Hardware tier (cpu-basic, t4-small, etc.). + private: Whether the Space should be private. + sleep_time: Auto-sleep after N seconds of inactivity. + token: HF API token (from user's OAuth session). + wait_timeout: Max seconds to wait for Space to start (default: 300). + cancel_event: A threading.Event (or compatible) checked during + polling loops. When set, the Space is deleted and + Sandbox.Cancelled is raised. + + Returns: + A Sandbox instance connected to the running Space. + """ + _log = log or print + api = HfApi(token=token) + + def _check_cancel(): + if cancel_event and cancel_event.is_set(): + _log("Sandbox creation cancelled by user, cleaning up...") + try: + api.delete_repo(space_id, repo_type="space") + _log(f"Deleted Space {space_id}") + except Exception: + pass + raise cls.Cancelled(f"Sandbox creation cancelled: {space_id}") + + base = name or "sandbox" + suffix = uuid.uuid4().hex[:8] + space_id = f"{owner}/{base}-{suffix}" + + _log(f"Creating sandbox: {space_id} (from {template})...") + + kwargs = { + "from_id": template, + "to_id": space_id, + "private": private, + "hardware": hardware, + } + if sleep_time is not None: + kwargs["sleep_time"] = sleep_time + + api.duplicate_space(**kwargs) + _log(f"Space created: https://huggingface.co/spaces/{space_id}") + + _check_cancel() + + # Inject secrets BEFORE uploading server files (which triggers rebuild). + # Secrets added after a Space is running aren't available until restart, + # so they must be set before the build/start cycle. + if secrets: + for key, val in secrets.items(): + api.add_space_secret(space_id, key, val) + + # Upload sandbox server and Dockerfile (triggers rebuild) + cls._setup_server(space_id, api, log=_log) + + _check_cancel() + + # Wait for it to come online (rebuild + start) + _log(f"Waiting for Space to start (timeout: {wait_timeout}s)...") + deadline = time.time() + wait_timeout + while time.time() < deadline: + _check_cancel() + runtime = api.get_space_runtime(space_id) + if runtime.stage == "RUNNING": + _log(f"Space is running (hardware: {runtime.hardware})") + break + if runtime.stage in ("RUNTIME_ERROR", "BUILD_ERROR"): + raise RuntimeError( + f"Space failed to start: {runtime.stage}. " + f"Check https://huggingface.co/spaces/{space_id}" + ) + _log(f" {runtime.stage}...") + time.sleep(WAIT_INTERVAL) + else: + raise TimeoutError( + f"Space did not start within {wait_timeout}s. " + f"Check https://huggingface.co/spaces/{space_id}" + ) + + _check_cancel() + + # Wait for the API server to be responsive (non-fatal) + sb = cls(space_id=space_id, token=token, _owns_space=True) + try: + sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log) + except TimeoutError as e: + _log( + f"Warning: API health check timed out ({e}), but Space is RUNNING. Continuing." + ) + return sb + + @staticmethod + def _setup_server(space_id: str, api: HfApi, *, log: Callable[[str], object] = print) -> None: + """Upload embedded sandbox server + Dockerfile to the Space (single commit).""" + log(f"Uploading sandbox server to {space_id}...") + api.create_commit( + repo_id=space_id, + repo_type="space", + operations=[ + CommitOperationAdd( + path_in_repo="sandbox_server.py", + path_or_fileobj=io.BytesIO(_SANDBOX_SERVER.encode()), + ), + CommitOperationAdd( + path_in_repo="Dockerfile", + path_or_fileobj=io.BytesIO(_DOCKERFILE.encode()), + ), + ], + commit_message="Setup sandbox server", + ) + log("Server files uploaded, rebuild triggered.") + + @classmethod + def connect(cls, space_id: str, *, token: str | None = None) -> Sandbox: + """ + Connect to an existing running Space. + + Does a health check to verify the Space is reachable. + """ + sb = cls(space_id=space_id, token=token, _owns_space=False) + sb._wait_for_api(timeout=60) + return sb + + def _wait_for_api(self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print): + """Poll the health endpoint until the server responds.""" + deadline = time.time() + timeout + last_err = None + last_status = None + while time.time() < deadline: + try: + resp = self._client.get("health", timeout=10) + last_status = resp.status_code + if resp.status_code == 200: + log(f"API is responsive at {self._base_url}") + return + except Exception as e: + last_err = e + time.sleep(3) + raise TimeoutError( + f"Sandbox API at {self._base_url} not responding after {timeout}s. " + f"Last status: {last_status}, last error: {last_err}" + ) + + def delete(self): + """Delete the Space. Only works if this Sandbox created it.""" + if not self._owns_space: + raise RuntimeError( + f"This Sandbox did not create {self.space_id}. " + f"Use self._hf_api.delete_repo() directly if you're sure." + ) + print(f"Deleting sandbox: {self.space_id}...") + self._hf_api.delete_repo(self.space_id, repo_type="space") + self._client.close() + print("Deleted.") + + def pause(self): + """Pause the Space (stops billing, preserves state).""" + self._hf_api.pause_space(self.space_id) + + def restart(self): + """Restart the Space.""" + self._hf_api.restart_space(self.space_id) + self._wait_for_api() + + @property + def url(self) -> str: + """Public URL of the Space.""" + return f"https://huggingface.co/spaces/{self.space_id}" + + @property + def status(self) -> str: + """Current Space stage (RUNNING, BUILDING, PAUSED, etc.).""" + return self._hf_api.get_space_runtime(self.space_id).stage + + def __enter__(self) -> Sandbox: + return self + + def __exit__(self, *exc): + if self._owns_space: + try: + self.delete() + except Exception as e: + print(f"Warning: failed to delete sandbox: {e}", file=sys.stderr) + self._client.close() + + # ── HTTP plumbing ───────────────────────────────────────────── + + def _call( + self, endpoint: str, payload: dict, timeout: float | None = None + ) -> ToolResult: + # Strip leading slash for correct httpx base_url resolution + endpoint = endpoint.lstrip("/") + effective_timeout = timeout or self.timeout + last_error = "" + + # Retry up to 3 times for transient failures (sandbox waking from + # sleep returns empty / non-JSON responses while it starts up). + for attempt in range(3): + try: + resp = self._client.post( + endpoint, + json=payload, + timeout=effective_timeout, + ) + try: + data = resp.json() + except (ValueError, UnicodeDecodeError): + # Non-JSON response — sandbox is likely still starting up. + body_preview = resp.text[:200] if resp.text else "(empty)" + last_error = ( + f"Sandbox returned non-JSON response (HTTP {resp.status_code}): " + f"{body_preview}" + ) + if attempt < 2: + time.sleep(3 * (attempt + 1)) + continue + return ToolResult(success=False, error=last_error) + + if resp.status_code == 200: + return ToolResult( + success=data.get("success", True), + output=data.get("output", ""), + error=data.get("error", ""), + ) + return ToolResult( + success=False, + error=data.get("error", f"HTTP {resp.status_code}"), + ) + except httpx.TimeoutException: + return ToolResult( + success=False, error=f"Timeout after {effective_timeout}s" + ) + except httpx.ConnectError: + last_error = ( + f"Cannot connect to sandbox. Is {self.space_id} running? " + f"Status: {self.status}" + ) + if attempt < 2: + time.sleep(3 * (attempt + 1)) + continue + return ToolResult(success=False, error=last_error) + except Exception as e: + return ToolResult(success=False, error=str(e)) + + return ToolResult(success=False, error=last_error or "Unknown error") + + # ── Tools ───────────────────────────────────────────────────── + + def bash( + self, + command: str, + *, + work_dir: str | None = None, + timeout: int | None = None, + description: str | None = None, + ) -> ToolResult: + return self._call( + "bash", + { + "command": command, + "work_dir": work_dir or self.work_dir, + "timeout": min(timeout or self.timeout, MAX_TIMEOUT), + }, + timeout=timeout, + ) + + def read( + self, path: str, *, offset: int | None = None, limit: int | None = None + ) -> ToolResult: + self._files_read.add(path) + return self._call( + "read", + { + "path": path, + "offset": offset, + "limit": limit or (DEFAULT_READ_LIMIT if offset is None else None), + }, + ) + + def write(self, path: str, content: str) -> ToolResult: + if path not in self._files_read: + check = self._call("exists", {"path": path}) + if check.success and check.output == "true": + return ToolResult( + success=False, + error=( + f"File {path} exists but has not been read this session. " + f"Read it first, or use sandbox_edit for targeted changes." + ), + ) + result = self._call("write", {"path": path, "content": content}) + if result.success: + self._files_read.add(path) + return result + + def edit( + self, path: str, old_str: str, new_str: str, *, replace_all: bool = False, + mode: str = "replace", + ) -> ToolResult: + if old_str == new_str: + return ToolResult(success=False, error="old_str and new_str are identical.") + if path not in self._files_read: + return ToolResult( + success=False, + error=f"File {path} has not been read this session. Read it first.", + ) + return self._call( + "edit", + { + "path": path, + "old_str": old_str, + "new_str": new_str, + "replace_all": replace_all, + "mode": mode, + }, + ) + + def kill_all(self) -> ToolResult: + """Kill all active bash processes on the sandbox. Used on cancellation.""" + return self._call("kill", {}) + + # ── Tool schemas & dispatch ─────────────────────────────────── + + TOOLS = { + "bash": { + "description": ( + "Run a shell command in the remote sandbox and return stdout/stderr.\n" + "\n" + "IMPORTANT: Do NOT use bash for file operations — use the dedicated tools instead:\n" + "- To read files: use read (not cat/head/tail)\n" + "- To edit files: use edit (not sed/awk)\n" + "- To write files: use write (not echo/cat < > /app/output.log 2>&1 & echo $!\n" + "Then check status:\n" + " kill -0 2>/dev/null && echo 'running' || echo 'done'\n" + " tail -n 50 /app/output.log\n" + "\n" + "Timeout default 240s, max 1200s." + ), + "parameters": { + "type": "object", + "required": ["command"], + "additionalProperties": False, + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute.", + }, + "description": { + "type": "string", + "description": "Short description (5-10 words, active voice).", + }, + "work_dir": { + "type": "string", + "description": "Working directory (default: /app).", + }, + "timeout": { + "type": "integer", + "description": "Optional timeout in seconds (default: 240, max: 1200).", + }, + }, + }, + }, + "read": { + "description": ( + "Reads a file from the sandbox filesystem. Returns contents with line " + "numbers (cat -n format).\n" + "\n" + "Usage:\n" + "- By default, reads up to 2000 lines from the beginning of the file.\n" + "- You can optionally specify offset and limit for large files, but prefer " + "reading the whole file first.\n" + "- Lines longer than 4000 chars are truncated.\n" + "- Cannot read directories — use bash with 'ls' instead.\n" + "- You should read multiple potentially useful files in parallel when possible.\n" + "- IMPORTANT: Always read a file before editing or overwriting it. The edit and " + "write tools will reject operations on files you haven't read." + ), + "parameters": { + "type": "object", + "required": ["path"], + "additionalProperties": False, + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to read.", + }, + "offset": { + "type": "integer", + "description": "The line number to start reading from (1-based). Only provide if the file is too large to read at once.", + }, + "limit": { + "type": "integer", + "description": "The number of lines to read. Only provide if the file is too large to read at once.", + }, + }, + }, + }, + "write": { + "description": ( + "Writes a file to the sandbox filesystem. Overwrites the existing file if " + "one exists at the path.\n" + "\n" + "- If this is an existing file, you MUST use the read tool first. This tool " + "will fail if you did not read the file first.\n" + "- ALWAYS prefer editing existing files with the edit tool over overwriting " + "with write.\n" + "- Creates parent directories as needed." + ), + "parameters": { + "type": "object", + "required": ["path", "content"], + "additionalProperties": False, + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to write.", + }, + "content": { + "type": "string", + "description": "The complete file content to write.", + }, + }, + }, + }, + "edit": { + "description": ( + "Performs string replacements in files. Supports exact matching with " + "fuzzy fallback.\n" + "\n" + "Usage:\n" + "- You must read the file at least once before editing. This tool will " + "error if you attempt an edit without reading the file.\n" + "- The edit will FAIL if old_str is not unique in the file. Either provide " + "a larger string with more surrounding context to make it unique, or set " + "replace_all to true.\n" + "- old_str and new_str must differ.\n" + "- Preserve indentation exactly as it appears in the file.\n" + "- Do NOT include line number prefixes from read output in old_str or new_str.\n" + "- To delete code, set new_str to empty string.\n" + "- Use replace_all for renaming variables or strings across the file.\n" + "\n" + "Modes:\n" + "- replace (default): replace first occurrence of old_str with new_str.\n" + "- append_after: insert new_str immediately after old_str (old_str is kept).\n" + "- prepend_before: insert new_str immediately before old_str (old_str is kept)." + ), + "parameters": { + "type": "object", + "required": ["path", "old_str", "new_str"], + "additionalProperties": False, + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to edit.", + }, + "old_str": { + "type": "string", + "description": "The text to find in the file. Must match exactly (fuzzy matching is used as fallback).", + }, + "new_str": { + "type": "string", + "description": "The replacement text. For append_after/prepend_before modes, the text to insert.", + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences of old_str (default: false).", + "default": False, + }, + "mode": { + "type": "string", + "enum": ["replace", "append_after", "prepend_before"], + "description": "Edit mode (default: replace).", + "default": "replace", + }, + }, + }, + }, + } + + @classmethod + def tool_definitions(cls) -> list[dict]: + return [{"name": name, **spec} for name, spec in cls.TOOLS.items()] + + def call_tool(self, name: str, arguments: dict[str, Any]) -> ToolResult: + dispatch = { + "bash": lambda a: self.bash( + a["command"], + work_dir=a.get("work_dir"), + timeout=a.get("timeout"), + description=a.get("description"), + ), + "read": lambda a: self.read( + a["path"], + offset=a.get("offset"), + limit=a.get("limit"), + ), + "write": lambda a: self.write(a["path"], a["content"]), + "edit": lambda a: self.edit( + a["path"], + a["old_str"], + a["new_str"], + replace_all=a.get("replace_all", False), + mode=a.get("mode", "replace"), + ), + } + fn = dispatch.get(name) + if not fn: + return ToolResult(success=False, error=f"Unknown tool: {name}") + return fn(arguments) diff --git a/plugin/lib/ml_intern_lib/tools/sandbox_tool.py b/plugin/lib/ml_intern_lib/tools/sandbox_tool.py new file mode 100644 index 00000000..67f00306 --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/sandbox_tool.py @@ -0,0 +1,301 @@ +""" +Sandbox tools — expose the Sandbox client as agent tools. + +5 tools total: + sandbox_create — explicit sandbox creation (requires approval) + bash, read, write, edit — operations on the sandbox + +If any operation tool is called without an active sandbox, +a cpu-basic sandbox is auto-created (no approval needed). +""" + +from __future__ import annotations + +import asyncio +import threading +from typing import Any + +from huggingface_hub import HfApi, SpaceHardware + +from ml_intern_lib.session_stub import Event +from ml_intern_lib.tools.sandbox_client import Sandbox + + +def _looks_like_path(script: str) -> bool: + """Return True if the script string looks like a file path (not inline code).""" + return ( + isinstance(script, str) + and script.strip() == script + and not any(c in script for c in "\r\n\0") + and ( + script.startswith("/") + or script.startswith("./") + or script.startswith("../") + ) + ) + + +async def resolve_sandbox_script( + sandbox: Any, script: str +) -> tuple[str | None, str | None]: + """Read a file from the sandbox if *script* looks like a path. + + Returns: + (content, error) — content is the file text on success, + error is a message on failure. Both None means *script* + is not a path (caller should use it as-is). + """ + if not sandbox or not _looks_like_path(script): + return None, None + try: + # Use the read endpoint instead of bash("cat ...") which truncates at 25KB. + result = await asyncio.to_thread(sandbox.read, script, limit=100_000) + if result.success and result.output: + # Strip line number prefixes (read returns "N\tcontent" format) + lines = [] + for line in result.output.split("\n"): + parts = line.split("\t", 1) + lines.append(parts[1] if len(parts) == 2 else line) + return "\n".join(lines), None + return None, f"Failed to read {script} from sandbox: {result.error}" + except Exception as e: + return None, f"Failed to read {script} from sandbox: {e}" + + +# ── Tool name mapping (short agent names → Sandbox client names) ────── + + +async def _ensure_sandbox( + session: Any, hardware: str = "cpu-basic", **create_kwargs +) -> tuple[Sandbox | None, str | None]: + """ + Ensure a sandbox exists on the session. Auto-creates with given hardware if needed. + + Returns: + (sandbox, error_message) — one will be None. + """ + if session and getattr(session, "sandbox", None): + return session.sandbox, None + + if not session: + return None, "No session available." + + token = session.hf_token + if not token: + return None, "No HF token available. Cannot create sandbox." + + api = HfApi(token=token) + user_info = api.whoami() + owner = user_info.get("name", user_info.get("user", "")) + if not owner: + return None, "Could not determine HF username from token." + + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "sandbox", + "log": f"Auto-creating sandbox for {owner} ({hardware})...", + }, + ) + ) + + # Thread-safe log callback: posts tool_log events from the worker thread + loop = asyncio.get_running_loop() + + def _log(msg: str) -> None: + loop.call_soon_threadsafe( + session.event_queue.put_nowait, + Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}), + ) + + # Bridge asyncio cancel event to a threading.Event for the blocking create call. + # We poll session._cancelled from the main loop in a background task and set + # a threading.Event that Sandbox.create checks during its polling loops. + cancel_flag = threading.Event() + + async def _watch_cancel(): + await session._cancelled.wait() + cancel_flag.set() + + watcher_task = asyncio.create_task(_watch_cancel()) + + kwargs = { + "owner": owner, + "hardware": hardware, + "token": token, + "secrets": {"HF_TOKEN": token}, + "log": _log, + "cancel_event": cancel_flag, + **create_kwargs, + } + if hardware != "cpu-basic": + kwargs["sleep_time"] = 2700 + import time as _t + _t_start = _t.monotonic() + try: + sb = await asyncio.to_thread(Sandbox.create, **kwargs) + except Sandbox.Cancelled: + return None, "Sandbox creation cancelled by user." + finally: + watcher_task.cancel() + session.sandbox = sb + + # Telemetry: sandbox creation (infra consumption signal) + from ml_intern_lib import telemetry_stub as telemetry + await telemetry.record_sandbox_create( + session, sb, hardware=hardware, + create_latency_s=int(_t.monotonic() - _t_start), + ) + + # Set a descriptive title (template title is inherited on duplicate) + from huggingface_hub import metadata_update + + await asyncio.to_thread( + metadata_update, + sb.space_id, + {"title": "ml-intern sandbox"}, + repo_type="space", + overwrite=True, + token=token, + ) + + await session.send_event( + Event( + event_type="tool_log", + data={"tool": "sandbox", "log": f"Sandbox ready: {sb.space_id} ({sb.url})"}, + ) + ) + + return sb, None + + +# ── sandbox_create tool ────────────────────────────────────────────── + +SANDBOX_CREATE_TOOL_SPEC = { + "name": "sandbox_create", + "description": ( + "Create a persistent remote Linux environment for developing and testing scripts.\n\n" + "Workflow: sandbox_create → write script → pip install → test with small run → fix errors → hf_jobs at scale.\n" + "The sandbox persists across tool calls within the session. pip install works out of the box.\n\n" + "Use this when: you need to develop, test, and iterate on scripts before launching via hf_jobs. " + "Especially for training scripts where you need to verify imports, test on a small subset, and fix errors interactively.\n\n" + "Skip this when: the task is a simple one-shot operation (status check, resource search, quick data query), " + "or the script is copied from a verified working example with minimal changes.\n\n" + "For ML code that uses CUDA, bf16, or model loading: use GPU hardware (t4-small minimum). " + "CPU sandboxes cannot run GPU code paths — your test will not catch GPU-related errors.\n\n" + "Before choosing hardware, estimate your VRAM needs (models you run, training data size). Rule of thumb: bf16/fp16 ≈ 2 bytes/param, " + "fp32 ≈ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n" + "Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). " + "If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n" + "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n" + ), + "parameters": { + "type": "object", + "required": [], + "additionalProperties": False, + "properties": { + "hardware": { + "type": "string", + "enum": [e.value for e in SpaceHardware], + "description": "Hardware tier for the sandbox (default: cpu-basic)", + }, + "private": { + "type": "boolean", + "description": "If true, create a private Space", + }, + }, + }, +} + + +async def sandbox_create_handler( + args: dict[str, Any], session: Any = None +) -> tuple[str, bool]: + """Handle sandbox_create tool calls.""" + # If sandbox already exists, return its info + if session and getattr(session, "sandbox", None): + sb = session.sandbox + return ( + f"Sandbox already active: {sb.space_id}\n" + f"URL: {sb.url}\n" + f"Use bash/read/write/edit to interact with it." + ), True + + hardware = args.get("hardware", "cpu-basic") + create_kwargs = {} + if "private" in args: + create_kwargs["private"] = args["private"] + + try: + sb, error = await _ensure_sandbox(session, hardware=hardware, **create_kwargs) + except Exception as e: + return f"Failed to create sandbox: {e}", False + + if error: + return error, False + + return ( + f"Sandbox created: {sb.space_id}\n" + f"URL: {sb.url}\n" + f"Hardware: {hardware}\n" + f"Use bash/read/write/edit to interact with it." + ), True + + +def _make_tool_handler(sandbox_tool_name: str): + """Factory: create a handler for a sandbox operation tool.""" + + async def handler(args: dict[str, Any], session: Any = None) -> tuple[str, bool]: + # Require sandbox to exist — user must approve sandbox_create first + if not session or not getattr(session, "sandbox", None): + return "No sandbox running. Call sandbox_create first to start one.", False + + sb = session.sandbox + + try: + result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args) + if result.success: + output = result.output or "(no output)" + return output, True + else: + error_msg = result.error or "Unknown error" + output = result.output + if output: + return f"{output}\n\nERROR: {error_msg}", False + return f"ERROR: {error_msg}", False + except Exception as e: + return f"Sandbox operation failed: {e}", False + + return handler + + +def get_sandbox_tools(): + """Return all 5 sandbox ToolSpecs (sandbox_create + 4 operation tools).""" + from ml_intern_lib.tool_spec import ToolSpec + + tools = [] + + # sandbox_create (explicit creation, requires approval) + tools.append( + ToolSpec( + name=SANDBOX_CREATE_TOOL_SPEC["name"], + description=SANDBOX_CREATE_TOOL_SPEC["description"], + parameters=SANDBOX_CREATE_TOOL_SPEC["parameters"], + handler=sandbox_create_handler, + ) + ) + + # Operation tools (auto-execute, no approval needed) + for name in Sandbox.TOOLS.keys(): + spec = Sandbox.TOOLS[name] + tools.append( + ToolSpec( + name=name, + description=spec["description"], + parameters=spec["parameters"], + handler=_make_tool_handler(name), + ) + ) + + return tools diff --git a/plugin/lib/ml_intern_lib/tools/types.py b/plugin/lib/ml_intern_lib/tools/types.py new file mode 100644 index 00000000..c968e35b --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/types.py @@ -0,0 +1,16 @@ +""" +Types for Hugging Face tools + +Ported from: hf-mcp-server/packages/mcp/src/types/ +""" + +from typing import TypedDict + + +class ToolResult(TypedDict, total=False): + """Result returned by HF tool operations""" + + formatted: str + totalResults: int + resultsShared: int + isError: bool diff --git a/plugin/lib/ml_intern_lib/tools/utilities.py b/plugin/lib/ml_intern_lib/tools/utilities.py new file mode 100644 index 00000000..93b4229e --- /dev/null +++ b/plugin/lib/ml_intern_lib/tools/utilities.py @@ -0,0 +1,142 @@ +""" +Utility functions for Hugging Face tools + +Ported from: hf-mcp-server/packages/mcp/src/jobs/formatters.ts +Includes GPU memory validation for job submissions +""" + +import json +from datetime import datetime +from typing import Any, Dict, List, Optional + + +def truncate(text: str, max_length: int) -> str: + """Truncate a string to a maximum length with ellipsis""" + if len(text) <= max_length: + return text + return text[: max_length - 3] + "..." + + +def format_date(date_str: Optional[str]) -> str: + """Format a date string to a readable format""" + if not date_str: + return "N/A" + try: + date = datetime.fromisoformat(date_str.replace("Z", "+00:00")) + return date.strftime("%Y-%m-%d %H:%M:%S") + except Exception: + return date_str + + +def format_command(command: Optional[List[str]]) -> str: + """Format command array as a single string""" + if not command or len(command) == 0: + return "N/A" + return " ".join(command) + + +def get_image_or_space(job: Dict[str, Any]) -> str: + """Get image/space identifier from job""" + if job.get("spaceId"): + return job["spaceId"] + if job.get("dockerImage"): + return job["dockerImage"] + return "N/A" + + +def format_jobs_table(jobs: List[Dict[str, Any]]) -> str: + """Format jobs as a markdown table""" + if len(jobs) == 0: + return "No jobs found." + + # Calculate dynamic ID column width + longest_id_length = max(len(job["id"]) for job in jobs) + id_column_width = max(longest_id_length, len("JOB ID")) + + # Define column widths + col_widths = { + "id": id_column_width, + "image": 20, + "command": 30, + "created": 19, + "status": 12, + } + + # Build header + header = f"| {'JOB ID'.ljust(col_widths['id'])} | {'IMAGE/SPACE'.ljust(col_widths['image'])} | {'COMMAND'.ljust(col_widths['command'])} | {'CREATED'.ljust(col_widths['created'])} | {'STATUS'.ljust(col_widths['status'])} |" + separator = f"|{'-' * (col_widths['id'] + 2)}|{'-' * (col_widths['image'] + 2)}|{'-' * (col_widths['command'] + 2)}|{'-' * (col_widths['created'] + 2)}|{'-' * (col_widths['status'] + 2)}|" + + # Build rows + rows = [] + for job in jobs: + job_id = job["id"] + image = truncate(get_image_or_space(job), col_widths["image"]) + command = truncate(format_command(job.get("command")), col_widths["command"]) + created = truncate(format_date(job.get("createdAt")), col_widths["created"]) + status = truncate(job["status"]["stage"], col_widths["status"]) + + rows.append( + f"| {job_id.ljust(col_widths['id'])} | {image.ljust(col_widths['image'])} | {command.ljust(col_widths['command'])} | {created.ljust(col_widths['created'])} | {status.ljust(col_widths['status'])} |" + ) + + return "\n".join([header, separator] + rows) + + +def format_scheduled_jobs_table(jobs: List[Dict[str, Any]]) -> str: + """Format scheduled jobs as a markdown table""" + if len(jobs) == 0: + return "No scheduled jobs found." + + # Calculate dynamic ID column width + longest_id_length = max(len(job["id"]) for job in jobs) + id_column_width = max(longest_id_length, len("ID")) + + # Define column widths + col_widths = { + "id": id_column_width, + "schedule": 12, + "image": 18, + "command": 25, + "lastRun": 19, + "nextRun": 19, + "suspend": 9, + } + + # Build header + header = f"| {'ID'.ljust(col_widths['id'])} | {'SCHEDULE'.ljust(col_widths['schedule'])} | {'IMAGE/SPACE'.ljust(col_widths['image'])} | {'COMMAND'.ljust(col_widths['command'])} | {'LAST RUN'.ljust(col_widths['lastRun'])} | {'NEXT RUN'.ljust(col_widths['nextRun'])} | {'SUSPENDED'.ljust(col_widths['suspend'])} |" + separator = f"|{'-' * (col_widths['id'] + 2)}|{'-' * (col_widths['schedule'] + 2)}|{'-' * (col_widths['image'] + 2)}|{'-' * (col_widths['command'] + 2)}|{'-' * (col_widths['lastRun'] + 2)}|{'-' * (col_widths['nextRun'] + 2)}|{'-' * (col_widths['suspend'] + 2)}|" + + # Build rows + rows = [] + for job in jobs: + job_id = job["id"] + schedule = truncate(job["schedule"], col_widths["schedule"]) + image = truncate(get_image_or_space(job["jobSpec"]), col_widths["image"]) + command = truncate( + format_command(job["jobSpec"].get("command")), col_widths["command"] + ) + last_run = truncate(format_date(job.get("lastRun")), col_widths["lastRun"]) + next_run = truncate(format_date(job.get("nextRun")), col_widths["nextRun"]) + suspend = "Yes" if job.get("suspend") else "No" + + rows.append( + f"| {job_id.ljust(col_widths['id'])} | {schedule.ljust(col_widths['schedule'])} | {image.ljust(col_widths['image'])} | {command.ljust(col_widths['command'])} | {last_run.ljust(col_widths['lastRun'])} | {next_run.ljust(col_widths['nextRun'])} | {suspend.ljust(col_widths['suspend'])} |" + ) + + return "\n".join([header, separator] + rows) + + +def format_job_details(jobs: Any) -> str: + """Format job details as JSON in a markdown code block""" + + job_array = jobs if isinstance(jobs, list) else [jobs] + json_str = json.dumps(job_array, indent=2) + return f"```json\n{json_str}\n```" + + +def format_scheduled_job_details(jobs: Any) -> str: + """Format scheduled job details as JSON in a markdown code block""" + + job_array = jobs if isinstance(jobs, list) else [jobs] + json_str = json.dumps(job_array, indent=2) + return f"```json\n{json_str}\n```" diff --git a/plugin/pyproject.toml b/plugin/pyproject.toml new file mode 100644 index 00000000..a6df7ab9 --- /dev/null +++ b/plugin/pyproject.toml @@ -0,0 +1,28 @@ +[project] +name = "ml-intern-plugin" +version = "0.1.0" +description = "MCP server + vendored tools for the ml-intern Claude Code plugin" +readme = "README.md" +requires-python = ">=3.11" +license = { text = "Apache-2.0" } +dependencies = [ + "mcp>=1.0.0", + "huggingface-hub>=1.0.1", + "httpx>=0.27.0", + "requests>=2.33.0", + "thefuzz>=0.22.1", + "whoosh>=2.7.4", + "beautifulsoup4>=4.12.0", + "nbconvert>=7.16.6", + "nbformat>=5.10.4", +] + +[tool.uv] +package = false + +[tool.hatch.build.targets.wheel] +packages = ["lib/ml_intern_lib"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/scripts/sync_plugin_vendored.py b/scripts/sync_plugin_vendored.py new file mode 100755 index 00000000..4fe871de --- /dev/null +++ b/scripts/sync_plugin_vendored.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +"""Re-sync the plugin's vendored library from agent/tools and agent/core/redact. + +Run from the repo root: + + uv run python scripts/sync_plugin_vendored.py + +What it does: + 1. Copies agent/tools/*.py (minus research/plan/private) → plugin/lib/ml_intern_lib/tools/ + 2. Copies agent/core/redact.py → plugin/lib/ml_intern_lib/redact.py + 3. Rewrites imports inside the copies: + from agent.tools.X → from ml_intern_lib.tools.X + from agent.core.session import Event + → from ml_intern_lib.session_stub import Event + from agent.core.tools import ToolSpec + → from ml_intern_lib.tool_spec import ToolSpec + from agent.core import telemetry + → from ml_intern_lib import telemetry_stub as telemetry + +This is idempotent — running it twice produces the same output. + +Diff the result to confirm before committing. +""" + +from __future__ import annotations + +import re +import shutil +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +SRC_TOOLS = REPO_ROOT / "agent" / "tools" +SRC_REDACT = REPO_ROOT / "agent" / "core" / "redact.py" +DST_PKG = REPO_ROOT / "plugin" / "lib" / "ml_intern_lib" +DST_TOOLS = DST_PKG / "tools" + +# Tools NOT vendored (replaced by Claude Code natives or disabled upstream). +SKIP = {"research_tool.py", "plan_tool.py", "private_hf_repo_tools.py"} + + +def rewrite(src: str) -> str: + src = re.sub(r"\bfrom agent\.tools\.", "from ml_intern_lib.tools.", src) + src = re.sub( + r"\bfrom agent\.core\.session import Event\b", + "from ml_intern_lib.session_stub import Event", + src, + ) + src = re.sub( + r"\bfrom agent\.core\.tools import ToolSpec\b", + "from ml_intern_lib.tool_spec import ToolSpec", + src, + ) + src = re.sub( + r"\bfrom agent\.core import telemetry\b", + "from ml_intern_lib import telemetry_stub as telemetry", + src, + ) + return src + + +def main() -> int: + if not SRC_TOOLS.is_dir(): + print(f"FATAL: {SRC_TOOLS} not found — run from repo root", file=sys.stderr) + return 1 + if not DST_TOOLS.is_dir(): + print(f"FATAL: {DST_TOOLS} not found — plugin layout missing?", file=sys.stderr) + return 1 + + # Tools + copied = 0 + for src in SRC_TOOLS.glob("*.py"): + if src.name in SKIP: + continue + dst = DST_TOOLS / src.name + text = src.read_text() + rewritten = rewrite(text) + dst.write_text(rewritten) + copied += 1 + print(f"Vendored {copied} tool files to {DST_TOOLS}") + + # Redact + redact_src = SRC_REDACT.read_text() + (DST_PKG / "redact.py").write_text(rewrite(redact_src)) + print(f"Vendored redact.py to {DST_PKG}/redact.py") + + # Sanity check: no remaining `from agent.` imports + leftover = [] + for f in DST_PKG.rglob("*.py"): + for i, line in enumerate(f.read_text().splitlines(), start=1): + if re.search(r"\bfrom agent\.|\bimport agent\.", line): + leftover.append(f"{f}:{i}: {line.strip()}") + if leftover: + print("WARNING — leftover agent.* imports:", file=sys.stderr) + for line in leftover: + print(f" {line}", file=sys.stderr) + return 2 + + print("OK — no leftover agent.* imports") + return 0 + + +if __name__ == "__main__": + sys.exit(main())