diff --git a/.gitignore b/.gitignore index e9abee3..bf59b5d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ node_modules .env -tmp \ No newline at end of file +tmp +.venv +__pycache__ \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index c4debf4..c67c81f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -9,9 +9,10 @@ 1. Start at `src/cli//main.ts` and the matching `src/cli//README.md`. 2. Follow the pipeline classes under `src/cli//clients/*` and schemas under `src/cli//types/*`. 3. Reuse shared helpers: `src/utils/parse-args.ts`, `src/utils/question-handler.ts`, `src/clients/logger.ts`. -4. Keep changes minimal; add/update **Vitest** tests (`*.test.ts`) when behavior changes. -5. Run: `pnpm typecheck`, `pnpm lint`, `pnpm test` (and `pnpm format:check` if formatting changed). -6. All runtime artifacts go under `tmp/` (never commit them). +4. Keep `main.ts` focused on the basic agent flow; move non-trivial logic into `clients/` or `utils/`. +5. Keep changes minimal; add/update **Vitest** tests (`*.test.ts`) when behavior changes. +6. Run: `pnpm typecheck`, `pnpm lint`, `pnpm test` (and `pnpm format:check` if formatting changed). +7. All runtime artifacts go under `tmp/` (never commit them). **Scratch space:** Use `tmp/` for generated HTML/markdown/JSON/reports. @@ -31,6 +32,13 @@ - Install deps: `pnpm install` - Set `OPENAI_API_KEY` via env or `.env` (humans do this; agents must not read secrets) - If a task requires Playwright, follow the repo README for system deps +- If a task requires Python (e.g., `etf-backtest`), set up the venv: + ```bash + # On Debian/Ubuntu, install venv support first: sudo apt install python3-venv + python3 -m venv .venv + source .venv/bin/activate + pip install numpy pandas torch + ``` **Common scripts (see `package.json` for all):** @@ -86,6 +94,9 @@ All file tools are sandboxed to `tmp/` using path validation (`src/tools/utils/f - **`listFiles`** (`src/tools/list-files/list-files-tool.ts`) - Lists files/dirs under `tmp/`. - Params: `{ path?: string }` (defaults to `tmp/` root) +- **`runPython`** (`src/tools/run-python/run-python-tool.ts`) + - Runs a Python script from a configured scripts directory. + - Params: `{ scriptName: string, input: string }` (input is JSON string; pass `""` for no input) ### Safe web fetch tool @@ -99,9 +110,16 @@ All file tools are sandboxed to `tmp/` using path validation (`src/tools/utils/f ## 5) Coding conventions (how changes should look) - Initialize `Logger` in CLI entry points and pass it into clients/pipelines via constructor options. +- Use `Logger` instead of `console.log`/`console.error` for output. +- Use `AgentRunner` (`src/clients/agent-runner.ts`) as the default wrapper when running agents. - Prefer shared helpers in `src/utils` (`parse-args`, `question-handler`) over custom logic. +- `main.ts` should stay focused on the **basic agent flow**: argument parsing → agent setup → run loop → final output. Move helper logic into `clients/` or `utils/` - Prefer TypeScript path aliases over deep relative imports: `~tools/*`, `~clients/*`, `~utils/*`. - Use Zod schemas for CLI args and tool IO. +- Keep object field names in `camelCase` (e.g., `trainSamples`), not `snake_case`. +- Keep Zod schemas in a dedicated `schemas.ts` file for each CLI (avoid inline schemas in `main.ts`). +- Keep constants in a dedicated `constants.ts` file for each CLI. +- Move hardcoded numeric values into `constants.ts` (treat numbers as configuration). - For HTTP fetching in code, prefer `Fetch` (sanitized) or `PlaywrightScraper` for JS-heavy pages. - When adding tools that touch files, use `src/tools/utils/fs.ts` for path validation. - Comments should capture invariants or subtle behavior, not restate code. @@ -127,3 +145,7 @@ All file tools are sandboxed to `tmp/` using path validation (`src/tools/utils/f - [ ] Any generated artifacts are in `tmp/` only --- + +# ExecPlans + +When writing complex features or significant refactors, use an ExecPlan (as described in `agent/PLANS.md`) from design to implementation. diff --git a/README.md b/README.md index 3ceff76..2a5ca6e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # cli-agent-sandbox -A minimal TypeScript CLI sandbox for testing agent workflows and safe web scraping. This is a single-package repo built with [`@openai/agents`](https://github.com/openai/openai-agents-js), and it includes a guestbook demo, a Finnish name explorer CLI, a publication scraping pipeline with a Playwright-based scraper for JS-rendered pages, and agent tools scoped to `tmp` with strong safety checks. +A minimal TypeScript CLI sandbox for testing agent workflows and safe web scraping. This is a single-package repo built with [`@openai/agents`](https://github.com/openai/openai-agents-js), and it includes a guestbook demo, a Finnish name explorer CLI, a publication scraping pipeline with a Playwright-based scraper for JS-rendered pages, an ETF backtest CLI, and agent tools scoped to `tmp` with strong safety checks. ## Quick Start @@ -11,20 +11,33 @@ A minimal TypeScript CLI sandbox for testing agent workflows and safe web scrapi 5. Run the demo: `pnpm run:guestbook` 6. (Optional) Explore Finnish name stats: `pnpm run:name-explorer -- --mode ai|stats` 7. (Optional) Run publication scraping: `pnpm run:scrape-publications -- --url="https://example.com"` +8. (Optional) Run ETF backtest: `pnpm run:etf-backtest -- --isin=IE00B5BMR087` (requires Python setup below) + +### Python Setup (for ETF backtest) + +```bash +# On Debian/Ubuntu, install venv support first: +sudo apt install python3-venv + +python3 -m venv .venv +source .venv/bin/activate +pip install numpy pandas torch +``` ## Commands -| Command | Description | -| ------------------------------ | ------------------------------------------------- | -| `pnpm run:guestbook` | Run the interactive guestbook CLI demo | -| `pnpm run:name-explorer` | Explore Finnish name statistics (AI Q&A or stats) | -| `pnpm run:scrape-publications` | Scrape publication links and build a review page | -| `pnpm typecheck` | Run TypeScript type checking | -| `pnpm lint` | Run ESLint for code quality | -| `pnpm lint:fix` | Run ESLint and auto-fix issues | -| `pnpm format` | Format code with Prettier | -| `pnpm format:check` | Check code formatting | -| `pnpm test` | Run Vitest test suite | +| Command | Description | +| ------------------------------ | ------------------------------------------------------ | +| `pnpm run:guestbook` | Run the interactive guestbook CLI demo | +| `pnpm run:name-explorer` | Explore Finnish name statistics (AI Q&A or stats) | +| `pnpm run:scrape-publications` | Scrape publication links and build a review page | +| `pnpm run:etf-backtest` | Run ETF backtest + feature optimizer (requires Python) | +| `pnpm typecheck` | Run TypeScript type checking | +| `pnpm lint` | Run ESLint for code quality | +| `pnpm lint:fix` | Run ESLint and auto-fix issues | +| `pnpm format` | Format code with Prettier | +| `pnpm format:check` | Check code formatting | +| `pnpm test` | Run Vitest test suite | ## Publication scraping @@ -46,7 +59,7 @@ The publication pipeline uses `PlaywrightScraper` to render JavaScript-heavy pag The `run:name-explorer` script explores Finnish name statistics. It supports an AI Q&A mode (default) backed by SQL tools, plus a `stats` mode that generates an HTML report. -![Name Explorer demo](src/cli/name-explorer/demo-1.png) +Name Explorer demo Usage: @@ -56,22 +69,55 @@ pnpm run:name-explorer -- [--mode ai|stats] [--refetch] Outputs are written under `tmp/name-explorer/`, including `statistics.html` in stats mode. +## ETF backtest + +The `run:etf-backtest` CLI fetches ETF history from justetf.com (via Playwright), caches it under +`tmp/etf-backtest//data.json`, and runs the Python experiment loop via the `runPython` tool. + +ETF Backtest demo + +Usage: + +``` +pnpm run:etf-backtest -- --isin=IE00B5BMR087 [--maxIterations=5] [--seed=42] [--refresh] [--verbose] +``` + +Notes: + +- `--refresh` forces a refetch; otherwise cached data is reused. +- Python scripts live in `src/cli/etf-backtest/scripts/`. + ## Tools -File tools are sandboxed to the `tmp/` directory with path validation to prevent traversal and symlink attacks. The `fetchUrl` tool adds SSRF protections and HTML sanitization. +File tools are sandboxed to the `tmp/` directory with path validation to prevent traversal and symlink attacks. The `fetchUrl` tool adds SSRF protections and HTML sanitization, and `runPython` executes whitelisted Python scripts from a configured directory. + +| Tool | Location | Description | +| ----------- | ----------------------------------------- | ------------------------------------------------------------------------------ | +| `fetchUrl` | `src/tools/fetch-url/fetch-url-tool.ts` | Fetches URLs safely and returns sanitized Markdown/text | +| `readFile` | `src/tools/read-file/read-file-tool.ts` | Reads file content from `tmp` directory | +| `writeFile` | `src/tools/write-file/write-file-tool.ts` | Writes content to files in `tmp` directory | +| `listFiles` | `src/tools/list-files/list-files-tool.ts` | Lists files and directories under `tmp` | +| `runPython` | `src/tools/run-python/run-python-tool.ts` | Runs Python scripts from a configured scripts directory (JSON stdin supported) | + +`runPython` details: -| Tool | Location | Description | -| ----------- | ----------------------------------------- | ------------------------------------------------------- | -| `fetchUrl` | `src/tools/fetch-url/fetch-url-tool.ts` | Fetches URLs safely and returns sanitized Markdown/text | -| `readFile` | `src/tools/read-file/read-file-tool.ts` | Reads file content from `tmp` directory | -| `writeFile` | `src/tools/write-file/write-file-tool.ts` | Writes content to files in `tmp` directory | -| `listFiles` | `src/tools/list-files/list-files-tool.ts` | Lists files and directories under `tmp` | +- `scriptName` must be a `.py` file name in the configured scripts directory (no subpaths). +- `input` is a JSON string passed to stdin (use `""` for no input). ## Project Structure ``` src/ ├── cli/ +│ ├── etf-backtest/ +│ │ ├── main.ts # ETF backtest CLI entry point +│ │ ├── README.md # ETF backtest docs +│ │ ├── constants.ts # CLI constants +│ │ ├── schemas.ts # CLI args + agent output schemas +│ │ ├── clients/ # Data fetcher + Playwright capture +│ │ ├── utils/ # Scoring + formatting helpers +│ │ ├── types/ # ETF data types +│ │ └── scripts/ # Python backtest + prediction scripts │ ├── guestbook/ │ │ ├── main.ts # Guestbook CLI entry point │ │ └── README.md # Guestbook CLI docs @@ -90,15 +136,16 @@ src/ ├── clients/ │ ├── fetch.ts # Shared HTTP fetch + sanitization │ ├── logger.ts # Shared console logger +│ ├── agent-runner.ts # Default agent runner wrapper │ └── playwright-scraper.ts # Playwright-based web scraper ├── utils/ │ ├── parse-args.ts # Shared CLI arg parsing helper │ └── question-handler.ts # Shared CLI prompt + validation helper ├── tools/ -│ ├── index.ts # Tool exports │ ├── fetch-url/ # Safe fetch tool │ ├── list-files/ # List files tool │ ├── read-file/ # Read file tool +│ ├── run-python/ # Run Python scripts tool │ ├── write-file/ # Write file tool │ └── utils/ │ ├── fs.ts # Path safety utilities @@ -111,6 +158,7 @@ tmp/ # Runtime scratch space (tool I/O) ## CLI conventions - When using `Logger`, initialize it in the CLI entry point and pass it into clients/pipelines via constructor options. +- Use `AgentRunner` (`src/clients/agent-runner.ts`) as the default wrapper when running agents. - Prefer shared helpers in `src/utils` (`parse-args`, `question-handler`) over custom argument parsing or prompt logic. - Use the TypeScript path aliases for shared modules: `~tools/*`, `~clients/*`, `~utils/*`. Example: `import { readFileTool } from "~tools/read-file/read-file-tool";` diff --git a/agent/PLANS.md b/agent/PLANS.md new file mode 100644 index 0000000..ccc7566 --- /dev/null +++ b/agent/PLANS.md @@ -0,0 +1,136 @@ +# ExecPlans for cli-agent-sandbox + +This repo is a minimal TypeScript CLI sandbox. ExecPlans exist to make larger changes safe, reproducible, and testable by a novice who only has the repo and the plan. Keep plans tailored to this repository, not a generic template. + +Use an ExecPlan only for complex features or significant refactors. For small, localized changes, skip the plan and just implement. + +## Non-negotiables + +- Self-contained: the plan must include all context needed to execute it without external docs or prior plans. +- Observable outcomes: describe what a human can run and see to prove the change works. +- Living document: update the plan as work proceeds; never let it drift from reality. +- Repo-safe: never read `.env`, never write outside the repo or `tmp/`, never commit or push. +- Minimal, test-covered changes: update or add Vitest tests when behavior changes. + +## Repository context to embed in every plan + +Include a short orientation paragraph naming the key paths and how they relate: + +- Entry points live in `src/cli//main.ts` with a matching `src/cli//README.md`. +- Pipelines and clients live in `src/cli//clients/*`; schemas in `src/cli//types/*`. +- Shared helpers: `src/utils/parse-args.ts`, `src/utils/question-handler.ts`, `src/clients/logger.ts`. +- Tool sandboxing is under `src/tools/*` and path validation in `src/tools/utils/fs.ts`. +- Runtime artifacts belong under `tmp/` only. + +If the plan adds a new CLI, state that it must be scaffolded via: + + pnpm scaffold:cli -- --name=my-cli --description="What it does" + +Then add `"run:my-cli": "tsx src/cli/my-cli/main.ts"` to `package.json`. + +## Repo conventions to capture in plans (when relevant) + +- Initialize `Logger` in CLI entry points and pass it into clients/pipelines via constructor options. +- Use Zod schemas for CLI args and tool IO; name the schema files in the plan. +- Prefer TypeScript path aliases like `~tools/*`, `~clients/*`, `~utils/*` over deep relative imports. +- Avoid `index.ts` barrel exports; use explicit module paths. +- For HTTP fetching, prefer sanitized `Fetch` or `PlaywrightScraper` as appropriate. +- Any file-touching tool must use path validation from `src/tools/utils/fs.ts`. + +## Required sections in every ExecPlan + +Use these headings, in this order, and keep them up to date: + +1. **Purpose / Big Picture** — what the user gains and how they can see it working. +2. **Progress** — checklist with timestamps (UTC), split partial work into “done” vs “remaining”. +3. **Surprises & Discoveries** — unexpected behaviors or constraints with short evidence. +4. **Decision Log** — decision, rationale, date/author. +5. **Outcomes & Retrospective** — what was achieved, gaps, lessons learned. +6. **Context and Orientation** — repo-specific orientation and key files. +7. **Conventions and Contracts** — logging, schemas, imports, and tool safety expectations. +8. **Plan of Work** — prose describing edits, with precise file paths and locations. +9. **Concrete Steps** — exact commands to run (cwd included) and expected short outputs. +10. **Validation and Acceptance** — behavioral acceptance and tests; name new tests. +11. **Idempotence and Recovery** — how to rerun safely; rollback guidance if needed. +12. **Artifacts and Notes** — concise transcripts, diffs, or snippets as indented blocks. +13. **Interfaces and Dependencies** — required modules, types, function signatures, and why. + +## Formatting rules + +- The ExecPlan is a normal Markdown document (no outer code fence). +- Prefer prose over lists; the only mandatory checklist is in **Progress**. +- Define any non-obvious term the first time you use it. +- Use repo-relative paths and exact function/module names. +- Do not point to external docs; embed the needed context in the plan itself. + +## Validation defaults for this repo + +State which of these apply, and include expected outcomes: + +- `pnpm typecheck` +- `pnpm lint` (or `pnpm lint:fix` if auto-fixing is intended) +- `pnpm test` +- `pnpm format:check` (if formatting changes) + +If the change affects a CLI, include a concrete CLI invocation and expected output. + +## ExecPlan skeleton (copy and fill) + + # + + This ExecPlan is a living document. Update **Progress**, **Surprises & Discoveries**, **Decision Log**, and **Outcomes & Retrospective** as work proceeds. + + ## Purpose / Big Picture + + Describe the user-visible behavior and how to observe it. + + ## Progress + + - [ ] (2026-01-25 00:00Z) Example incomplete step. + + ## Surprises & Discoveries + + - Observation: … + Evidence: … + + ## Decision Log + + - Decision: … + Rationale: … + Date/Author: … + + ## Outcomes & Retrospective + + Summarize results, gaps, and lessons learned. + + ## Context and Orientation + + Explain the relevant parts of `src/cli/...`, shared helpers, and tools. + + ## Conventions and Contracts + + Call out logging, Zod schemas, imports, and any tool safety expectations. + + ## Plan of Work + + Prose description of edits with precise file paths and locations. + + ## Concrete Steps + + State commands with cwd and short expected outputs. + + ## Validation and Acceptance + + Behavioral acceptance plus test commands and expectations. + + ## Idempotence and Recovery + + How to rerun safely and roll back if needed. + + ## Artifacts and Notes + + Short transcripts, diffs, or snippets as indented blocks. + + ## Interfaces and Dependencies + + Required types/modules/functions and why they exist. diff --git a/eslint.config.ts b/eslint.config.ts index c8d2d0a..c29cb54 100644 --- a/eslint.config.ts +++ b/eslint.config.ts @@ -89,6 +89,22 @@ export default defineConfig( ], }, ], + // Avoid template literals in logger calls for better structured logging + "no-restricted-syntax": [ + "error", + { + selector: + "CallExpression[callee.type='MemberExpression'][callee.object.name='logger'][callee.property.name=/^(debug|info|warn|error|tool|question|answer)$/] > TemplateLiteral", + message: + "Avoid template literals in logger calls. Use a plain string and pass data as extra args (e.g. logger.info('Saved file', { path })).", + }, + { + selector: + "CallExpression[callee.type='MemberExpression'][callee.object.type='MemberExpression'][callee.object.property.name='logger'][callee.property.name=/^(debug|info|warn|error|tool|question|answer)$/] > TemplateLiteral", + message: + "Avoid template literals in logger calls. Use a plain string and pass data as extra args (e.g. logger.info('Saved file', { path })).", + }, + ], }, }, { diff --git a/package.json b/package.json index bd0366e..e7a9c88 100644 --- a/package.json +++ b/package.json @@ -7,7 +7,8 @@ "run:guestbook": "tsx src/cli/guestbook/main.ts", "run:name-explorer": "pnpm -s node:tsx -- src/cli/name-explorer/main.ts", "run:scrape-publications": "tsx src/cli/scrape-publications/main.ts", - "scaffold:cli": "pnpm -s node:tsx -- scripts/scaffold-cli.ts", + "run:etf-backtest": "tsx src/cli/etf-backtest/main.ts", + "scaffold:cli": "tsx scripts/scaffold-cli.ts", "node:tsx": "node --disable-warning=ExperimentalWarning --import tsx", "typecheck": "tsc --noEmit", "lint": "eslint .", diff --git a/scripts/scaffold-cli.ts b/scripts/scaffold-cli.ts index 40f43dd..1adb30f 100644 --- a/scripts/scaffold-cli.ts +++ b/scripts/scaffold-cli.ts @@ -4,12 +4,14 @@ * Scaffold a new CLI from the basic template. * * Usage: - * pnpm scaffold:cli -- --name=my-cli --description="My CLI description" - * pnpm scaffold:cli -- --name=my-cli # description defaults to "TODO: Add description" + * pnpm scaffold:cli --name=my-cli --description="My CLI description" + * pnpm scaffold:cli --name=my-cli # description defaults to "TODO: Add description" */ import fs from "node:fs/promises"; import path from "node:path"; -import { argv } from "zx"; +import { Logger } from "~clients/logger"; +import { parseArgs } from "~utils/parse-args"; +import { z } from "zod"; const TEMPLATE_DIR = path.join(process.cwd(), "templates", "cli-basic"); const CLI_DIR = path.join(process.cwd(), "src", "cli"); @@ -80,17 +82,14 @@ const copyTemplateFile = async ( }; const main = async (): Promise => { - const name = argv.name as string | undefined; - const description = - (argv.description as string | undefined) ?? "TODO: Add description"; + const logger = new Logger(); - if (!name) { - console.error("Error: --name is required"); - console.error( - 'Usage: pnpm scaffold:cli -- --name=my-cli --description="My description"' - ); - process.exit(1); - } + const argsSchema = z.object({ + name: z.string({ error: "Error: --name is required" }), + description: z.string().default("TODO: Add description"), + }); + + const { name, description } = parseArgs({ logger, schema: argsSchema }); validateCliName(name); diff --git a/src/cli/etf-backtest/README.md b/src/cli/etf-backtest/README.md new file mode 100644 index 0000000..d6b804f --- /dev/null +++ b/src/cli/etf-backtest/README.md @@ -0,0 +1,198 @@ +# ETF Backtest + +Iterative feature selection optimization agent for realistic 12-month ETF return predictions. + +The agent selects price-only features, runs experiments, and optimizes for **prediction accuracy** (not trading performance). It uses non-overlapping evaluation windows for honest assessment. + +![ETF Backtest demo](./demo-1.png) + +## Requirements + +- Python 3 with `numpy`, `pandas`, and `torch` installed (see repo README for setup) +- Playwright system deps (Chromium) for data fetch (see repo README) +- ETF data cached under `tmp/etf-backtest//data.json` (auto-fetched; use `--refresh`) + +## Run + +```bash +# Run optimization (default: 5 iterations max) +pnpm run:etf-backtest + +# With options +pnpm run:etf-backtest --isin=IE00B5BMR087 --maxIterations=5 --seed=42 --verbose --refresh +``` + +## Arguments + +| Argument | Default | Description | +| ----------------- | -------------- | ------------------------------------ | +| `--isin` | `IE00B5BMR087` | ETF ISIN (used to fetch/cached data) | +| `--maxIterations` | `5` | Maximum optimization iterations | +| `--seed` | `42` | Random seed for reproducibility | +| `--refresh` | `false` | Force refetch even if cache exists | +| `--verbose` | `false` | Enable verbose logging | + +## Feature Menu + +The agent selects 8-12 features from these categories: + +| Category | Features | +| ----------- | -------------------------------------------------------- | +| Momentum | `mom_1m`, `mom_3m`, `mom_6m`, `mom_12m` | +| Trend | `px_sma50`, `px_sma200`, `sma50_sma200`, `dist_52w_high` | +| Risk | `vol_1m`, `vol_3m`, `vol_6m`, `dd_current`, `mdd_12m` | +| Oscillators | `rsi_14`, `bb_width` | + +## How It Works + +1. **Agent selects features** from the menu (starts with 8-12) +2. **Runs experiment** via `run_experiment.py` (backtest + prediction) +3. **Analyzes results**: R² (non-overlapping), direction accuracy, MAE +4. **Persists learnings** after each iteration (per-ISIN history + best score) +5. **Decides**: continue with tweaked features or stop +6. **Stops early** if no improvement for 2 iterations + +## Flowchart + +```mermaid +flowchart TD + A["Start CLI"] --> B["Parse args (zod)"] + B --> C["Init Logger, Fetcher, LearningsManager"] + C --> D["Fetch ETF data (cache or refresh)"] + D --> E["Load or create learnings"] + E --> F["Build agent + runPython tool"] + F --> G{"Iteration < maxIterations?"} + + G -->|yes| H["Run AgentRunner (1 experiment)"] + H --> I["Parse agent JSON output"] + I --> J{"Valid output?"} + J -->|no| K["Prompt agent to fix format"] + K --> H + J -->|yes| L["Extract experiment results"] + L --> M["Compute score & update best"] + M --> N["Save learnings"] + N --> O{"Stop conditions?"} + O -->|agent final| P["Set stop reason"] + O -->|no improvement| P + O -->|continue| G + + G -->|no| P + P --> Q["Finalize learnings"] + Q --> R["Print final report"] + R --> S["Close fetcher"] + S --> T["Done"] +``` + +## Metrics + +### Prediction Accuracy (Primary - Optimization Target) + +| Metric | Description | +| --------------------------------- | -------------------------------------------------- | +| `r2NonOverlapping` | R² on non-overlapping 12-month windows (honest) | +| `directionAccuracyNonOverlapping` | Sign prediction accuracy on independent periods | +| `mae` | Mean absolute error of 12-month return predictions | +| `calibrationRatio` | Predicted std / actual std (target: 0.8-1.2) | + +### Backtest Metrics (Informational Only) + +| Metric | Description | +| ------------- | -------------------------------------- | +| `sharpe` | Sharpe ratio of daily trading strategy | +| `maxDrawdown` | Maximum peak-to-trough decline | +| `cagr` | Compound annual growth rate | + +### Why Non-Overlapping? + +With 252-day (12-month) forward targets, consecutive data points overlap by 99.6%. This inflates apparent R² because the model sees nearly identical targets. Non-overlapping evaluation uses truly independent periods (~10 samples per decade) for realistic performance assessment. + +## Output + +``` +============================================================ +OPTIMIZATION COMPLETE +============================================================ +Iterations: 3 +Best iteration: 2 +Stop reason: No improvement for 2 consecutive iterations + +Best Feature Set: + - mom_1m + - mom_3m + - vol_1m + - px_sma50 + ... + +Prediction Accuracy (Non-Overlapping - Honest Assessment): + R²: 0.045 + Direction Accuracy: 60.0% + Independent Samples: 10 + +Prediction Accuracy (Overlapping - Inflated): + R²: 0.152 + Direction Accuracy: 58.5% + MAE: 12.3% + Calibration: 0.95 + +Backtest Metrics (Informational): + Sharpe Ratio: 0.85 + Max Drawdown: -18.5% + CAGR: 12.3% + +12-Month Prediction: + Expected Return: 8.5% + 95% CI: [-12.5%, 29.5%] + +Uncertainty Details: + Base Std: 8.2% + Adjusted Std: 10.5% + Extrapolation: Yes (features outside training range) + +Confidence: MODERATE +Note: Non-overlapping metrics use only 10 independent periods. +Past performance does not guarantee future results. +============================================================ +``` + +Writes under `tmp/etf-backtest//`: + +- `data.json`: cached price series used for experiments +- `learnings.json`: per-iteration history and best-result tracking + +## Scripts + +| Script | Purpose | +| ------------------- | ------------------------------------------------- | +| `run_experiment.py` | Unified experiment runner (backtest + prediction) | +| `shared.py` | Feature registry and model training utilities | + +## Uncertainty Estimation + +The 95% confidence interval uses adjusted uncertainty that accounts for: + +1. **Base uncertainty**: Standard deviation of test set residuals +2. **Extrapolation penalty**: Increased when current features are >2 std from training mean +3. **Market floor**: Minimum 10% std (12-month returns are inherently uncertain) + +## Pitfall Avoidance + +- **Overlapping windows**: Evaluation uses non-overlapping periods for honest metrics +- **Lookahead**: Signal at t → position at t+1 +- **Data leakage**: Standardize using train mean/std only +- **No shuffle**: Chronological train/val/test split +- **Extrapolation**: Confidence intervals widen when features are outside training range + +## Data Format + +Expects `tmp/etf-backtest//data.json` when run via the CLI: + +```json +{ + "series": [ + { "date": "YYYY-MM-DD", "value": { "raw": } } + ] +} +``` + +If you run `run_experiment.py` directly, it defaults to `tmp/etf-backtest/data.json` +unless you pass `"dataPath"` in the JSON stdin payload. diff --git a/src/cli/etf-backtest/clients/etf-data-fetcher.ts b/src/cli/etf-backtest/clients/etf-data-fetcher.ts new file mode 100644 index 0000000..5bc3297 --- /dev/null +++ b/src/cli/etf-backtest/clients/etf-data-fetcher.ts @@ -0,0 +1,158 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import type { Logger } from "~clients/logger"; +import { PlaywrightScraper } from "~clients/playwright-scraper"; +import { resolveTmpPathForRead, resolveTmpPathForWrite } from "~tools/utils/fs"; + +import { + API_CAPTURE_TIMEOUT_MS, + ETF_CHART_PERIOD_KEY, + ETF_CHART_PERIOD_VALUE, + ETF_DATA_DIR, + ETF_DATA_FILENAME, + ETF_PROFILE_PATH, + getEtfApiPattern, + JUST_ETF_BASE_URL, +} from "../constants"; +import type { EtfDataResponse } from "../schemas"; +import { EtfDataResponseSchema, isEtfDataResponse } from "../schemas"; + +export type EtfDataFetcherConfig = { + logger: Logger; + headless?: boolean; +}; + +export type FetchResult = { + data: EtfDataResponse; + dataPath: string; + fromCache: boolean; +}; + +/** + * Fetches ETF data from justetf.com with caching. + * Data is stored in ISIN-specific folders under tmp/etf-backtest/{isin}/data.json. + */ +export class EtfDataFetcher { + private logger: Logger; + private scraper: PlaywrightScraper; + + constructor(config: EtfDataFetcherConfig) { + this.logger = config.logger; + this.scraper = new PlaywrightScraper({ + logger: config.logger, + headless: config.headless ?? true, + defaultWaitStrategy: "domcontentloaded", + }); + } + + /** + * Build the relative path for cached data (relative to tmp/). + */ + private getDataPath(isin: string): string { + return path.join(ETF_DATA_DIR, isin, ETF_DATA_FILENAME); + } + + /** + * Build the justetf.com profile URL for the given ISIN. + */ + private buildProfileUrl(isin: string): string { + return `${JUST_ETF_BASE_URL}${ETF_PROFILE_PATH}?isin=${encodeURIComponent(isin)}`; + } + + /** + * Check if cached data exists for the given ISIN. + */ + private async hasCachedData(isin: string): Promise { + try { + await resolveTmpPathForRead(this.getDataPath(isin)); + return true; + } catch { + return false; + } + } + + /** + * Load cached data from disk. + */ + private async loadCachedData(isin: string): Promise { + const dataPath = await resolveTmpPathForRead(this.getDataPath(isin)); + const content = await fs.readFile(dataPath, "utf8"); + const json = JSON.parse(content) as unknown; + return EtfDataResponseSchema.parse(json); + } + + /** + * Save data to disk. + */ + private async saveData(isin: string, data: EtfDataResponse): Promise { + const dataPath = await resolveTmpPathForWrite(this.getDataPath(isin)); + await fs.writeFile(dataPath, JSON.stringify(data, null, 2), "utf8"); + return dataPath; + } + + /** + * Fetch ETF data from justetf.com by navigating to the profile page + * and intercepting the API response. + */ + private async fetchFromWeb(isin: string): Promise { + const profileUrl = this.buildProfileUrl(isin); + this.logger.info("Fetching ETF data from justetf.com", { + isin, + url: profileUrl, + }); + + const result = await this.scraper.scrapeWithNetworkCapture( + { + targetUrl: profileUrl, + captureUrlPattern: getEtfApiPattern(isin), + captureTimeoutMs: API_CAPTURE_TIMEOUT_MS, + validateResponse: isEtfDataResponse, + localStorage: { + [ETF_CHART_PERIOD_KEY]: ETF_CHART_PERIOD_VALUE, + }, + } + ); + + const validated = EtfDataResponseSchema.parse(result.data); + + this.logger.info("Successfully fetched ETF data", { + isin, + seriesLength: validated.series.length, + latestDate: validated.latestQuoteDate, + capturedUrl: result.capturedUrl, + }); + + return validated; + } + + /** + * Fetch ETF data with caching support. + * Returns cached data if available, unless refresh is true. + */ + async fetch(isin: string, refresh: boolean): Promise { + const relativePath = this.getDataPath(isin); + + // Check cache unless refresh is requested + if (!refresh && (await this.hasCachedData(isin))) { + this.logger.info("Using cached ETF data", { isin }); + const data = await this.loadCachedData(isin); + const dataPath = await resolveTmpPathForRead(relativePath); + return { data, dataPath, fromCache: true }; + } + + // Fetch from web + const data = await this.fetchFromWeb(isin); + const dataPath = await this.saveData(isin, data); + + this.logger.info("Saved ETF data to cache", { isin, path: dataPath }); + + return { data, dataPath, fromCache: false }; + } + + /** + * Close the browser and release resources. + */ + async close(): Promise { + await this.scraper.close(); + } +} diff --git a/src/cli/etf-backtest/clients/learnings-manager.ts b/src/cli/etf-backtest/clients/learnings-manager.ts new file mode 100644 index 0000000..cee91cd --- /dev/null +++ b/src/cli/etf-backtest/clients/learnings-manager.ts @@ -0,0 +1,154 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import type { Logger } from "~clients/logger"; +import { resolveTmpPathForRead, resolveTmpPathForWrite } from "~tools/utils/fs"; + +import { + ETF_DATA_DIR, + LEARNINGS_FILENAME, + MAX_HISTORY_ITEMS, +} from "../constants"; +import type { ExperimentResult, IterationRecord, Learnings } from "../schemas"; +import { LearningsSchema } from "../schemas"; +import { computeScore } from "../utils/scoring"; + +export type LearningsManagerConfig = { + logger: Logger; +}; + +/** + * Manages learnings persistence for ETF backtest optimization. + * Stores iteration history and best results per ISIN. + */ +export class LearningsManager { + private logger: Logger; + + constructor(config: LearningsManagerConfig) { + this.logger = config.logger; + } + + /** + * Build the relative path for learnings file (relative to tmp/). + */ + private getLearningsPath(isin: string): string { + return path.join(ETF_DATA_DIR, isin, LEARNINGS_FILENAME); + } + + /** + * Load existing learnings from disk. + * Returns null if no learnings exist. + */ + async load(isin: string): Promise { + try { + const learningsPath = await resolveTmpPathForRead( + this.getLearningsPath(isin) + ); + const content = await fs.readFile(learningsPath, "utf8"); + const json = JSON.parse(content) as unknown; + const validated = LearningsSchema.parse(json); + this.logger.info("Loaded existing learnings", { + isin, + totalIterations: validated.totalIterations, + historyCount: validated.history.length, + }); + return validated; + } catch { + this.logger.info("No existing learnings found", { isin }); + return null; + } + } + + /** + * Create initial learnings structure for a new ISIN. + */ + createInitial(isin: string): Learnings { + const now = new Date().toISOString(); + return { + isin, + createdAt: now, + updatedAt: now, + totalIterations: 0, + bestResult: null, + history: [], + }; + } + + /** + * Add an iteration result to learnings. + * Updates bestResult if this iteration is better. + */ + addIteration( + learnings: Learnings, + iteration: number, + result: ExperimentResult + ): Learnings { + const score = computeScore(result.metrics); + const isBest = + learnings.bestResult === null || score > learnings.bestResult.score; + + const record: IterationRecord = { + iteration: learnings.totalIterations + iteration, + timestamp: new Date().toISOString(), + featureIds: result.featureIds, + score, + metrics: { + r2NonOverlapping: result.metrics.r2NonOverlapping, + directionAccuracyNonOverlapping: + result.metrics.directionAccuracyNonOverlapping, + mae: result.metrics.mae, + sharpe: result.metrics.sharpe, + }, + wasBest: isBest, + }; + + // Update history, trimming to max size (keep most recent) + const newHistory = [...learnings.history, record]; + if (newHistory.length > MAX_HISTORY_ITEMS) { + newHistory.shift(); + } + + const updatedLearnings: Learnings = { + ...learnings, + updatedAt: new Date().toISOString(), + history: newHistory, + }; + + // Update best result if this is better + if (isBest) { + updatedLearnings.bestResult = { + iteration: record.iteration, + featureIds: result.featureIds, + score, + metrics: record.metrics, + }; + } + + return updatedLearnings; + } + + /** + * Increment total iterations counter (called at end of run). + */ + finishRun(learnings: Learnings, iterationsCompleted: number): Learnings { + return { + ...learnings, + totalIterations: learnings.totalIterations + iterationsCompleted, + updatedAt: new Date().toISOString(), + }; + } + + /** + * Save learnings to disk. + */ + async save(isin: string, learnings: Learnings): Promise { + const learningsPath = await resolveTmpPathForWrite( + this.getLearningsPath(isin) + ); + await fs.writeFile( + learningsPath, + JSON.stringify(learnings, null, 2), + "utf8" + ); + this.logger.debug("Saved learnings", { isin, path: learningsPath }); + } +} diff --git a/src/cli/etf-backtest/constants.ts b/src/cli/etf-backtest/constants.ts new file mode 100644 index 0000000..505acf4 --- /dev/null +++ b/src/cli/etf-backtest/constants.ts @@ -0,0 +1,102 @@ +import path from "node:path"; + +export const DEFAULT_VERBOSE = false; +export const DEFAULT_MAX_ITERATIONS = 5; +export const DEFAULT_SEED = 42; +export const DEFAULT_REFRESH = false; + +// Default ISIN: iShares Core S&P 500 UCITS ETF +export const DEFAULT_ISIN = "IE00B5BMR087"; + +// justetf.com configuration +export const JUST_ETF_BASE_URL = "https://www.justetf.com"; +export const ETF_PROFILE_PATH = "/en/etf-profile.html"; +// Match performance-chart requests with full historical data (dateFrom before 2020) +export const getEtfApiPattern = (isin: string): RegExp => + new RegExp(`/api/etfs/${isin}/performance-chart.*dateFrom=(19|200|201)`); +export const API_CAPTURE_TIMEOUT_MS = 15000; + +// localStorage key to set chart period to MAX for full historical data +export const ETF_CHART_PERIOD_KEY = "etfProfileChart.defaultPeriod"; +export const ETF_CHART_PERIOD_VALUE = "MAX"; + +// Data storage paths (relative to tmp/) +export const ETF_DATA_DIR = "etf-backtest"; +export const ETF_DATA_FILENAME = "data.json"; + +export const MAX_NO_IMPROVEMENT = 2; +export const ZERO = 0; +export const MAX_TURNS_PER_ITERATION = 3; + +export const MIN_FEATURES = 8; +export const MAX_FEATURES = 12; +export const PREDICTION_HORIZON_MONTHS = 12; +export const OVERLAP_PERCENT = 99; +export const SAMPLES_PER_DECADE = 10; +export const CI_LEVEL_PERCENT = 95; + +export const TARGET_R2_NON_OVERLAPPING = 0.05; +export const TARGET_DIR_ACC_NON_OVERLAPPING = 0.55; +export const TARGET_CALIBRATION_MIN = 0.8; +export const TARGET_CALIBRATION_MAX = 1.2; + +export const SCORE_WEIGHTS = { + r2NonOverlapping: 2, + directionAccuracyNonOverlapping: 1, + mae: -2, +} as const; + +export const NEGATIVE_SHARPE_THRESHOLD = 0; +export const NEGATIVE_SHARPE_PENALTY = -0.5; + +export const CONFIDENCE_THRESHOLDS = { + moderate: { + r2NonOverlapping: 0.03, + directionAccuracyNonOverlapping: 0.5, + maxCiWidth: 0.5, + }, + reasonable: { + r2NonOverlapping: 0.08, + directionAccuracyNonOverlapping: 0.6, + maxCiWidth: 0.4, + }, +} as const; + +export const PERCENT_MULTIPLIER = 100; + +export const DECIMAL_PLACES = { + r2: 3, + percent: 1, + calibration: 2, + sharpe: 2, + cagr: 1, + score: 3, +} as const; + +export const LINE_WIDTH = 60; +export const LINE_SEPARATOR = "=".repeat(LINE_WIDTH); + +export const NO_IMPROVEMENT_REASON = `No improvement for ${MAX_NO_IMPROVEMENT} consecutive iterations`; + +export const INDEX_NOT_FOUND = -1; +export const JSON_SLICE_END_OFFSET = 1; + +export const SCRIPTS_DIR = path.join( + process.cwd(), + "src", + "cli", + "etf-backtest", + "scripts" +); + +export const FEATURE_MENU = { + momentum: ["mom_1m", "mom_3m", "mom_6m", "mom_12m"], + trend: ["px_sma50", "px_sma200", "sma50_sma200", "dist_52w_high"], + risk: ["vol_1m", "vol_3m", "vol_6m", "dd_current", "mdd_12m"], + oscillators: ["rsi_14", "bb_width"], +} as const; + +// Learnings configuration +export const LEARNINGS_FILENAME = "learnings.json"; +export const MAX_HISTORY_ITEMS = 20; +export const LEARNINGS_SUMMARY_TOP_N = 5; diff --git a/src/cli/etf-backtest/demo-1.png b/src/cli/etf-backtest/demo-1.png new file mode 100644 index 0000000..38c0070 Binary files /dev/null and b/src/cli/etf-backtest/demo-1.png differ diff --git a/src/cli/etf-backtest/main.ts b/src/cli/etf-backtest/main.ts new file mode 100644 index 0000000..e11aca4 --- /dev/null +++ b/src/cli/etf-backtest/main.ts @@ -0,0 +1,347 @@ +// pnpm run:etf-backtest + +// Iterative ETF feature selection optimization agent +// Runs experiments with different feature combinations and finds the best set + +import "dotenv/config"; + +import { AgentRunner } from "~clients/agent-runner"; +import { Logger } from "~clients/logger"; +import { createRunPythonTool } from "~tools/run-python/run-python-tool"; +import { parseArgs } from "~utils/parse-args"; + +import { EtfDataFetcher } from "./clients/etf-data-fetcher"; +import { LearningsManager } from "./clients/learnings-manager"; +import { + DECIMAL_PLACES, + FEATURE_MENU, + MAX_FEATURES, + MAX_NO_IMPROVEMENT, + MAX_TURNS_PER_ITERATION, + MIN_FEATURES, + NO_IMPROVEMENT_REASON, + OVERLAP_PERCENT, + PREDICTION_HORIZON_MONTHS, + SAMPLES_PER_DECADE, + SCRIPTS_DIR, + TARGET_CALIBRATION_MAX, + TARGET_CALIBRATION_MIN, + TARGET_DIR_ACC_NON_OVERLAPPING, + TARGET_R2_NON_OVERLAPPING, + ZERO, +} from "./constants"; +import { AgentOutputSchema, CliArgsSchema } from "./schemas"; +import type { ExperimentResult, Learnings } from "./schemas"; +import { extractLastExperimentResult } from "./utils/experiment-extract"; +import { printFinalResults } from "./utils/final-report"; +import { formatFixed, formatPercent } from "./utils/formatters"; +import { formatLearningsForPrompt } from "./utils/learnings-formatter"; +import { + buildRecoveryPrompt, + buildRunPythonUsage, +} from "./utils/prompt-builders"; +import { computeScore } from "./utils/scoring"; + +const logger = new Logger(); + +// --- Parse CLI arguments --- +const { verbose, isin, refresh, maxIterations, seed } = parseArgs({ + logger, + schema: CliArgsSchema, +}); + +// --- Build agent instructions --- +const buildInstructions = () => ` +You are an ETF feature selection optimization agent. Your goal is to find features that produce **accurate ${PREDICTION_HORIZON_MONTHS}-month return predictions**, not optimal trading strategies. + +## Important Distinction +- **Prediction accuracy** (R², direction accuracy, MAE) = Can we forecast the ${PREDICTION_HORIZON_MONTHS}-month return? +- **Trading performance** (Sharpe, drawdown) = Is this a good trading strategy? + +You are optimizing for PREDICTION ACCURACY. Trading metrics are informational only. + +## Feature Menu +Choose ${MIN_FEATURES}-${MAX_FEATURES} features from the following categories: + +**Momentum (price-based returns over periods):** +${FEATURE_MENU.momentum.map((f) => `- ${f}`).join("\n")} + +**Trend (price relative to moving averages):** +${FEATURE_MENU.trend.map((f) => `- ${f}`).join("\n")} + +**Risk (volatility and drawdown measures):** +${FEATURE_MENU.risk.map((f) => `- ${f}`).join("\n")} + +**Oscillators (optional, technical indicators):** +${FEATURE_MENU.oscillators.map((f) => `- ${f}`).join("\n")} + +## Metrics Priority (most to least important) +1. **r2NonOverlapping** - R² on non-overlapping ${PREDICTION_HORIZON_MONTHS}-month windows (honest assessment). Target > ${TARGET_R2_NON_OVERLAPPING} +2. **directionAccuracyNonOverlapping** - Did we predict the sign correctly? Target > ${formatPercent( + TARGET_DIR_ACC_NON_OVERLAPPING +)} +3. **mae** - Mean absolute error of predictions. Lower is better +4. **calibrationRatio** - Is predicted magnitude realistic? Target ${TARGET_CALIBRATION_MIN}-${TARGET_CALIBRATION_MAX} + +## Non-Overlapping vs Overlapping Metrics +- **Non-overlapping metrics** use truly independent ${PREDICTION_HORIZON_MONTHS}-month periods (~${SAMPLES_PER_DECADE} samples per decade) +- **Overlapping metrics** use all data but windows overlap ${OVERLAP_PERCENT}%, inflating apparent performance +- Focus on NON-OVERLAPPING metrics for realistic assessment + +## Backtest Metrics (informational only) +- Sharpe ratio: If negative, features may be problematic (sanity check) +- Max drawdown: Not an optimization target + +## Feature Selection Guidelines +For ${PREDICTION_HORIZON_MONTHS}-month predictions: +- **Momentum features** (mom_*) capture recent trends but may mean-revert over ${PREDICTION_HORIZON_MONTHS} months +- **Trend features** (px_sma*, sma50_sma200) show long-term direction +- **Risk features** (vol_*, dd_*, mdd_*) capture volatility regimes + +Be skeptical of high R² values - with overlapping windows, apparent fit is inflated. +Focus on features with economic intuition for ${PREDICTION_HORIZON_MONTHS}-month horizons. + +## Tool Usage +IMPORTANT: Run exactly ONE experiment per turn. Do not run multiple experiments. + +Call runPython with: +- scriptName: "run_experiment.py" +- input: { "featureIds": [...], "seed": , "dataPath": "" } + +After you receive results, respond with your analysis. Do not call runPython again in the same turn. + +## Response Format +After each experiment, respond with JSON (do not call any more tools): +{ + "status": "continue" | "final", + "selectedFeatures": ["feature1", "feature2", ...], + "reasoning": "Explain your analysis focusing on prediction accuracy", + "stopReason": "Explain why stopping if final, otherwise null" +} +`; + +// --- Run iterative optimization --- +const runAgentOptimization = async ( + dataPath: string, + initialLearnings: Learnings, + learningsManager: LearningsManager +) => { + const runPythonTool = createRunPythonTool({ + scriptsDir: SCRIPTS_DIR, + logger, + }); + + const agentRunner = new AgentRunner({ + name: "EtfFeatureOptimizer", + model: "gpt-5-mini", + tools: [runPythonTool], + outputType: AgentOutputSchema, + instructions: buildInstructions(), + logger, + logToolResults: verbose, + stateless: true, // Required for reasoning models to avoid "reasoning item without following item" errors + }); + + // Track state + let bestResult: ExperimentResult | null = null; + let bestIteration = ZERO; + let bestScore = Number.NEGATIVE_INFINITY; + let noImprovementCount = ZERO; + let iteration = ZERO; + let stopReason = "Max iterations reached"; + let learnings = initialLearnings; + + // Initial prompt with learnings context + const learningsSummary = formatLearningsForPrompt(learnings); + const runPythonUsage = buildRunPythonUsage({ seed, dataPath }); + let currentPrompt = ` +Start feature selection optimization for ISIN ${isin}. +${learningsSummary} +Begin by selecting ${MIN_FEATURES}-${MAX_FEATURES} features that you think will best predict ${PREDICTION_HORIZON_MONTHS}-month returns. +Consider using a mix from each category (momentum, trend, risk). + +${runPythonUsage} + +After running the experiment, analyze the results and decide whether to continue or stop. +`; + + while (iteration < maxIterations) { + iteration++; + logger.info("--- Iteration ---", { iteration, maxIterations }); + + let runResult; + try { + runResult = await agentRunner.run({ + prompt: currentPrompt, + maxTurns: MAX_TURNS_PER_ITERATION, // Limit turns per iteration: 1 tool call + 1 result + 1 output + }); + } catch (err) { + // Handle MaxTurnsExceededError - try to extract result from partial state + if ( + err && + typeof err === "object" && + "state" in err && + err.state && + typeof err.state === "object" && + "_newItems" in err.state + ) { + logger.warn("Agent exceeded turn limit, extracting partial results..."); + const state = err.state as { + _newItems?: { type: string; output?: unknown }[]; + }; + const partialResult = extractLastExperimentResult({ + newItems: state._newItems, + }); + if (partialResult) { + const score = computeScore(partialResult.metrics); + if (score > bestScore) { + bestScore = score; + bestResult = partialResult; + bestIteration = iteration; + } + } + currentPrompt = buildRecoveryPrompt( + "You ran too many experiments in one turn. Please run exactly ONE experiment, then respond with your JSON analysis.", + { seed, dataPath } + ); + continue; + } + throw err; + } + const parseResult = AgentOutputSchema.safeParse(runResult.finalOutput); + + if (!parseResult.success) { + logger.warn("Invalid agent response format, continuing..."); + if (verbose) { + logger.debug("Parse error", { error: parseResult.error }); + } + currentPrompt = buildRecoveryPrompt( + "Your response was not valid JSON. Please respond with the correct format.", + { seed, dataPath } + ); + continue; + } + + const output = parseResult.data; + logger.info("Features selected", { features: output.selectedFeatures }); + logger.info("Reasoning preview", { + preview: output.reasoning, + }); + + // Try to extract experiment result from the tool call outputs + const lastToolResult = extractLastExperimentResult(runResult); + + if (lastToolResult) { + const score = computeScore(lastToolResult.metrics); + logger.info("Prediction metrics", { + r2NonOverlapping: formatFixed( + lastToolResult.metrics.r2NonOverlapping, + DECIMAL_PLACES.r2 + ), + directionAccuracyNonOverlapping: formatPercent( + lastToolResult.metrics.directionAccuracyNonOverlapping + ), + mae: formatPercent(lastToolResult.metrics.mae), + score: formatFixed(score, DECIMAL_PLACES.score), + }); + if (verbose) { + logger.debug("Backtest metrics", { + sharpe: formatFixed( + lastToolResult.metrics.sharpe, + DECIMAL_PLACES.sharpe + ), + maxDrawdown: formatPercent(lastToolResult.metrics.maxDrawdown), + }); + } + + if (score > bestScore) { + bestScore = score; + bestResult = lastToolResult; + bestIteration = iteration; + noImprovementCount = ZERO; + logger.info("New best result!"); + } else { + noImprovementCount++; + logger.info("No improvement", { + noImprovementCount, + maxNoImprovement: MAX_NO_IMPROVEMENT, + }); + } + + // Record iteration in learnings and save progress + learnings = learningsManager.addIteration( + learnings, + iteration, + lastToolResult + ); + await learningsManager.save(isin, learnings); + } + + // Check stop conditions + if (output.status === "final") { + stopReason = output.stopReason ?? "Agent decided to stop"; + logger.info("Agent stopped", { stopReason }); + break; + } + + if (noImprovementCount >= MAX_NO_IMPROVEMENT) { + stopReason = NO_IMPROVEMENT_REASON; + logger.info(stopReason); + break; + } + + // Build next prompt with learnings context (required since stateless mode loses context) + const updatedLearningsSummary = formatLearningsForPrompt(learnings); + currentPrompt = ` +Continue feature selection optimization for ISIN ${isin}. +You have ${maxIterations - iteration} iterations remaining. + ${updatedLearningsSummary} + Based on your previous experiment, decide: + - If you want to try different features, select them and run another experiment + - If you think you've found a good set, respond with status "final" + +${runPythonUsage} + +Focus on: Higher r2NonOverlapping, higher directionAccuracyNonOverlapping, lower MAE. +Backtest metrics (Sharpe, drawdown) are informational only. +`; + } + + // Finalize learnings and save + learnings = learningsManager.finishRun(learnings, iteration); + await learningsManager.save(isin, learnings); + logger.info("Learnings saved", { + totalIterations: learnings.totalIterations, + historyCount: learnings.history.length, + }); + + // Output final results + if (bestResult) { + printFinalResults(logger, bestResult, bestIteration, iteration, stopReason); + } else { + logger.warn("No successful experiments completed."); + } +}; + +// --- Main --- +logger.info("ETF Backtest Feature Optimization starting...", { isin }); +if (verbose) { + logger.debug("Verbose mode enabled"); +} + +const fetcher = new EtfDataFetcher({ logger }); +const learningsManager = new LearningsManager({ logger }); + +try { + const { dataPath } = await fetcher.fetch(isin, refresh); + + // Load or create learnings + let learnings = await learningsManager.load(isin); + learnings ??= learningsManager.createInitial(isin); + + await runAgentOptimization(dataPath, learnings, learningsManager); +} finally { + await fetcher.close(); +} + +logger.info("\nETF Backtest completed."); diff --git a/src/cli/etf-backtest/schemas.ts b/src/cli/etf-backtest/schemas.ts new file mode 100644 index 0000000..7859d09 --- /dev/null +++ b/src/cli/etf-backtest/schemas.ts @@ -0,0 +1,139 @@ +import { z } from "zod"; + +import { + DEFAULT_ISIN, + DEFAULT_MAX_ITERATIONS, + DEFAULT_REFRESH, + DEFAULT_SEED, + DEFAULT_VERBOSE, +} from "./constants"; + +// ISIN validation: 2 letter country code + 10 alphanumeric characters +const IsinSchema = z + .string() + .regex(/^[A-Z]{2}[A-Z0-9]{10}$/, "Invalid ISIN format"); + +export const CliArgsSchema = z.object({ + verbose: z.coerce.boolean().default(DEFAULT_VERBOSE), + isin: IsinSchema.default(DEFAULT_ISIN), + refresh: z.coerce.boolean().default(DEFAULT_REFRESH), + maxIterations: z.coerce.number().default(DEFAULT_MAX_ITERATIONS), + seed: z.coerce.number().default(DEFAULT_SEED), +}); + +export type CliArgs = z.infer; + +export const AgentOutputSchema = z.object({ + status: z.enum(["continue", "final"]), + selectedFeatures: z.array(z.string()), + reasoning: z.string(), + stopReason: z.string().nullable(), +}); + +export type AgentOutput = z.infer; + +export const ExperimentResultSchema = z.object({ + featureIds: z.array(z.string()), + metrics: z.object({ + // Backtest metrics (informational) + sharpe: z.number(), + maxDrawdown: z.number(), + cagr: z.number(), + // Prediction metrics (overlapping) + r2: z.number(), + mse: z.number(), + directionAccuracy: z.number(), + mae: z.number(), + calibrationRatio: z.number(), + // Non-overlapping metrics (honest assessment) + r2NonOverlapping: z.number(), + directionAccuracyNonOverlapping: z.number(), + }), + prediction: z.object({ + pred12mReturn: z.number(), + ci95Low: z.number(), + ci95High: z.number(), + uncertainty: z.object({ + baseStd: z.number(), + adjustedStd: z.number(), + extrapolationMultiplier: z.number(), + isExtrapolating: z.boolean(), + }), + }), + modelInfo: z.object({ + trainSamples: z.number(), + valSamples: z.number(), + testSamples: z.number(), + }), + dataInfo: z.object({ + totalSamples: z.number(), + nonOverlappingSamples: z.number(), + effectiveIndependentPeriods: z.number(), + }), +}); + +export type ExperimentResult = z.infer; + +// Single iteration record - captures what was tried and what happened +export const IterationRecordSchema = z.object({ + iteration: z.number(), + timestamp: z.string(), + featureIds: z.array(z.string()), + score: z.number(), + metrics: z.object({ + r2NonOverlapping: z.number(), + directionAccuracyNonOverlapping: z.number(), + mae: z.number(), + sharpe: z.number(), + }), + wasBest: z.boolean(), +}); + +export type IterationRecord = z.infer; + +const BestResultSchema = IterationRecordSchema.omit({ + timestamp: true, + wasBest: true, +}); + +// Complete learnings file structure +export const LearningsSchema = z.object({ + isin: z.string(), + createdAt: z.string(), + updatedAt: z.string(), + totalIterations: z.number(), + bestResult: BestResultSchema.nullable(), + history: z.array(IterationRecordSchema), +}); + +export type Learnings = z.infer; + +// Value with raw number and localized string representation +export const LocalizedValueSchema = z.object({ + raw: z.number(), + localized: z.string(), +}); + +// Single data point in the time series +export const SeriesPointSchema = z.object({ + date: z.string(), // ISO format: "YYYY-MM-DD" + value: LocalizedValueSchema, +}); + +// Full API response from justetf.com +export const EtfDataResponseSchema = z.object({ + latestQuote: LocalizedValueSchema, + latestQuoteDate: z.string(), + price: LocalizedValueSchema, + performance: LocalizedValueSchema, + prevDaySeries: z.array(SeriesPointSchema), + series: z.array(SeriesPointSchema), +}); + +export type LocalizedValue = z.infer; +export type SeriesPoint = z.infer; +export type EtfDataResponse = z.infer; + +export const isEtfDataResponse = (data: unknown): data is EtfDataResponse => { + return EtfDataResponseSchema.safeParse(data).success; +}; diff --git a/src/cli/etf-backtest/scripts/run_experiment.py b/src/cli/etf-backtest/scripts/run_experiment.py new file mode 100644 index 0000000..5ae53c3 --- /dev/null +++ b/src/cli/etf-backtest/scripts/run_experiment.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python3 +""" +Single Experiment Runner for ETF Feature Selection + +Combines backtest and prediction into one script. +Accepts input via stdin JSON, outputs results as JSON to stdout. + +Input format: +{ + "featureIds": ["mom_1m", "mom_3m", "vol_1m", "px_sma50"], + "seed": 42 +} + +Output format: +{ + "featureIds": [...], + "metrics": { "sharpe": ..., "maxDrawdown": ..., "r2": ..., "mse": ..., "cagr": ... }, + "prediction": { "pred12mReturn": ..., "ci95Low": ..., "ci95High": ... } +} +""" + +import json +import sys +import numpy as np +import pandas as pd +import torch +from pathlib import Path + +from shared import ( + load_data, + build_selected_features, + add_forward_target, + split_data, + train_model, + FORWARD_DAYS, + ALL_FEATURE_IDS, +) + +# === CONFIG === +DEFAULT_DATA_PATH = Path(__file__).parent.parent.parent.parent.parent / "tmp" / "etf-backtest" / "data.json" +COST_BPS = 5 # transaction cost in basis points + + +def set_seed(seed: int): + """Set random seeds for reproducibility.""" + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def standardize( + train: pd.DataFrame, + val: pd.DataFrame, + test: pd.DataFrame, + feature_cols: list[str], +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Standardize features using train mean/std only.""" + X_train = train[feature_cols].values + X_val = val[feature_cols].values + X_test = test[feature_cols].values + + y_train = train["target"].values + y_val = val["target"].values + y_test = test["target"].values + + mean = X_train.mean(axis=0) + std = X_train.std(axis=0) + std[std == 0] = 1 + + X_train = (X_train - mean) / std + X_val = (X_val - mean) / std + X_test = (X_test - mean) / std + + return X_train, X_val, X_test, y_train, y_val, y_test, mean, std + + +def run_backtest(test_df: pd.DataFrame, predictions: np.ndarray) -> dict: + """Run backtest and compute metrics.""" + df = test_df.copy() + df["pred"] = predictions + + # Signal: pred > 0 -> long + df["signal"] = (df["pred"] > 0).astype(int) + + # Position with 1-day lag (avoids lookahead) + df["position"] = df["signal"].shift(1).fillna(0) + + # Need daily returns for backtest + df["daily_ret"] = df["price"].pct_change() + + # Strategy returns + df["strat_ret"] = df["position"] * df["daily_ret"] + + # Transaction costs + df["trade"] = df["position"].diff().abs().fillna(0) + df["cost"] = df["trade"] * (COST_BPS / 10000) + df["strat_ret_net"] = df["strat_ret"] - df["cost"] + + # Equity curve + df["equity"] = (1 + df["strat_ret_net"]).cumprod() + + # Metrics + returns = df["strat_ret_net"].dropna().values + equity = df["equity"].dropna().values + + if len(equity) < 2: + return {"sharpe": 0, "maxDrawdown": 0, "cagr": 0, "totalReturn": 0} + + total_return = equity[-1] / equity[0] - 1 + n_days = len(returns) + years = n_days / 252 + cagr = (equity[-1] ** (1 / years)) - 1 if years > 0 else 0 + + ann_vol = returns.std() * np.sqrt(252) + sharpe = (returns.mean() * 252) / ann_vol if ann_vol > 0 else 0 + + peak = np.maximum.accumulate(equity) + drawdown = (equity - peak) / peak + max_dd = drawdown.min() + + return { + "sharpe": float(sharpe), + "maxDrawdown": float(max_dd), + "cagr": float(cagr), + "totalReturn": float(total_return), + } + + +def compute_prediction( + model, + df_raw: pd.DataFrame, + feature_ids: list[str], + mean: np.ndarray, + std: np.ndarray, + uncertainty: dict, +) -> dict: + """Generate 12-month forward prediction with adjusted confidence interval.""" + device = next(model.parameters()).device + + # Build features for latest data point + df = build_selected_features(df_raw, feature_ids) + df = df.dropna(subset=feature_ids) + + if len(df) == 0: + raise ValueError("No valid feature rows for prediction") + + latest = df.iloc[-1] + features = latest[feature_ids].values.astype(np.float64) + + # Standardize using training statistics + features_std = (features - mean) / std + features_t = torch.tensor(features_std, dtype=torch.float32, device=device) + + model.eval() + with torch.no_grad(): + prediction = model(features_t).item() + + # Use adjusted std for more realistic confidence intervals + adjusted_std = uncertainty["adjustedStd"] + + return { + "pred12mReturn": float(prediction), + "ci95Low": float(prediction - 1.96 * adjusted_std), + "ci95High": float(prediction + 1.96 * adjusted_std), + "uncertainty": uncertainty, + } + + +def compute_test_metrics(y_test: np.ndarray, predictions: np.ndarray) -> dict: + """Compute R² and MSE on test set.""" + mse = float(np.mean((y_test - predictions) ** 2)) + ss_res = np.sum((y_test - predictions) ** 2) + ss_tot = np.sum((y_test - y_test.mean()) ** 2) + r2 = float(1 - ss_res / ss_tot) if ss_tot > 0 else 0.0 + return {"r2": r2, "mse": mse} + + +def compute_prediction_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict: + """Metrics focused on 12-month prediction quality.""" + direction_accuracy = float(np.mean((y_true > 0) == (y_pred > 0))) + mae = float(np.mean(np.abs(y_true - y_pred))) + pred_std = y_pred.std() + true_std = y_true.std() + calibration_ratio = float(pred_std / true_std) if true_std > 0 else 0.0 + return { + "directionAccuracy": direction_accuracy, + "mae": mae, + "calibrationRatio": calibration_ratio, + } + + +def compute_non_overlapping_metrics( + y_true: np.ndarray, y_pred: np.ndarray, forward_days: int = FORWARD_DAYS +) -> dict: + """Evaluate on non-overlapping windows for honest assessment.""" + n = len(y_true) + indices = list(range(0, n, forward_days)) + if len(indices) < 2: + return { + "r2NonOverlapping": 0.0, + "directionAccuracyNonOverlapping": 0.0, + "nonOverlappingSamples": len(indices), + } + + y_true_no = y_true[indices] + y_pred_no = y_pred[indices] + + # R² on non-overlapping samples + ss_res = np.sum((y_true_no - y_pred_no) ** 2) + ss_tot = np.sum((y_true_no - y_true_no.mean()) ** 2) + r2_no = float(1 - ss_res / ss_tot) if ss_tot > 0 else 0.0 + + # Direction accuracy on non-overlapping samples + dir_acc_no = float(np.mean((y_true_no > 0) == (y_pred_no > 0))) + + return { + "r2NonOverlapping": r2_no, + "directionAccuracyNonOverlapping": dir_acc_no, + "nonOverlappingSamples": len(indices), + } + + +def compute_uncertainty_adjusted( + test_preds: np.ndarray, + y_test: np.ndarray, + latest_features: np.ndarray, + train_mean: np.ndarray, + train_std: np.ndarray, +) -> dict: + """Uncertainty with extrapolation penalty and market floor.""" + residuals = y_test - test_preds + base_std = float(residuals.std()) + + # Extrapolation penalty if features are outside training distribution + z_scores = np.abs((latest_features - train_mean) / train_std) + max_z = float(z_scores.max()) + extrapolation_mult = 1.0 + 0.1 * max(0, max_z - 2) + + # Market floor: 12-month returns are inherently uncertain (~10% minimum) + MARKET_FLOOR = 0.10 + adjusted_std = max(base_std * extrapolation_mult, MARKET_FLOOR) + + return { + "baseStd": base_std, + "adjustedStd": float(adjusted_std), + "extrapolationMultiplier": float(extrapolation_mult), + "isExtrapolating": bool(max_z > 2), + } + + +def run_experiment(feature_ids: list[str], seed: int, data_path: Path) -> dict: + """Run a single experiment with given features.""" + set_seed(seed) + + # Load data + if not data_path.exists(): + raise FileNotFoundError(f"Data file not found: {data_path}") + + df_raw = load_data(data_path) + + # Build features and add forward target + df = build_selected_features(df_raw, feature_ids) + df = add_forward_target(df, FORWARD_DAYS) + + # Drop NaN rows + df = df.dropna(subset=feature_ids + ["target"]).reset_index(drop=True) + + if len(df) < 100: + raise ValueError(f"Insufficient data: only {len(df)} valid rows") + + # Split data + train, val, test = split_data(df) + + # Standardize + X_train, X_val, X_test, y_train, y_val, y_test, mean, std = standardize( + train, val, test, feature_ids + ) + + # Train model + model = train_model(X_train, y_train, X_val, y_val) + + # Get test predictions + device = next(model.parameters()).device + X_test_t = torch.tensor(X_test, dtype=torch.float32, device=device) + model.eval() + with torch.no_grad(): + test_preds = model(X_test_t).cpu().numpy() + + # Compute model metrics (overlapping) + model_metrics = compute_test_metrics(y_test, test_preds) + prediction_metrics = compute_prediction_metrics(y_test, test_preds) + + # Compute non-overlapping metrics for honest assessment + non_overlap_metrics = compute_non_overlapping_metrics(y_test, test_preds, FORWARD_DAYS) + + # Build features for latest data point to compute uncertainty + df_latest = build_selected_features(df_raw, feature_ids) + df_latest = df_latest.dropna(subset=feature_ids) + latest_features = df_latest.iloc[-1][feature_ids].values.astype(np.float64) + + # Compute adjusted uncertainty + uncertainty = compute_uncertainty_adjusted(test_preds, y_test, latest_features, mean, std) + + # Run backtest on test set (informational only) + backtest_metrics = run_backtest(test, test_preds) + + # Generate forward prediction with adjusted uncertainty + prediction = compute_prediction(model, df_raw, feature_ids, mean, std, uncertainty) + + return { + "featureIds": feature_ids, + "metrics": { + # Backtest metrics (informational, not optimization target) + "sharpe": backtest_metrics["sharpe"], + "maxDrawdown": backtest_metrics["maxDrawdown"], + "cagr": backtest_metrics["cagr"], + # Prediction metrics (overlapping) + "r2": model_metrics["r2"], + "mse": model_metrics["mse"], + "directionAccuracy": prediction_metrics["directionAccuracy"], + "mae": prediction_metrics["mae"], + "calibrationRatio": prediction_metrics["calibrationRatio"], + # Non-overlapping metrics (honest assessment) + "r2NonOverlapping": non_overlap_metrics["r2NonOverlapping"], + "directionAccuracyNonOverlapping": non_overlap_metrics["directionAccuracyNonOverlapping"], + }, + "prediction": prediction, + "modelInfo": { + "trainSamples": len(train), + "valSamples": len(val), + "testSamples": len(test), + }, + "dataInfo": { + "totalSamples": len(df), + "nonOverlappingSamples": non_overlap_metrics["nonOverlappingSamples"], + "effectiveIndependentPeriods": non_overlap_metrics["nonOverlappingSamples"], + }, + } + + +def main(): + # Read input from stdin + try: + input_data = json.load(sys.stdin) + except json.JSONDecodeError as e: + print(json.dumps({"error": f"Invalid JSON input: {e}"}), file=sys.stdout) + sys.exit(1) + + feature_ids = input_data.get("featureIds") + if feature_ids is None: + feature_ids = input_data.get("feature_ids", []) + seed = input_data.get("seed", 42) + data_path_str = input_data.get("dataPath") + data_path = Path(data_path_str) if data_path_str else DEFAULT_DATA_PATH + + # Validate featureIds + if not feature_ids: + print(json.dumps({"error": "featureIds is required and must not be empty"}), file=sys.stdout) + sys.exit(1) + + invalid = [f for f in feature_ids if f not in ALL_FEATURE_IDS] + if invalid: + print(json.dumps({ + "error": f"Unknown featureIds: {invalid}", + "validFeatures": ALL_FEATURE_IDS, + }), file=sys.stdout) + sys.exit(1) + + try: + result = run_experiment(feature_ids, seed, data_path) + print(json.dumps(result, indent=2), file=sys.stdout) + except Exception as e: + print(json.dumps({"error": str(e)}), file=sys.stdout) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/cli/etf-backtest/scripts/shared.py b/src/cli/etf-backtest/scripts/shared.py new file mode 100644 index 0000000..203e10b --- /dev/null +++ b/src/cli/etf-backtest/scripts/shared.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +""" +Shared utilities for ETF backtest and prediction scripts. + +Contains common code for data loading, feature engineering, and model training. +""" + +import json +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from pathlib import Path +from typing import Callable + +# === CONFIG === +TRAIN_RATIO = 0.70 +VAL_RATIO = 0.15 +# TEST_RATIO = 0.15 (implicit) + +LAGS = 20 # number of lagged returns +MA_SHORT = 10 # short moving average window +MA_LONG = 50 # long moving average window +VOL_WINDOW = 20 # rolling volatility window + +HIDDEN1 = 64 +HIDDEN2 = 32 +DROPOUT = 0.2 +LR = 0.001 +EPOCHS = 100 +PATIENCE = 10 # early stopping patience +BATCH_SIZE = 32 + +FORWARD_DAYS = 252 # ~12 months for prediction target + + +# === DATA LOADING === +def load_data(path: Path) -> pd.DataFrame: + """Load JSON and convert to DataFrame with date and price.""" + with open(path) as f: + data = json.load(f) + + series = data["series"] + df = pd.DataFrame([ + {"date": item["date"], "cumret": item["value"]["raw"]} + for item in series + ]) + df["date"] = pd.to_datetime(df["date"]) + df = df.sort_values("date").reset_index(drop=True) + + # Convert cumulative % return to price (base=100) + df["price"] = 100 * (1 + df["cumret"] / 100) + return df + + +# === FEATURE ENGINEERING === +def build_base_features(df: pd.DataFrame) -> pd.DataFrame: + """ + Build base features from price series. + Does NOT set target - caller must add their own target column. + """ + df = df.copy() + + # Daily returns + df["ret"] = df["price"].pct_change() + + # Lagged returns: r(t-1), r(t-2), ..., r(t-LAGS) + for i in range(1, LAGS + 1): + df[f"ret_lag{i}"] = df["ret"].shift(i) + + # Moving average ratios + df["ma_short"] = df["price"].rolling(MA_SHORT).mean() + df["ma_long"] = df["price"].rolling(MA_LONG).mean() + df["ma_ratio_short"] = df["price"] / df["ma_short"] - 1 + df["ma_ratio_long"] = df["price"] / df["ma_long"] - 1 + + # Rolling volatility + df["volatility"] = df["ret"].rolling(VOL_WINDOW).std() + + return df + + +def get_feature_cols() -> list[str]: + """Return list of feature column names.""" + cols = [f"ret_lag{i}" for i in range(1, LAGS + 1)] + cols += ["ma_ratio_short", "ma_ratio_long", "volatility"] + return cols + + +# === TRAIN/VAL/TEST SPLIT === +def split_data(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Chronological split into train/val/test.""" + n = len(df) + train_end = int(n * TRAIN_RATIO) + val_end = int(n * (TRAIN_RATIO + VAL_RATIO)) + + train = df.iloc[:train_end].copy() + val = df.iloc[train_end:val_end].copy() + test = df.iloc[val_end:].copy() + + return train, val, test + + +def standardize(train: pd.DataFrame, val: pd.DataFrame, test: pd.DataFrame, + feature_cols: list[str]) -> tuple[np.ndarray, np.ndarray, np.ndarray, + np.ndarray, np.ndarray, np.ndarray, + np.ndarray, np.ndarray]: + """ + Standardize features using train mean/std only. + Returns X_train, X_val, X_test, y_train, y_val, y_test, mean, std. + """ + X_train = train[feature_cols].values + X_val = val[feature_cols].values + X_test = test[feature_cols].values + + y_train = train["target"].values + y_val = val["target"].values + y_test = test["target"].values + + # Compute mean/std from train only + mean = X_train.mean(axis=0) + std = X_train.std(axis=0) + std[std == 0] = 1 # avoid division by zero + + X_train = (X_train - mean) / std + X_val = (X_val - mean) / std + X_test = (X_test - mean) / std + + return X_train, X_val, X_test, y_train, y_val, y_test, mean, std + + +# === MODEL === +class MLP(nn.Module): + """Simple MLP: Input -> 64 -> 32 -> 1""" + def __init__(self, input_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_dim, HIDDEN1), + nn.ReLU(), + nn.Dropout(DROPOUT), + nn.Linear(HIDDEN1, HIDDEN2), + nn.ReLU(), + nn.Dropout(DROPOUT), + nn.Linear(HIDDEN2, 1) + ) + + def forward(self, x): + return self.net(x).squeeze(-1) + + +def train_model(X_train: np.ndarray, y_train: np.ndarray, + X_val: np.ndarray, y_val: np.ndarray) -> MLP: + """Train MLP with early stopping.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + X_train_t = torch.tensor(X_train, dtype=torch.float32, device=device) + y_train_t = torch.tensor(y_train, dtype=torch.float32, device=device) + X_val_t = torch.tensor(X_val, dtype=torch.float32, device=device) + y_val_t = torch.tensor(y_val, dtype=torch.float32, device=device) + + model = MLP(X_train.shape[1]).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=LR) + criterion = nn.MSELoss() + + best_val_loss = float("inf") + patience_counter = 0 + best_state = None + + n_batches = (len(X_train_t) + BATCH_SIZE - 1) // BATCH_SIZE + + for epoch in range(EPOCHS): + model.train() + indices = torch.randperm(len(X_train_t)) + + for i in range(n_batches): + batch_idx = indices[i*BATCH_SIZE : (i+1)*BATCH_SIZE] + X_batch = X_train_t[batch_idx] + y_batch = y_train_t[batch_idx] + + optimizer.zero_grad() + pred = model(X_batch) + loss = criterion(pred, y_batch) + loss.backward() + optimizer.step() + + # Validation + model.eval() + with torch.no_grad(): + val_pred = model(X_val_t) + val_loss = criterion(val_pred, y_val_t).item() + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + best_state = model.state_dict() + else: + patience_counter += 1 + if patience_counter >= PATIENCE: + print(f"Early stopping at epoch {epoch+1}") + break + + if best_state: + model.load_state_dict(best_state) + + return model + + +# === FEATURE REGISTRY === +def compute_rsi(prices: pd.Series, period: int = 14) -> pd.Series: + """Compute Relative Strength Index.""" + delta = prices.diff() + gain = delta.where(delta > 0, 0.0).rolling(period).mean() + loss = (-delta.where(delta < 0, 0.0)).rolling(period).mean() + rs = gain / loss.replace(0, np.nan) + return 100 - (100 / (1 + rs)) + + +def compute_bb_width(prices: pd.Series, period: int = 20, num_std: float = 2) -> pd.Series: + """Compute Bollinger Band width (normalized by middle band).""" + sma = prices.rolling(period).mean() + std = prices.rolling(period).std() + upper = sma + num_std * std + lower = sma - num_std * std + return (upper - lower) / sma + + +def compute_rolling_mdd(df: pd.DataFrame, window: int) -> pd.Series: + """Compute rolling maximum drawdown over a window.""" + def mdd_func(x): + peak = np.maximum.accumulate(x) + dd = (x - peak) / peak + return dd.min() + return df["price"].rolling(window).apply(mdd_func, raw=True) + + +# Feature registry: maps feature_id to a function that computes the feature +FEATURE_REGISTRY: dict[str, Callable[[pd.DataFrame], pd.Series]] = { + # Momentum (returns over periods) + "mom_1m": lambda df: df["price"].pct_change(21), + "mom_3m": lambda df: df["price"].pct_change(63), + "mom_6m": lambda df: df["price"].pct_change(126), + "mom_12m": lambda df: df["price"].pct_change(252), + + # Trend (price vs moving averages) + "px_sma50": lambda df: df["price"] / df["price"].rolling(50).mean() - 1, + "px_sma200": lambda df: df["price"] / df["price"].rolling(200).mean() - 1, + "sma50_sma200": lambda df: df["price"].rolling(50).mean() / df["price"].rolling(200).mean() - 1, + "dist_52w_high": lambda df: df["price"] / df["price"].rolling(252).max() - 1, + + # Risk (volatility and drawdown) + "vol_1m": lambda df: df["price"].pct_change().rolling(21).std(), + "vol_3m": lambda df: df["price"].pct_change().rolling(63).std(), + "vol_6m": lambda df: df["price"].pct_change().rolling(126).std(), + "dd_current": lambda df: df["price"] / df["price"].cummax() - 1, + "mdd_12m": lambda df: compute_rolling_mdd(df, 252), + + # Oscillators + "rsi_14": lambda df: compute_rsi(df["price"], 14), + "bb_width": lambda df: compute_bb_width(df["price"], 20, 2), +} + +ALL_FEATURE_IDS = list(FEATURE_REGISTRY.keys()) + + +def build_selected_features(df: pd.DataFrame, feature_ids: list[str]) -> pd.DataFrame: + """ + Build only the selected features from the registry. + Returns DataFrame with price, date, and selected feature columns. + """ + df = df.copy() + + # Validate feature_ids + invalid = [f for f in feature_ids if f not in FEATURE_REGISTRY] + if invalid: + raise ValueError(f"Unknown feature_ids: {invalid}") + + # Compute each selected feature + for feature_id in feature_ids: + df[feature_id] = FEATURE_REGISTRY[feature_id](df) + + return df + + +def add_forward_target(df: pd.DataFrame, forward_days: int = FORWARD_DAYS) -> pd.DataFrame: + """Add forward return target for prediction.""" + df = df.copy() + df["target"] = df["price"].shift(-forward_days) / df["price"] - 1 + return df diff --git a/src/cli/etf-backtest/utils/experiment-extract.ts b/src/cli/etf-backtest/utils/experiment-extract.ts new file mode 100644 index 0000000..903abf2 --- /dev/null +++ b/src/cli/etf-backtest/utils/experiment-extract.ts @@ -0,0 +1,65 @@ +import { INDEX_NOT_FOUND, JSON_SLICE_END_OFFSET, ZERO } from "../constants"; +import { ExperimentResultSchema } from "../schemas"; +import type { ExperimentResult } from "../schemas"; + +const extractJsonFromStdout = (stdout: string): unknown => { + const startIdx = stdout.indexOf("{"); + if (startIdx === INDEX_NOT_FOUND) { + return null; + } + + let braceCount = ZERO; + let endIdx = INDEX_NOT_FOUND; + for (let i = startIdx; i < stdout.length; i++) { + if (stdout[i] === "{") { + braceCount++; + } + if (stdout[i] === "}") { + braceCount--; + } + if (braceCount === ZERO) { + endIdx = i; + break; + } + } + + if (endIdx === INDEX_NOT_FOUND) { + return null; + } + + const jsonStr = stdout.slice(startIdx, endIdx + JSON_SLICE_END_OFFSET); + return JSON.parse(jsonStr); +}; + +export const extractLastExperimentResult = (runResult: { + newItems?: { type: string; output?: unknown }[]; +}): ExperimentResult | null => { + try { + const items = runResult.newItems ?? []; + for (const item of items) { + if (item.type === "tool_call_output_item" && item.output) { + const output = item.output; + let parsed: unknown; + if (typeof output === "string") { + parsed = JSON.parse(output); + } else { + parsed = output; + } + + const toolResult = parsed as { stdout?: string }; + if (toolResult.stdout) { + const result = extractJsonFromStdout(toolResult.stdout); + if (result) { + const validated = ExperimentResultSchema.safeParse(result); + if (validated.success) { + return validated.data; + } + } + } + } + } + } catch { + // Parsing failed, return null + } + return null; +}; diff --git a/src/cli/etf-backtest/utils/final-report.ts b/src/cli/etf-backtest/utils/final-report.ts new file mode 100644 index 0000000..9cc76f5 --- /dev/null +++ b/src/cli/etf-backtest/utils/final-report.ts @@ -0,0 +1,115 @@ +import type { Logger } from "~clients/logger"; + +import { + CI_LEVEL_PERCENT, + CONFIDENCE_THRESHOLDS, + DECIMAL_PLACES, + LINE_SEPARATOR, + PREDICTION_HORIZON_MONTHS, +} from "../constants"; +import type { ExperimentResult } from "../schemas"; +import { formatFixed, formatPercent } from "./formatters"; + +export const printFinalResults = ( + logger: Logger, + bestResult: ExperimentResult, + bestIteration: number, + totalIterations: number, + stopReason: string +) => { + const ciWidth = + bestResult.prediction.ci95High - bestResult.prediction.ci95Low; + let confidence = "LOW"; + if ( + bestResult.metrics.r2NonOverlapping > + CONFIDENCE_THRESHOLDS.moderate.r2NonOverlapping && + bestResult.metrics.directionAccuracyNonOverlapping > + CONFIDENCE_THRESHOLDS.moderate.directionAccuracyNonOverlapping && + ciWidth < CONFIDENCE_THRESHOLDS.moderate.maxCiWidth + ) { + confidence = "MODERATE"; + } + if ( + bestResult.metrics.r2NonOverlapping > + CONFIDENCE_THRESHOLDS.reasonable.r2NonOverlapping && + bestResult.metrics.directionAccuracyNonOverlapping > + CONFIDENCE_THRESHOLDS.reasonable.directionAccuracyNonOverlapping && + ciWidth < CONFIDENCE_THRESHOLDS.reasonable.maxCiWidth + ) { + confidence = "REASONABLE"; + } + + const lines = [ + "", + LINE_SEPARATOR, + "OPTIMIZATION COMPLETE", + LINE_SEPARATOR, + `Iterations: ${totalIterations}`, + `Best iteration: ${bestIteration}`, + `Stop reason: ${stopReason}`, + "", + "Best Feature Set:", + ...bestResult.featureIds.map((feature) => ` - ${feature}`), + "", + "Prediction Accuracy (Non-Overlapping - Honest Assessment):", + ` R²: ${formatFixed( + bestResult.metrics.r2NonOverlapping, + DECIMAL_PLACES.r2 + )}`, + ` Direction Accuracy: ${formatPercent( + bestResult.metrics.directionAccuracyNonOverlapping + )}`, + ` Independent Samples: ${bestResult.dataInfo.nonOverlappingSamples}`, + "", + "Prediction Accuracy (Overlapping - Inflated):", + ` R²: ${formatFixed( + bestResult.metrics.r2, + DECIMAL_PLACES.r2 + )}`, + ` Direction Accuracy: ${formatPercent( + bestResult.metrics.directionAccuracy + )}`, + ` MAE: ${formatPercent(bestResult.metrics.mae)}`, + ` Calibration: ${formatFixed( + bestResult.metrics.calibrationRatio, + DECIMAL_PLACES.calibration + )}`, + "", + "Backtest Metrics (Informational):", + ` Sharpe Ratio: ${formatFixed( + bestResult.metrics.sharpe, + DECIMAL_PLACES.sharpe + )}`, + ` Max Drawdown: ${formatPercent(bestResult.metrics.maxDrawdown)}`, + ` CAGR: ${formatPercent( + bestResult.metrics.cagr, + DECIMAL_PLACES.cagr + )}`, + "", + `${PREDICTION_HORIZON_MONTHS}-Month Prediction:`, + ` Expected Return: ${formatPercent(bestResult.prediction.pred12mReturn)}`, + ` ${CI_LEVEL_PERCENT}% CI: [${formatPercent( + bestResult.prediction.ci95Low + )}, ${formatPercent(bestResult.prediction.ci95High)}]`, + "", + "Uncertainty Details:", + ` Base Std: ${formatPercent( + bestResult.prediction.uncertainty.baseStd + )}`, + ` Adjusted Std: ${formatPercent( + bestResult.prediction.uncertainty.adjustedStd + )}`, + ` Extrapolation: ${ + bestResult.prediction.uncertainty.isExtrapolating + ? "Yes (features outside training range)" + : "No" + }`, + "", + `Confidence: ${confidence}`, + `Note: Non-overlapping metrics use only ${bestResult.dataInfo.nonOverlappingSamples} independent periods.`, + "Past performance does not guarantee future results.", + LINE_SEPARATOR, + ]; + + logger.answer(lines.join("\n")); +}; diff --git a/src/cli/etf-backtest/utils/formatters.ts b/src/cli/etf-backtest/utils/formatters.ts new file mode 100644 index 0000000..55b4725 --- /dev/null +++ b/src/cli/etf-backtest/utils/formatters.ts @@ -0,0 +1,9 @@ +import { DECIMAL_PLACES, PERCENT_MULTIPLIER } from "../constants"; + +export const formatPercent = ( + value: number, + decimals = DECIMAL_PLACES.percent +): string => `${(value * PERCENT_MULTIPLIER).toFixed(decimals)}%`; + +export const formatFixed = (value: number, decimals: number): string => + value.toFixed(decimals); diff --git a/src/cli/etf-backtest/utils/learnings-formatter.ts b/src/cli/etf-backtest/utils/learnings-formatter.ts new file mode 100644 index 0000000..a800fd1 --- /dev/null +++ b/src/cli/etf-backtest/utils/learnings-formatter.ts @@ -0,0 +1,115 @@ +import { DECIMAL_PLACES, LEARNINGS_SUMMARY_TOP_N } from "../constants"; +import type { Learnings } from "../schemas"; +import { formatFixed, formatPercent } from "./formatters"; + +const FEATURE_PREVIEW_COUNT = 4; +const TOP_HALF_DIVISOR = 2; +const TOP_FEATURES_COUNT = 5; +const FEATURES_TO_AVOID_COUNT = 3; + +/** + * Format learnings into a prompt-friendly text summary. + * Returns empty string if no useful learnings exist. + */ +export const formatLearningsForPrompt = ( + learnings: Learnings | null +): string => { + if (!learnings || learnings.history.length === 0) { + return ""; + } + + const lines: string[] = [ + "", + "## Previous Learnings", + `Total iterations run: ${learnings.totalIterations}`, + ]; + + // Best result summary + if (learnings.bestResult) { + lines.push(""); + lines.push("**Best result so far:**"); + lines.push(`- Features: ${learnings.bestResult.featureIds.join(", ")}`); + lines.push( + `- Score: ${formatFixed(learnings.bestResult.score, DECIMAL_PLACES.score)}` + ); + lines.push( + `- R² (non-overlapping): ${formatFixed(learnings.bestResult.metrics.r2NonOverlapping, DECIMAL_PLACES.r2)}` + ); + lines.push( + `- Direction accuracy: ${formatPercent(learnings.bestResult.metrics.directionAccuracyNonOverlapping)}` + ); + lines.push(`- MAE: ${formatPercent(learnings.bestResult.metrics.mae)}`); + } + + // Top N best attempts (sorted by score) + const sortedHistory = [...learnings.history] + .sort((a, b) => b.score - a.score) + .slice(0, LEARNINGS_SUMMARY_TOP_N); + + if (sortedHistory.length > 1) { + lines.push(""); + lines.push(`**Top ${sortedHistory.length} attempts:**`); + for (const record of sortedHistory) { + const featurePreview = record.featureIds.slice(0, FEATURE_PREVIEW_COUNT); + const suffix = + record.featureIds.length > FEATURE_PREVIEW_COUNT ? "..." : ""; + lines.push( + `- [Score ${formatFixed(record.score, DECIMAL_PLACES.score)}] ` + + `Features: ${featurePreview.join(", ")}${suffix}` + ); + } + } + + // Feature frequency analysis (which features appear in best results?) + const featureFrequency = new Map(); + const topHalf = sortedHistory.slice( + 0, + Math.ceil(sortedHistory.length / TOP_HALF_DIVISOR) + ); + for (const record of topHalf) { + for (const feature of record.featureIds) { + featureFrequency.set(feature, (featureFrequency.get(feature) ?? 0) + 1); + } + } + + const frequentFeatures = [...featureFrequency.entries()] + .sort((a, b) => b[1] - a[1]) + .slice(0, TOP_FEATURES_COUNT) + .map(([feature]) => feature); + + if (frequentFeatures.length > 0) { + lines.push(""); + lines.push( + `**Features common in top results:** ${frequentFeatures.join(", ")}` + ); + } + + // Identify features that consistently appear in poor results + const bottomHalf = sortedHistory.slice( + Math.ceil(sortedHistory.length / TOP_HALF_DIVISOR) + ); + const poorFeatures = new Map(); + for (const record of bottomHalf) { + for (const feature of record.featureIds) { + poorFeatures.set(feature, (poorFeatures.get(feature) ?? 0) + 1); + } + } + + // Features in bottom half but not in top half + const toAvoid = [...poorFeatures.entries()] + .filter(([feature]) => !featureFrequency.has(feature)) + .slice(0, FEATURES_TO_AVOID_COUNT) + .map(([feature]) => feature); + + if (toAvoid.length > 0) { + lines.push(`**Features to reconsider:** ${toAvoid.join(", ")}`); + } + + lines.push(""); + lines.push( + "Use these learnings to guide your feature selection. Try to beat the best score." + ); + lines.push(""); + + return lines.join("\n"); +}; diff --git a/src/cli/etf-backtest/utils/prompt-builders.test.ts b/src/cli/etf-backtest/utils/prompt-builders.test.ts new file mode 100644 index 0000000..c9eb975 --- /dev/null +++ b/src/cli/etf-backtest/utils/prompt-builders.test.ts @@ -0,0 +1,31 @@ +import { describe, expect, it } from "vitest"; + +import { buildRecoveryPrompt, buildRunPythonUsage } from "./prompt-builders"; + +describe("buildRunPythonUsage", () => { + it("includes seed and dataPath in the tool input", () => { + const result = buildRunPythonUsage({ + seed: 7, + dataPath: "tmp/etf-backtest/data.json", + }); + + expect(result).toContain('"seed": 7'); + expect(result).toContain('"dataPath": "tmp/etf-backtest/data.json"'); + expect(result).toContain('scriptName: "run_experiment.py"'); + }); +}); + +describe("buildRecoveryPrompt", () => { + it("appends runPython usage after the message", () => { + const message = "Recovery message."; + const result = buildRecoveryPrompt(message, { + seed: 1, + dataPath: "data.json", + }); + + expect(result.startsWith(message)).toBe(true); + expect(result).toContain("Use runPython with:"); + expect(result).toContain('"seed": 1'); + expect(result).toContain('"dataPath": "data.json"'); + }); +}); diff --git a/src/cli/etf-backtest/utils/prompt-builders.ts b/src/cli/etf-backtest/utils/prompt-builders.ts new file mode 100644 index 0000000..6822bdb --- /dev/null +++ b/src/cli/etf-backtest/utils/prompt-builders.ts @@ -0,0 +1,19 @@ +type RunPythonUsageOptions = { + seed: number; + dataPath: string; +}; + +export const buildRunPythonUsage = ({ + seed, + dataPath, +}: RunPythonUsageOptions): string => + [ + "Use runPython with:", + '- scriptName: "run_experiment.py"', + `- input: { "featureIds": [...your features...], "seed": ${seed}, "dataPath": "${dataPath}" }`, + ].join("\n"); + +export const buildRecoveryPrompt = ( + message: string, + options: RunPythonUsageOptions +): string => [message, "", buildRunPythonUsage(options)].join("\n"); diff --git a/src/cli/etf-backtest/utils/scoring.ts b/src/cli/etf-backtest/utils/scoring.ts new file mode 100644 index 0000000..4dbd25c --- /dev/null +++ b/src/cli/etf-backtest/utils/scoring.ts @@ -0,0 +1,21 @@ +import { + NEGATIVE_SHARPE_PENALTY, + NEGATIVE_SHARPE_THRESHOLD, + SCORE_WEIGHTS, + ZERO, +} from "../constants"; +import type { ExperimentResult } from "../schemas"; + +export const computeScore = (metrics: ExperimentResult["metrics"]): number => { + // Primary: prediction accuracy on non-overlapping samples (honest assessment) + // Secondary: Sharpe < 0 is a red flag (sanity check only) + return ( + metrics.r2NonOverlapping * SCORE_WEIGHTS.r2NonOverlapping + + metrics.directionAccuracyNonOverlapping * + SCORE_WEIGHTS.directionAccuracyNonOverlapping + + metrics.mae * SCORE_WEIGHTS.mae + + (metrics.sharpe < NEGATIVE_SHARPE_THRESHOLD + ? NEGATIVE_SHARPE_PENALTY + : ZERO) + ); +}; diff --git a/src/cli/guestbook/main.ts b/src/cli/guestbook/main.ts index f17dda2..a4013e6 100644 --- a/src/cli/guestbook/main.ts +++ b/src/cli/guestbook/main.ts @@ -1,19 +1,28 @@ // pnpm run:guestbook -import { Agent, run } from "@openai/agents"; - import "dotenv/config"; -import { readFileTool } from "~tools/read-file/read-file-tool"; -import { writeFileTool } from "~tools/write-file/write-file-tool"; +import { AgentRunner } from "~clients/agent-runner"; +import { Logger } from "~clients/logger"; +import { createReadFileTool } from "~tools/read-file/read-file-tool"; +import { createWriteFileTool } from "~tools/write-file/write-file-tool"; +import { z } from "zod"; import { question } from "zx"; -console.log("Guestbook running..."); +const logger = new Logger(); + +logger.info("Guestbook running..."); + +const OutputSchema = z.object({ + success: z.boolean(), + message: z.string(), +}); -const agent = new Agent({ +const agentRunner = new AgentRunner({ name: "GuestbookAgent", model: "gpt-5-mini", - tools: [writeFileTool, readFileTool], + tools: [createWriteFileTool({ logger }), createReadFileTool({ logger })], + outputType: OutputSchema, instructions: ` You maintain a shared "greeting guestbook" at guestbook.md. Rules: @@ -23,7 +32,12 @@ Rules: - If it doesn't exist, create it with a header and an Entries section. - Each entry must include the user's name. - Keep it upbeat and a little nerdy, but not cringe. + +IMPORTANT: Always respond with a JSON object in this format: +{"success": true/false, "message": "description of what was done"} `, + logger, + stateless: true, // Each run is independent }); const userName = await question("Enter user name: "); @@ -53,11 +67,23 @@ Steps: 4) Write the final Markdown back to guestbook.md. `; -const result = await run(agent, prompt); +const result = await agentRunner.run({ prompt }); +const parseResult = OutputSchema.safeParse(result.finalOutput); -console.log("Agent result:", result.finalOutput); +if (parseResult.success) { + logger.info("Result", { message: parseResult.data.message }); +} else { + logger.warn("Unexpected response format"); + logger.info(String(result.finalOutput)); +} -// Optional: show the file contents after write -const preview = await run(agent, `Read and print the contents of guestbook.md`); -console.log("\n--- Preview ---\n"); -console.log(preview.finalOutput); +// Show the file contents after write +const preview = await agentRunner.run({ + prompt: `Read guestbook.md and include its full contents in your response message.`, +}); +const previewResult = OutputSchema.safeParse(preview.finalOutput); +if (previewResult.success) { + logger.answer(previewResult.data.message); +} else { + logger.answer(JSON.stringify(preview.finalOutput, null, 2)); +} diff --git a/src/cli/name-explorer/clients/database.ts b/src/cli/name-explorer/clients/database.ts index 84c5a47..aaf6d71 100644 --- a/src/cli/name-explorer/clients/database.ts +++ b/src/cli/name-explorer/clients/database.ts @@ -92,9 +92,11 @@ export class NameDatabase { this.db.exec("ROLLBACK"); throw error; } - this.logger.debug( - `Inserted ${entries.length} ${gender} names for decade ${decade}` - ); + this.logger.debug("Inserted names for decade", { + count: entries.length, + gender, + decade, + }); } /** @@ -152,7 +154,9 @@ export class NameDatabase { this.insertNames(decadeData.decade, "boy", decadeData.boys); this.insertNames(decadeData.decade, "girl", decadeData.girls); } - this.logger.debug(`Loaded ${this.getTotalCount()} records from JSON`); + this.logger.debug("Loaded records from JSON", { + count: this.getTotalCount(), + }); } /** @@ -286,7 +290,7 @@ export class AggregatedNameDatabase { this.db.exec("ROLLBACK"); throw error; } - this.logger.debug(`Loaded ${gender} names from ${filePath}`); + this.logger.debug("Loaded names from file", { gender, filePath }); } /** diff --git a/src/cli/name-explorer/clients/pipeline.ts b/src/cli/name-explorer/clients/pipeline.ts index 5fc6c37..d299e68 100644 --- a/src/cli/name-explorer/clients/pipeline.ts +++ b/src/cli/name-explorer/clients/pipeline.ts @@ -108,12 +108,12 @@ export class NameSuggesterPipeline { let fromCache = false; if (htmlExists && mdExists && !this.refetch) { - this.logger.debug(`Cached: ${decade} page ${page}`); + this.logger.debug("Cache hit", { decade, page }); html = await fs.readFile(htmlFile, "utf-8"); markdown = await fs.readFile(mdFile, "utf-8"); fromCache = true; } else { - this.logger.info(`Fetching ${decade} page ${page}...`); + this.logger.info("Fetching decade page", { decade, page }); html = await this.fetchClient.fetchHtml(url); markdown = await this.fetchClient.fetchMarkdown(url); @@ -140,15 +140,17 @@ export class NameSuggesterPipeline { decades?: string[]; pages?: number[]; } = {}): Promise { - this.logger.info( - `Will process ${decades.length} decades × ${pages.length} pages = ${decades.length * pages.length} combinations` - ); + this.logger.info("Processing plan", { + decades: decades.length, + pages: pages.length, + combinations: decades.length * pages.length, + }); let cachedPages = 0; let fetchedPages = 0; for (const decade of decades) { - this.logger.info(`Processing decade ${decade}...`); + this.logger.info("Processing decade", { decade }); for (const page of pages) { const { parsedNames, fromCache } = await this.fetchDecadePage({ @@ -183,7 +185,7 @@ export class NameSuggesterPipeline { const consolidatedData: ConsolidatedData = this.db.getAll(); const outputPath = path.join(this.outputDir, filename); await fs.writeFile(outputPath, JSON.stringify(consolidatedData, null, 2)); - this.logger.info(`Saved consolidated data to ${outputPath}`); + this.logger.info("Saved consolidated data", { outputPath }); return outputPath; } @@ -200,9 +202,9 @@ export class NameSuggesterPipeline { const jsonContent = await fs.readFile(outputPath, "utf-8"); const data = JSON.parse(jsonContent) as ConsolidatedData; this.db.loadFromConsolidatedData(data); - this.logger.info( - `Loaded existing data from JSON (${this.db.getTotalCount()} records)` - ); + this.logger.info("Loaded existing data from JSON", { + count: this.db.getTotalCount(), + }); const aggregatedDb = await this.loadAggregatedCsvData(); @@ -219,13 +221,15 @@ export class NameSuggesterPipeline { const { totalPages, cachedPages, fetchedPages } = await this.processAllDecades(); - this.logger.info( - `Processing complete: ${fetchedPages} fetched, ${cachedPages} cached, ${totalPages} total` - ); + this.logger.info("Processing complete", { + fetchedPages, + cachedPages, + totalPages, + }); - this.logger.info( - `Database contains ${this.db.getTotalCount()} name records` - ); + this.logger.info("Database contains name records", { + count: this.db.getTotalCount(), + }); await this.saveConsolidatedData(); @@ -267,9 +271,9 @@ export class NameSuggesterPipeline { aggregatedDb.loadFromCsv(femaleCsvPath, "female"); } - this.logger.info( - `Loaded aggregated CSV data (${aggregatedDb.getTotalCount()} records)` - ); + this.logger.info("Loaded aggregated CSV data", { + count: aggregatedDb.getTotalCount(), + }); return aggregatedDb; } diff --git a/src/cli/name-explorer/main.ts b/src/cli/name-explorer/main.ts index 8d95b94..39f8247 100644 --- a/src/cli/name-explorer/main.ts +++ b/src/cli/name-explorer/main.ts @@ -4,7 +4,7 @@ import "dotenv/config"; import { writeFile } from "fs/promises"; -import { Agent, MemorySession, Runner } from "@openai/agents"; +import { AgentRunner } from "~clients/agent-runner"; import { Logger } from "~clients/logger"; import { parseArgs } from "~utils/parse-args"; import { QuestionHandler } from "~utils/question-handler"; @@ -55,7 +55,7 @@ const runStatsMode = async () => { const outputPath = "tmp/name-explorer/statistics.html"; await writeFile(outputPath, html, "utf-8"); - logger.info(`Statistics page written to ${outputPath}`); + logger.info("Statistics page written", { outputPath }); }; // --- AI Mode: Interactive Q&A with SQL agent --- @@ -73,7 +73,7 @@ const runAiMode = async () => { tools.push(createAggregatedSqlQueryTool(aggregatedDb)); } - const agent = new Agent({ + const agentRunner = new AgentRunner({ name: "NameExpertAgent", model: "gpt-5-mini", tools, @@ -95,33 +95,11 @@ IMPORTANT: Respond with ONLY a valid JSON object: - Use status "final" when you have the answer. Put the answer in "content". - Use status "needs_clarification" only if you cannot answer without more input. Put a single, concise question in "content". When answering, do not include any questions. Do not include markdown or extra keys.`, - }); - - const runner = new Runner(); - - const toolsInProgress = new Set(); - - runner.on("agent_tool_start", (_context, _agent, tool, details) => { - const toolCall = details.toolCall as Record; - const callId = toolCall.id as string; - if (toolsInProgress.has(callId)) { - return; - } - toolsInProgress.add(callId); - - const args = String(toolCall.arguments); - logger.tool(`Calling ${tool.name}: ${args || "no arguments"}`); - }); - - runner.on("agent_tool_end", (_context, _agent, tool, result) => { - logger.tool(`${tool.name} completed`); - const preview = - result.length > 200 ? result.substring(0, 200) + "..." : result; - logger.debug(`Result: ${preview}`); + logger, + logToolArgs: true, }); const questionHandler = new QuestionHandler({ logger }); - const session = new MemorySession(); const userQuestion = await questionHandler.askString({ prompt: "Ask about Finnish names: ", @@ -133,7 +111,7 @@ When answering, do not include any questions. Do not include markdown or extra k let currentQuestion = userQuestion; while (true) { - const result = await runner.run(agent, currentQuestion, { session }); + const result = await agentRunner.run({ prompt: currentQuestion }); const parseResult = NameSuggesterOutputSchema.safeParse(result.finalOutput); if (!parseResult.success) { diff --git a/src/cli/scrape-publications/clients/publication-pipeline.ts b/src/cli/scrape-publications/clients/publication-pipeline.ts index be8959f..58f5b78 100644 --- a/src/cli/scrape-publications/clients/publication-pipeline.ts +++ b/src/cli/scrape-publications/clients/publication-pipeline.ts @@ -197,7 +197,7 @@ export class PublicationPipeline { `content.md (${fromCache.markdown ? "cached" : sourceLabel})`, `content.html (${fromCache.html ? "cached" : sourceLabel})`, ]; - this.logger.info(`Content ready: ${contentStatus.join(", ")}`); + this.logger.info("Content ready", { contentStatus }); return { markdown, html, fromCache, source }; } @@ -226,7 +226,10 @@ export class PublicationPipeline { path.join(this.outputDir, "links.json"), JSON.stringify(allLinks, null, 2) ); - this.logger.info(`Saved ${allLinks.length} links to links.json`); + this.logger.info("Saved links", { + count: allLinks.length, + file: "links.json", + }); let filteredLinks = ( filterSubstring @@ -238,9 +241,10 @@ export class PublicationPipeline { path.join(this.outputDir, "filtered-links.json"), JSON.stringify(filteredLinks, null, 2) ); - this.logger.info( - `Saved ${filteredLinks.length} filtered links to filtered-links.json` - ); + this.logger.info("Saved filtered links", { + count: filteredLinks.length, + file: "filtered-links.json", + }); let filteredUrlSet = new Set(filteredLinks); let linkCandidates = this.scraper.extractLinkCandidates( @@ -296,18 +300,19 @@ export class PublicationPipeline { usedFallback = true; currentSource = "basic-fetch"; - this.logger.info( - `Fallback fetch found ${linkCandidates.length} link candidates` - ); + this.logger.info("Fallback fetch found link candidates", { + count: linkCandidates.length, + }); } await fs.writeFile( path.join(this.outputDir, "link-candidates.json"), JSON.stringify(linkCandidates, null, 2) ); - this.logger.info( - `Saved ${linkCandidates.length} link candidates to link-candidates.json` - ); + this.logger.info("Saved link candidates", { + count: linkCandidates.length, + file: "link-candidates.json", + }); return { allLinks, @@ -337,9 +342,11 @@ export class PublicationPipeline { JSON.stringify(selectors, null, 2) ); - this.logger.info(`Identified selectors:`); - this.logger.info(` Title: ${selectors.titleSelector}`); - this.logger.info(` Date: ${selectors.dateSelector ?? "(not found)"}`); + this.logger.info("Identified selectors"); + this.logger.info("Title selector", { selector: selectors.titleSelector }); + this.logger.info("Date selector", { + selector: selectors.dateSelector ?? "(not found)", + }); this.logger.info("Extracting publication data..."); @@ -353,9 +360,10 @@ export class PublicationPipeline { JSON.stringify(publications, null, 2) ); - this.logger.info( - `Saved ${publications.length} publications to publication-links.json` - ); + this.logger.info("Saved publication links", { + count: publications.length, + file: "publication-links.json", + }); return { selectors, publications }; } @@ -382,7 +390,9 @@ export class PublicationPipeline { titleSlugCounts.set(titleSlug, (titleSlugCounts.get(titleSlug) ?? 0) + 1); } - this.logger.info(`Found ${publications.length} publication links to fetch`); + this.logger.info("Found publication links to fetch", { + count: publications.length, + }); let fetchedCount = 0; let skippedCount = 0; @@ -393,7 +403,7 @@ export class PublicationPipeline { const titleSlug = titleSlugs[index]; if (!titleSlug) { - this.logger.warn(`Skipping publication with empty title slug: ${url}`); + this.logger.warn("Skipping publication with empty title slug", { url }); continue; } @@ -416,9 +426,11 @@ export class PublicationPipeline { if (!needsHtml && !needsMarkdown) { skippedCount++; - this.logger.info( - `[${skippedCount + fetchedCount}/${publications.length}] Cached: ${url}` - ); + this.logger.info("Publication cached", { + index: skippedCount + fetchedCount, + total: publications.length, + url, + }); continue; } @@ -442,20 +454,30 @@ export class PublicationPipeline { needsHtml ? "Fetched HTML" : "Cached HTML", needsMarkdown ? "Wrote Markdown" : "Cached Markdown", ]; - this.logger.info( - `[${skippedCount + fetchedCount}/${publications.length}] ${statusParts.join(", ")}: ${url}` - ); + this.logger.info("Publication processed", { + index: skippedCount + fetchedCount, + total: publications.length, + status: statusParts, + url, + }); } catch (error) { this.logger.error( - `[${skippedCount + fetchedCount}/${publications.length}] Failed: ${url}`, + "Publication fetch failed", + { + index: skippedCount + fetchedCount, + total: publications.length, + url, + }, error ); } } - this.logger.info( - `Fetch complete: ${fetchedCount} new HTML, ${markdownCount} markdown written, ${skippedCount} cached` - ); + this.logger.info("Fetch complete", { + fetchedCount, + markdownCount, + skippedCount, + }); return { fetchedCount, skippedCount, markdownCount }; } @@ -493,13 +515,13 @@ export class PublicationPipeline { const sampleHtmlPath = path.join(publicationsDir, firstHtmlFile); const sampleHtml = await fs.readFile(sampleHtmlPath, "utf-8"); - this.logger.info(`Analyzing sample HTML: ${firstHtmlFile}`); + this.logger.info("Analyzing sample HTML", { file: firstHtmlFile }); const contentSelectors = await this.scraper.identifyContentSelector(sampleHtml); - this.logger.info( - `Identified content selector: ${contentSelectors.contentSelector}` - ); + this.logger.info("Identified content selector", { + selector: contentSelectors.contentSelector, + }); await fs.writeFile( path.join(this.outputDir, "content-selectors.json"), @@ -542,7 +564,9 @@ export class PublicationPipeline { filename: firstPossibleFilename, error: "HTML file not found", }); - this.logger.warn(`HTML file not found for: ${publication.title}`); + this.logger.warn("HTML file not found for publication", { + title: publication.title, + }); continue; } @@ -557,7 +581,9 @@ export class PublicationPipeline { filename: usedFilename, error: "No content found with selector", }); - this.logger.warn(`No content found for: ${publication.title}`); + this.logger.warn("No content found for publication", { + title: publication.title, + }); continue; } @@ -574,9 +600,11 @@ export class PublicationPipeline { filename: usedFilename, }); - this.logger.info( - `[${extractionResults.length}/${publications.length}] Extracted: ${publication.title}` - ); + this.logger.info("Publication extracted", { + index: extractionResults.length, + total: publications.length, + title: publication.title, + }); } await fs.writeFile( @@ -596,9 +624,10 @@ export class PublicationPipeline { JSON.stringify(report, null, 2) ); - this.logger.info( - `Content extraction complete: ${report.successful}/${report.total} publications processed` - ); + this.logger.info("Content extraction complete", { + successful: report.successful, + total: report.total, + }); return { publications: publicationsWithContent, report }; } @@ -621,7 +650,7 @@ export class PublicationPipeline { const reviewPath = path.join(this.outputDir, "review.html"); await fs.writeFile(reviewPath, reviewHtml); - this.logger.info(`Review page saved to: ${reviewPath}`); + this.logger.info("Review page saved", { reviewPath }); return reviewPath; } diff --git a/src/cli/scrape-publications/clients/publication-scraper.ts b/src/cli/scrape-publications/clients/publication-scraper.ts index d2bf715..35c40b4 100644 --- a/src/cli/scrape-publications/clients/publication-scraper.ts +++ b/src/cli/scrape-publications/clients/publication-scraper.ts @@ -242,9 +242,12 @@ IMPORTANT: Respond with ONLY a valid JSON object: if (firstGroup) { const [topSignature, topGroup] = firstGroup; if (topGroup.length > 0) { - this.logger.debug( - `Selected structure group: ${topSignature} (${topGroup.length} candidates, score: ${this.scoreStructureSignature(topSignature)})` - ); + const score = this.scoreStructureSignature(topSignature); + this.logger.debug("Selected structure group", { + signature: topSignature, + candidates: topGroup.length, + score, + }); return topGroup.slice(0, maxSamples); } } @@ -590,7 +593,9 @@ Respond with only a JSON object containing "titleSelector" and "dateSelector" (n const date = this.extractDate(candidate.html, selectors, candidate.url); if (!title) { - this.logger.warn(`Could not extract title for: ${candidate.url}`); + this.logger.warn("Could not extract title for candidate", { + url: candidate.url, + }); continue; } @@ -604,9 +609,10 @@ Respond with only a JSON object containing "titleSelector" and "dateSelector" (n if (result.success) { publications.push(result.data); } else { - this.logger.warn( - `Validation failed for ${candidate.url}: ${result.error.message}` - ); + this.logger.warn("Validation failed for candidate", { + url: candidate.url, + error: result.error.message, + }); } } diff --git a/src/cli/scrape-publications/main.ts b/src/cli/scrape-publications/main.ts index 17f07c4..a1630e0 100644 --- a/src/cli/scrape-publications/main.ts +++ b/src/cli/scrape-publications/main.ts @@ -43,7 +43,7 @@ const outputDir = path.join( urlSlug ); -logger.info(`Output directory: ${outputDir}`); +logger.info("Output directory", { outputDir }); // 3. Create pipeline const pipeline = new PublicationPipeline({ diff --git a/src/clients/agent-runner.test.ts b/src/clients/agent-runner.test.ts new file mode 100644 index 0000000..a0433bc --- /dev/null +++ b/src/clients/agent-runner.test.ts @@ -0,0 +1,396 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { z } from "zod"; + +// Import after mocking +import { AgentRunner } from "./agent-runner"; +import { Logger } from "./logger"; + +type EventHandler = (...args: unknown[]) => void; + +// Store instances for test access +let mockRunnerInstance: { + on: ReturnType; + run: ReturnType; +}; +let mockSessionInstance: object; +let eventHandlers: Map; + +vi.mock("@openai/agents", () => { + // Create fresh mocks that will be configured in beforeEach + return { + Agent: vi.fn(function MockAgent() { + return {}; + }), + Runner: vi.fn(function MockRunner() { + return mockRunnerInstance; + }), + MemorySession: vi.fn(function MockMemorySession() { + return mockSessionInstance; + }), + }; +}); + +const getHandler = ( + handlers: Map, + event: string +): EventHandler => { + const handler = handlers.get(event); + if (!handler) { + throw new Error(`Handler for event "${event}" not found`); + } + return handler; +}; + +describe("AgentRunner", () => { + let logger: Logger; + + const TestOutputSchema = z.object({ + message: z.string(), + }); + + beforeEach(() => { + logger = new Logger({ level: "error" }); + eventHandlers = new Map(); + + mockRunnerInstance = { + on: vi.fn((event: string, handler: EventHandler) => { + eventHandlers.set(event, handler); + }), + run: vi.fn().mockResolvedValue({ finalOutput: { message: "test" } }), + }; + + mockSessionInstance = {}; + }); + + afterEach(() => { + vi.clearAllMocks(); + eventHandlers.clear(); + }); + + describe("constructor", () => { + it("registers event handlers", () => { + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + }); + + expect(mockRunnerInstance.on).toHaveBeenCalledWith( + "agent_tool_start", + expect.any(Function) + ); + expect(mockRunnerInstance.on).toHaveBeenCalledWith( + "agent_tool_end", + expect.any(Function) + ); + }); + }); + + describe("run", () => { + it("calls runner.run with agent and session", async () => { + const agentRunner = new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + }); + + await agentRunner.run({ prompt: "test prompt" }); + + expect(mockRunnerInstance.run).toHaveBeenCalledWith( + expect.anything(), // agent + "test prompt", + { session: mockSessionInstance } + ); + }); + + it("passes maxTurns option", async () => { + const agentRunner = new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + }); + + await agentRunner.run({ prompt: "test prompt", maxTurns: 3 }); + + expect(mockRunnerInstance.run).toHaveBeenCalledWith( + expect.anything(), + "test prompt", + { session: mockSessionInstance, maxTurns: 3 } + ); + }); + + it("returns the run result", async () => { + const expectedResult = { finalOutput: { message: "success" } }; + mockRunnerInstance.run.mockResolvedValue(expectedResult); + + const agentRunner = new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + }); + + const result = await agentRunner.run({ prompt: "test prompt" }); + + expect(result).toBe(expectedResult); + }); + }); + + describe("event handlers", () => { + it("deduplicates tool_start events by call id", () => { + const toolLogSpy = vi.spyOn(logger, "tool"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + }); + + const handler = getHandler(eventHandlers, "agent_tool_start"); + const mockTool = { name: "testTool" }; + const mockDetails = { toolCall: { id: "call-123", arguments: "{}" } }; + + // First call should log + handler(null, null, mockTool, mockDetails); + expect(toolLogSpy).toHaveBeenCalledTimes(1); + + // Second call with same id should not log + handler(null, null, mockTool, mockDetails); + expect(toolLogSpy).toHaveBeenCalledTimes(1); + + // Different id should log + const differentDetails = { + toolCall: { id: "call-456", arguments: "{}" }, + }; + handler(null, null, mockTool, differentDetails); + expect(toolLogSpy).toHaveBeenCalledTimes(2); + }); + + it("logs tool arguments when logToolArgs is true", () => { + const toolLogSpy = vi.spyOn(logger, "tool"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + logToolArgs: true, + }); + + const handler = getHandler(eventHandlers, "agent_tool_start"); + const mockTool = { name: "testTool" }; + const mockDetails = { + toolCall: { id: "call-123", arguments: '{"key":"value"}' }, + }; + + handler(null, null, mockTool, mockDetails); + + expect(toolLogSpy).toHaveBeenCalledWith("Calling tool", { + name: "testTool", + args: '{"key":"value"}', + }); + }); + + it("does not log tool arguments when logToolArgs is false", () => { + const toolLogSpy = vi.spyOn(logger, "tool"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + logToolArgs: false, + }); + + const handler = getHandler(eventHandlers, "agent_tool_start"); + const mockTool = { name: "testTool" }; + const mockDetails = { + toolCall: { id: "call-123", arguments: '{"key":"value"}' }, + }; + + handler(null, null, mockTool, mockDetails); + + expect(toolLogSpy).toHaveBeenCalledWith("Calling tool", { + name: "testTool", + }); + }); + + it("logs result preview when logToolResults is true", () => { + const testLogger = new Logger({ level: "debug" }); + const debugLogSpy = vi.spyOn(testLogger, "debug"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger: testLogger, + logToolResults: true, + }); + + const handler = getHandler(eventHandlers, "agent_tool_end"); + const mockTool = { name: "testTool" }; + + handler(null, null, mockTool, "short result"); + + expect(debugLogSpy).toHaveBeenCalledWith("Tool result preview", { + preview: "short result", + }); + }); + + it("truncates long results based on resultPreviewLimit", () => { + const testLogger = new Logger({ level: "debug" }); + const debugLogSpy = vi.spyOn(testLogger, "debug"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger: testLogger, + logToolResults: true, + resultPreviewLimit: 10, + }); + + const handler = getHandler(eventHandlers, "agent_tool_end"); + const mockTool = { name: "testTool" }; + + handler(null, null, mockTool, "this is a very long result string"); + + expect(debugLogSpy).toHaveBeenCalledWith("Tool result preview", { + preview: "this is a ...", + }); + }); + + it("does not log result when logToolResults is false", () => { + const testLogger = new Logger({ level: "debug" }); + const debugLogSpy = vi.spyOn(testLogger, "debug"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger: testLogger, + logToolResults: false, + }); + + const handler = getHandler(eventHandlers, "agent_tool_end"); + const mockTool = { name: "testTool" }; + + handler(null, null, mockTool, "some result"); + + expect(debugLogSpy).not.toHaveBeenCalled(); + }); + }); + + describe("memorySession", () => { + it("returns the session instance", () => { + const agentRunner = new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + }); + + expect(agentRunner.memorySession).toBe(mockSessionInstance); + }); + }); + + describe("default config values", () => { + it("defaults logToolArgs to false", () => { + const toolLogSpy = vi.spyOn(logger, "tool"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger, + }); + + const handler = getHandler(eventHandlers, "agent_tool_start"); + const mockTool = { name: "testTool" }; + const mockDetails = { + toolCall: { id: "call-123", arguments: '{"key":"value"}' }, + }; + + handler(null, null, mockTool, mockDetails); + + // Should not include arguments + expect(toolLogSpy).toHaveBeenCalledWith("Calling tool", { + name: "testTool", + }); + }); + + it("defaults logToolResults to true", () => { + const testLogger = new Logger({ level: "debug" }); + const debugLogSpy = vi.spyOn(testLogger, "debug"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger: testLogger, + }); + + const handler = getHandler(eventHandlers, "agent_tool_end"); + const mockTool = { name: "testTool" }; + + handler(null, null, mockTool, "result"); + + expect(debugLogSpy).toHaveBeenCalled(); + }); + + it("defaults resultPreviewLimit to 200", () => { + const testLogger = new Logger({ level: "debug" }); + const debugLogSpy = vi.spyOn(testLogger, "debug"); + + new AgentRunner({ + name: "TestAgent", + model: "gpt-5-mini", + tools: [], + outputType: TestOutputSchema, + instructions: "Test instructions", + logger: testLogger, + }); + + const handler = getHandler(eventHandlers, "agent_tool_end"); + const mockTool = { name: "testTool" }; + const longResult = "x".repeat(250); + + handler(null, null, mockTool, longResult); + + // Should truncate at 200 chars + expect(debugLogSpy).toHaveBeenCalledWith("Tool result preview", { + preview: "x".repeat(200) + "...", + }); + }); + }); +}); diff --git a/src/clients/agent-runner.ts b/src/clients/agent-runner.ts new file mode 100644 index 0000000..ff8350c --- /dev/null +++ b/src/clients/agent-runner.ts @@ -0,0 +1,129 @@ +import { Agent, MemorySession, Runner } from "@openai/agents"; +import type { RunResult, Tool } from "@openai/agents"; +import type { ZodType } from "zod"; + +import type { Logger } from "./logger"; + +const DEFAULT_RESULT_PREVIEW_LIMIT = 200; + +export type AgentRunnerConfig = { + // Agent config + name: string; + model: "gpt-5-mini"; + tools: Tool[]; + outputType: ZodType; + instructions: string; + + // Logging config + logger: Logger; + logToolArgs?: boolean; + logToolResults?: boolean; + resultPreviewLimit?: number; + + /** + * If true, each run() call uses a fresh context (no session history). + * Required for reasoning models (gpt-5-mini) when making multiple independent runs. + */ + stateless?: boolean; +}; + +export type RunProps = { + prompt: string; + maxTurns?: number; + /** If true, run without session history (fresh context). Useful for independent follow-up queries. */ + stateless?: boolean; +}; + +type AgentType = Agent>; + +/** + * Wrapper around OpenAI Agent + Runner + MemorySession with built-in + * event logging for tool calls. Provides a consistent interface for + * running agents across different CLIs. + */ +export class AgentRunner { + private agent: AgentType; + private runner: Runner; + private session: MemorySession; + private logger: Logger; + private toolsInProgress: Set; + private logToolArgs: boolean; + private logToolResults: boolean; + private resultPreviewLimit: number; + private stateless: boolean; + + constructor(config: AgentRunnerConfig) { + this.logger = config.logger; + this.logToolArgs = config.logToolArgs ?? false; + this.logToolResults = config.logToolResults ?? true; + this.resultPreviewLimit = + config.resultPreviewLimit ?? DEFAULT_RESULT_PREVIEW_LIMIT; + this.toolsInProgress = new Set(); + this.stateless = config.stateless ?? false; + + this.agent = new Agent({ + name: config.name, + model: config.model, + tools: config.tools, + outputType: config.outputType, + instructions: config.instructions, + }); + + this.runner = new Runner(); + this.session = new MemorySession(); + + this.setupEventHandlers(); + } + + private setupEventHandlers(): void { + this.runner.on("agent_tool_start", (_context, _agent, tool, details) => { + const toolCall = details.toolCall as Record; + const callId = toolCall.id as string; + + // Deduplicate tool calls (events may fire multiple times) + if (this.toolsInProgress.has(callId)) { + return; + } + this.toolsInProgress.add(callId); + + if (this.logToolArgs) { + const args = String(toolCall.arguments); + this.logger.tool("Calling tool", { + name: tool.name, + args: args || "no arguments", + }); + } else { + this.logger.tool("Calling tool", { name: tool.name }); + } + }); + + this.runner.on("agent_tool_end", (_context, _agent, tool, result) => { + this.logger.tool("Tool completed", { name: tool.name }); + + if (this.logToolResults) { + const preview = + result.length > this.resultPreviewLimit + ? result.substring(0, this.resultPreviewLimit) + "..." + : result; + this.logger.debug("Tool result preview", { preview }); + } + }); + } + + async run({ + prompt, + ...rest + }: RunProps): Promise>> { + // When stateless=true, omit session to avoid reasoning item sequence errors + // that occur when reusing MemorySession with reasoning models + const sessionOption = this.stateless ? {} : { session: this.session }; + return this.runner.run(this.agent, prompt, { + ...sessionOption, + ...rest, + }); + } + + get memorySession(): MemorySession { + return this.session; + } +} diff --git a/src/clients/playwright-scraper.ts b/src/clients/playwright-scraper.ts index b158951..b2e4b0d 100644 --- a/src/clients/playwright-scraper.ts +++ b/src/clients/playwright-scraper.ts @@ -25,6 +25,23 @@ export type ScrapeRequest = { targetUrl: string; } & ScrapeOptions; +// Options for network capture during scraping +export type NetworkCaptureOptions = ScrapeOptions & { + captureUrlPattern: RegExp; // Pattern to match API requests to capture + captureTimeoutMs?: number; // Timeout waiting for API response (default: 15000) + validateResponse?: (data: unknown) => data is T; // Optional validator to filter responses + localStorage?: Record; // Key-value pairs to set in localStorage before navigation +}; + +export type NetworkCaptureRequest = { + targetUrl: string; +} & NetworkCaptureOptions; + +export type NetworkCaptureResult = { + data: T; + capturedUrl: string; +}; + /** * A web scraper client that uses Playwright to scrape webpages * requiring JavaScript rendering. Returns sanitized HTML or Markdown. @@ -80,8 +97,11 @@ export class PlaywrightScraper { const timeout = options.timeoutMs ?? this.defaultTimeoutMs; const waitStrategy = options.waitStrategy ?? this.defaultWaitStrategy; - this.logger.debug(`Navigating to: ${targetUrl}`); - this.logger.debug(`Wait strategy: ${waitStrategy}, timeout: ${timeout}ms`); + this.logger.debug("Navigating to URL", { targetUrl }); + this.logger.debug("Wait strategy", { + waitStrategy, + timeoutMs: timeout, + }); await page.goto(targetUrl, { timeout, @@ -89,7 +109,9 @@ export class PlaywrightScraper { }); if (options.waitForSelector) { - this.logger.debug(`Waiting for selector: ${options.waitForSelector}`); + this.logger.debug("Waiting for selector", { + selector: options.waitForSelector, + }); await page.waitForSelector(options.waitForSelector, { timeout }); } @@ -112,9 +134,9 @@ export class PlaywrightScraper { const html = await page.content(); const sanitized = sanitizeHtml(html); - this.logger.debug( - `Scraped and sanitized HTML (${sanitized.length} chars)` - ); + this.logger.debug("Scraped and sanitized HTML", { + length: sanitized.length, + }); return sanitized; } catch (error) { this.handleError({ targetUrl, error }); @@ -138,7 +160,7 @@ export class PlaywrightScraper { const html = await this.scrapeHtml({ targetUrl, ...options }); const markdown = convertToMarkdown(html); - this.logger.debug(`Converted to Markdown (${markdown.length} chars)`); + this.logger.debug("Converted to Markdown", { length: markdown.length }); return markdown; } @@ -157,29 +179,157 @@ export class PlaywrightScraper { }): void { if (error instanceof Error) { if (error.name === "TimeoutError" || error.message.includes("Timeout")) { - this.logger.error( - `Timeout while scraping ${targetUrl}: ${error.message}` - ); + this.logger.error("Timeout while scraping", { + targetUrl, + message: error.message, + }); return; } if (error.message.includes("net::ERR_")) { - this.logger.error( - `Network error scraping ${targetUrl}: ${error.message}` - ); + this.logger.error("Network error scraping", { + targetUrl, + message: error.message, + }); return; } if (error.message.includes("Navigation failed")) { - this.logger.error( - `Navigation failed for ${targetUrl}: ${error.message}` - ); + this.logger.error("Navigation failed", { + targetUrl, + message: error.message, + }); return; } - this.logger.error(`Error scraping ${targetUrl}: ${error.message}`); + this.logger.error("Error scraping", { + targetUrl, + message: error.message, + }); } else { - this.logger.error(`Unknown error scraping ${targetUrl}:`, error); + this.logger.error("Unknown error scraping", { targetUrl }, error); + } + } + + /** + * Scrape a URL while capturing a specific network response. + * Sets up route interception to capture JSON responses matching the URL pattern. + * If validateResponse is provided, only responses passing validation are captured. + */ + async scrapeWithNetworkCapture({ + targetUrl, + captureUrlPattern, + captureTimeoutMs = 15000, + validateResponse, + localStorage, + ...options + }: NetworkCaptureRequest): Promise> { + const browser = await this.getBrowser(); + const page = await browser.newPage(); + + // Set localStorage before any navigation if provided + if (localStorage && Object.keys(localStorage).length > 0) { + const entries = Object.entries(localStorage); + this.logger.debug("Setting localStorage entries", { + keys: entries.map(([k]) => k), + }); + + // Add init script that runs before page load to set localStorage + await page.addInitScript((items: [string, string][]) => { + for (const [key, value] of items) { + window.localStorage.setItem(key, value); + } + }, entries); + } + + let resolveCapture: (result: NetworkCaptureResult) => void; + let rejectCapture: (error: Error) => void; + let captured = false; + + const capturePromise = new Promise>( + (resolve, reject) => { + resolveCapture = resolve; + rejectCapture = reject; + } + ); + + const captureTimeout = setTimeout(() => { + rejectCapture( + new Error( + `Network capture timeout: No response matching ${captureUrlPattern.source} within ${captureTimeoutMs}ms` + ) + ); + }, captureTimeoutMs); + + try { + await page.route("**/*", async (route) => { + // Skip route handling if page is closing or already captured + if (page.isClosed()) { + return; + } + + const request = route.request(); + const url = request.url(); + + if (captureUrlPattern.test(url) && !captured) { + this.logger.debug("Intercepted matching request", { url }); + + try { + const response = await route.fetch(); + const body = await response.text(); + const data = JSON.parse(body) as unknown; + + // If validator provided, check if response matches expected shape + if (validateResponse && !validateResponse(data)) { + this.logger.debug("Response did not pass validation, skipping", { + url, + }); + await route.fulfill({ response }); + return; + } + + this.logger.debug("Captured network response", { + url, + bodyLength: body.length, + }); + + captured = true; + clearTimeout(captureTimeout); + + // Fulfill the route before resolving to avoid race condition + await route.fulfill({ response }); + resolveCapture({ data: data as T, capturedUrl: url }); + } catch (err) { + // Only continue if not already handled and page is still open + if (!page.isClosed()) { + this.logger.warn("Failed to capture response", { + url, + error: err, + }); + try { + await route.continue(); + } catch { + // Route may already be handled, ignore + } + } + } + } else { + try { + await route.continue(); + } catch { + // Route may already be handled or page closed, ignore + } + } + }); + + await this.navigateAndWait({ page, targetUrl, options }); + return await capturePromise; + } catch (error) { + clearTimeout(captureTimeout); + this.handleError({ targetUrl, error }); + throw error; + } finally { + await page.close(); } } diff --git a/src/tools/fetch-url/fetch-url-tool.test.ts b/src/tools/fetch-url/fetch-url-tool.test.ts index c8ee32e..7728793 100644 --- a/src/tools/fetch-url/fetch-url-tool.test.ts +++ b/src/tools/fetch-url/fetch-url-tool.test.ts @@ -3,7 +3,10 @@ import * as urlSafety from "~tools/utils/url-safety"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import type { FetchResult } from "./fetch-url-tool"; -import { fetchUrlTool } from "./fetch-url-tool"; +import { createFetchUrlTool } from "./fetch-url-tool"; + +// eslint-disable-next-line @typescript-eslint/no-empty-function +const mockLogger = { tool: () => {} } as never; // Mock the url-safety module vi.mock("~tools/utils/url-safety", async (importOriginal) => { @@ -59,7 +62,7 @@ const createMockResponse = (options: { const parseResult = (result: string): FetchResult => JSON.parse(result) as FetchResult; -describe("fetchUrlTool", () => { +describe("createFetchUrlTool", () => { beforeEach(() => { vi.resetAllMocks(); vi.stubGlobal("fetch", vi.fn()); @@ -83,7 +86,7 @@ describe("fetchUrlTool", () => { }); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "http://localhost/secret", }) ); @@ -100,7 +103,7 @@ describe("fetchUrlTool", () => { }); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "http://192.168.1.1/admin", }) ); @@ -116,7 +119,7 @@ describe("fetchUrlTool", () => { }); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "http://169.254.169.254/latest/meta-data/", }) ); @@ -143,7 +146,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/redirect", }) ); @@ -178,7 +181,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page1", }) ); @@ -203,7 +206,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/start", maxRedirects: 2, }) @@ -227,7 +230,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/bad-redirect", }) ); @@ -257,7 +260,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/start", }) ); @@ -276,7 +279,7 @@ describe("fetchUrlTool", () => { }) ); - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", etag: '"abc123"', }); @@ -297,7 +300,7 @@ describe("fetchUrlTool", () => { }) ); - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", lastModified: "Wed, 21 Oct 2024 07:28:00 GMT", }); @@ -320,7 +323,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", etag: '"abc123"', }) @@ -342,7 +345,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); @@ -361,7 +364,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); @@ -381,7 +384,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/large", maxBytes: 1024, }) @@ -403,7 +406,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/large", maxChars: 1000, }) @@ -423,7 +426,7 @@ describe("fetchUrlTool", () => { }); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/slow", timeoutMs: 1000, }) @@ -446,7 +449,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); @@ -466,7 +469,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); @@ -485,7 +488,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); @@ -502,7 +505,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); @@ -520,7 +523,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); @@ -539,7 +542,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); @@ -560,7 +563,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/file.txt", }) ); @@ -579,7 +582,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/api/data", }) ); @@ -599,7 +602,7 @@ describe("fetchUrlTool", () => { ); const result = parseResult( - await invokeTool(fetchUrlTool, { + await invokeTool(createFetchUrlTool({ logger: mockLogger }), { url: "https://example.com/page", }) ); diff --git a/src/tools/fetch-url/fetch-url-tool.ts b/src/tools/fetch-url/fetch-url-tool.ts index 0edbf9b..a12c023 100644 --- a/src/tools/fetch-url/fetch-url-tool.ts +++ b/src/tools/fetch-url/fetch-url-tool.ts @@ -1,5 +1,6 @@ import crypto from "node:crypto"; import { tool } from "@openai/agents"; +import type { Logger } from "~clients/logger"; import { processHtmlContent } from "~tools/utils/html-processing"; import { resolveAndValidateUrl } from "~tools/utils/url-safety"; @@ -358,75 +359,78 @@ const executeFetch = async (params: { } }; +export type FetchUrlToolOptions = { + logger: Logger; +}; + /** * Safe HTTP GET fetch tool for agent runtime. * Fetches web pages with SSRF protection, HTML sanitization, and Markdown conversion. */ -export const fetchUrlTool = tool({ - name: "fetchUrl", - description: - "Fetches a web page via HTTP GET and returns clean, sanitized Markdown content. " + - "Includes SSRF protection (blocks localhost, private IPs, cloud metadata endpoints). " + - "HTML content is sanitized to remove scripts, iframes, and event handlers before conversion.", - parameters: { - type: "object", - properties: { - url: { - type: "string", - description: "The URL to fetch (must be http or https)", - }, - timeoutMs: { - type: "number", - description: - "Request timeout in milliseconds (default: 15000, max: 30000)", - }, - maxBytes: { - type: "number", - description: - "Maximum response size in bytes (default: 2097152 / 2MB, max: 5242880 / 5MB)", - }, - maxRedirects: { - type: "number", - description: - "Maximum number of redirects to follow (default: 5, max: 10)", - }, - maxChars: { - type: "number", - description: - "Maximum characters in output markdown/text (default: 50000)", - }, - etag: { - type: "string", - description: "ETag from previous request for conditional fetch", - }, - lastModified: { - type: "string", - description: - "Last-Modified value from previous request for conditional fetch", +export const createFetchUrlTool = ({ logger }: FetchUrlToolOptions) => + tool({ + name: "fetchUrl", + description: + "Fetches a web page via HTTP GET and returns clean, sanitized Markdown content. " + + "Includes SSRF protection (blocks localhost, private IPs, cloud metadata endpoints). " + + "HTML content is sanitized to remove scripts, iframes, and event handlers before conversion.", + parameters: { + type: "object", + properties: { + url: { + type: "string", + description: "The URL to fetch (must be http or https)", + }, + timeoutMs: { + type: "number", + description: + "Request timeout in milliseconds (default: 15000, max: 30000)", + }, + maxBytes: { + type: "number", + description: + "Maximum response size in bytes (default: 2097152 / 2MB, max: 5242880 / 5MB)", + }, + maxRedirects: { + type: "number", + description: + "Maximum number of redirects to follow (default: 5, max: 10)", + }, + maxChars: { + type: "number", + description: + "Maximum characters in output markdown/text (default: 50000)", + }, + etag: { + type: "string", + description: "ETag from previous request for conditional fetch", + }, + lastModified: { + type: "string", + description: + "Last-Modified value from previous request for conditional fetch", + }, }, + required: ["url"], + additionalProperties: false, + }, + execute: async (params: { + url: string; + timeoutMs?: number; + maxBytes?: number; + maxRedirects?: number; + maxChars?: number; + etag?: string; + lastModified?: string; + }) => { + logger.tool("Fetching URL", { url: params.url }); + const result = await executeFetch(params); + logger.tool("Fetch result", { + ok: result.ok, + status: result.status, + finalUrl: result.finalUrl, + error: result.error, + }); + return JSON.stringify(result, null, 2); }, - required: ["url"], - additionalProperties: false, - }, - execute: async (params: { - url: string; - timeoutMs?: number; - maxBytes?: number; - maxRedirects?: number; - maxChars?: number; - etag?: string; - lastModified?: string; - }) => { - console.log("Fetching URL:", params.url); - const result = await executeFetch(params); - console.log("Fetch result:", { - ok: result.ok, - status: result.status, - finalUrl: result.finalUrl, - hasMarkdown: !!result.markdown, - hasText: !!result.text, - error: result.error, - }); - return JSON.stringify(result, null, 2); - }, -}); + }); diff --git a/src/tools/list-files/list-files-tool.test.ts b/src/tools/list-files/list-files-tool.test.ts index 660d3f7..106a32d 100644 --- a/src/tools/list-files/list-files-tool.test.ts +++ b/src/tools/list-files/list-files-tool.test.ts @@ -4,11 +4,13 @@ import { TMP_ROOT } from "~tools/utils/fs"; import { invokeTool, tryCreateSymlink } from "~tools/utils/test-utils"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; -import { listFilesTool } from "./list-files-tool"; +import { createListFilesTool } from "./list-files-tool"; -describe("listFilesTool tmp path safety", () => { +describe("createListFilesTool tmp path safety", () => { let testDir = ""; let relativeDir = ""; + // eslint-disable-next-line @typescript-eslint/no-empty-function + const mockLogger = { tool: () => {} } as never; beforeEach(async () => { await fs.mkdir(TMP_ROOT, { recursive: true }); @@ -29,6 +31,7 @@ describe("listFilesTool tmp path safety", () => { await fs.writeFile(path.join(testDir, "file2.txt"), "content2", "utf8"); await fs.mkdir(path.join(testDir, "subdir"), { recursive: true }); + const listFilesTool = createListFilesTool({ logger: mockLogger }); const result = await invokeTool(listFilesTool, { path: relativeDir, }); @@ -41,6 +44,7 @@ describe("listFilesTool tmp path safety", () => { it("lists files with absolute paths under tmp", async () => { await fs.writeFile(path.join(testDir, "absolute.txt"), "content", "utf8"); + const listFilesTool = createListFilesTool({ logger: mockLogger }); const result = await invokeTool(listFilesTool, { path: testDir, }); @@ -49,6 +53,7 @@ describe("listFilesTool tmp path safety", () => { }); it("lists root of tmp when no path provided", async () => { + const listFilesTool = createListFilesTool({ logger: mockLogger }); const result = await invokeTool(listFilesTool, {}); expect(result).toContain("Contents of tmp:"); @@ -56,6 +61,7 @@ describe("listFilesTool tmp path safety", () => { }); it("rejects path traversal attempts", async () => { + const listFilesTool = createListFilesTool({ logger: mockLogger }); const result = await invokeTool(listFilesTool, { path: "../", }); @@ -74,6 +80,7 @@ describe("listFilesTool tmp path safety", () => { const symlinkPath = path.join(relativeDir, "link"); + const listFilesTool = createListFilesTool({ logger: mockLogger }); const result = await invokeTool(listFilesTool, { path: symlinkPath, }); @@ -84,6 +91,7 @@ describe("listFilesTool tmp path safety", () => { const emptyDir = path.join(testDir, "empty"); await fs.mkdir(emptyDir, { recursive: true }); + const listFilesTool = createListFilesTool({ logger: mockLogger }); const result = await invokeTool(listFilesTool, { path: path.join(relativeDir, "empty"), }); diff --git a/src/tools/list-files/list-files-tool.ts b/src/tools/list-files/list-files-tool.ts index 466743d..e11bddf 100644 --- a/src/tools/list-files/list-files-tool.ts +++ b/src/tools/list-files/list-files-tool.ts @@ -1,39 +1,48 @@ import fs from "node:fs/promises"; import path from "node:path"; import { tool } from "@openai/agents"; +import type { Logger } from "~clients/logger"; import { resolveTmpPathForList, TMP_ROOT } from "~tools/utils/fs"; -export const listFilesTool = tool({ - name: "listFiles", - description: - "Lists files and directories under the repo tmp directory (path is relative to tmp). If no path provided, lists root of tmp.", - parameters: { - type: "object", - properties: { - path: { - type: "string", - description: - "Relative path within the repo tmp directory (optional, defaults to tmp root)", +export type ListFilesToolOptions = { + logger: Logger; +}; + +export const createListFilesTool = ({ logger }: ListFilesToolOptions) => + tool({ + name: "listFiles", + description: + "Lists files and directories under the repo tmp directory (path is relative to tmp). If no path provided, lists root of tmp.", + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: + "Relative path within the repo tmp directory (optional, defaults to tmp root)", + }, }, + required: [], + additionalProperties: false, }, - required: [], - additionalProperties: false, - }, - execute: async ({ path: dirPath }: { path?: string }) => { - console.log("Listing files at path:", dirPath ?? "(tmp root)"); - const targetPath = await resolveTmpPathForList(dirPath); - console.log("Resolved target path:", targetPath); + execute: async ({ path: dirPath }: { path?: string }) => { + logger.tool("Listing files", { path: dirPath ?? "tmp root" }); + const targetPath = await resolveTmpPathForList(dirPath); - const entries = await fs.readdir(targetPath, { withFileTypes: true }); - const lines = entries.map((entry) => { - const type = entry.isDirectory() ? "[dir] " : "[file]"; - return `${type} ${entry.name}`; - }); + const entries = await fs.readdir(targetPath, { withFileTypes: true }); + const lines = entries.map((entry) => { + const type = entry.isDirectory() ? "[dir] " : "[file]"; + return `${type} ${entry.name}`; + }); - const relativePath = path.relative(TMP_ROOT, targetPath); - const displayPath = relativePath || "tmp"; - return lines.length > 0 - ? `Contents of ${displayPath}:\n${lines.join("\n")}` - : `${displayPath} is empty`; - }, -}); + const relativePath = path.relative(TMP_ROOT, targetPath); + const displayPath = relativePath || "tmp"; + logger.tool("Listed entries", { + count: entries.length, + displayPath, + }); + return lines.length > 0 + ? `Contents of ${displayPath}:\n${lines.join("\n")}` + : `${displayPath} is empty`; + }, + }); diff --git a/src/tools/read-file/read-file-tool.test.ts b/src/tools/read-file/read-file-tool.test.ts index e23b2af..eba6541 100644 --- a/src/tools/read-file/read-file-tool.test.ts +++ b/src/tools/read-file/read-file-tool.test.ts @@ -4,11 +4,13 @@ import { TMP_ROOT } from "~tools/utils/fs"; import { invokeTool, tryCreateSymlink } from "~tools/utils/test-utils"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; -import { readFileTool } from "./read-file-tool"; +import { createReadFileTool } from "./read-file-tool"; -describe("readFileTool tmp path safety", () => { +describe("createReadFileTool tmp path safety", () => { let testDir = ""; let relativeDir = ""; + // eslint-disable-next-line @typescript-eslint/no-empty-function + const mockLogger = { tool: () => {} } as never; beforeEach(async () => { await fs.mkdir(TMP_ROOT, { recursive: true }); @@ -29,6 +31,7 @@ describe("readFileTool tmp path safety", () => { const content = "hello"; await fs.writeFile(path.join(TMP_ROOT, relativePath), content, "utf8"); + const readFileTool = createReadFileTool({ logger: mockLogger }); const readResult = await invokeTool(readFileTool, { path: relativePath, }); @@ -40,6 +43,7 @@ describe("readFileTool tmp path safety", () => { const content = "absolute"; await fs.writeFile(absolutePath, content, "utf8"); + const readFileTool = createReadFileTool({ logger: mockLogger }); const readResult = await invokeTool(readFileTool, { path: absolutePath, }); @@ -47,6 +51,7 @@ describe("readFileTool tmp path safety", () => { }); it("rejects path traversal attempts", async () => { + const readFileTool = createReadFileTool({ logger: mockLogger }); const readResult = await invokeTool(readFileTool, { path: "../outside.txt", }); @@ -65,6 +70,7 @@ describe("readFileTool tmp path safety", () => { const symlinkPath = path.join(relativeDir, "link", "file.txt"); + const readFileTool = createReadFileTool({ logger: mockLogger }); const readResult = await invokeTool(readFileTool, { path: symlinkPath, }); diff --git a/src/tools/read-file/read-file-tool.ts b/src/tools/read-file/read-file-tool.ts index fa32179..f122d48 100644 --- a/src/tools/read-file/read-file-tool.ts +++ b/src/tools/read-file/read-file-tool.ts @@ -1,26 +1,32 @@ import fs from "node:fs/promises"; import { tool } from "@openai/agents"; +import type { Logger } from "~clients/logger"; import { resolveTmpPathForRead } from "~tools/utils/fs"; -export const readFileTool = tool({ - name: "readFile", - description: - "Reads content from a file under the repo tmp directory (path is relative to tmp).", - parameters: { - type: "object", - properties: { - path: { - type: "string", - description: "Relative path within the repo tmp directory", +export type ReadFileToolOptions = { + logger: Logger; +}; + +export const createReadFileTool = ({ logger }: ReadFileToolOptions) => + tool({ + name: "readFile", + description: + "Reads content from a file under the repo tmp directory (path is relative to tmp).", + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: "Relative path within the repo tmp directory", + }, }, + required: ["path"], + additionalProperties: false, + }, + execute: async ({ path: filePath }: { path: string }) => { + logger.tool("Reading file", { path: filePath }); + const targetPath = await resolveTmpPathForRead(filePath); + logger.tool("Read file result", { targetPath }); + return fs.readFile(targetPath, "utf8"); }, - required: ["path"], - additionalProperties: false, - }, - execute: async ({ path: filePath }: { path: string }) => { - console.log("Reading file at path:", filePath); - const targetPath = await resolveTmpPathForRead(filePath); - console.log("Resolved target path:", targetPath); - return fs.readFile(targetPath, "utf8"); - }, -}); + }); diff --git a/src/tools/run-python/run-python-tool.test.ts b/src/tools/run-python/run-python-tool.test.ts new file mode 100644 index 0000000..357a4b8 --- /dev/null +++ b/src/tools/run-python/run-python-tool.test.ts @@ -0,0 +1,120 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import { TMP_ROOT } from "~tools/utils/fs"; +import { invokeTool } from "~tools/utils/test-utils"; +import { afterEach, beforeEach, describe, expect, it } from "vitest"; + +import type { PythonResult } from "./run-python-tool"; +import { createRunPythonTool, isValidScriptName } from "./run-python-tool"; + +describe("isValidScriptName", () => { + it("accepts valid script names", () => { + expect(isValidScriptName("hello.py")).toBe(true); + expect(isValidScriptName("my_script.py")).toBe(true); + expect(isValidScriptName("test-script.py")).toBe(true); + expect(isValidScriptName("Script123.py")).toBe(true); + }); + + it("rejects non-.py extensions", () => { + expect(isValidScriptName("hello.js")).toBe(false); + expect(isValidScriptName("hello.txt")).toBe(false); + expect(isValidScriptName("hello")).toBe(false); + expect(isValidScriptName("hello.py.txt")).toBe(false); + }); + + it("rejects path separators", () => { + expect(isValidScriptName("subdir/hello.py")).toBe(false); + expect(isValidScriptName("../hello.py")).toBe(false); + expect(isValidScriptName("subdir\\hello.py")).toBe(false); + }); + + it("rejects path traversal", () => { + expect(isValidScriptName("..hello.py")).toBe(false); + expect(isValidScriptName("hello..py")).toBe(false); + }); + + it("rejects special characters", () => { + expect(isValidScriptName("hello world.py")).toBe(false); + expect(isValidScriptName("hello@script.py")).toBe(false); + expect(isValidScriptName("hello$script.py")).toBe(false); + }); +}); + +describe("createRunPythonTool", () => { + let testDir = ""; + let scriptsDir = ""; + // eslint-disable-next-line @typescript-eslint/no-empty-function + const mockLogger = { tool: () => {} } as never; + + beforeEach(async () => { + await fs.mkdir(TMP_ROOT, { recursive: true }); + testDir = await fs.mkdtemp(path.join(TMP_ROOT, "vitest-python-")); + scriptsDir = path.join(testDir, "scripts"); + await fs.mkdir(scriptsDir, { recursive: true }); + }); + + afterEach(async () => { + if (testDir) { + await fs.rm(testDir, { recursive: true, force: true }); + } + testDir = ""; + scriptsDir = ""; + }); + + it("rejects invalid script names", async () => { + const tool = createRunPythonTool({ scriptsDir, logger: mockLogger }); + const resultJson = await invokeTool(tool, { + scriptName: "../etc/passwd", + input: "", + }); + const result = JSON.parse(resultJson) as PythonResult; + + expect(result.success).toBe(false); + expect(result.error).toContain("Invalid script name"); + }); + + it("handles non-existent scripts", async () => { + const tool = createRunPythonTool({ scriptsDir, logger: mockLogger }); + const resultJson = await invokeTool(tool, { + scriptName: "nonexistent.py", + input: "", + }); + const result = JSON.parse(resultJson) as PythonResult; + + expect(result.success).toBe(false); + // Python exits with non-zero code when script doesn't exist + expect(result.exitCode).not.toBe(0); + }); + + it("calls logger when provided", async () => { + const scriptContent = 'print("test")'; + await fs.writeFile(path.join(scriptsDir, "test.py"), scriptContent, "utf8"); + + const loggedMessages: string[] = []; + const mockLogger = { + tool: (msg: string) => loggedMessages.push(msg), + }; + + const tool = createRunPythonTool({ + scriptsDir, + logger: mockLogger as never, + }); + await invokeTool(tool, { scriptName: "test.py", input: "" }); + + expect(loggedMessages.length).toBe(2); + expect(loggedMessages[0]).toContain("Running Python script"); + expect(loggedMessages[1]).toContain("Python result"); + }); + + it("handles invalid JSON input", async () => { + const tool = createRunPythonTool({ scriptsDir, logger: mockLogger }); + const resultJson = await invokeTool(tool, { + scriptName: "any.py", + input: "not valid json", + }); + const result = JSON.parse(resultJson) as PythonResult; + + expect(result.success).toBe(false); + expect(result.error).toBe("Invalid JSON in input parameter"); + }); +}); diff --git a/src/tools/run-python/run-python-tool.ts b/src/tools/run-python/run-python-tool.ts new file mode 100644 index 0000000..96d0be8 --- /dev/null +++ b/src/tools/run-python/run-python-tool.ts @@ -0,0 +1,278 @@ +import { spawn } from "node:child_process"; +import path from "node:path"; +import { tool } from "@openai/agents"; +import type { Logger } from "~clients/logger"; + +/** + * Result of a Python script execution + */ +export type PythonResult = { + success: boolean; + exitCode: number | null; + stdout: string; + stderr: string; + durationMs: number; + error?: string; +}; + +/** + * Repo venv Python path for CLIs that use the project .venv. + */ +export const PYTHON_BINARY = path.join( + process.cwd(), + ".venv", + "bin", + "python3" +); + +/** + * Default configuration values + */ +const DEFAULTS = { + timeoutMs: 30000, + maxOutputBytes: 50 * 1024, // 50KB + pythonBinary: PYTHON_BINARY, +} as const; + +/** + * Maximum allowed values + */ +const MAX_VALUES = { + timeoutMs: 120000, +} as const; + +/** + * Clamp a value between bounds + */ +const clamp = (value: number, min: number, max: number): number => + Math.max(min, Math.min(max, value)); + +/** + * Validate that script name is safe (no path traversal) + */ +export const isValidScriptName = (scriptName: string): boolean => { + // Must end with .py + if (!scriptName.endsWith(".py")) { + return false; + } + + // No path separators allowed (no subdirectories) + if (scriptName.includes("/") || scriptName.includes("\\")) { + return false; + } + + // No path traversal + if (scriptName.includes("..")) { + return false; + } + + // Only allow alphanumeric, underscores, hyphens, and .py extension + const validPattern = /^[a-zA-Z0-9_-]+\.py$/; + return validPattern.test(scriptName); +}; + +/** + * Execute a Python script from the specified scripts directory + */ +const executePython = async (params: { + scriptsDir: string; + scriptName: string; + args?: string[]; + input?: Record; + timeoutMs?: number; + pythonBinary?: string; +}): Promise => { + const { + scriptsDir, + scriptName, + args = [], + input, + timeoutMs = DEFAULTS.timeoutMs, + pythonBinary = DEFAULTS.pythonBinary, + } = params; + + const startTime = Date.now(); + + // Validate script name + if (!isValidScriptName(scriptName)) { + return { + success: false, + exitCode: null, + stdout: "", + stderr: "", + durationMs: Date.now() - startTime, + error: `Invalid script name: "${scriptName}". Must be a .py file with no path separators.`, + }; + } + + const scriptPath = path.join(scriptsDir, scriptName); + const effectiveTimeout = clamp(timeoutMs, 1000, MAX_VALUES.timeoutMs); + + return new Promise((resolve) => { + const controller = new AbortController(); + const timeoutId = setTimeout(() => { + controller.abort(); + }, effectiveTimeout); + + let stdout = ""; + let stderr = ""; + let stdoutTruncated = false; + let stderrTruncated = false; + + const proc = spawn(pythonBinary, [scriptPath, ...args], { + signal: controller.signal, + cwd: scriptsDir, + }); + + // Write JSON input to stdin if provided + if (input !== undefined) { + proc.stdin.write(JSON.stringify(input)); + proc.stdin.end(); + } else { + proc.stdin.end(); + } + + proc.stdout.on("data", (data: Buffer) => { + if (stdout.length < DEFAULTS.maxOutputBytes) { + stdout += data.toString(); + if (stdout.length > DEFAULTS.maxOutputBytes) { + stdout = stdout.slice(0, DEFAULTS.maxOutputBytes); + stdoutTruncated = true; + } + } + }); + + proc.stderr.on("data", (data: Buffer) => { + if (stderr.length < DEFAULTS.maxOutputBytes) { + stderr += data.toString(); + if (stderr.length > DEFAULTS.maxOutputBytes) { + stderr = stderr.slice(0, DEFAULTS.maxOutputBytes); + stderrTruncated = true; + } + } + }); + + proc.on("close", (code) => { + clearTimeout(timeoutId); + const durationMs = Date.now() - startTime; + + if (stdoutTruncated) { + stdout += "\n[OUTPUT TRUNCATED]"; + } + if (stderrTruncated) { + stderr += "\n[OUTPUT TRUNCATED]"; + } + + resolve({ + success: code === 0, + exitCode: code, + stdout, + stderr, + durationMs, + }); + }); + + proc.on("error", (err) => { + clearTimeout(timeoutId); + const durationMs = Date.now() - startTime; + + if (err.name === "AbortError") { + resolve({ + success: false, + exitCode: null, + stdout, + stderr, + durationMs, + error: `Script execution timed out after ${effectiveTimeout}ms`, + }); + } else { + resolve({ + success: false, + exitCode: null, + stdout, + stderr, + durationMs, + error: err.message, + }); + } + }); + }); +}; + +export type RunPythonToolOptions = { + /** Absolute path to the directory containing Python scripts */ + scriptsDir: string; + /** Logger for tool execution logging */ + logger: Logger; + /** Python binary to use (defaults to "python3") */ + pythonBinary?: string; +}; + +/** + * Creates a tool to execute Python scripts from a specified directory. + * Scripts must be pre-defined .py files in the configured scriptsDir. + */ +export const createRunPythonTool = ({ + scriptsDir, + logger, + pythonBinary, +}: RunPythonToolOptions) => + tool({ + name: "runPython", + description: + "Executes a Python script from the configured scripts directory. " + + "Only .py files in the scripts directory can be executed. " + + "Optionally accepts JSON input to pass via stdin. " + + "Returns stdout, stderr, exit code, and execution time.", + parameters: { + type: "object", + properties: { + scriptName: { + type: "string", + description: + 'Name of the Python script to run (e.g., "hello.py"). Must be a .py file in the scripts directory.', + }, + input: { + type: "string", + description: + 'JSON string to pass to the script via stdin. Pass empty string "" if no input needed. The script should read from stdin using json.load(sys.stdin).', + }, + }, + required: ["scriptName", "input"], + additionalProperties: false, + }, + execute: async (params: { scriptName: string; input: string }) => { + logger.tool("Running Python script", { scriptName: params.scriptName }); + + // Parse the input string to object if provided (empty string means no input) + let parsedInput: Record | undefined; + if (params.input && params.input.trim() !== "") { + try { + parsedInput = JSON.parse(params.input) as Record; + } catch { + return JSON.stringify({ + success: false, + exitCode: null, + stdout: "", + stderr: "", + durationMs: 0, + error: "Invalid JSON in input parameter", + } satisfies PythonResult); + } + } + + const result = await executePython({ + scriptsDir, + scriptName: params.scriptName, + input: parsedInput, + pythonBinary, + }); + logger.tool("Python result", { + success: result.success, + exitCode: result.exitCode, + durationMs: result.durationMs, + error: result.error, + }); + return JSON.stringify(result, null, 2); + }, + }); diff --git a/src/tools/write-file/write-file-tool.test.ts b/src/tools/write-file/write-file-tool.test.ts index 18fee0a..563624f 100644 --- a/src/tools/write-file/write-file-tool.test.ts +++ b/src/tools/write-file/write-file-tool.test.ts @@ -4,11 +4,13 @@ import { TMP_ROOT } from "~tools/utils/fs"; import { invokeTool, tryCreateSymlink } from "~tools/utils/test-utils"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; -import { writeFileTool } from "./write-file-tool"; +import { createWriteFileTool } from "./write-file-tool"; -describe("writeFileTool tmp path safety", () => { +describe("createWriteFileTool tmp path safety", () => { let testDir = ""; let relativeDir = ""; + // eslint-disable-next-line @typescript-eslint/no-empty-function + const mockLogger = { tool: () => {} } as never; beforeEach(async () => { await fs.mkdir(TMP_ROOT, { recursive: true }); @@ -28,6 +30,7 @@ describe("writeFileTool tmp path safety", () => { const relativePath = path.join(relativeDir, "relative.txt"); const content = "hello"; + const writeFileTool = createWriteFileTool({ logger: mockLogger }); const writeResult = await invokeTool(writeFileTool, { path: relativePath, content, @@ -44,6 +47,7 @@ describe("writeFileTool tmp path safety", () => { const absolutePath = path.join(testDir, "absolute.txt"); const content = "absolute"; + const writeFileTool = createWriteFileTool({ logger: mockLogger }); const writeResult = await invokeTool(writeFileTool, { path: absolutePath, content, @@ -54,6 +58,7 @@ describe("writeFileTool tmp path safety", () => { }); it("rejects path traversal attempts", async () => { + const writeFileTool = createWriteFileTool({ logger: mockLogger }); const writeResult = await invokeTool(writeFileTool, { path: "../outside.txt", content: "nope", @@ -73,6 +78,7 @@ describe("writeFileTool tmp path safety", () => { const symlinkPath = path.join(relativeDir, "link", "file.txt"); + const writeFileTool = createWriteFileTool({ logger: mockLogger }); const writeResult = await invokeTool(writeFileTool, { path: symlinkPath, content: "nope", diff --git a/src/tools/write-file/write-file-tool.ts b/src/tools/write-file/write-file-tool.ts index 2d94482..ade755b 100644 --- a/src/tools/write-file/write-file-tool.ts +++ b/src/tools/write-file/write-file-tool.ts @@ -1,35 +1,44 @@ import fs from "node:fs/promises"; import path from "node:path"; import { tool } from "@openai/agents"; +import type { Logger } from "~clients/logger"; import { resolveTmpPathForWrite, TMP_ROOT } from "~tools/utils/fs"; -export const writeFileTool = tool({ - name: "writeFile", - description: - "Writes content to a file under the repo tmp directory (path is relative to tmp).", - parameters: { - type: "object", - properties: { - path: { - type: "string", - description: "Relative path within the repo tmp directory", +export type WriteFileToolOptions = { + logger: Logger; +}; + +export const createWriteFileTool = ({ logger }: WriteFileToolOptions) => + tool({ + name: "writeFile", + description: + "Writes content to a file under the repo tmp directory (path is relative to tmp).", + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: "Relative path within the repo tmp directory", + }, + content: { type: "string", description: "The content to write" }, }, - content: { type: "string", description: "The content to write" }, + required: ["path", "content"], + additionalProperties: false, + }, + execute: async ({ + path: filePath, + content, + }: { + path: string; + content: string; + }) => { + logger.tool("Writing file", { path: filePath }); + const targetPath = await resolveTmpPathForWrite(filePath); + await fs.writeFile(targetPath, content, "utf8"); + const relativePath = path.relative(TMP_ROOT, targetPath); + const displayPath = path.join("tmp", relativePath); + const bytes = Buffer.byteLength(content, "utf8"); + logger.tool("Wrote file", { bytes, path: displayPath }); + return `Wrote ${bytes} bytes to ${displayPath}`; }, - required: ["path", "content"], - additionalProperties: false, - }, - execute: async ({ - path: filePath, - content, - }: { - path: string; - content: string; - }) => { - console.log("Writing file at path:", filePath); - const targetPath = await resolveTmpPathForWrite(filePath); - await fs.writeFile(targetPath, content, "utf8"); - const relativePath = path.relative(TMP_ROOT, targetPath); - return `Wrote ${Buffer.byteLength(content, "utf8")} bytes to tmp/${relativePath}`; - }, -}); + }); diff --git a/src/utils/parse-args.ts b/src/utils/parse-args.ts index d26f588..f97c2c6 100644 --- a/src/utils/parse-args.ts +++ b/src/utils/parse-args.ts @@ -19,6 +19,6 @@ export const parseArgs = ({ }: ParseArgsOptions): z.infer => { logger.debug("Parsing CLI arguments..."); const args = schema.parse(argv); - logger.debug(`Parsed args: ${JSON.stringify(args)}`); + logger.debug("Parsed args", { args }); return args; }; diff --git a/src/utils/question-handler.ts b/src/utils/question-handler.ts index 745f785..973e51c 100644 --- a/src/utils/question-handler.ts +++ b/src/utils/question-handler.ts @@ -100,12 +100,12 @@ export class QuestionHandler { const validationMessage = errorMessage ?? result.error.issues[0]?.message ?? "Invalid input"; - this.logger.question(`Validation failed: ${validationMessage}`); + this.logger.question("Validation failed", { message: validationMessage }); if (attempts < maxRetries) { - this.logger.question( - `Please try again (${maxRetries - attempts} attempts remaining)` - ); + this.logger.question("Please try again", { + remainingAttempts: maxRetries - attempts, + }); } }