From 18bf25aadffc6640df220b150d18a77d8e3f110a Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Fri, 13 Feb 2026 07:10:50 -0800 Subject: [PATCH 01/36] Add superfacility staging fallbacks --- demo/superfacility/README.md | 7 ++ .../config_perlmutter_remote.yaml | 9 ++- .../config_perlmutter_remote_exact.yaml | 28 ++++++++ src/services/run_superfacility.py | 34 +++++++-- src/services/run_superfacility_tools.py | 70 +++++++++++++++++++ 5 files changed, 138 insertions(+), 10 deletions(-) create mode 100644 demo/superfacility/config_perlmutter_remote_exact.yaml diff --git a/demo/superfacility/README.md b/demo/superfacility/README.md index 5bdca86..47d58fd 100644 --- a/demo/superfacility/README.md +++ b/demo/superfacility/README.md @@ -74,14 +74,21 @@ Start with one of the configs in this folder and update the paths: - `config_perlmutter_remote.yaml`: stage from another host to Perlmutter via SFAPI Synapse-style shared layout (SFAPI on NERSC systems, example): +`/global/cfs/cdirs/$SBATCH_ACCOUNT/superfacility` + +Fallback if the shared layout is unavailable: `/global/cfs/cdirs/$SBATCH_ACCOUNT/$USER/superfacility` +The runner now attempts the shared layout first and falls back to the +user-specific path when it cannot verify or write to the shared location. + ## Example command ```bash # export SBATCH_ACCOUNT=amsc014 python amrex_agent.py \ --prompt "Run AMReX Advection_AmrCore using the default inputs file without changes." \ + --baseline-override AMReX/Tests/Amr/Advection_AmrCore/Exec \ --config demo/superfacility/config_perlmutter_remote.yaml \ --environment perlmutter ``` diff --git a/demo/superfacility/config_perlmutter_remote.yaml b/demo/superfacility/config_perlmutter_remote.yaml index 1990355..2c2ef78 100644 --- a/demo/superfacility/config_perlmutter_remote.yaml +++ b/demo/superfacility/config_perlmutter_remote.yaml @@ -7,17 +7,16 @@ environment: perlmutter output_dir: /tmp/amrex_agent_runs # Demo overrides (baseline + inputs) -baseline_override: AMReX/Tests/Amr/Advection_AmrCore/Exec inputs_file_override: inputs # Remote destination for staged runs -remote_output_dir: /global/cfs/cdirs/$SBATCH_ACCOUNT/$USER/superfacility +remote_output_dir: /global/cfs/cdirs/amsc014/superfacility/output remote_staging: true remote_staging_method: sfapi_client -remote_run_dir: /global/cfs/cdirs/$SBATCH_ACCOUNT/$USER/superfacility/run_test_1 -superfacility_account: $SBATCH_ACCOUNT +remote_run_dir: /global/cfs/cdirs/amsc014/superfacility/output/run_test_1 +superfacility_account: amsc014 # Remote executable discovery (finds *.ex, prefers CUDA) -remote_executable_template: /global/cfs/cdirs/$SBATCH_ACCOUNT/$USER/superfacility/{repo_name}/{case_dir} +remote_executable_template: /global/cfs/cdirs/amsc014/superfacility/{repo_name}/{case_dir} remote_executable_find: true # Repo path (optional; follows normal priority if unset) diff --git a/demo/superfacility/config_perlmutter_remote_exact.yaml b/demo/superfacility/config_perlmutter_remote_exact.yaml new file mode 100644 index 0000000..158a4b0 --- /dev/null +++ b/demo/superfacility/config_perlmutter_remote_exact.yaml @@ -0,0 +1,28 @@ +# Superfacility (SFAPI) config overrides for Perlmutter (remote staging) + +# Environment selection +environment: perlmutter + +# Local output directory (this host) +output_dir: /tmp/amrex_agent_runs + +# Demo overrides (baseline + inputs) +baseline_override: PeleLMeX/Exec/Production/JetInCrossflow +inputs_file_override: inputs + +# Remote destination for staged runs +remote_output_dir: /global/cfs/cdirs/amsc014/superfacility/output +remote_staging: true +remote_staging_method: sfapi_client +remote_run_dir: /global/cfs/cdirs/amsc014/superfacility/output/run_test_1 +superfacility_account: amsc014 + +# Remote executable (exact path) +remote_executable_path: /global/cfs/cdirs/amsc014/superfacility/PeleLMeX/Exec/Production/JetInCrossflow/PeleLMeX3d.gnu.MPI.CUDA.ex +remote_executable_find: false + +# Repo path (optional; follows normal priority if unset) +# amrex_repo_path: /path/to/amrex + +# Execution +allow_local_run: false diff --git a/src/services/run_superfacility.py b/src/services/run_superfacility.py index bcc8ade..ad7d79f 100644 --- a/src/services/run_superfacility.py +++ b/src/services/run_superfacility.py @@ -14,10 +14,13 @@ from src.services.build_tools import compile_amrex from src.services.run_superfacility_tools import ( + ensure_remote_directory_rest, generate_slurm_script, _load_sfapi_key_file, find_remote_executable, + list_remote_entries, monitor_job, + resolve_remote_output_dir, stage_out_outputs, stage_run_directory, submit_job, @@ -179,6 +182,12 @@ def _resolve_remote_executable( if rendered_path.suffix == ".ex": return rendered_path if getattr(self.config, "remote_executable_find", False): + try: + list_remote_entries(str(rendered_path), system=system) + except Exception as exc: + raise RuntimeError( + f"Remote case directory not found for executable search: {rendered_path}" + ) from exc found = find_remote_executable( remote_case_dir=str(rendered_path), system=system, @@ -195,6 +204,12 @@ def _resolve_remote_executable( return None remote_case_dir = Path("/global/cfs/cdirs") / account / user / repo_name / relative_case_dir + try: + list_remote_entries(str(remote_case_dir), system=system) + except Exception as exc: + raise RuntimeError( + f"Remote case directory not found for executable search: {remote_case_dir}" + ) from exc found = find_remote_executable( remote_case_dir=str(remote_case_dir), system=system, @@ -410,17 +425,26 @@ def submit(self, remote_run_dir = None exclude_names = None if remote_staging: - remote_output_dir = getattr(self.config, "remote_output_dir", None) - if remote_output_dir is None: - remote_output_dir = self.config.output_dir + preferred_output_dir = getattr(self.config, "remote_output_dir", None) + if preferred_output_dir is None: + preferred_output_dir = self.config.output_dir logger.warning( "[Config] remote_output_dir not set; defaulting staging target to %s", - remote_output_dir, + preferred_output_dir, ) - remote_output_dir = Path(os.path.expandvars(str(remote_output_dir))) + remote_output_dir = resolve_remote_output_dir( + preferred_output_dir=preferred_output_dir, + account=account, + user=os.getenv("USER"), + system=system, + ) fixed_remote_run_dir = getattr(self.config, "remote_run_dir", None) if fixed_remote_run_dir: remote_run_dir = Path(os.path.expandvars(str(fixed_remote_run_dir))) + ensure_remote_directory_rest( + remote_run_dir=str(remote_run_dir), + upload_host=system, + ) else: remote_run_dir = Path(remote_output_dir) / run_dir.name if remote_executable_path and executable: diff --git a/src/services/run_superfacility_tools.py b/src/services/run_superfacility_tools.py index f0d648e..44f0f9e 100644 --- a/src/services/run_superfacility_tools.py +++ b/src/services/run_superfacility_tools.py @@ -917,6 +917,76 @@ def _resolve_sfapi_credentials() -> tuple[str | None, str | None]: return None, None +def _superfacility_suffix(local_path: Path) -> str | None: + parts = local_path.parts + for idx, part in enumerate(parts): + if part == "superfacility": + suffix_parts = parts[idx + 1 :] + return str(Path(*suffix_parts)) if suffix_parts else "" + return None + + +def resolve_remote_output_dir( + preferred_output_dir: str | None, + account: str, + user: str | None = None, + system: str = "perlmutter", + nersc_session: dict | None = None, +) -> Path: + """ + Resolve a writable remote output directory with shared->user fallback. + """ + import logging + import os + + logger = logging.getLogger(__name__) + candidates: list[Path] = [] + suffix: str | None = None + if preferred_output_dir: + preferred_path = Path(os.path.expandvars(str(preferred_output_dir))) + if not str(preferred_path).startswith("/global/cfs/cdirs/"): + logger.info("Using local output dir for staging: %s", preferred_path) + return preferred_path + candidates.append(preferred_path) + suffix = _superfacility_suffix(preferred_path) + + shared_root = Path(f"/global/cfs/cdirs/{account}/superfacility") + user_root = ( + Path(f"/global/cfs/cdirs/{account}/{user}/superfacility") if user else None + ) + for root in [shared_root, user_root]: + if not root: + continue + candidate = root + if suffix is not None: + candidate = root / suffix if suffix else root + if candidate not in candidates: + candidates.append(candidate) + + for candidate in candidates: + try: + list_remote_entries(str(candidate), nersc_session=nersc_session, system=system) + except Exception as exc: + logger.debug("Remote output dir check failed for %s: %s", candidate, exc) + continue + try: + ensure_remote_directory_rest( + remote_run_dir=str(candidate), + nersc_session=nersc_session, + upload_host=system, + ) + except Exception as exc: + logger.debug("Remote output dir not writable for %s: %s", candidate, exc) + continue + logger.info("Using remote output dir: %s", candidate) + return candidate + + raise RuntimeError( + "No writable remote output dir found. Checked: " + + ", ".join(str(candidate) for candidate in candidates) + ) + + def _download_remote_file_sfapi( remote_path: str, local_path: str | Path, From e45bf8d137d0575ad9fc4a8e7fab5ba60ef5d712 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Fri, 13 Feb 2026 07:30:21 -0800 Subject: [PATCH 02/36] Add superfacility staging tests --- tests/e2e/readme_command_registry.json | 254 +++++++++++++++++++----- tests/e2e/test_readme_command_runner.py | 60 ++++++ tests/unit/test_superfacility_tools.py | 54 +++++ 3 files changed, 322 insertions(+), 46 deletions(-) diff --git a/tests/e2e/readme_command_registry.json b/tests/e2e/readme_command_registry.json index a468253..232e557 100644 --- a/tests/e2e/readme_command_registry.json +++ b/tests/e2e/readme_command_registry.json @@ -36,8 +36,8 @@ "start_line": 35 }, { - "first_line": "python amrex_agent.py --prompt \"Run AMReX Advection_AmrCore with a 64x64 grid and 2 AMR levels\"", - "hash": "sha256:36c7bdbbf7b8d90b4a9dbca773b91fbe5e68ede6593f1c7017e4ea036e689e71", + "first_line": "python amrex_agent.py \\", + "hash": "sha256:4e7d7d2c243598eb295d88a4a37e850a86e329d33f5f4e9bb389397f1a2194c9", "id": "README.md::bash_block::5", "index": 5, "kind": "bash_block", @@ -51,7 +51,7 @@ "index": 6, "kind": "bash_block", "path": "README.md", - "start_line": 55 + "start_line": 58 }, { "first_line": "python amrex_agent.py --prompt \"Run AMReX Advection_AmrCore with a 64x64 grid\" --config demo/amrex/config.yaml", @@ -60,7 +60,7 @@ "index": 7, "kind": "bash_block", "path": "README.md", - "start_line": 70 + "start_line": 73 }, { "first_line": "python amrex_agent.py \\", @@ -69,7 +69,7 @@ "index": 8, "kind": "bash_block", "path": "README.md", - "start_line": 93 + "start_line": 104 }, { "first_line": "python amrex_agent.py \\", @@ -78,7 +78,7 @@ "index": 9, "kind": "bash_block", "path": "README.md", - "start_line": 104 + "start_line": 115 }, { "first_line": "pytest tests/unit", @@ -87,7 +87,7 @@ "index": 10, "kind": "bash_block", "path": "README.md", - "start_line": 132 + "start_line": 143 }, { "first_line": "# 1. Clone the target AMReX code (if not already done)", @@ -359,15 +359,6 @@ "path": "database/indexing/README.md", "start_line": 651 }, - { - "first_line": "# Generate schema for PeleC", - "hash": "sha256:c76a68c7ba96f7f2f7eceaa5f770942d51d8822edccd26601478cfbd071ec40e", - "id": "database/schemas/README.md::bash_block::1", - "index": 1, - "kind": "bash_block", - "path": "database/schemas/README.md", - "start_line": 37 - }, { "first_line": "# From project root", "hash": "sha256:526d080f54d47035e7b7f063dde70631a0520fe009713c50f68ee797246a1420", @@ -431,15 +422,6 @@ "path": "database/scripts/README.md", "start_line": 199 }, - { - "first_line": "**Note:** This script is under development. For reliable testing, use Option 1 with `amrex_agent.py` and basic user requirement files.", - "hash": "sha256:4f81e95b8e596bae8d6fa03c5839abb949b8b3af46c60e6651cf73044c8cf5f5", - "id": "demo/README.md::amrex_agent_inline::1", - "index": 1, - "kind": "amrex_agent_inline", - "path": "demo/README.md", - "start_line": 197 - }, { "first_line": "cd /path/to/amrex-agent", "hash": "sha256:b6153898e80245459e5a15e933bd83427bd8bdb4724e82e5df0134e5e2894b1a", @@ -620,6 +602,15 @@ "path": "demo/README.md", "start_line": 231 }, + { + "first_line": "**Note:** This script is under development. For reliable testing, use Option 1 with `amrex_agent.py` and basic user requirement files.", + "hash": "sha256:4f81e95b8e596bae8d6fa03c5839abb949b8b3af46c60e6651cf73044c8cf5f5", + "id": "demo/README.md::amrex_agent_inline::1", + "index": 1, + "kind": "amrex_agent_inline", + "path": "demo/README.md", + "start_line": 197 + }, { "first_line": "export ALCF_API_KEY=your_access_token", "hash": "sha256:cc75cd9bfacd0bef4e24290cb91411a8a36fc56ffc53943bdf00f3440e0ed98d", @@ -773,6 +764,96 @@ "path": "demo/erf/README.md", "start_line": 34 }, + { + "first_line": "cd /global/cfs/cdirs/amsc014/superfacility/amrex-agent", + "hash": "sha256:1d59725073fa20f2106ed37942d5953c33ba4b9496e16e5c6e4988947e43fcfb", + "id": "demo/mcp/README.md::bash_block::1", + "index": 1, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 10 + }, + { + "first_line": "conda env create -f environment.yaml", + "hash": "sha256:f81bd5f2144ae1b693dfb4442dda601d3a7d8207d67b64cec8594b4bfd220ce7", + "id": "demo/mcp/README.md::bash_block::2", + "index": 2, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 17 + }, + { + "first_line": "module use /soft/modulefiles", + "hash": "sha256:346c06319cb04c2ac5d538d724f28dd972ed3fd648b69c008d85525cc5e4211b", + "id": "demo/mcp/README.md::bash_block::3", + "index": 3, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 24 + }, + { + "first_line": "cd /global/cfs/cdirs/amsc014/superfacility/amrex-agent", + "hash": "sha256:ca8203801498ee34f5332c59704b81cf4866f32be2bf30caac8a21a44eed05f8", + "id": "demo/mcp/README.md::bash_block::4", + "index": 4, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 42 + }, + { + "first_line": "cd /lus/eagle/projects/COMB-FLOW-UNI/mcp/amrex-agent", + "hash": "sha256:86321cc89cddab6084a631582d8d1a362f3204c515c419e8bb627247911543cf", + "id": "demo/mcp/README.md::bash_block::5", + "index": 5, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 50 + }, + { + "first_line": "python demo/mcp/mcp_stdio_client_smoke.py", + "hash": "sha256:b8d309c23897620c9842486dc7968711feb70b19d08b4370a05342f30eea16a3", + "id": "demo/mcp/README.md::bash_block::6", + "index": 6, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 57 + }, + { + "first_line": "python demo/mcp/mcp_execute_workflow_examples.py", + "hash": "sha256:11a728be0dac834506b22d669e2914105ee7c284e430a4eb445b77e0df3edda5", + "id": "demo/mcp/README.md::bash_block::7", + "index": 7, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 100 + }, + { + "first_line": "python demo/mcp/mcp_inprocess_examples.py", + "hash": "sha256:624872aa809c6348936d478ca8d33097aa067c6f0ca3066de418d3851a4ada0a", + "id": "demo/mcp/README.md::bash_block::8", + "index": 8, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 108 + }, + { + "first_line": "./demo/mcp/run_dns_cli_examples.sh", + "hash": "sha256:5d298aa7e620abb1c8432da7b173407e92933816e1b9aacd5508ed186c94291d", + "id": "demo/mcp/README.md::bash_block::9", + "index": 9, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 114 + }, + { + "first_line": "./demo/mcp/run_dns_cli_examples_dry_verbose.sh", + "hash": "sha256:fc29455ac36ccafda8e0188a7186c27b78b1e1e603bb3a41db1aec0b5664b583", + "id": "demo/mcp/README.md::bash_block::10", + "index": 10, + "kind": "bash_block", + "path": "demo/mcp/README.md", + "start_line": 120 + }, { "first_line": "bash demo/setup_demo_database.sh --code pelec", "hash": "sha256:4973f05325f853e874e58e3830b91f38544eba7caada0695645a8b4e5c95a334", @@ -874,13 +955,22 @@ }, { "first_line": "python amrex_agent.py \\", - "hash": "sha256:2573fa46bbc8e087921e56c320bbd4379b697086dc7a8d59a453bb172d234b39", + "hash": "sha256:edfef53a53e263a9e0914642d29f9a87791377a55a63a77187e872ea9d78e590", "id": "demo/pelelmex/README.md::bash_block::8", "index": 8, "kind": "bash_block", "path": "demo/pelelmex/README.md", "start_line": 77 }, + { + "first_line": "python amrex_agent.py \\", + "hash": "sha256:e3be2cd0df49f39c1c8f370a1d752b4a995b5bb5813b1b0f885d3d5b1570e05d", + "id": "demo/remora/README.md::bash_block::1", + "index": 1, + "kind": "bash_block", + "path": "demo/remora/README.md", + "start_line": 7 + }, { "first_line": "export SFAPI_KEY_PATH=/path/to/priv_key.pem", "hash": "sha256:44141da296bb9d0107394cc56cd355d556e83d83d4d37b707e89294a20f52782", @@ -888,7 +978,7 @@ "index": 1, "kind": "bash_block", "path": "demo/superfacility/README.md", - "start_line": 18 + "start_line": 19 }, { "first_line": "cp ~/.superfacility/priv_key.pem ~/.superfacility/priv_key.pem.no_client_id", @@ -897,7 +987,7 @@ "index": 2, "kind": "bash_block", "path": "demo/superfacility/README.md", - "start_line": 51 + "start_line": 52 }, { "first_line": "export NERSC_API_TOKEN=your_token", @@ -906,34 +996,25 @@ "index": 3, "kind": "bash_block", "path": "demo/superfacility/README.md", - "start_line": 63 + "start_line": 64 }, { "first_line": "# export SBATCH_ACCOUNT=amsc014", - "hash": "sha256:2af258b85fb1a349bfd4708659189f009890744c503bb613b2f77f6e73a7798c", + "hash": "sha256:2551fb16b78ce6d8e946eec0311bd79c6aaf0b6968bd81a4df2b484d58b74f36", "id": "demo/superfacility/README.md::bash_block::4", "index": 4, "kind": "bash_block", "path": "demo/superfacility/README.md", - "start_line": 80 + "start_line": 87 }, { "first_line": "# export SBATCH_ACCOUNT=amsc014", - "hash": "sha256:98afdccff070e7c83638ecff4e27b50bc17a3462c0053374e6121b42f4a9523f", + "hash": "sha256:f3d445a657b47410e0205e8915ac35eae2910f558aefb1c4277a4df37b8342c8", "id": "demo/superfacility/README.md::bash_block::5", "index": 5, "kind": "bash_block", "path": "demo/superfacility/README.md", - "start_line": 90 - }, - { - "first_line": "Now run any demo or `amrex_agent.py` as usual. Retrieval will use the hosted", - "hash": "sha256:01ddafaface9f6a9bcb96fc7c2d0d876efa95f5ed12445585758227f154a0ffc", - "id": "demo/vector_store/README.md::amrex_agent_inline::1", - "index": 1, - "kind": "amrex_agent_inline", - "path": "demo/vector_store/README.md", - "start_line": 50 + "start_line": 98 }, { "first_line": "export OPENAI_API_KEY=your_key_here", @@ -989,6 +1070,51 @@ "path": "demo/vector_store/README.md", "start_line": 38 }, + { + "first_line": "Now run any demo or `amrex_agent.py` as usual. Retrieval will use the hosted", + "hash": "sha256:01ddafaface9f6a9bcb96fc7c2d0d876efa95f5ed12445585758227f154a0ffc", + "id": "demo/vector_store/README.md::amrex_agent_inline::1", + "index": 1, + "kind": "amrex_agent_inline", + "path": "demo/vector_store/README.md", + "start_line": 50 + }, + { + "first_line": "cd run_20260202_091048", + "hash": "sha256:b43482a22e6a170aa27a95abd3254b5830b2bfd2f2801b15abb902633439b4d6", + "id": "output/run_20260202_091048/README.md::bash_block::1", + "index": 1, + "kind": "bash_block", + "path": "output/run_20260202_091048/README.md", + "start_line": 18 + }, + { + "first_line": "tail -f run.out", + "hash": "sha256:62052093f40b6d7db0c8b99e86d8bf2e26ed5fdd15760285c8f6e9f37031dfe5", + "id": "output/run_20260202_091048/README.md::bash_block::2", + "index": 2, + "kind": "bash_block", + "path": "output/run_20260202_091048/README.md", + "start_line": 24 + }, + { + "first_line": "cd run_20260202_091204", + "hash": "sha256:cc5a923e195382c15e269b3e0aa11d9ed2aeb1f8132eba9d92505938075bc1d7", + "id": "output/run_20260202_091204/README.md::bash_block::1", + "index": 1, + "kind": "bash_block", + "path": "output/run_20260202_091204/README.md", + "start_line": 18 + }, + { + "first_line": "tail -f run.out", + "hash": "sha256:62052093f40b6d7db0c8b99e86d8bf2e26ed5fdd15760285c8f6e9f37031dfe5", + "id": "output/run_20260202_091204/README.md::bash_block::2", + "index": 2, + "kind": "bash_block", + "path": "output/run_20260202_091204/README.md", + "start_line": 24 + }, { "first_line": "pytest --unit", "hash": "sha256:add09341b6512043e4c1c9a669bf63bce56b49fd471e3b20ec544ae0069caaa1", @@ -1008,13 +1134,49 @@ "start_line": 39 }, { - "first_line": "pytest --unit --cov --cov-report=term-missing", - "hash": "sha256:8f8c3aabf88eeb9a64d11e7c20d28ee9b84915a32f20179bd9bb22e925bc24d5", + "first_line": "python -m tests.e2e.readme_command_runner --dry-run", + "hash": "sha256:427a274eaf1c02c394dc063e0a70b4a9710d079917ffa3b98b73cc89e2823d0a", "id": "tests/README.md::bash_block::3", "index": 3, "kind": "bash_block", "path": "tests/README.md", - "start_line": 49 + "start_line": 54 + }, + { + "first_line": "pytest --unit --cov --cov-report=term-missing", + "hash": "sha256:8f8c3aabf88eeb9a64d11e7c20d28ee9b84915a32f20179bd9bb22e925bc24d5", + "id": "tests/README.md::bash_block::4", + "index": 4, + "kind": "bash_block", + "path": "tests/README.md", + "start_line": 77 + }, + { + "first_line": "We track README bash blocks and `amrex_agent.py` calls in", + "hash": "sha256:d2ae73f87341e4ccd8c587b34d54f04d028ab7e3dcb3cef12c5712ecc0f7ab97", + "id": "tests/README.md::amrex_agent_inline::1", + "index": 1, + "kind": "amrex_agent_inline", + "path": "tests/README.md", + "start_line": 48 + }, + { + "first_line": "Execution is limited to commands that include `amrex_agent.py`, with a 30s", + "hash": "sha256:8424e96a60c6926b49b995ac35128eb5eb0c0073e9990872558e08f8f2d5bc5a", + "id": "tests/README.md::amrex_agent_inline::2", + "index": 2, + "kind": "amrex_agent_inline", + "path": "tests/README.md", + "start_line": 66 + }, + { + "first_line": "uses all `amrex_agent.py` commands in those files. It also skips if no LLM API", + "hash": "sha256:7663dac4223621dd624da8374d287b57eb0a94e0af5cce08a6b647e98401b62e", + "id": "tests/README.md::amrex_agent_inline::3", + "index": 3, + "kind": "amrex_agent_inline", + "path": "tests/README.md", + "start_line": 68 }, { "first_line": "# Run only demo e2e smoke tests", @@ -1052,4 +1214,4 @@ "path": "tests/integration/README.md", "start_line": 155 } -] +] \ No newline at end of file diff --git a/tests/e2e/test_readme_command_runner.py b/tests/e2e/test_readme_command_runner.py index bae8c0d..a363b1a 100644 --- a/tests/e2e/test_readme_command_runner.py +++ b/tests/e2e/test_readme_command_runner.py @@ -60,6 +60,47 @@ def test_readme_command_runner_execute_amrex_agent_only() -> None: pytest.fail("\n".join(lines)) +@pytest.mark.e2e +@pytest.mark.use_real_services +@pytest.mark.requires_repos +@pytest.mark.requires_schema +def test_readme_command_runner_execute_superfacility_sfapi() -> None: + if not _assets_available(): + pytest.skip("Required repos/schemas/indices not available for README execution.") + if not _llm_available(): + pytest.skip("LLM API key not available for README execution.") + if not _sfapi_available(): + pytest.skip("SFAPI credentials not available for README execution.") + + repo_root = __import__("pathlib").Path(__file__).resolve().parents[2] + file_filter = ["demo/superfacility/README.md"] + results = run_commands_by_file( + repo_root, + file_filter=file_filter, + dry_run=False, + timeout_seconds=120, + stop_on_failure=True, + entry_filter=_is_executable_readme_command, + command_transform=_force_stage_run, + ) + + failures = [ + (path, entry) + for path, entries in results.items() + for entry in entries + if entry["status"] in {"failed", "timeout"} + ] + if failures: + lines = ["Superfacility README command execution failures:"] + for path, entry in failures[:20]: + lines.append( + f" - {path} :: {entry['id']} ({entry['status']}, rc={entry.get('returncode')})" + ) + if len(failures) > 20: + lines.append(f" - ... and {len(failures) - 20} more") + pytest.fail("\n".join(lines)) + + @pytest.mark.e2e @pytest.mark.skip(reason="Manual-only: run full README commands when needed.") def test_readme_command_runner_execute_full_file_manual() -> None: @@ -126,6 +167,25 @@ def _force_dry_run(command: str) -> str: return f"{command} --run-mode dry" +def _force_stage_run(command: str) -> str: + if "amrex_agent.py" not in command: + return command + if "--run-mode" in command or "--dry-run" in command: + return command + return f"{command} --run-mode stage" + + def _is_executable_readme_command(entry: dict) -> bool: text = entry["text"] return "amrex_agent.py" in text + + +def _sfapi_available() -> bool: + from src.services.run_superfacility_tools import _resolve_sfapi_credentials + + client_id, secret = _resolve_sfapi_credentials() + if client_id and secret: + return True + if os.getenv("NERSC_API_TOKEN") or os.getenv("SFAPI_TOKEN"): + return True + return False diff --git a/tests/unit/test_superfacility_tools.py b/tests/unit/test_superfacility_tools.py index 88673fe..496207f 100644 --- a/tests/unit/test_superfacility_tools.py +++ b/tests/unit/test_superfacility_tools.py @@ -6,6 +6,7 @@ create_nersc_session, generate_slurm_script, monitor_job, + resolve_remote_output_dir, submit_job, submit_via_sfapi, ) @@ -97,3 +98,56 @@ def fake_run(cmd, capture_output=True, text=True): state = monitor_job(job_id="123", method="sbatch", poll_interval=0, max_polls=2) assert state == "COMPLETED" + + +def test_resolve_remote_output_dir_prefers_shared_then_fallback(monkeypatch): + checks = [] + + def fake_list_remote_entries(remote_dir, **_kwargs): + checks.append(remote_dir) + if remote_dir.endswith("/acct/superfacility/output"): + raise RuntimeError("shared missing") + return [{"name": "ok"}] + + def fake_ensure_remote_directory_rest(remote_run_dir, **_kwargs): + checks.append(f"ensure:{remote_run_dir}") + + monkeypatch.setattr( + "src.services.run_superfacility_tools.list_remote_entries", + fake_list_remote_entries, + ) + monkeypatch.setattr( + "src.services.run_superfacility_tools.ensure_remote_directory_rest", + fake_ensure_remote_directory_rest, + ) + + result = resolve_remote_output_dir( + preferred_output_dir="/global/cfs/cdirs/acct/superfacility/output", + account="acct", + user="jdoe", + ) + + assert result == Path("/global/cfs/cdirs/acct/jdoe/superfacility/output") + assert f"ensure:{result}" in checks + + +def test_resolve_remote_output_dir_returns_local_path(monkeypatch): + def fail_call(*_args, **_kwargs): + raise AssertionError("unexpected remote check") + + monkeypatch.setattr( + "src.services.run_superfacility_tools.list_remote_entries", + fail_call, + ) + monkeypatch.setattr( + "src.services.run_superfacility_tools.ensure_remote_directory_rest", + fail_call, + ) + + result = resolve_remote_output_dir( + preferred_output_dir="/tmp/amrex_agent_runs", + account="acct", + user="jdoe", + ) + + assert result == Path("/tmp/amrex_agent_runs") From 1e861569579ee6d9411f90d05e37222debeda2f8 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Fri, 13 Feb 2026 07:44:49 -0800 Subject: [PATCH 03/36] Clarify SFAPI auth errors for remote executable lookup --- src/services/run_superfacility.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/services/run_superfacility.py b/src/services/run_superfacility.py index ad7d79f..26dc8db 100644 --- a/src/services/run_superfacility.py +++ b/src/services/run_superfacility.py @@ -185,8 +185,15 @@ def _resolve_remote_executable( try: list_remote_entries(str(rendered_path), system=system) except Exception as exc: + hint = "" + if "No NERSC session" in str(exc): + hint = ( + " (SFAPI auth missing; set SFAPI_KEY_PATH or NERSC_API_TOKEN " + "to enable remote directory listing)" + ) raise RuntimeError( - f"Remote case directory not found for executable search: {rendered_path}" + "Remote case directory not available for executable search: " + f"{rendered_path}{hint}" ) from exc found = find_remote_executable( remote_case_dir=str(rendered_path), @@ -207,8 +214,15 @@ def _resolve_remote_executable( try: list_remote_entries(str(remote_case_dir), system=system) except Exception as exc: + hint = "" + if "No NERSC session" in str(exc): + hint = ( + " (SFAPI auth missing; set SFAPI_KEY_PATH or NERSC_API_TOKEN " + "to enable remote directory listing)" + ) raise RuntimeError( - f"Remote case directory not found for executable search: {remote_case_dir}" + "Remote case directory not available for executable search: " + f"{remote_case_dir}{hint}" ) from exc found = find_remote_executable( remote_case_dir=str(remote_case_dir), From c8d1008bca046efab29b59d7cdc213d1c3bc12c8 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Fri, 13 Feb 2026 08:06:25 -0800 Subject: [PATCH 04/36] Use sfapi_client key paths for auth --- demo/superfacility/README.md | 3 + src/services/run_superfacility_tools.py | 98 +++++++++++++++++++++---- 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/demo/superfacility/README.md b/demo/superfacility/README.md index 47d58fd..0a8660c 100644 --- a/demo/superfacility/README.md +++ b/demo/superfacility/README.md @@ -21,6 +21,9 @@ export SFAPI_KEY_PATH=/path/to/priv_key.pem # or SUPERFACILITY_KEY_PATH / NERSC_SFAPI_KEY_PATH ``` +The key file must be read/write only by you (`chmod 600`), or `sfapi_client` +will refuse to authenticate. + To obtain a client key, follow the NERSC SFAPI client instructions: ``` https://docs.nersc.gov/services/sfapi/authentication/#client diff --git a/src/services/run_superfacility_tools.py b/src/services/run_superfacility_tools.py index 44f0f9e..9aee5e6 100644 --- a/src/services/run_superfacility_tools.py +++ b/src/services/run_superfacility_tools.py @@ -164,8 +164,8 @@ def _parse_sfapi_pem(key_str: str) -> tuple[str, str] | None: return first_line, secret -def _load_sfapi_key_file() -> tuple[str, str] | None: - """Load Superfacility API client ID + key from a PEM file, if present.""" +def _resolve_sfapi_key_path() -> Path | None: + """Return the first available SFAPI PEM key path, if any.""" env_paths = [ os.getenv("SFAPI_KEY_PATH"), os.getenv("SUPERFACILITY_KEY_PATH"), @@ -175,23 +175,36 @@ def _load_sfapi_key_file() -> tuple[str, str] | None: superfacility_dir = Path.home() / ".superfacility" if superfacility_dir.exists(): - search_paths.extend(sorted(superfacility_dir.glob("*.pem"))) search_paths.append(superfacility_dir / "key.pem") search_paths.append(superfacility_dir / "priv_key.pem") + search_paths.extend(sorted(superfacility_dir.glob("*.pem"))) search_paths.extend([ Path.cwd() / "priv_key.pem", Path.home() / "sfapi" / "priv_key.pem", ]) + seen = set() for path in search_paths: + if path in seen: + continue + seen.add(path) if path.exists(): - try: - parsed = _parse_sfapi_pem(path.read_text()) - if parsed: - return parsed - except Exception: - continue + return path + return None + + +def _load_sfapi_key_file() -> tuple[str, str] | None: + """Load Superfacility API client ID + key from a PEM file, if present.""" + key_path = _resolve_sfapi_key_path() + if not key_path: + return None + try: + parsed = _parse_sfapi_pem(key_path.read_text()) + if parsed: + return parsed + except Exception: + return None return None @@ -259,6 +272,7 @@ def submit_via_sfapi_client( system: str = "perlmutter", client_id: str | None = None, secret: str | None = None, + key_path: str | Path | None = None, is_path: bool = False, config: dict | None = None, ) -> dict[str, Any]: @@ -275,6 +289,8 @@ def submit_via_sfapi_client( SFAPI OAuth client ID. secret : str or None, optional SFAPI private key (PEM). + key_path : str or Path or None, optional + SFAPI key file path (first line client_id, remainder PEM). config : dict or None, optional Optional configuration overrides. @@ -291,7 +307,7 @@ def submit_via_sfapi_client( except Exception: return {"error": "sfapi_client not available"} - if not client_id or not secret: + if not key_path and (not client_id or not secret): return {"error": "Missing SFAPI client credentials"} machine = Machine.perlmutter @@ -300,7 +316,11 @@ def submit_via_sfapi_client( try: logger = logging.getLogger(__name__) - with Client(client_id=client_id, secret=secret) as client: + if key_path: + client = Client(key=Path(key_path)) + else: + client = Client(client_id=client_id, secret=secret) + with client: perlmutter = client.compute(machine) if is_path: job = perlmutter.submit_job(script_path) @@ -475,19 +495,22 @@ def submit_job( logger = logging.getLogger(__name__) client_id = None secret = None + key_path = None if config is not None: client_id = config.get("superfacility_client_id") secret = config.get("superfacility_secret") + key_path = _resolve_sfapi_key_path() if not client_id or not secret: parsed = _load_sfapi_key_file() if parsed: client_id, secret = parsed - if client_id and secret: + if key_path or (client_id and secret): result = submit_via_sfapi_client( script_path=script_path, system=system, client_id=client_id, secret=secret, + key_path=key_path, is_path=is_path, config=config, ) @@ -524,6 +547,7 @@ def stage_run_directory_sfapi_client( remote_run_dir: str, client_id: str | None = None, secret: str | None = None, + key_path: str | Path | None = None, exclude_names: list[str] | None = None, ) -> None: """ @@ -546,7 +570,7 @@ def stage_run_directory_sfapi_client( except Exception as exc: raise RuntimeError("sfapi_client not available for staging") from exc - if not client_id or not secret: + if not key_path and (not client_id or not secret): raise RuntimeError("Missing SFAPI client credentials for staging") local_run_dir = Path(local_run_dir) @@ -554,7 +578,11 @@ def stage_run_directory_sfapi_client( if not local_run_dir.exists(): raise FileNotFoundError(f"Local run directory not found: {local_run_dir}") - with Client(client_id=client_id, secret=secret) as client: + if key_path: + client = Client(key=Path(key_path)) + else: + client = Client(client_id=client_id, secret=secret) + with client: perlmutter = client.compute(Machine.perlmutter) target_dir = None @@ -690,6 +718,7 @@ def stage_run_directory( remote_run_dir: str, client_id: str | None = None, secret: str | None = None, + key_path: str | Path | None = None, method: str = "auto", nersc_session: dict | None = None, exclude_names: list[str] | None = None, @@ -702,12 +731,16 @@ def stage_run_directory( - sfapi_client: require sfapi_client credentials - rest_upload: use REST upload endpoint (token/OAuth) """ + if key_path is None: + key_path = _resolve_sfapi_key_path() + if method == "sfapi_client": stage_run_directory_sfapi_client( local_run_dir=local_run_dir, remote_run_dir=remote_run_dir, client_id=client_id, secret=secret, + key_path=key_path, exclude_names=exclude_names, ) return @@ -736,6 +769,7 @@ def stage_run_directory( remote_run_dir=remote_run_dir, client_id=client_id, secret=secret, + key_path=key_path, exclude_names=exclude_names, ) return @@ -764,8 +798,9 @@ def list_remote_files( import requests logger = logging.getLogger(__name__) + key_path = _resolve_sfapi_key_path() client_id, secret = _resolve_sfapi_credentials() - if client_id and secret: + if key_path or (client_id and secret): try: from sfapi_client import Client from sfapi_client.compute import Machine @@ -775,7 +810,11 @@ def list_remote_files( else: if system == "perlmutter": try: - with Client(client_id=client_id, secret=secret) as client: + if key_path: + client = Client(key=Path(key_path)) + else: + client = Client(client_id=client_id, secret=secret) + with client: perlmutter = client.compute(Machine.perlmutter) entries = perlmutter.ls(remote_dir, directory=True) names: list[str] = [] @@ -849,6 +888,33 @@ def list_remote_entries( """ import requests + key_path = _resolve_sfapi_key_path() + client_id, secret = _resolve_sfapi_credentials() + if key_path or (client_id and secret): + try: + from sfapi_client import Client + from sfapi_client.compute import Machine + except Exception: + client_id = None + secret = None + else: + if system == "perlmutter": + try: + if key_path: + client = Client(key=Path(key_path)) + else: + client = Client(client_id=client_id, secret=secret) + with client: + perlmutter = client.compute(Machine.perlmutter) + entries = perlmutter.ls(remote_dir, directory=True) + return [ + {"name": getattr(entry, "name", None) or getattr(entry, "path", "")} + for entry in entries + if entry + ] + except Exception: + pass + if nersc_session is None: clients = find_nersc_clients() if clients: From 581b27278574d04c863507314f3e43d3e4319f61 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Fri, 13 Feb 2026 08:09:10 -0800 Subject: [PATCH 05/36] Fix sfapi_client file listing for executables --- src/services/run_superfacility_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/services/run_superfacility_tools.py b/src/services/run_superfacility_tools.py index 9aee5e6..b0e7d32 100644 --- a/src/services/run_superfacility_tools.py +++ b/src/services/run_superfacility_tools.py @@ -816,7 +816,7 @@ def list_remote_files( client = Client(client_id=client_id, secret=secret) with client: perlmutter = client.compute(Machine.perlmutter) - entries = perlmutter.ls(remote_dir, directory=True) + entries = perlmutter.ls(remote_dir, directory=False) names: list[str] = [] for entry in entries: name = getattr(entry, "name", None) or getattr(entry, "path", None) From 5f33c68d889b227d95033d2431d907d5aa93af10 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Fri, 13 Feb 2026 08:13:01 -0800 Subject: [PATCH 06/36] Prefer remote_output_dir on Perlmutter --- src/services/run_superfacility.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/services/run_superfacility.py b/src/services/run_superfacility.py index 26dc8db..6cf7e2a 100644 --- a/src/services/run_superfacility.py +++ b/src/services/run_superfacility.py @@ -271,9 +271,12 @@ def setup_job(self, logger.debug("\n=== Setting Up Job ===\n") - # Use config output_dir if not specified + # Use config output_dir if not specified (prefer remote_output_dir on Perlmutter) if output_dir is None: - output_dir = self.config.output_dir + if getattr(self.config, "environment", None) == "perlmutter": + output_dir = getattr(self.config, "remote_output_dir", None) or self.config.output_dir + else: + output_dir = self.config.output_dir output_dir_path = Path(output_dir) if output_dir else None if output_dir_path and output_dir_path.exists() and (output_dir_path / "inputs").exists(): From 91e64908522867e7014ec3bde200dd2cfcde278b Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Fri, 13 Feb 2026 08:17:28 -0800 Subject: [PATCH 07/36] Drop local output_dir from remote config --- demo/superfacility/config_perlmutter_remote.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/superfacility/config_perlmutter_remote.yaml b/demo/superfacility/config_perlmutter_remote.yaml index 2c2ef78..dba1389 100644 --- a/demo/superfacility/config_perlmutter_remote.yaml +++ b/demo/superfacility/config_perlmutter_remote.yaml @@ -4,7 +4,7 @@ environment: perlmutter # Local output directory (this host) -output_dir: /tmp/amrex_agent_runs +# output_dir: /tmp/amrex_agent_runs # Demo overrides (baseline + inputs) inputs_file_override: inputs From 88239bd0b727eec6dd91c3c4a220ef3dc9043daa Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 05:44:55 -0800 Subject: [PATCH 08/36] Add benchmark case grid and runner scaffold --- benchmark/cases/erf.yaml | 118 +++++++++++++ benchmark/cases/pelec.yaml | 118 +++++++++++++ benchmark/cases/pelelmex.yaml | 118 +++++++++++++ benchmark/cases/remora.yaml | 118 +++++++++++++ benchmark/specs/case_schema.yaml | 129 ++++++++++++++ scripts/run_benchmark.py | 291 +++++++++++++++++++++++++++++++ 6 files changed, 892 insertions(+) create mode 100644 benchmark/cases/erf.yaml create mode 100644 benchmark/cases/pelec.yaml create mode 100644 benchmark/cases/pelelmex.yaml create mode 100644 benchmark/cases/remora.yaml create mode 100644 benchmark/specs/case_schema.yaml create mode 100644 scripts/run_benchmark.py diff --git a/benchmark/cases/erf.yaml b/benchmark/cases/erf.yaml new file mode 100644 index 0000000..6f35ed5 --- /dev/null +++ b/benchmark/cases/erf.yaml @@ -0,0 +1,118 @@ +schema_version: "1.0" +suite_id: "erf_scaling_v1" +solver: "ERF" +cases: + - id: "erf_bubble_strong" + solver: "ERF" + case_name: "Bubble" + case_dir: "Exec/RegTests/Bubble" + inputs: "inputs" + description: "Thermal bubble rise for atmospheric dynamics scaling." + dimension: 3 + physics: [buoyancy, stratification, atmosphere] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "128x128x128" + amr_levels: 0 + - label: "M" + grid: "256x256x256" + amr_levels: 0 + - label: "L" + grid: "384x384x384" + amr_levels: 0 + tags: [regtest, atmosphere] + status: "ready" + + - id: "erf_density_current_strong" + solver: "ERF" + case_name: "DensityCurrent" + case_dir: "Exec/RegTests/DensityCurrent" + inputs: "inputs" + description: "Density current propagation for shear-driven flow scaling." + dimension: 3 + physics: [density_current, stratification, atmosphere] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "128x256x128" + amr_levels: 0 + - label: "M" + grid: "256x384x256" + amr_levels: 0 + - label: "L" + grid: "384x512x384" + amr_levels: 0 + tags: [regtest, atmosphere] + status: "ready" + + - id: "erf_abl_neutral_weak" + solver: "ERF" + case_name: "ABL-Neutral" + case_dir: "Exec/ABL" + inputs: "inputs" + description: "Neutral atmospheric boundary layer for weak scaling." + dimension: 3 + physics: [abl, turbulence, atmosphere] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "128x128x64" + amr_levels: 0 + - label: "M" + grid: "256x256x128" + amr_levels: 0 + - label: "L" + grid: "512x512x256" + amr_levels: 0 + tags: [abl, turbulence] + status: "ready" + + - id: "erf_abl_stable_weak" + solver: "ERF" + case_name: "ABL-Stable" + case_dir: "Exec/ABL" + inputs: "inputs_stable" + description: "Stable boundary layer for night-time stratification scaling." + dimension: 3 + physics: [abl, stable, atmosphere] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "128x128x64" + amr_levels: 0 + - label: "M" + grid: "256x256x128" + amr_levels: 0 + - label: "L" + grid: "512x512x256" + amr_levels: 0 + tags: [abl, stable] + status: "draft" + + - id: "erf_abl_convective_weak" + solver: "ERF" + case_name: "ABL-Convective" + case_dir: "Exec/ABL" + inputs: "inputs_convective" + description: "Convective boundary layer with surface heating for weak scaling." + dimension: 3 + physics: [abl, convection, atmosphere] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "128x128x64" + amr_levels: 0 + - label: "M" + grid: "256x256x128" + amr_levels: 0 + - label: "L" + grid: "512x512x256" + amr_levels: 0 + tags: [abl, convection] + status: "draft" diff --git a/benchmark/cases/pelec.yaml b/benchmark/cases/pelec.yaml new file mode 100644 index 0000000..ca61c44 --- /dev/null +++ b/benchmark/cases/pelec.yaml @@ -0,0 +1,118 @@ +schema_version: "1.0" +suite_id: "pelec_scaling_v1" +solver: "PeleC" +cases: + - id: "pelec_pmf_weak" + solver: "PeleC" + case_name: "PMF" + case_dir: "Exec/RegTests/PMF" + inputs: "pmf-lidryer-rk64.inp" + description: "Premixed methane flame; baseline reacting flow scaling case." + dimension: 3 + physics: [premixed, combustion, compressible] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "128x128x256" + amr_levels: 0 + - label: "M" + grid: "256x256x512" + amr_levels: 0 + - label: "L" + grid: "512x512x1024" + amr_levels: 0 + tags: [regtest, combustion, baseline] + status: "ready" + + - id: "pelec_sedov_strong" + solver: "PeleC" + case_name: "Sedov" + case_dir: "Exec/RegTests/Sedov" + inputs: "sedov-1.inp" + description: "Sedov blast wave for compressible hydro strong scaling." + dimension: 3 + physics: [blast, hydro, compressible] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "256x256x256" + amr_levels: 0 + - label: "M" + grid: "384x384x384" + amr_levels: 0 + - label: "L" + grid: "512x512x512" + amr_levels: 0 + tags: [regtest, hydro, shock] + status: "ready" + + - id: "pelec_tg_strong" + solver: "PeleC" + case_name: "TaylorGreen" + case_dir: "Exec/RegTests/TG" + inputs: "tg-1.inp" + description: "Taylor-Green vortex for turbulence scaling." + dimension: 3 + physics: [turbulence, vortex, compressible] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "256x256x256" + amr_levels: 0 + - label: "M" + grid: "384x384x384" + amr_levels: 0 + - label: "L" + grid: "512x512x512" + amr_levels: 0 + tags: [regtest, turbulence] + status: "ready" + + - id: "pelec_tgreact_weak" + solver: "PeleC" + case_name: "TGReact" + case_dir: "Exec/RegTests/TGReact" + inputs: "tgreact.inp" + description: "Reacting Taylor-Green vortex for weak scaling." + dimension: 3 + physics: [turbulence, combustion, compressible] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "128x128x128" + amr_levels: 0 + - label: "M" + grid: "256x256x256" + amr_levels: 0 + - label: "L" + grid: "512x512x512" + amr_levels: 0 + tags: [regtest, combustion, turbulence] + status: "ready" + + - id: "pelec_jetflame_weak" + solver: "PeleC" + case_name: "JetFlame" + case_dir: "Exec/Production/JetFlame" + inputs: "inputs" + description: "Turbulent jet flame production case for weak scaling." + dimension: 3 + physics: [jet, combustion, compressible] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "128x128x256" + amr_levels: 1 + - label: "M" + grid: "256x256x512" + amr_levels: 1 + - label: "L" + grid: "384x384x768" + amr_levels: 1 + tags: [production, combustion, jet] + status: "ready" diff --git a/benchmark/cases/pelelmex.yaml b/benchmark/cases/pelelmex.yaml new file mode 100644 index 0000000..d079624 --- /dev/null +++ b/benchmark/cases/pelelmex.yaml @@ -0,0 +1,118 @@ +schema_version: "1.0" +suite_id: "pelelmex_scaling_v1" +solver: "PeleLMeX" +cases: + - id: "pelelmex_flamesheet_strong" + solver: "PeleLMeX" + case_name: "FlameSheet" + case_dir: "Exec/RegTests/FlameSheet" + inputs: "inputs" + description: "Low-Mach diffusion flame sheet for strong scaling." + dimension: 3 + physics: [diffusion_flame, combustion, low_mach] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "256x256x128" + amr_levels: 0 + - label: "M" + grid: "384x384x192" + amr_levels: 0 + - label: "L" + grid: "512x512x256" + amr_levels: 0 + tags: [regtest, combustion, low_mach] + status: "ready" + + - id: "pelelmex_taylorgreen_strong" + solver: "PeleLMeX" + case_name: "TaylorGreen" + case_dir: "Exec/RegTests/TaylorGreen" + inputs: "inputs" + description: "Low-Mach Taylor-Green vortex for turbulence scaling." + dimension: 3 + physics: [turbulence, vortex, low_mach] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "256x256x256" + amr_levels: 0 + - label: "M" + grid: "384x384x384" + amr_levels: 0 + - label: "L" + grid: "512x512x512" + amr_levels: 0 + tags: [regtest, turbulence] + status: "ready" + + - id: "pelelmex_counterflow_weak" + solver: "PeleLMeX" + case_name: "CounterFlow" + case_dir: "Exec/Production/CounterFlow" + inputs: "inputs" + description: "Counterflow diffusion flame for weak scaling." + dimension: 3 + physics: [counterflow, combustion, low_mach] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "128x128x256" + amr_levels: 1 + - label: "M" + grid: "256x256x512" + amr_levels: 1 + - label: "L" + grid: "384x384x768" + amr_levels: 1 + tags: [production, combustion] + status: "ready" + + - id: "pelelmex_jet_in_crossflow_weak" + solver: "PeleLMeX" + case_name: "JetInCrossflow" + case_dir: "Exec/Production/JetInCrossflow" + inputs: "inputs" + description: "Jet in crossflow combustion for weak scaling." + dimension: 3 + physics: [jet, combustion, low_mach] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "128x256x128" + amr_levels: 1 + - label: "M" + grid: "256x384x256" + amr_levels: 1 + - label: "L" + grid: "384x512x384" + amr_levels: 1 + tags: [production, combustion, jet] + status: "ready" + + - id: "pelelmex_eb_c7_strong" + solver: "PeleLMeX" + case_name: "EB-C7" + case_dir: "Exec/RegTests/EB-C7" + inputs: "inputs" + description: "Embedded boundary low-Mach case for geometry-aware scaling." + dimension: 3 + physics: [embedded_boundary, combustion, low_mach] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "192x192x192" + amr_levels: 1 + - label: "M" + grid: "256x256x256" + amr_levels: 1 + - label: "L" + grid: "384x384x384" + amr_levels: 1 + tags: [regtest, eb, geometry] + status: "draft" diff --git a/benchmark/cases/remora.yaml b/benchmark/cases/remora.yaml new file mode 100644 index 0000000..a964284 --- /dev/null +++ b/benchmark/cases/remora.yaml @@ -0,0 +1,118 @@ +schema_version: "1.0" +suite_id: "remora_scaling_v1" +solver: "REMORA" +cases: + - id: "remora_seamount_strong" + solver: "REMORA" + case_name: "Seamount" + case_dir: "Exec/Seamount" + inputs: "inputs" + description: "Seamount circulation case for strong scaling." + dimension: 3 + physics: [ocean, topography, circulation] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "256x256x64" + amr_levels: 0 + - label: "M" + grid: "384x384x96" + amr_levels: 0 + - label: "L" + grid: "512x512x128" + amr_levels: 0 + tags: [ocean, topography] + status: "ready" + + - id: "remora_upwelling_strong" + solver: "REMORA" + case_name: "Upwelling" + case_dir: "Exec/Upwelling" + inputs: "inputs" + description: "Coastal upwelling case for strong scaling." + dimension: 3 + physics: [ocean, upwelling, coastal] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "256x256x64" + amr_levels: 0 + - label: "M" + grid: "384x384x96" + amr_levels: 0 + - label: "L" + grid: "512x512x128" + amr_levels: 0 + tags: [ocean, coastal] + status: "ready" + + - id: "remora_double_gyre_weak" + solver: "REMORA" + case_name: "DoubleGyre" + case_dir: "Exec/DoubleGyre" + inputs: "inputs" + description: "Wind-driven double gyre for weak scaling." + dimension: 3 + physics: [ocean, gyre, wind_forcing] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "256x256x64" + amr_levels: 0 + - label: "M" + grid: "384x384x96" + amr_levels: 0 + - label: "L" + grid: "512x512x128" + amr_levels: 0 + tags: [ocean, wind] + status: "ready" + + - id: "remora_channel_test_strong" + solver: "REMORA" + case_name: "Channel_Test" + case_dir: "Exec/Channel_Test" + inputs: "inputs" + description: "Idealized channel flow test for strong scaling." + dimension: 3 + physics: [ocean, channel, circulation] + scaling: + type: "strong" + sizes: + - label: "S" + grid: "256x128x64" + amr_levels: 0 + - label: "M" + grid: "384x192x96" + amr_levels: 0 + - label: "L" + grid: "512x256x128" + amr_levels: 0 + tags: [ocean, test] + status: "ready" + + - id: "remora_doubly_periodic_weak" + solver: "REMORA" + case_name: "DoublyPeriodic" + case_dir: "Exec/DoublyPeriodic" + inputs: "inputs" + description: "Doubly periodic ocean box for weak scaling." + dimension: 3 + physics: [ocean, periodic, idealized] + scaling: + type: "weak" + sizes: + - label: "S" + grid: "256x256x64" + amr_levels: 0 + - label: "M" + grid: "384x384x96" + amr_levels: 0 + - label: "L" + grid: "512x512x128" + amr_levels: 0 + tags: [ocean, idealized] + status: "ready" diff --git a/benchmark/specs/case_schema.yaml b/benchmark/specs/case_schema.yaml new file mode 100644 index 0000000..74b9fb5 --- /dev/null +++ b/benchmark/specs/case_schema.yaml @@ -0,0 +1,129 @@ +$schema: "https://json-schema.org/draft/2020-12/schema" +title: "Scaling Benchmark Case Suite" +type: object +additionalProperties: false +required: + - schema_version + - suite_id + - solver + - cases +properties: + schema_version: + type: string + suite_id: + type: string + solver: + type: string + cases: + type: array + minItems: 1 + items: + type: object + additionalProperties: false + required: + - id + - solver + - case_name + - case_dir + - inputs + - description + - dimension + - physics + - scaling + - tags + properties: + id: + type: string + minLength: 3 + solver: + type: string + enum: + - PeleC + - PeleLMeX + - ERF + - REMORA + case_name: + type: string + case_dir: + type: string + inputs: + type: string + description: + type: string + dimension: + type: integer + enum: [2, 3] + physics: + type: array + minItems: 1 + items: + type: string + scaling: + type: object + additionalProperties: false + required: + - type + - sizes + properties: + type: + type: string + enum: [strong, weak] + sizes: + type: array + minItems: 1 + items: + type: object + additionalProperties: false + required: + - label + - grid + - amr_levels + properties: + label: + type: string + grid: + type: string + amr_levels: + type: integer + minimum: 0 + notes: + type: string + resources: + type: object + additionalProperties: false + properties: + ranks: + type: integer + minimum: 1 + nodes: + type: integer + minimum: 1 + gpus: + type: integer + minimum: 0 + walltime: + type: string + tags: + type: array + minItems: 1 + items: + type: string + references: + type: array + items: + type: string + status: + type: string + enum: [draft, ready, deprecated] + run: + type: object + additionalProperties: false + properties: + command: + type: string + working_dir: + type: string + env: + type: object + additionalProperties: + type: string diff --git a/scripts/run_benchmark.py b/scripts/run_benchmark.py new file mode 100644 index 0000000..cc3584f --- /dev/null +++ b/scripts/run_benchmark.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +"""Benchmark runner scaffold for scaling case suites.""" + +from __future__ import annotations + +import argparse +import json +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import yaml +from jsonschema import Draft202012Validator + + +DEFAULT_SCHEMA = Path("benchmark/specs/case_schema.yaml") +DEFAULT_CASES_DIR = Path("benchmark/cases") +DEFAULT_RUNS_DIR = Path("benchmark/runs") + + +def _load_yaml(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as handle: + return yaml.safe_load(handle) or {} + + +def _load_schema(schema_path: Path) -> dict[str, Any]: + return _load_yaml(schema_path) + + +def _iter_case_files(cases_dir: Path) -> list[Path]: + return sorted(cases_dir.glob("*.yaml")) + + +def _validate_file(data: dict[str, Any], schema: dict[str, Any], path: Path) -> list[str]: + validator = Draft202012Validator(schema) + errors = [] + for error in sorted(validator.iter_errors(data), key=lambda err: list(err.path)): + location = "/".join(str(part) for part in error.path) or "" + errors.append(f"{path}: {location}: {error.message}") + solver = data.get("solver") + for case in data.get("cases", []): + case_solver = case.get("solver") + if solver and case_solver and solver != case_solver: + errors.append( + f"{path}: cases/{case.get('id', 'unknown')}: solver {case_solver} " + f"does not match suite solver {solver}" + ) + return errors + + +def _collect_cases(cases_dir: Path) -> list[dict[str, Any]]: + cases: list[dict[str, Any]] = [] + for path in _iter_case_files(cases_dir): + data = _load_yaml(path) + for case in data.get("cases", []): + case = dict(case) + case["_suite_id"] = data.get("suite_id") + case["_suite_solver"] = data.get("solver") + case["_source"] = str(path) + cases.append(case) + return cases + + +def _filter_cases( + cases: list[dict[str, Any]], solver: str | None, case_ids: list[str] +) -> list[dict[str, Any]]: + filtered = cases + if solver: + filtered = [case for case in filtered if case.get("solver") == solver] + if case_ids: + wanted = set(case_ids) + filtered = [case for case in filtered if case.get("id") in wanted] + return filtered + + +def _expand_case_runs(cases: list[dict[str, Any]]) -> list[dict[str, Any]]: + runs: list[dict[str, Any]] = [] + for case in cases: + for size in case["scaling"]["sizes"]: + runs.append( + { + "case_id": case["id"], + "solver": case["solver"], + "case_name": case["case_name"], + "size_label": size["label"], + "grid": size["grid"], + "amr_levels": size["amr_levels"], + "source": case["_source"], + } + ) + return runs + + +def _check_unique_ids(cases: list[dict[str, Any]]) -> list[str]: + seen: dict[str, str] = {} + errors = [] + for case in cases: + case_id = case.get("id") + if not case_id: + continue + source = case.get("_source", "") + if case_id in seen: + errors.append(f"Duplicate case id {case_id} in {source} (also {seen[case_id]})") + else: + seen[case_id] = source + return errors + + +def _print_cases(cases: list[dict[str, Any]]) -> None: + for case in cases: + print(f"{case['id']}: {case['solver']} - {case['case_name']}") + + +def _print_plan(runs: list[dict[str, Any]]) -> None: + for run in runs: + print( + f"{run['case_id']} [{run['size_label']}] " + f"grid={run['grid']} amr={run['amr_levels']}" + ) + + +def _init_run_state(cases: list[dict[str, Any]]) -> dict[str, Any]: + created_at = datetime.now(timezone.utc).isoformat() + state = { + "schema_version": 1, + "run_id": f"run_{created_at.replace(':', '').replace('-', '')}", + "created_at": created_at, + "cases": {}, + } + for case in cases: + sizes = {} + for size in case["scaling"]["sizes"]: + sizes[size["label"]] = { + "grid": size["grid"], + "amr_levels": size["amr_levels"], + "status": "pending", + } + state["cases"][case["id"]] = { + "solver": case["solver"], + "case_name": case["case_name"], + "source": case["_source"], + "status": "pending", + "sizes": sizes, + } + return state + + +def _write_state(path: Path, state: dict[str, Any]) -> None: + path.write_text(json.dumps(state, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def _load_state(path: Path) -> dict[str, Any]: + return json.loads(path.read_text(encoding="utf-8")) + + +def _format_command(case: dict[str, Any], size: dict[str, Any]) -> str | None: + run_info = case.get("run") or {} + command = run_info.get("command") + if not command: + return None + return command.format( + solver=case.get("solver"), + case_dir=case.get("case_dir"), + inputs=case.get("inputs"), + grid=size.get("grid"), + label=size.get("label"), + ) + + +def _execute_case(case: dict[str, Any], size: dict[str, Any], dry_run: bool) -> tuple[str, str]: + command = _format_command(case, size) + if not command: + return "skipped", "no command defined" + if dry_run: + return "planned", command + import subprocess # Imported lazily to keep startup fast. + + working_dir = (case.get("run") or {}).get("working_dir") or case.get("case_dir") + result = subprocess.run(command, shell=True, cwd=working_dir) + if result.returncode == 0: + return "complete", command + return "failed", command + + +def _run_cases( + cases: list[dict[str, Any]], + run_dir: Path, + resume: bool, + dry_run: bool, +) -> int: + run_dir.mkdir(parents=True, exist_ok=True) + state_path = run_dir / "run_state.json" + + if resume: + if not state_path.exists(): + print(f"Missing run state at {state_path}", file=sys.stderr) + return 2 + state = _load_state(state_path) + else: + state = _init_run_state(cases) + _write_state(state_path, state) + + for case in cases: + case_state = state["cases"].get(case["id"]) + if not case_state: + continue + for size in case["scaling"]["sizes"]: + size_state = case_state["sizes"][size["label"]] + if size_state["status"] == "complete": + continue + status, detail = _execute_case(case, size, dry_run) + size_state["status"] = status + size_state["detail"] = detail + if status == "failed": + case_state["status"] = "failed" + _write_state(state_path, state) + print(f"Failed: {case['id']} [{size['label']}]", file=sys.stderr) + return 1 + if all(entry["status"] == "complete" for entry in case_state["sizes"].values()): + case_state["status"] = "complete" + elif all(entry["status"] == "skipped" for entry in case_state["sizes"].values()): + case_state["status"] = "skipped" + else: + case_state["status"] = "in_progress" + _write_state(state_path, state) + + return 0 + + +def main() -> int: + parser = argparse.ArgumentParser(description="Scaling benchmark runner scaffold") + parser.add_argument("--cases-dir", type=Path, default=DEFAULT_CASES_DIR) + parser.add_argument("--schema", type=Path, default=DEFAULT_SCHEMA) + parser.add_argument("--solver", choices=["PeleC", "PeleLMeX", "ERF", "REMORA"]) + parser.add_argument("--case-id", action="append", default=[]) + + subparsers = parser.add_subparsers(dest="command", required=True) + + subparsers.add_parser("list", help="List benchmark cases") + subparsers.add_parser("validate", help="Validate benchmark case files") + subparsers.add_parser("plan", help="Print expanded run plan") + + run_parser = subparsers.add_parser("run", help="Execute benchmark cases (stub)") + run_parser.add_argument("--run-dir", type=Path, default=None) + run_parser.add_argument("--resume", action="store_true") + run_parser.add_argument("--execute", action="store_true") + + args = parser.parse_args() + + cases_dir = args.cases_dir + schema = _load_schema(args.schema) + + if args.command == "validate": + errors: list[str] = [] + for path in _iter_case_files(cases_dir): + data = _load_yaml(path) + errors.extend(_validate_file(data, schema, path)) + cases = _collect_cases(cases_dir) + errors.extend(_check_unique_ids(cases)) + if errors: + for error in errors: + print(error, file=sys.stderr) + return 1 + print("All benchmark case files are valid.") + return 0 + + cases = _collect_cases(cases_dir) + cases = _filter_cases(cases, args.solver, args.case_id) + + if args.command == "list": + _print_cases(cases) + return 0 + + if args.command == "plan": + runs = _expand_case_runs(cases) + _print_plan(runs) + return 0 + + if args.command == "run": + run_dir = args.run_dir + if run_dir is None: + run_dir = DEFAULT_RUNS_DIR / datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + dry_run = not args.execute + return _run_cases(cases, run_dir, args.resume, dry_run) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 8318355ec0313ec890c654dca45fd3364a885649 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 05:41:52 -0800 Subject: [PATCH 09/36] Add metrics collection and JSONL summaries --- src/config.py | 61 ++++++ src/main.py | 36 ++++ src/models/graph_state_canonical.py | 1 + src/nodes/analysis_node.py | 45 +++-- src/nodes/architect_node.py | 26 ++- src/nodes/input_writer_node.py | 26 ++- src/nodes/reviewer_node.py | 61 ++++-- src/services/inputs_file_selector.py | 51 +++++ src/services/knowledge.py | 54 ++++++ src/utils/metrics.py | 277 +++++++++++++++++++++++++++ tests/unit/test_metrics_collector.py | 29 +++ 11 files changed, 621 insertions(+), 46 deletions(-) create mode 100644 src/utils/metrics.py create mode 100644 tests/unit/test_metrics_collector.py diff --git a/src/config.py b/src/config.py index a9dd50f..1a4d9fa 100644 --- a/src/config.py +++ b/src/config.py @@ -648,10 +648,22 @@ class AMReXAgentConfig(BaseModel): default=Path("./output"), description="Base output directory for simulations" ) + metrics_output_dir: Optional[Path] = Field( + default=None, + description="Directory for metrics.jsonl when run_directory is unavailable (defaults to output_dir)." + ) + metrics_filename: str = Field( + default="metrics.jsonl", + description="Metrics JSONL filename for workflow summaries." + ) save_intermediate: bool = Field( default=True, description="Save intermediate results (plans, configs, etc.)" ) + metrics_enabled: bool = Field( + default=True, + description="Enable metrics collection and JSONL output" + ) # === Validator Configuration === disabled_validators: List[str] = Field( @@ -947,6 +959,7 @@ def get_llm_client(config: AMReXAgentConfig): def _wrap_llm_client_if_needed(client, config: AMReXAgentConfig): client = _wrap_llm_client_with_retry(client, config) + client = _wrap_llm_client_with_metrics(client, config) strategy = getattr(config, "llm_gate_strategy", "off") or "off" if strategy == "off": return client @@ -967,6 +980,14 @@ def _wrap_llm_client_with_retry(client, config: AMReXAgentConfig): return _LLMRetryClient(client, max_attempts) +def _wrap_llm_client_with_metrics(client, config: AMReXAgentConfig): + if getattr(config, "metrics_enabled", True) is False: + return client + if isinstance(client, _LLMMetricsClient): + return client + return _LLMMetricsClient(client, config) + + def wrap_llm_client(client, config: AMReXAgentConfig): """Public helper to apply LLM gating to an existing client instance.""" return _wrap_llm_client_if_needed(client, config) @@ -996,6 +1017,46 @@ def __getattr__(self, name: str): return getattr(self._client, name) +class _LLMMetricsClient: + def __init__(self, client, config: AMReXAgentConfig) -> None: + self._client = client + self._config = config + self.chat = _LLMMetricsChat(client.chat, config) + + def __getattr__(self, name: str): + return getattr(self._client, name) + + +class _LLMMetricsChat: + def __init__(self, chat_resource, config: AMReXAgentConfig) -> None: + self._chat = chat_resource + self._config = config + self.completions = _LLMMetricsCompletions(chat_resource.completions, config) + + def __getattr__(self, name: str): + return getattr(self._chat, name) + + +class _LLMMetricsCompletions: + def __init__(self, completions_resource, config: AMReXAgentConfig) -> None: + self._completions = completions_resource + self._config = config + + def create(self, *args: Any, **kwargs: Any) -> Any: + response = self._completions.create(*args, **kwargs) + try: + from src.utils.metrics import metrics_collector + + metrics_collector.record_llm_usage( + response, + model=kwargs.get("model"), + provider=getattr(self._config, "llm_provider", None), + ) + except Exception: + pass + return response + + class _LLMRetryChat: def __init__(self, chat_resource, max_attempts: int) -> None: self._chat = chat_resource diff --git a/src/main.py b/src/main.py index cf220b5..792b2d8 100644 --- a/src/main.py +++ b/src/main.py @@ -688,6 +688,42 @@ def main(args: list[str] | None = None) -> None: logger.debug("Starting AMReXAgent workflow...") result = run_agent(user_requirement, config) + # Save metrics JSONL (if enabled) + try: + from src.utils.metrics import metrics_collector + + if getattr(config, "metrics_enabled", True) and metrics_collector.events(): + summary = metrics_collector.build_workflow_summary() + summary.update({ + "job_status": result.get("job_status", "unknown"), + "iteration": result.get("iteration", 0), + "run_directory": result.get("run_directory"), + }) + metrics_collector.record_event( + "workflow_summary", + summary, + stage="workflow", + node="main", + iteration=result.get("iteration", 0), + ) + if 'run_directory' in result: + run_dir = Path(result['run_directory']) + metrics_path = run_dir / getattr(config, "metrics_filename", "metrics.jsonl") + else: + base_dir = ( + Path(parsed_args.output_dir) + if parsed_args.output_dir + else (config.metrics_output_dir or config.output_dir) + ) + base_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"metrics_{timestamp}.jsonl" + metrics_path = base_dir / filename + metrics_collector.write_jsonl(str(metrics_path)) + logger.info(f"Metrics saved to {metrics_path}") + except Exception as e: + logger.warning(f"Failed to save metrics JSONL: {e}") + # Save workflow_history if requested if parsed_args.save_workflow: try: diff --git a/src/models/graph_state_canonical.py b/src/models/graph_state_canonical.py index 64ebade..fd0040f 100644 --- a/src/models/graph_state_canonical.py +++ b/src/models/graph_state_canonical.py @@ -209,6 +209,7 @@ class GraphState(TypedDict, total=False): # ======================================== # METADATA # ======================================== + metrics: Optional[Dict[str, Any]] # Aggregated metrics (tokens, retrieval, validation) phase: str # Current workflow phase # Values: "planning", "execution", "analysis", "complete" diff --git a/src/nodes/analysis_node.py b/src/nodes/analysis_node.py index 6fa5246..c422d07 100644 --- a/src/nodes/analysis_node.py +++ b/src/nodes/analysis_node.py @@ -286,6 +286,8 @@ def analysis_node(state: GraphState) -> dict[str, Any]: and getattr(config, "retry_guidance_use_llm", False) ): try: + from src.utils.metrics import metrics_context + from database.configs.base_amrex_config import BaseAMReXConfig from src.config import get_llm_client @@ -311,16 +313,17 @@ class RetryGuidance(BaseModel): baseline_base_action: str = Field(description="keep or switch") rationale: str | None = None - base_client = unwrap_llm_client(llm_client) - instr_client = instructor.from_openai(base_client) - instr_client = wrap_llm_client(instr_client, config) - parsed = instr_client.chat.completions.create( - model=config.llm_model, - response_model=RetryGuidance, - messages=[{"role": "user", "content": filled}], - temperature=0.0, - max_retries=2, - ) + with metrics_context("analysis", node="analysis", iteration=iteration): + base_client = unwrap_llm_client(llm_client) + instr_client = instructor.from_openai(base_client) + instr_client = wrap_llm_client(instr_client, config) + parsed = instr_client.chat.completions.create( + model=config.llm_model, + response_model=RetryGuidance, + messages=[{"role": "user", "content": filled}], + temperature=0.0, + max_retries=2, + ) retry_guidance.update({ "inputs_base_action": parsed.inputs_base_action or "keep", "baseline_base_action": parsed.baseline_base_action or "keep", @@ -328,12 +331,13 @@ class RetryGuidance(BaseModel): "baseline_reason": parsed.rationale, }) except (ImportError, ModuleNotFoundError): - response = llm_client.chat.completions.create( - model=config.llm_model, - messages=[{"role": "user", "content": filled}], - temperature=0.0, - max_tokens=200, - ) + with metrics_context("analysis", node="analysis", iteration=iteration): + response = llm_client.chat.completions.create( + model=config.llm_model, + messages=[{"role": "user", "content": filled}], + temperature=0.0, + max_tokens=200, + ) content = response.choices[0].message.content.strip() import json parsed = json.loads(content) @@ -347,6 +351,13 @@ class RetryGuidance(BaseModel): except Exception as exc: logger.debug(f"Retry guidance LLM unavailable: {exc}") + try: + from src.utils.metrics import metrics_collector + + metrics_summary = metrics_collector.summarize_stage("analysis", iteration=iteration) + except Exception: + metrics_summary = {} + history_entry = { "node": "analysis", "timestamp": datetime.utcnow().isoformat() + "Z", @@ -365,6 +376,8 @@ class RetryGuidance(BaseModel): "retry_guidance": retry_guidance, } } + if metrics_summary: + history_entry["details"]["metrics"] = metrics_summary new_history = workflow_history + [history_entry] diff --git a/src/nodes/architect_node.py b/src/nodes/architect_node.py index 88706cc..d7d3ab0 100644 --- a/src/nodes/architect_node.py +++ b/src/nodes/architect_node.py @@ -300,16 +300,19 @@ def architect_node(state: GraphState) -> dict[str, Any]: logger.debug(f"Total exclusions: {len(excluded_cases)} cases, {len(excluded_inputs_files)} inputs files") try: + from src.utils.metrics import metrics_context + # Execute planning using strategy dispatcher # This calls create_plan_rag() or create_plan() based on config.indexing_strategy # and normalizes the output to canonical format - plan_result = service.execute_planning( - user_prompt=prompt, - prefer_quality="excellent", - excluded_cases=excluded_cases, - excluded_inputs_files=excluded_inputs_files, - parameter_resolution_feedback=parameter_resolution_feedback, # NEW: pass to service - ) + with metrics_context("architect", node="architect", iteration=new_iteration): + plan_result = service.execute_planning( + user_prompt=prompt, + prefer_quality="excellent", + excluded_cases=excluded_cases, + excluded_inputs_files=excluded_inputs_files, + parameter_resolution_feedback=parameter_resolution_feedback, # NEW: pass to service + ) logger.debug(f"[DATA TRANSFER] Called architect service with feedback={parameter_resolution_feedback is not None}") logger.info(f"Plan created: {plan_result.selected_case}") @@ -406,6 +409,13 @@ def architect_node(state: GraphState) -> dict[str, Any]: else: action = "plan_created" + try: + from src.utils.metrics import metrics_collector + + metrics_summary = metrics_collector.summarize_stage("architect", iteration=new_iteration) + except Exception: + metrics_summary = {} + # Create structured history entry (Fix #7 - canonical format with complete computation output) # CRITICAL: Store FULL computation output here, not snippets or counts history_entry = { @@ -433,6 +443,8 @@ def architect_node(state: GraphState) -> dict[str, Any]: ), } } + if metrics_summary: + history_entry["details"]["metrics"] = metrics_summary # Append to history (immutable - create new list) new_history = workflow_history + [history_entry] diff --git a/src/nodes/input_writer_node.py b/src/nodes/input_writer_node.py index 976ba58..159926a 100644 --- a/src/nodes/input_writer_node.py +++ b/src/nodes/input_writer_node.py @@ -213,15 +213,18 @@ def _should_prefer_strategy_on_retry(state: GraphState) -> bool: } try: + from src.utils.metrics import metrics_context + # 2. Call service with individual parameters (per contract) # apply_plan() signature: selected_case, modifications, baseline, reasoning, output_dir - result = service.apply_plan( - selected_case=selected_case, - modifications=modifications, - baseline=baseline, - reasoning=reasoning, - output_dir=str(run_dir) # Service physically creates this - ) + with metrics_context("input_writer", node="input_writer", iteration=state.get("iteration", 0)): + result = service.apply_plan( + selected_case=selected_case, + modifications=modifications, + baseline=baseline, + reasoning=reasoning, + output_dir=str(run_dir) # Service physically creates this + ) logger.info(f"Files written to: {result.get('run_dir', 'unknown')}") @@ -344,6 +347,13 @@ def _should_prefer_strategy_on_retry(state: GraphState) -> bool: # Create history entry (per contract line 62-89) # Store full computation output in details (canonical path) + try: + from src.utils.metrics import metrics_collector + + metrics_summary = metrics_collector.summarize_stage("input_writer", iteration=state.get("iteration", 0)) + except Exception: + metrics_summary = {} + history_entry = { "node": "input_writer", "timestamp": datetime.utcnow().isoformat() + "Z", @@ -368,6 +378,8 @@ def _should_prefer_strategy_on_retry(state: GraphState) -> bool: "inputs_candidates": result.get("inputs_candidates", []), } } + if metrics_summary: + history_entry["details"]["metrics"] = metrics_summary new_history = workflow_history + [history_entry] diff --git a/src/nodes/reviewer_node.py b/src/nodes/reviewer_node.py index 3e9bdab..c59383b 100644 --- a/src/nodes/reviewer_node.py +++ b/src/nodes/reviewer_node.py @@ -624,6 +624,8 @@ def _derive_retry_guidance(violations, rejected_inputs_file, rejected_baseline_c and getattr(config, "retry_guidance_use_llm", False) ): try: + from src.utils.metrics import metrics_context + from database.configs import get_config_for_path from src.config import get_llm_client @@ -650,16 +652,17 @@ class RetryGuidance(BaseModel): baseline_base_action: str = Field(description="keep or switch") rationale: str | None = None - base_client = unwrap_llm_client(llm_client) - instr_client = instructor.from_openai(base_client) - instr_client = wrap_llm_client(instr_client, config) - parsed = instr_client.chat.completions.create( - model=config.llm_model, - response_model=RetryGuidance, - messages=[{"role": "user", "content": filled}], - temperature=0.0, - max_retries=2, - ) + with metrics_context("reviewer", node="reviewer", iteration=iteration): + base_client = unwrap_llm_client(llm_client) + instr_client = instructor.from_openai(base_client) + instr_client = wrap_llm_client(instr_client, config) + parsed = instr_client.chat.completions.create( + model=config.llm_model, + response_model=RetryGuidance, + messages=[{"role": "user", "content": filled}], + temperature=0.0, + max_retries=2, + ) retry_guidance.update({ "inputs_base_action": parsed.inputs_base_action or "keep", "baseline_base_action": parsed.baseline_base_action or "keep", @@ -667,12 +670,13 @@ class RetryGuidance(BaseModel): "baseline_reason": parsed.rationale, }) except (ImportError, ModuleNotFoundError): - response = llm_client.chat.completions.create( - model=config.llm_model, - messages=[{"role": "user", "content": filled}], - temperature=0.0, - max_tokens=200, - ) + with metrics_context("reviewer", node="reviewer", iteration=iteration): + response = llm_client.chat.completions.create( + model=config.llm_model, + messages=[{"role": "user", "content": filled}], + temperature=0.0, + max_tokens=200, + ) content = response.choices[0].message.content.strip() import json parsed = json.loads(content) @@ -696,6 +700,29 @@ class RetryGuidance(BaseModel): if validation_result.available_schema_params: logger.debug(f"[REVIEWER NODE] Sample schema params: {validation_result.available_schema_params[:10]}") + try: + from src.utils.metrics import metrics_collector + + validation_metrics = { + "error_count": len(errors_current), + "warning_count": len(warnings), + "violation_count": len(validation_result.violations), + "unknown_param_count": unknown_param_count, + "persistent_unknown_count": persistent_unknown_count, + "schema_missing": has_schema_missing, + "solver_unknown": has_solver_unknown, + } + metrics_collector.record_event( + "validation_metrics", + validation_metrics, + stage="reviewer", + node="reviewer", + iteration=iteration, + ) + metrics_summary = metrics_collector.summarize_stage("reviewer", iteration=iteration) + except Exception: + metrics_summary = {} + history_entry = { "node": "reviewer", "timestamp": datetime.utcnow().isoformat() + "Z", @@ -723,6 +750,8 @@ class RetryGuidance(BaseModel): "plan_rejected_inputs_file": rejected_inputs, } } + if metrics_summary: + history_entry["details"]["metrics"] = metrics_summary new_history = workflow_history + [history_entry] diff --git a/src/services/inputs_file_selector.py b/src/services/inputs_file_selector.py index 8fb4da9..cd8a935 100644 --- a/src/services/inputs_file_selector.py +++ b/src/services/inputs_file_selector.py @@ -160,6 +160,8 @@ def select_best_inputs_file( Selected inputs file path, or None if not found. """ excluded_files = excluded_files or [] + original_strategy = strategy + fallback_reason = None # Find candidate files (use provided or discover) if available_files: @@ -185,16 +187,30 @@ def select_best_inputs_file( selected = cls._select_override(case_dir, config) if selected: logger.info(f"Selected inputs file: {selected.name} (strategy: override)") + _record_inputs_selection( + strategy="override", + original_strategy=original_strategy, + selected=selected, + candidates=candidates, + ) return selected # Fallback to smallest if override didn't resolve + fallback_reason = "override_unresolved" strategy = "smallest" if strategy == "llm_compare": selected = cls._select_with_llm(case_dir, candidates, config=config) if selected: logger.info(f"Selected inputs file: {selected.name} (strategy: llm_compare)") + _record_inputs_selection( + strategy="llm_compare", + original_strategy=original_strategy, + selected=selected, + candidates=candidates, + ) return selected # LLM unavailable or failed → fallback to smallest + fallback_reason = "llm_unavailable" strategy = "smallest" # Score by strategy @@ -209,6 +225,14 @@ def select_best_inputs_file( selected = scored[0][1] logger.info(f"Selected inputs file: {selected.name} (strategy: {strategy}, score: {scored[0][0]:.2f})") + _record_inputs_selection( + strategy=strategy, + original_strategy=original_strategy, + selected=selected, + candidates=candidates, + fallback_reason=fallback_reason, + score=scored[0][0], + ) return selected @@ -313,3 +337,30 @@ def _select_override(case_dir: Path, config: Any | None) -> Path | None: if default_path.exists(): return default_path return None + + +def _record_inputs_selection( + *, + strategy: str, + original_strategy: str, + selected: Path, + candidates: list[Path], + fallback_reason: str | None = None, + score: float | None = None, +) -> None: + try: + from src.utils.metrics import metrics_collector + + metrics_collector.record_event( + "retrieval_strategy", + { + "strategy": strategy, + "original_strategy": original_strategy, + "fallback_reason": fallback_reason, + "selected": selected.name, + "candidate_count": len(candidates), + "score": score, + }, + ) + except Exception: + return diff --git a/src/services/knowledge.py b/src/services/knowledge.py index 06c9427..47028d4 100644 --- a/src/services/knowledge.py +++ b/src/services/knowledge.py @@ -104,12 +104,22 @@ def query(self, question: str, context: dict | None = None) -> dict: # If FAISS has high confidence, return it directly if faiss_result and faiss_result.get('confidence', 0) > 0.8: logger.debug(f" Using FAISS result (confidence: {faiss_result['confidence']:.2f})") + _record_retrieval_metrics( + strategy="faiss", + confidence=faiss_result.get("confidence", 0.0), + source=faiss_result.get("source"), + ) return faiss_result tools, solver_config = self._get_tools_for_context(context) if not tools or not tools.get("ask"): solver_name = getattr(solver_config, "code_name", "unknown") if faiss_result: + _record_retrieval_metrics( + strategy="faiss", + confidence=faiss_result.get("confidence", 0.0), + source=faiss_result.get("source"), + ) return faiss_result return { "answer": f"Knowledge tools not available for solver {solver_name}", @@ -122,6 +132,11 @@ def query(self, question: str, context: dict | None = None) -> dict: if not self.knowledge_loaded: if faiss_result: # Return FAISS result even if low confidence + _record_retrieval_metrics( + strategy="faiss", + confidence=faiss_result.get("confidence", 0.0), + source=faiss_result.get("source"), + ) return faiss_result return { "answer": "Knowledge base not loaded", @@ -156,8 +171,19 @@ def query(self, question: str, context: dict | None = None) -> dict: # Combine FAISS and LLM results if both available if faiss_result: + _record_retrieval_metrics( + strategy="hybrid", + confidence=llm_result.get("confidence", 0.0), + source=llm_result.get("method"), + faiss_confidence=faiss_result.get("confidence", 0.0), + ) return self._combine_results(faiss_result, llm_result) + _record_retrieval_metrics( + strategy="llm", + confidence=llm_result.get("confidence", 0.0), + source=llm_result.get("method"), + ) return llm_result except Exception as e: @@ -166,6 +192,11 @@ def query(self, question: str, context: dict | None = None) -> dict: # Return FAISS result if available as fallback if faiss_result: logger.debug(" Using FAISS result as fallback after LLM failure") + _record_retrieval_metrics( + strategy="faiss_fallback", + confidence=faiss_result.get("confidence", 0.0), + source=faiss_result.get("source"), + ) return faiss_result return { @@ -545,6 +576,29 @@ def build_cborg_embeddings_rag(self, documents: list[str]) -> None: pass +def _record_retrieval_metrics( + *, + strategy: str, + confidence: float | None = None, + source: str | None = None, + faiss_confidence: float | None = None, +) -> None: + try: + from src.utils.metrics import metrics_collector + + metrics_collector.record_event( + "retrieval_strategy", + { + "strategy": strategy, + "confidence": confidence, + "faiss_confidence": faiss_confidence, + "source": source, + }, + ) + except Exception: + return + + # Test the service if __name__ == "__main__": import sys diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 0000000..ab82999 --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,277 @@ +"""Metrics collection utilities for AMReXAgent.""" + +from __future__ import annotations + +import json +import logging +from contextlib import contextmanager +from contextvars import ContextVar +from datetime import datetime +from typing import Any + +logger = logging.getLogger(__name__) + +_stage_var: ContextVar[str | None] = ContextVar("metrics_stage", default=None) +_node_var: ContextVar[str | None] = ContextVar("metrics_node", default=None) +_iteration_var: ContextVar[int | None] = ContextVar("metrics_iteration", default=None) +_extra_var: ContextVar[dict[str, Any] | None] = ContextVar("metrics_extra", default=None) + + +@contextmanager +def metrics_context( + stage: str, + *, + node: str | None = None, + iteration: int | None = None, + extra: dict[str, Any] | None = None, +): + """Set metrics context for downstream instrumentation.""" + tokens = [] + tokens.append(_stage_var.set(stage)) + if node is not None: + tokens.append(_node_var.set(node)) + if iteration is not None: + tokens.append(_iteration_var.set(iteration)) + if extra is not None: + tokens.append(_extra_var.set(extra)) + try: + yield + finally: + for token in reversed(tokens): + token.var.reset(token) + + +class MetricsCollector: + """Collects instrumentation events for aggregation and export.""" + + def __init__(self) -> None: + self._events: list[dict[str, Any]] = [] + + def record_event( + self, + event_type: str, + data: dict[str, Any], + *, + stage: str | None = None, + node: str | None = None, + iteration: int | None = None, + ) -> dict[str, Any]: + event = { + "timestamp": datetime.utcnow().isoformat() + "Z", + "type": event_type, + "stage": stage or _stage_var.get() or "unknown", + "node": node or _node_var.get() or stage or _stage_var.get() or "unknown", + "iteration": iteration if iteration is not None else _iteration_var.get(), + "data": data, + } + extra = _extra_var.get() + if extra: + event["context"] = dict(extra) + self._events.append(event) + return event + + def record_llm_usage( + self, + response: Any, + *, + model: str | None = None, + provider: str | None = None, + ) -> dict[str, Any] | None: + usage = _extract_usage(response) + if not usage: + return None + payload = { + "model": model or _extract_model(response), + "provider": provider, + **usage, + } + return self.record_event("llm_usage", payload) + + def summarize_stage(self, stage: str, iteration: int | None = None) -> dict[str, Any]: + events = [ + e + for e in self._events + if e.get("stage") == stage and (iteration is None or e.get("iteration") == iteration) + ] + if not events: + return {} + + summary: dict[str, Any] = {} + llm_summary = _aggregate_llm_usage(events) + if llm_summary: + summary["llm"] = llm_summary + retrieval_summary = _aggregate_retrieval(events) + if retrieval_summary: + summary["retrieval"] = retrieval_summary + validation_summary = _aggregate_validation(events) + if validation_summary: + summary["validation"] = validation_summary + return summary + + def build_workflow_summary(self, stages: list[str] | None = None) -> dict[str, Any]: + events = list(self._events) + if not events: + return {} + + if stages is None: + stages = sorted({e.get("stage") for e in events if e.get("stage")}) + + stage_summaries: dict[str, Any] = {} + tokens_by_stage: dict[str, dict[str, int]] = {} + for stage in stages: + summary = self.summarize_stage(stage) + if summary: + stage_summaries[stage] = summary + llm = summary.get("llm", {}) + if llm: + tokens_by_stage[stage] = { + "input": llm.get("prompt_tokens", 0), + "output": llm.get("completion_tokens", 0), + "total": llm.get("total_tokens", 0), + } + + llm_all = _aggregate_llm_usage(events) + tokens_total_input = llm_all.get("prompt_tokens", 0) + tokens_total_output = llm_all.get("completion_tokens", 0) + tokens_total = llm_all.get("total_tokens", 0) + + models, providers = _aggregate_models(events) + + summary = { + "tokens_total_input": tokens_total_input, + "tokens_total_output": tokens_total_output, + "tokens_total": tokens_total, + "tokens_by_stage": tokens_by_stage, + "models": models, + "providers": providers, + } + if stage_summaries: + summary["stages"] = stage_summaries + return summary + + def events(self) -> list[dict[str, Any]]: + return list(self._events) + + def write_jsonl(self, path: str) -> None: + if not self._events: + return + try: + with open(path, "w", encoding="utf-8") as handle: + for event in self._events: + handle.write(json.dumps(event, default=str)) + handle.write("\n") + except Exception as exc: + logger.warning("Failed to write metrics JSONL to %s: %s", path, exc) + + def reset(self) -> None: + self._events = [] + + +def _extract_usage(response: Any) -> dict[str, Any]: + usage = getattr(response, "usage", None) + if usage is None: + usage = getattr(getattr(response, "_raw_response", None), "usage", None) + if usage is None: + usage = getattr(getattr(response, "response", None), "usage", None) + if usage is None: + return {} + if isinstance(usage, dict): + prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens") + completion_tokens = usage.get("completion_tokens") or usage.get("output_tokens") + total_tokens = usage.get("total_tokens") + else: + prompt_tokens = getattr(usage, "prompt_tokens", None) or getattr(usage, "input_tokens", None) + completion_tokens = getattr(usage, "completion_tokens", None) or getattr(usage, "output_tokens", None) + total_tokens = getattr(usage, "total_tokens", None) + + if total_tokens is None and (prompt_tokens is not None or completion_tokens is not None): + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + + if prompt_tokens is None and completion_tokens is None and total_tokens is None: + return {} + + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + +def _extract_model(response: Any) -> str | None: + model = getattr(response, "model", None) + if model: + return model + raw = getattr(response, "_raw_response", None) + return getattr(raw, "model", None) if raw else None + + +def _aggregate_llm_usage(events: list[dict[str, Any]]) -> dict[str, Any]: + llm_events = [e for e in events if e.get("type") == "llm_usage"] + if not llm_events: + return {} + summary = { + "total_calls": len(llm_events), + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "by_model": {}, + } + for event in llm_events: + data = event.get("data", {}) + model = data.get("model") or "unknown" + prompt = data.get("prompt_tokens") or 0 + completion = data.get("completion_tokens") or 0 + total = data.get("total_tokens") + if total is None: + total = prompt + completion + summary["prompt_tokens"] += prompt + summary["completion_tokens"] += completion + summary["total_tokens"] += total + per_model = summary["by_model"].setdefault( + model, + {"calls": 0, "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + per_model["calls"] += 1 + per_model["prompt_tokens"] += prompt + per_model["completion_tokens"] += completion + per_model["total_tokens"] += total + return summary + + +def _aggregate_retrieval(events: list[dict[str, Any]]) -> dict[str, Any]: + retrieval_events = [e for e in events if e.get("type") == "retrieval_strategy"] + if not retrieval_events: + return {} + summary: dict[str, Any] = {"strategies": {}, "last": None} + for event in retrieval_events: + data = event.get("data", {}) + strategy = data.get("strategy") or "unknown" + summary["strategies"][strategy] = summary["strategies"].get(strategy, 0) + 1 + summary["last"] = data + return summary + + +def _aggregate_validation(events: list[dict[str, Any]]) -> dict[str, Any]: + validation_events = [e for e in events if e.get("type") == "validation_metrics"] + if not validation_events: + return {} + return validation_events[-1].get("data", {}) + + +def _aggregate_models(events: list[dict[str, Any]]) -> tuple[list[str], list[str]]: + models: list[str] = [] + providers: list[str] = [] + for event in events: + if event.get("type") != "llm_usage": + continue + data = event.get("data", {}) + model = data.get("model") + provider = data.get("provider") + if model and model not in models: + models.append(model) + if provider and provider not in providers: + providers.append(provider) + return models, providers + + +metrics_collector = MetricsCollector() diff --git a/tests/unit/test_metrics_collector.py b/tests/unit/test_metrics_collector.py new file mode 100644 index 0000000..f96d14f --- /dev/null +++ b/tests/unit/test_metrics_collector.py @@ -0,0 +1,29 @@ +from src.utils.metrics import metrics_collector, metrics_context + + +class _FakeUsage: + def __init__(self, prompt_tokens: int, completion_tokens: int) -> None: + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.total_tokens = prompt_tokens + completion_tokens + + +class _FakeResponse: + def __init__(self, model: str, usage: _FakeUsage) -> None: + self.model = model + self.usage = usage + + +def test_metrics_collector_aggregates_llm_tokens() -> None: + metrics_collector.reset() + response = _FakeResponse("test-model", _FakeUsage(12, 8)) + + with metrics_context("architect", node="architect", iteration=1): + metrics_collector.record_llm_usage(response, model="test-model", provider="test") + + summary = metrics_collector.summarize_stage("architect", iteration=1) + assert summary["llm"]["total_calls"] == 1 + assert summary["llm"]["total_tokens"] == 20 + assert summary["llm"]["prompt_tokens"] == 12 + assert summary["llm"]["completion_tokens"] == 8 + assert summary["llm"]["by_model"]["test-model"]["calls"] == 1 From 9aceb6e422319231136b39dd153c6e3c0d08af8e Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 05:40:41 -0800 Subject: [PATCH 10/36] Add benchmark runner outputs and compare_models test --- scripts/benchmark_example.json | 33 ++++++++ scripts/compare_models.py | 125 ++++++++++++++++++++++++++++++ tests/unit/test_compare_models.py | 80 +++++++++++++++++++ 3 files changed, 238 insertions(+) create mode 100644 scripts/benchmark_example.json create mode 100644 scripts/compare_models.py create mode 100644 tests/unit/test_compare_models.py diff --git a/scripts/benchmark_example.json b/scripts/benchmark_example.json new file mode 100644 index 0000000..dc0b0ed --- /dev/null +++ b/scripts/benchmark_example.json @@ -0,0 +1,33 @@ +{ + "prompts": [ + { + "id": "demo_advection", + "prompt": "Run AMReX Advection_AmrCore with a 128x128 grid and 3 AMR levels." + }, + { + "id": "demo_premixed_flame", + "prompt": "2D hydrogen premixed flame with 32x128 grid cells, 2 AMR levels, 200 timesteps." + } + ], + "models": [ + { + "id": "cborg_default", + "overrides": { + "llm_provider": "cborg", + "llm_model": "lbl/Llama-4-Scout-17B-16E-Instruct" + } + }, + { + "id": "pnnl_claude", + "overrides": { + "llm_provider": "pnnl", + "llm_model": "claude-haiku-4-5-20251001-v1-birthright" + } + } + ], + "run_args": { + "run_mode": "dry", + "save_workflow": true, + "save_log": true + } +} diff --git a/scripts/compare_models.py b/scripts/compare_models.py new file mode 100644 index 0000000..1ce79b0 --- /dev/null +++ b/scripts/compare_models.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +"""Aggregate benchmark metrics into comparison tables.""" + +from __future__ import annotations + +import argparse +import csv +import json +from pathlib import Path +from typing import Any + + +def _read_jsonl(path: Path) -> list[dict[str, Any]]: + records: list[dict[str, Any]] = [] + for line in path.read_text().splitlines(): + if not line.strip(): + continue + records.append(json.loads(line)) + return records + + +def _mean(values: list[float]) -> float | None: + if not values: + return None + return sum(values) / len(values) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate model comparison tables.") + parser.add_argument("--run-dir", help="Benchmark run directory (contains benchmark_runs.jsonl).") + parser.add_argument("--input", help="Path to benchmark_runs.jsonl.") + parser.add_argument( + "--output-dir", + help="Directory to write CSV outputs (default: run dir).", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.run_dir and not args.input: + raise ValueError("Provide --run-dir or --input.") + + run_dir = Path(args.run_dir) if args.run_dir else None + input_path = Path(args.input) if args.input else (run_dir / "benchmark_runs.jsonl") + output_dir = Path(args.output_dir) if args.output_dir else (run_dir or input_path.parent) + output_dir.mkdir(parents=True, exist_ok=True) + + records = _read_jsonl(input_path) + if not records: + raise ValueError(f"No records found in {input_path}") + + by_model: dict[str, list[dict[str, Any]]] = {} + for record in records: + model_id = record.get("model_id", "unknown") + by_model.setdefault(model_id, []).append(record) + + by_model_rows = [] + for model_id, items in sorted(by_model.items()): + durations = [r["duration_seconds"] for r in items if isinstance(r.get("duration_seconds"), (int, float))] + perf_values = [] + for item in items: + perf = (item.get("analysis_performance") or {}).get("avg_cells_per_sec") + if isinstance(perf, (int, float)): + perf_values.append(perf) + total = len(items) + completed = sum(1 for r in items if r.get("job_status") == "completed") + failed = sum(1 for r in items if r.get("job_status") == "failed") + skipped = sum(1 for r in items if r.get("job_status") == "skipped") + analysis_success = sum(1 for r in items if r.get("analysis_status") == "success") + by_model_rows.append({ + "model_id": model_id, + "total_runs": total, + "completed_runs": completed, + "failed_runs": failed, + "skipped_runs": skipped, + "analysis_success_runs": analysis_success, + "success_rate": (completed / total) if total else 0.0, + "analysis_success_rate": (analysis_success / total) if total else 0.0, + "avg_duration_seconds": _mean(durations), + "avg_cells_per_sec": _mean(perf_values), + }) + + by_prompt_rows = [] + for record in records: + by_prompt_rows.append({ + "prompt_id": record.get("prompt_id"), + "model_id": record.get("model_id"), + "job_status": record.get("job_status"), + "analysis_status": record.get("analysis_status"), + "duration_seconds": record.get("duration_seconds"), + "avg_cells_per_sec": (record.get("analysis_performance") or {}).get("avg_cells_per_sec"), + "run_directory": record.get("run_directory"), + "prompt_excerpt": record.get("prompt_excerpt"), + }) + + summary_row = { + "total_models": len(by_model), + "total_runs": len(records), + "completed_runs": sum(1 for r in records if r.get("job_status") == "completed"), + "failed_runs": sum(1 for r in records if r.get("job_status") == "failed"), + "skipped_runs": sum(1 for r in records if r.get("job_status") == "skipped"), + } + + def write_csv(path: Path, rows: list[dict[str, Any]]) -> None: + if not rows: + return + with path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + write_csv(output_dir / "by_model.csv", by_model_rows) + write_csv(output_dir / "by_prompt.csv", by_prompt_rows) + write_csv(output_dir / "summary.csv", [summary_row]) + + print(json.dumps({ + "input": str(input_path), + "output_dir": str(output_dir), + "rows": len(records), + }, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_compare_models.py b/tests/unit/test_compare_models.py new file mode 100644 index 0000000..b783491 --- /dev/null +++ b/tests/unit/test_compare_models.py @@ -0,0 +1,80 @@ +import csv +import importlib.util +import json +from pathlib import Path + + +def _load_compare_models_module() -> object: + module_path = Path(__file__).resolve().parents[2] / "scripts" / "compare_models.py" + spec = importlib.util.spec_from_file_location("compare_models", module_path) + module = importlib.util.module_from_spec(spec) + if spec.loader is None: + raise RuntimeError("Failed to load compare_models module.") + spec.loader.exec_module(module) + return module + + +def _read_csv_rows(path: Path) -> list[dict[str, str]]: + with path.open(newline="", encoding="utf-8") as handle: + return list(csv.DictReader(handle)) + + +def test_compare_models_generates_tables(tmp_path, monkeypatch): + input_path = tmp_path / "benchmark_runs.jsonl" + records = [ + { + "model_id": "model_a", + "prompt_id": "prompt_001", + "job_status": "completed", + "analysis_status": "success", + "duration_seconds": 12.0, + "analysis_performance": {"avg_cells_per_sec": 1000.0}, + "run_directory": "/tmp/run_a", + "prompt_excerpt": "demo prompt a", + }, + { + "model_id": "model_b", + "prompt_id": "prompt_002", + "job_status": "failed", + "analysis_status": "failed", + "duration_seconds": 8.0, + "analysis_performance": {"avg_cells_per_sec": 2000.0}, + "run_directory": "/tmp/run_b", + "prompt_excerpt": "demo prompt b", + }, + ] + input_path.write_text("\n".join(json.dumps(r) for r in records)) + + output_dir = tmp_path / "out" + module = _load_compare_models_module() + monkeypatch.setattr( + "sys.argv", + [ + "compare_models.py", + "--input", + str(input_path), + "--output-dir", + str(output_dir), + ], + ) + + module.main() + + by_model_path = output_dir / "by_model.csv" + by_prompt_path = output_dir / "by_prompt.csv" + summary_path = output_dir / "summary.csv" + + assert by_model_path.exists() + assert by_prompt_path.exists() + assert summary_path.exists() + + by_model_rows = _read_csv_rows(by_model_path) + assert {row["model_id"] for row in by_model_rows} == {"model_a", "model_b"} + + model_a = next(row for row in by_model_rows if row["model_id"] == "model_a") + assert model_a["total_runs"] == "1" + assert model_a["completed_runs"] == "1" + assert model_a["analysis_success_runs"] == "1" + + summary_rows = _read_csv_rows(summary_path) + assert summary_rows[0]["total_runs"] == "2" From 3dd3d06fb7635b0b8db23355d0401b998b6676e2 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 05:57:21 -0800 Subject: [PATCH 11/36] Add metrics aggregation adapter for raw benchmark records --- scripts/aggregate_metrics.py | 92 ++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 scripts/aggregate_metrics.py diff --git a/scripts/aggregate_metrics.py b/scripts/aggregate_metrics.py new file mode 100644 index 0000000..5a9412d --- /dev/null +++ b/scripts/aggregate_metrics.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +"""Aggregate workflow metrics JSONL into raw benchmark records.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any, Iterable + + +def _iter_metric_files(path: Path) -> Iterable[Path]: + if path.is_file(): + return [path] + return sorted(path.rglob("metrics*.jsonl")) + + +def _load_events(path: Path) -> list[dict[str, Any]]: + events: list[dict[str, Any]] = [] + for line in path.read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + events.append(json.loads(line)) + return events + + +def _record_from_event(event: dict[str, Any], source: Path) -> dict[str, Any]: + data = event.get("data", {}) + context = event.get("context", {}) + models = data.get("models") or [] + providers = data.get("providers") or [] + model_id = context.get("model_id") or (models[0] if models else "unknown") + provider = context.get("provider") or (providers[0] if providers else None) + + return { + "model_id": model_id, + "provider": provider, + "prompt_id": context.get("prompt_id"), + "prompt_excerpt": context.get("prompt_excerpt"), + "job_status": data.get("job_status"), + "iteration": data.get("iteration"), + "run_directory": data.get("run_directory"), + "selected_case": context.get("selected_case"), + "tokens_total_input": data.get("tokens_total_input"), + "tokens_total_output": data.get("tokens_total_output"), + "tokens_total": data.get("tokens_total"), + "tokens_by_stage": data.get("tokens_by_stage"), + "stages": data.get("stages"), + "source": str(source), + } + + +def _write_jsonl(path: Path, records: list[dict[str, Any]]) -> None: + if not records: + return + with path.open("w", encoding="utf-8") as handle: + for record in records: + handle.write(json.dumps(record, default=str)) + handle.write("\n") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Aggregate metrics JSONL into raw benchmark records.") + parser.add_argument("--input", required=True, help="Metrics JSONL file or directory.") + parser.add_argument( + "--output", + default=None, + help="Output raw_metrics.jsonl path (defaults next to input).", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + input_path = Path(args.input) + output_path = Path(args.output) if args.output else (input_path / "raw_metrics.jsonl") + if input_path.is_file(): + output_path = Path(args.output) if args.output else input_path.parent / "raw_metrics.jsonl" + + records: list[dict[str, Any]] = [] + for path in _iter_metric_files(input_path): + for event in _load_events(path): + if event.get("type") != "workflow_summary": + continue + records.append(_record_from_event(event, path)) + + _write_jsonl(output_path, records) + print(json.dumps({"output": str(output_path), "records": len(records)}, indent=2)) + + +if __name__ == "__main__": + main() From 9be57d1fe8213c5d9d16666b83a8ca206d480a14 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 06:00:31 -0800 Subject: [PATCH 12/36] Refactor benchmark runner into shared case and model modules --- scripts/run_benchmark.py | 319 ++++++-------------------- src/benchmark_runner.py | 478 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 550 insertions(+), 247 deletions(-) create mode 100644 src/benchmark_runner.py diff --git a/scripts/run_benchmark.py b/scripts/run_benchmark.py index cc3584f..1634e7e 100644 --- a/scripts/run_benchmark.py +++ b/scripts/run_benchmark.py @@ -1,17 +1,20 @@ #!/usr/bin/env python3 -"""Benchmark runner scaffold for scaling case suites.""" +"""Run benchmark suites (case grids or multi-model prompt runs).""" from __future__ import annotations import argparse -import json -import sys from datetime import datetime, timezone from pathlib import Path -from typing import Any -import yaml -from jsonschema import Draft202012Validator +from src.benchmark_runner import ( + collect_cases, + expand_case_runs, + filter_cases, + run_case_suite, + run_model_benchmark, + validate_case_files, +) DEFAULT_SCHEMA = Path("benchmark/specs/case_schema.yaml") @@ -19,100 +22,12 @@ DEFAULT_RUNS_DIR = Path("benchmark/runs") -def _load_yaml(path: Path) -> dict[str, Any]: - with path.open("r", encoding="utf-8") as handle: - return yaml.safe_load(handle) or {} - - -def _load_schema(schema_path: Path) -> dict[str, Any]: - return _load_yaml(schema_path) - - -def _iter_case_files(cases_dir: Path) -> list[Path]: - return sorted(cases_dir.glob("*.yaml")) - - -def _validate_file(data: dict[str, Any], schema: dict[str, Any], path: Path) -> list[str]: - validator = Draft202012Validator(schema) - errors = [] - for error in sorted(validator.iter_errors(data), key=lambda err: list(err.path)): - location = "/".join(str(part) for part in error.path) or "" - errors.append(f"{path}: {location}: {error.message}") - solver = data.get("solver") - for case in data.get("cases", []): - case_solver = case.get("solver") - if solver and case_solver and solver != case_solver: - errors.append( - f"{path}: cases/{case.get('id', 'unknown')}: solver {case_solver} " - f"does not match suite solver {solver}" - ) - return errors - - -def _collect_cases(cases_dir: Path) -> list[dict[str, Any]]: - cases: list[dict[str, Any]] = [] - for path in _iter_case_files(cases_dir): - data = _load_yaml(path) - for case in data.get("cases", []): - case = dict(case) - case["_suite_id"] = data.get("suite_id") - case["_suite_solver"] = data.get("solver") - case["_source"] = str(path) - cases.append(case) - return cases - - -def _filter_cases( - cases: list[dict[str, Any]], solver: str | None, case_ids: list[str] -) -> list[dict[str, Any]]: - filtered = cases - if solver: - filtered = [case for case in filtered if case.get("solver") == solver] - if case_ids: - wanted = set(case_ids) - filtered = [case for case in filtered if case.get("id") in wanted] - return filtered - - -def _expand_case_runs(cases: list[dict[str, Any]]) -> list[dict[str, Any]]: - runs: list[dict[str, Any]] = [] - for case in cases: - for size in case["scaling"]["sizes"]: - runs.append( - { - "case_id": case["id"], - "solver": case["solver"], - "case_name": case["case_name"], - "size_label": size["label"], - "grid": size["grid"], - "amr_levels": size["amr_levels"], - "source": case["_source"], - } - ) - return runs - - -def _check_unique_ids(cases: list[dict[str, Any]]) -> list[str]: - seen: dict[str, str] = {} - errors = [] - for case in cases: - case_id = case.get("id") - if not case_id: - continue - source = case.get("_source", "") - if case_id in seen: - errors.append(f"Duplicate case id {case_id} in {source} (also {seen[case_id]})") - else: - seen[case_id] = source - return errors - - -def _print_cases(cases: list[dict[str, Any]]) -> None: +def _print_cases(cases: list[dict]) -> None: for case in cases: print(f"{case['id']}: {case['solver']} - {case['case_name']}") -def _print_plan(runs: list[dict[str, Any]]) -> None: +def _print_plan(runs: list[dict]) -> None: for run in runs: print( f"{run['case_id']} [{run['size_label']}] " @@ -120,170 +35,80 @@ def _print_plan(runs: list[dict[str, Any]]) -> None: ) -def _init_run_state(cases: list[dict[str, Any]]) -> dict[str, Any]: - created_at = datetime.now(timezone.utc).isoformat() - state = { - "schema_version": 1, - "run_id": f"run_{created_at.replace(':', '').replace('-', '')}", - "created_at": created_at, - "cases": {}, - } - for case in cases: - sizes = {} - for size in case["scaling"]["sizes"]: - sizes[size["label"]] = { - "grid": size["grid"], - "amr_levels": size["amr_levels"], - "status": "pending", - } - state["cases"][case["id"]] = { - "solver": case["solver"], - "case_name": case["case_name"], - "source": case["_source"], - "status": "pending", - "sizes": sizes, - } - return state - - -def _write_state(path: Path, state: dict[str, Any]) -> None: - path.write_text(json.dumps(state, indent=2, sort_keys=True) + "\n", encoding="utf-8") +def _cases_parser(subparsers: argparse._SubParsersAction) -> None: + cases_parser = subparsers.add_parser("cases", help="Operate on benchmark case suites") + cases_parser.add_argument("--cases-dir", type=Path, default=DEFAULT_CASES_DIR) + cases_parser.add_argument("--schema", type=Path, default=DEFAULT_SCHEMA) + cases_parser.add_argument("--solver", choices=["PeleC", "PeleLMeX", "ERF", "REMORA"]) + cases_parser.add_argument("--case-id", action="append", default=[]) + case_sub = cases_parser.add_subparsers(dest="command", required=True) + case_sub.add_parser("list", help="List benchmark cases") + case_sub.add_parser("validate", help="Validate benchmark case files") + case_sub.add_parser("plan", help="Print expanded run plan") -def _load_state(path: Path) -> dict[str, Any]: - return json.loads(path.read_text(encoding="utf-8")) - - -def _format_command(case: dict[str, Any], size: dict[str, Any]) -> str | None: - run_info = case.get("run") or {} - command = run_info.get("command") - if not command: - return None - return command.format( - solver=case.get("solver"), - case_dir=case.get("case_dir"), - inputs=case.get("inputs"), - grid=size.get("grid"), - label=size.get("label"), - ) - - -def _execute_case(case: dict[str, Any], size: dict[str, Any], dry_run: bool) -> tuple[str, str]: - command = _format_command(case, size) - if not command: - return "skipped", "no command defined" - if dry_run: - return "planned", command - import subprocess # Imported lazily to keep startup fast. - - working_dir = (case.get("run") or {}).get("working_dir") or case.get("case_dir") - result = subprocess.run(command, shell=True, cwd=working_dir) - if result.returncode == 0: - return "complete", command - return "failed", command - - -def _run_cases( - cases: list[dict[str, Any]], - run_dir: Path, - resume: bool, - dry_run: bool, -) -> int: - run_dir.mkdir(parents=True, exist_ok=True) - state_path = run_dir / "run_state.json" - - if resume: - if not state_path.exists(): - print(f"Missing run state at {state_path}", file=sys.stderr) - return 2 - state = _load_state(state_path) - else: - state = _init_run_state(cases) - _write_state(state_path, state) - - for case in cases: - case_state = state["cases"].get(case["id"]) - if not case_state: - continue - for size in case["scaling"]["sizes"]: - size_state = case_state["sizes"][size["label"]] - if size_state["status"] == "complete": - continue - status, detail = _execute_case(case, size, dry_run) - size_state["status"] = status - size_state["detail"] = detail - if status == "failed": - case_state["status"] = "failed" - _write_state(state_path, state) - print(f"Failed: {case['id']} [{size['label']}]", file=sys.stderr) - return 1 - if all(entry["status"] == "complete" for entry in case_state["sizes"].values()): - case_state["status"] = "complete" - elif all(entry["status"] == "skipped" for entry in case_state["sizes"].values()): - case_state["status"] = "skipped" - else: - case_state["status"] = "in_progress" - _write_state(state_path, state) - - return 0 - - -def main() -> int: - parser = argparse.ArgumentParser(description="Scaling benchmark runner scaffold") - parser.add_argument("--cases-dir", type=Path, default=DEFAULT_CASES_DIR) - parser.add_argument("--schema", type=Path, default=DEFAULT_SCHEMA) - parser.add_argument("--solver", choices=["PeleC", "PeleLMeX", "ERF", "REMORA"]) - parser.add_argument("--case-id", action="append", default=[]) - - subparsers = parser.add_subparsers(dest="command", required=True) - - subparsers.add_parser("list", help="List benchmark cases") - subparsers.add_parser("validate", help="Validate benchmark case files") - subparsers.add_parser("plan", help="Print expanded run plan") - - run_parser = subparsers.add_parser("run", help="Execute benchmark cases (stub)") + run_parser = case_sub.add_parser("run", help="Execute benchmark cases (stub)") run_parser.add_argument("--run-dir", type=Path, default=None) run_parser.add_argument("--resume", action="store_true") run_parser.add_argument("--execute", action="store_true") - args = parser.parse_args() - cases_dir = args.cases_dir - schema = _load_schema(args.schema) +def _models_parser(subparsers: argparse._SubParsersAction) -> None: + model_parser = subparsers.add_parser("models", help="Run multi-model prompt benchmarks") + model_parser.add_argument("--config", required=True, help="Benchmark config file (json/yaml).") + model_parser.add_argument( + "--output-dir", + default="output/benchmarks", + help="Base output directory for benchmark artifacts.", + ) + model_parser.add_argument("--run-name", default=None, help="Optional run name override.") - if args.command == "validate": - errors: list[str] = [] - for path in _iter_case_files(cases_dir): - data = _load_yaml(path) - errors.extend(_validate_file(data, schema, path)) - cases = _collect_cases(cases_dir) - errors.extend(_check_unique_ids(cases)) - if errors: - for error in errors: - print(error, file=sys.stderr) - return 1 - print("All benchmark case files are valid.") - return 0 - cases = _collect_cases(cases_dir) - cases = _filter_cases(cases, args.solver, args.case_id) +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark runner") + subparsers = parser.add_subparsers(dest="mode", required=True) + _cases_parser(subparsers) + _models_parser(subparsers) + return parser.parse_args() - if args.command == "list": - _print_cases(cases) - return 0 - if args.command == "plan": - runs = _expand_case_runs(cases) - _print_plan(runs) +def main() -> int: + args = parse_args() + + if args.mode == "cases": + if args.command == "validate": + errors = validate_case_files(args.schema, args.cases_dir) + if errors: + for error in errors: + print(error) + return 1 + print("All benchmark case files are valid.") + return 0 + + cases = collect_cases(args.cases_dir) + cases = filter_cases(cases, args.solver, args.case_id) + + if args.command == "list": + _print_cases(cases) + return 0 + + if args.command == "plan": + runs = expand_case_runs(cases) + _print_plan(runs) + return 0 + + if args.command == "run": + run_dir = args.run_dir + if run_dir is None: + run_dir = DEFAULT_RUNS_DIR / datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + dry_run = not args.execute + return run_case_suite(cases, run_dir, args.resume, dry_run) + + if args.mode == "models": + result = run_model_benchmark(Path(args.config), Path(args.output_dir), args.run_name) + print(result) return 0 - if args.command == "run": - run_dir = args.run_dir - if run_dir is None: - run_dir = DEFAULT_RUNS_DIR / datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") - dry_run = not args.execute - return _run_cases(cases, run_dir, args.resume, dry_run) - return 0 diff --git a/src/benchmark_runner.py b/src/benchmark_runner.py new file mode 100644 index 0000000..1805a41 --- /dev/null +++ b/src/benchmark_runner.py @@ -0,0 +1,478 @@ +"""Benchmark runners for case suites and model comparisons.""" + +from __future__ import annotations + +import json +import os +import re +import subprocess +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import yaml +from jsonschema import Draft202012Validator + + +# ===== Shared helpers ===== + +def _load_yaml(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as handle: + return yaml.safe_load(handle) or {} + + +def _load_data(path: Path) -> dict[str, Any]: + suffix = path.suffix.lower() + if suffix == ".json": + return json.loads(path.read_text(encoding="utf-8")) + if suffix in {".yaml", ".yml"}: + return _load_yaml(path) + raise ValueError(f"Unsupported config type: {path}") + + +# ===== Case suite runner ===== + +def validate_case_files(schema_path: Path, cases_dir: Path) -> list[str]: + schema = _load_yaml(schema_path) + errors: list[str] = [] + validator = Draft202012Validator(schema) + for path in sorted(cases_dir.glob("*.yaml")): + data = _load_yaml(path) + for error in sorted(validator.iter_errors(data), key=lambda err: list(err.path)): + location = "/".join(str(part) for part in error.path) or "" + errors.append(f"{path}: {location}: {error.message}") + solver = data.get("solver") + for case in data.get("cases", []): + case_solver = case.get("solver") + if solver and case_solver and solver != case_solver: + errors.append( + f"{path}: cases/{case.get('id', 'unknown')}: solver {case_solver} " + f"does not match suite solver {solver}" + ) + return errors + + +def collect_cases(cases_dir: Path) -> list[dict[str, Any]]: + cases: list[dict[str, Any]] = [] + for path in sorted(cases_dir.glob("*.yaml")): + data = _load_yaml(path) + for case in data.get("cases", []): + case = dict(case) + case["_suite_id"] = data.get("suite_id") + case["_suite_solver"] = data.get("solver") + case["_source"] = str(path) + cases.append(case) + return cases + + +def filter_cases( + cases: list[dict[str, Any]], solver: str | None, case_ids: list[str] +) -> list[dict[str, Any]]: + filtered = cases + if solver: + filtered = [case for case in filtered if case.get("solver") == solver] + if case_ids: + wanted = set(case_ids) + filtered = [case for case in filtered if case.get("id") in wanted] + return filtered + + +def expand_case_runs(cases: list[dict[str, Any]]) -> list[dict[str, Any]]: + runs: list[dict[str, Any]] = [] + for case in cases: + for size in case["scaling"]["sizes"]: + runs.append( + { + "case_id": case["id"], + "solver": case["solver"], + "case_name": case["case_name"], + "size_label": size["label"], + "grid": size["grid"], + "amr_levels": size["amr_levels"], + "source": case["_source"], + } + ) + return runs + + +def init_case_run_state(cases: list[dict[str, Any]]) -> dict[str, Any]: + created_at = datetime.now(timezone.utc).isoformat() + state = { + "schema_version": 1, + "run_id": f"run_{created_at.replace(':', '').replace('-', '')}", + "created_at": created_at, + "cases": {}, + } + for case in cases: + sizes = {} + for size in case["scaling"]["sizes"]: + sizes[size["label"]] = { + "grid": size["grid"], + "amr_levels": size["amr_levels"], + "status": "pending", + } + state["cases"][case["id"]] = { + "solver": case["solver"], + "case_name": case["case_name"], + "source": case["_source"], + "status": "pending", + "sizes": sizes, + } + return state + + +def _write_state(path: Path, state: dict[str, Any]) -> None: + path.write_text(json.dumps(state, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def _load_state(path: Path) -> dict[str, Any]: + return json.loads(path.read_text(encoding="utf-8")) + + +def _format_command(case: dict[str, Any], size: dict[str, Any]) -> str | None: + run_info = case.get("run") or {} + command = run_info.get("command") + if not command: + return None + return command.format( + solver=case.get("solver"), + case_dir=case.get("case_dir"), + inputs=case.get("inputs"), + grid=size.get("grid"), + label=size.get("label"), + ) + + +def _execute_case(case: dict[str, Any], size: dict[str, Any], dry_run: bool) -> tuple[str, str]: + command = _format_command(case, size) + if not command: + return "skipped", "no command defined" + if dry_run: + return "planned", command + + working_dir = (case.get("run") or {}).get("working_dir") or case.get("case_dir") + result = subprocess.run(command, shell=True, cwd=working_dir) + if result.returncode == 0: + return "complete", command + return "failed", command + + +def run_case_suite( + cases: list[dict[str, Any]], + run_dir: Path, + resume: bool, + dry_run: bool, +) -> int: + run_dir.mkdir(parents=True, exist_ok=True) + state_path = run_dir / "run_state.json" + + if resume: + if not state_path.exists(): + raise FileNotFoundError(f"Missing run state at {state_path}") + state = _load_state(state_path) + else: + state = init_case_run_state(cases) + _write_state(state_path, state) + + for case in cases: + case_state = state["cases"].get(case["id"]) + if not case_state: + continue + for size in case["scaling"]["sizes"]: + size_state = case_state["sizes"][size["label"]] + if size_state["status"] == "complete": + continue + status, detail = _execute_case(case, size, dry_run) + size_state["status"] = status + size_state["detail"] = detail + if status == "failed": + case_state["status"] = "failed" + _write_state(state_path, state) + return 1 + if all(entry["status"] == "complete" for entry in case_state["sizes"].values()): + case_state["status"] = "complete" + elif all(entry["status"] == "skipped" for entry in case_state["sizes"].values()): + case_state["status"] = "skipped" + else: + case_state["status"] = "in_progress" + _write_state(state_path, state) + + return 0 + + +# ===== Multi-model benchmark runner ===== + +def _merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + merged = dict(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(base.get(key), dict): + merged[key] = {**base[key], **value} + else: + merged[key] = value + return merged + + +def _ensure_prompt_entry(entry: Any, index: int) -> dict[str, Any]: + if isinstance(entry, str): + return {"id": f"prompt_{index:03d}", "prompt": entry} + if isinstance(entry, dict): + prompt_id = entry.get("id") or f"prompt_{index:03d}" + return { + "id": prompt_id, + "prompt": entry.get("prompt"), + "prompt_path": entry.get("prompt_path"), + } + raise ValueError(f"Invalid prompt entry at index {index}: {entry!r}") + + +def _load_prompt_text(entry: dict[str, Any]) -> str | None: + prompt = entry.get("prompt") + if prompt: + return str(prompt) + prompt_path = entry.get("prompt_path") + if prompt_path: + path = Path(prompt_path) + return path.read_text().strip() + return None + + +def _build_config_for_model(model: dict[str, Any], output_dir: Path) -> Path | None: + base_config = {} + config_path = model.get("config_path") + if config_path: + base_config = _load_data(Path(config_path)) + overrides = model.get("overrides") or {} + merged = _merge_dicts(base_config, overrides) + if not merged: + return None + config_dir = output_dir / "configs" + config_dir.mkdir(parents=True, exist_ok=True) + model_id = model.get("id", "model") + model_slug = _slugify(model_id) + config_file = config_dir / f"{model_slug}.json" + config_file.write_text(json.dumps(merged, indent=2)) + return config_file + + +def _build_command( + config_path: Path | None, + prompt_entry: dict[str, Any], + output_dir: Path, + run_args: dict[str, Any], +) -> list[str]: + cmd = [os.environ.get("AMREX_AGENT_PYTHON", "python"), "amrex_agent.py", "--json"] + if prompt_entry.get("prompt"): + cmd.extend(["--prompt", prompt_entry["prompt"]]) + elif prompt_entry.get("prompt_path"): + cmd.extend(["--prompt-path", prompt_entry["prompt_path"]]) + else: + raise ValueError(f"Prompt entry missing prompt text/path: {prompt_entry!r}") + + if config_path: + cmd.extend(["--config", str(config_path)]) + cmd.extend(["--output-dir", str(output_dir)]) + + flag_map = { + "run_mode": "--run-mode", + "environment": "--environment", + "inputs_file_strategy": "--inputs-file-strategy", + "remap_strategy": "--remap-strategy", + "inputs_file_override": "--inputs-file-override", + "baseline_override": "--baseline-override", + "baseline_switch_after_retries": "--baseline-switch-after-retries", + "llm_gate_strategy": "--llm-gate-strategy", + "indexing_strategy": "--indexing-strategy", + "run_ntasks": "--run-ntasks", + } + for key, flag in flag_map.items(): + value = run_args.get(key) + if value is not None: + cmd.extend([flag, str(value)]) + + if run_args.get("run_serial"): + cmd.append("--run-serial") + if run_args.get("dry_run"): + cmd.append("--dry-run") + if run_args.get("preconfirm"): + cmd.append("--preconfirm") + if run_args.get("save_workflow"): + cmd.append("--save-workflow") + if run_args.get("save_transcript"): + cmd.append("--save-transcript") + if run_args.get("save_log"): + cmd.append("--save-log") + if run_args.get("verbose"): + cmd.append("--verbose") + + extra_args = run_args.get("extra_args") or [] + if extra_args: + cmd.extend([str(arg) for arg in extra_args]) + + return cmd + + +def _write_jsonl(path: Path, payload: dict[str, Any]) -> None: + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, default=str)) + handle.write("\n") + + +def _slugify(value: str) -> str: + return re.sub(r"[^A-Za-z0-9._-]+", "_", value) + + +def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | None) -> dict[str, Any]: + bench_config = _load_data(config_path) + + prompts_raw = bench_config.get("prompts") or [] + if not prompts_raw: + raise ValueError("Benchmark config must include prompts.") + models = bench_config.get("models") or [] + if not models: + raise ValueError("Benchmark config must include models.") + + run_args = bench_config.get("run_args") or {} + env_common = bench_config.get("env") or {} + + run_name = run_name or datetime.now().strftime("bench_%Y%m%d_%H%M%S") + run_dir = output_dir / run_name + run_dir.mkdir(parents=True, exist_ok=True) + runs_dir = run_dir / "runs" + runs_dir.mkdir(parents=True, exist_ok=True) + + manifest = { + "run_name": run_name, + "created_at": datetime.now().isoformat(), + "benchmark_config": str(config_path), + "models": [], + "prompts": [], + "run_args": run_args, + "env_keys": sorted(set(env_common.keys())), + } + + prompt_entries = [] + for idx, entry in enumerate(prompts_raw, start=1): + normalized = _ensure_prompt_entry(entry, idx) + prompt_text = _load_prompt_text(normalized) + if not prompt_text: + raise ValueError(f"Prompt entry missing text/path: {normalized!r}") + prompt_entries.append({**normalized, "prompt": prompt_text}) + manifest["prompts"].append({ + "id": normalized["id"], + "prompt_path": normalized.get("prompt_path"), + "prompt_excerpt": prompt_text[:160], + }) + + raw_metrics_path = run_dir / "benchmark_runs.jsonl" + + for model in models: + model_id = model.get("id") + if not model_id: + raise ValueError("Each model must include an id.") + model_slug = _slugify(model_id) + model_env = dict(env_common) + model_env.update(model.get("env") or {}) + + model_config_path = _build_config_for_model(model, run_dir) + manifest["models"].append({ + "id": model_id, + "slug": model_slug, + "config_path": str(model_config_path) if model_config_path else None, + "config_source": model.get("config_path"), + "override_keys": sorted((model.get("overrides") or {}).keys()), + "env_keys": sorted((model.get("env") or {}).keys()), + }) + + for prompt in prompt_entries: + prompt_id = prompt["id"] + prompt_dir = runs_dir / model_slug / prompt_id + prompt_dir.mkdir(parents=True, exist_ok=True) + env = os.environ.copy() + env.update({k: str(v) for k, v in model_env.items()}) + + cmd = _build_command(model_config_path, prompt, prompt_dir, run_args) + + started_at = datetime.now().isoformat() + start_time = time.time() + result_data: dict[str, Any] | None = None + error = None + exit_code = None + stdout = "" + stderr = "" + try: + timeout = run_args.get("timeout_seconds") + proc = subprocess.run( + cmd, + cwd=Path(__file__).resolve().parents[1], + env=env, + text=True, + capture_output=True, + timeout=timeout, + ) + exit_code = proc.returncode + stdout = proc.stdout + stderr = proc.stderr + if stdout.strip(): + result_data = json.loads(stdout) + except subprocess.TimeoutExpired as exc: + error = f"timeout_after_{exc.timeout}s" + stdout = exc.stdout or "" + stderr = exc.stderr or "" + except json.JSONDecodeError as exc: + error = f"json_parse_failed: {exc}" + except Exception as exc: + error = f"runner_error: {exc}" + + ended_at = datetime.now().isoformat() + duration = time.time() - start_time + + analysis_report = None + run_directory = None + job_status = None + selected_case = None + if result_data: + run_directory = result_data.get("run_directory") + job_status = result_data.get("job_status") + selected_case = result_data.get("selected_case") + analysis_report = result_data.get("analysis_report") + + if not analysis_report and run_directory: + report_path = Path(run_directory) / "analysis_report.json" + if report_path.exists(): + try: + analysis_report = json.loads(report_path.read_text()) + except Exception: + analysis_report = None + + record = { + "model_id": model_id, + "prompt_id": prompt_id, + "prompt_excerpt": prompt["prompt"][:200], + "job_status": job_status or "unknown", + "analysis_status": (analysis_report or {}).get("status"), + "analysis_performance": (analysis_report or {}).get("performance"), + "analysis_issues": (analysis_report or {}).get("issues"), + "run_directory": run_directory, + "selected_case": selected_case, + "started_at": started_at, + "ended_at": ended_at, + "duration_seconds": duration, + "exit_code": exit_code, + "error": error, + "stderr_excerpt": stderr[:2000] if stderr else None, + } + _write_jsonl(raw_metrics_path, record) + + per_run = prompt_dir / "result.json" + per_run.write_text(json.dumps({ + "command": cmd, + "exit_code": exit_code, + "stdout": stdout, + "stderr": stderr, + "result": result_data, + "record": record, + }, indent=2, default=str)) + + (run_dir / "manifest.json").write_text(json.dumps(manifest, indent=2)) + return {"run_dir": str(run_dir), "metrics": str(raw_metrics_path)} From 429ccde2a4576e65f1ba1561b20fdd79d262df82 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 06:48:09 -0800 Subject: [PATCH 13/36] Add difficulty and novelty metadata to benchmark cases --- benchmark/cases/erf.yaml | 89 ++++++++++++++++++++------------ benchmark/cases/pelec.yaml | 29 ++++++++++- benchmark/cases/pelelmex.yaml | 39 +++++++++++--- benchmark/cases/remora.yaml | 43 +++++++++++---- benchmark/specs/case_schema.yaml | 24 +++++++++ 5 files changed, 174 insertions(+), 50 deletions(-) diff --git a/benchmark/cases/erf.yaml b/benchmark/cases/erf.yaml index 6f35ed5..f2e3bd8 100644 --- a/benchmark/cases/erf.yaml +++ b/benchmark/cases/erf.yaml @@ -2,37 +2,19 @@ schema_version: "1.0" suite_id: "erf_scaling_v1" solver: "ERF" cases: - - id: "erf_bubble_strong" - solver: "ERF" - case_name: "Bubble" - case_dir: "Exec/RegTests/Bubble" - inputs: "inputs" - description: "Thermal bubble rise for atmospheric dynamics scaling." - dimension: 3 - physics: [buoyancy, stratification, atmosphere] - scaling: - type: "strong" - sizes: - - label: "S" - grid: "128x128x128" - amr_levels: 0 - - label: "M" - grid: "256x256x256" - amr_levels: 0 - - label: "L" - grid: "384x384x384" - amr_levels: 0 - tags: [regtest, atmosphere] - status: "ready" - - id: "erf_density_current_strong" solver: "ERF" case_name: "DensityCurrent" - case_dir: "Exec/RegTests/DensityCurrent" - inputs: "inputs" - description: "Density current propagation for shear-driven flow scaling." + case_dir: "Exec/DryRegTests/DensityCurrent" + inputs: "inputs_amr" + description: "Classic density (gravity) current benchmark (Straka 1993)." dimension: 3 physics: [density_current, stratification, atmosphere] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: false + novelty_tier: "parameter-only" scaling: type: "strong" sizes: @@ -52,10 +34,15 @@ cases: solver: "ERF" case_name: "ABL-Neutral" case_dir: "Exec/ABL" - inputs: "inputs" - description: "Neutral atmospheric boundary layer for weak scaling." + inputs: "inputs_smagorinsky" + description: "Atmospheric boundary layer with turbulence scheme and MOST options." dimension: 3 physics: [abl, turbulence, atmosphere] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "parameter-only" scaling: type: "weak" sizes: @@ -75,10 +62,15 @@ cases: solver: "ERF" case_name: "ABL-Stable" case_dir: "Exec/ABL" - inputs: "inputs_stable" - description: "Stable boundary layer for night-time stratification scaling." + inputs: "mrf_stable_gabls" + description: "Stable ABL configuration with hydrostatic sounding inputs." dimension: 3 physics: [abl, stable, atmosphere] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "config-extension" scaling: type: "weak" sizes: @@ -98,10 +90,15 @@ cases: solver: "ERF" case_name: "ABL-Convective" case_dir: "Exec/ABL" - inputs: "inputs_convective" - description: "Convective boundary layer with surface heating for weak scaling." + inputs: "mrf_unstable" + description: "Unstable/convective ABL configuration with perturbations." dimension: 3 physics: [abl, convection, atmosphere] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "config-extension" scaling: type: "weak" sizes: @@ -116,3 +113,31 @@ cases: amr_levels: 0 tags: [abl, convection] status: "draft" + + - id: "erf_taylor_green_strong" + solver: "ERF" + case_name: "TaylorGreenVortex" + case_dir: "Exec/DryRegTests/TaylorGreenVortex" + inputs: "inputs_advdiff" + description: "Taylor-Green vortex benchmark for advection/diffusion terms." + dimension: 3 + physics: [turbulence, vortex, atmosphere] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: false + novelty_tier: "parameter-only" + scaling: + type: "strong" + sizes: + - label: "S" + grid: "128x128x128" + amr_levels: 0 + - label: "M" + grid: "256x256x256" + amr_levels: 0 + - label: "L" + grid: "384x384x384" + amr_levels: 0 + tags: [regtest, turbulence] + status: "ready" diff --git a/benchmark/cases/pelec.yaml b/benchmark/cases/pelec.yaml index ca61c44..54bef08 100644 --- a/benchmark/cases/pelec.yaml +++ b/benchmark/cases/pelec.yaml @@ -10,6 +10,11 @@ cases: description: "Premixed methane flame; baseline reacting flow scaling case." dimension: 3 physics: [premixed, combustion, compressible] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "parameter-only" scaling: type: "weak" sizes: @@ -33,6 +38,11 @@ cases: description: "Sedov blast wave for compressible hydro strong scaling." dimension: 3 physics: [blast, hydro, compressible] + difficulty_tier: "easy" + prompt_length_band: "short" + concept_density: "low" + specialized_knowledge: false + novelty_tier: "parameter-only" scaling: type: "strong" sizes: @@ -53,9 +63,14 @@ cases: case_name: "TaylorGreen" case_dir: "Exec/RegTests/TG" inputs: "tg-1.inp" - description: "Taylor-Green vortex for turbulence scaling." + description: "Taylor-Green vortex (High-Order CFD workshop benchmark)." dimension: 3 physics: [turbulence, vortex, compressible] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: false + novelty_tier: "parameter-only" scaling: type: "strong" sizes: @@ -76,9 +91,14 @@ cases: case_name: "TGReact" case_dir: "Exec/RegTests/TGReact" inputs: "tgreact.inp" - description: "Reacting Taylor-Green vortex for weak scaling." + description: "Reacting Taylor-Green vortex (combustion DNS workshop setup)." dimension: 3 physics: [turbulence, combustion, compressible] + difficulty_tier: "hard" + prompt_length_band: "long" + concept_density: "high" + specialized_knowledge: true + novelty_tier: "config-extension" scaling: type: "weak" sizes: @@ -102,6 +122,11 @@ cases: description: "Turbulent jet flame production case for weak scaling." dimension: 3 physics: [jet, combustion, compressible] + difficulty_tier: "hard" + prompt_length_band: "long" + concept_density: "high" + specialized_knowledge: true + novelty_tier: "config-extension" scaling: type: "weak" sizes: diff --git a/benchmark/cases/pelelmex.yaml b/benchmark/cases/pelelmex.yaml index d079624..9380822 100644 --- a/benchmark/cases/pelelmex.yaml +++ b/benchmark/cases/pelelmex.yaml @@ -7,9 +7,14 @@ cases: case_name: "FlameSheet" case_dir: "Exec/RegTests/FlameSheet" inputs: "inputs" - description: "Low-Mach diffusion flame sheet for strong scaling." + description: "Harmonically perturbed flame sheet with Cantera initial solution." dimension: 3 physics: [diffusion_flame, combustion, low_mach] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "parameter-only" scaling: type: "strong" sizes: @@ -30,9 +35,14 @@ cases: case_name: "TaylorGreen" case_dir: "Exec/RegTests/TaylorGreen" inputs: "inputs" - description: "Low-Mach Taylor-Green vortex for turbulence scaling." + description: "Taylor-Green vortex pulled from PeleC High-Order CFD workshop case." dimension: 3 physics: [turbulence, vortex, low_mach] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: false + novelty_tier: "parameter-only" scaling: type: "strong" sizes: @@ -53,9 +63,14 @@ cases: case_name: "CounterFlow" case_dir: "Exec/Production/CounterFlow" inputs: "inputs" - description: "Counterflow diffusion flame for weak scaling." + description: "Counterflow diffusion flame with coolflow pre-run and ignition restart." dimension: 3 physics: [counterflow, combustion, low_mach] + difficulty_tier: "hard" + prompt_length_band: "long" + concept_density: "high" + specialized_knowledge: true + novelty_tier: "additional-runs" scaling: type: "weak" sizes: @@ -79,6 +94,11 @@ cases: description: "Jet in crossflow combustion for weak scaling." dimension: 3 physics: [jet, combustion, low_mach] + difficulty_tier: "hard" + prompt_length_band: "long" + concept_density: "high" + specialized_knowledge: true + novelty_tier: "additional-runs" scaling: type: "weak" sizes: @@ -96,12 +116,17 @@ cases: - id: "pelelmex_eb_c7_strong" solver: "PeleLMeX" - case_name: "EB-C7" - case_dir: "Exec/RegTests/EB-C7" - inputs: "inputs" - description: "Embedded boundary low-Mach case for geometry-aware scaling." + case_name: "EB_BackwardStepFlame" + case_dir: "Exec/RegTests/EB_BackwardStepFlame" + inputs: "eb_bfs.inp" + description: "Laminar flame stabilized behind a backward-facing step (EB reactive flow)." dimension: 3 physics: [embedded_boundary, combustion, low_mach] + difficulty_tier: "hard" + prompt_length_band: "medium" + concept_density: "high" + specialized_knowledge: true + novelty_tier: "build-dependent" scaling: type: "strong" sizes: diff --git a/benchmark/cases/remora.yaml b/benchmark/cases/remora.yaml index a964284..b1a523e 100644 --- a/benchmark/cases/remora.yaml +++ b/benchmark/cases/remora.yaml @@ -7,9 +7,14 @@ cases: case_name: "Seamount" case_dir: "Exec/Seamount" inputs: "inputs" - description: "Seamount circulation case for strong scaling." + description: "Stably stratified fluid at rest over a seamount (ROMS standard test)." dimension: 3 - physics: [ocean, topography, circulation] + physics: [ocean, stratification, bathymetry] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "parameter-only" scaling: type: "strong" sizes: @@ -30,9 +35,14 @@ cases: case_name: "Upwelling" case_dir: "Exec/Upwelling" inputs: "inputs" - description: "Coastal upwelling case for strong scaling." + description: "Wind-driven upwelling over a periodic channel (ROMS upwelling case)." dimension: 3 - physics: [ocean, upwelling, coastal] + physics: [ocean, upwelling, wind_forcing] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "parameter-only" scaling: type: "strong" sizes: @@ -53,9 +63,14 @@ cases: case_name: "DoubleGyre" case_dir: "Exec/DoubleGyre" inputs: "inputs" - description: "Wind-driven double gyre for weak scaling." + description: "Classic wind-driven double gyre test (ROMS double gyre case)." dimension: 3 physics: [ocean, gyre, wind_forcing] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "config-extension" scaling: type: "weak" sizes: @@ -76,9 +91,14 @@ cases: case_name: "Channel_Test" case_dir: "Exec/Channel_Test" inputs: "inputs" - description: "Idealized channel flow test for strong scaling." + description: "Reentrant channel test for turbulence and GLS mixing scheme." dimension: 3 - physics: [ocean, channel, circulation] + physics: [ocean, channel, turbulence] + difficulty_tier: "hard" + prompt_length_band: "short" + concept_density: "high" + specialized_knowledge: true + novelty_tier: "parameter-only" scaling: type: "strong" sizes: @@ -99,9 +119,14 @@ cases: case_name: "DoublyPeriodic" case_dir: "Exec/DoublyPeriodic" inputs: "inputs" - description: "Doubly periodic ocean box for weak scaling." + description: "Doubly periodic domain with depth-dependent velocity/temperature profile." dimension: 3 - physics: [ocean, periodic, idealized] + physics: [ocean, periodic, stratification] + difficulty_tier: "medium" + prompt_length_band: "medium" + concept_density: "medium" + specialized_knowledge: true + novelty_tier: "parameter-only" scaling: type: "weak" sizes: diff --git a/benchmark/specs/case_schema.yaml b/benchmark/specs/case_schema.yaml index 74b9fb5..8a11a91 100644 --- a/benchmark/specs/case_schema.yaml +++ b/benchmark/specs/case_schema.yaml @@ -29,6 +29,11 @@ properties: - description - dimension - physics + - difficulty_tier + - prompt_length_band + - concept_density + - specialized_knowledge + - novelty_tier - scaling - tags properties: @@ -58,6 +63,25 @@ properties: minItems: 1 items: type: string + difficulty_tier: + type: string + enum: [easy, medium, hard] + prompt_length_band: + type: string + enum: [short, medium, long] + concept_density: + type: string + enum: [low, medium, high] + specialized_knowledge: + type: boolean + novelty_tier: + type: string + enum: + - parameter-only + - config-extension + - code-touch + - build-dependent + - additional-runs scaling: type: object additionalProperties: false From 6b797bfbe50fcbce0158b5d0a2e639767d55b9fd Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 07:04:32 -0800 Subject: [PATCH 14/36] Add benchmark metrics context and CSV aggregation --- scripts/aggregate_metrics.py | 133 +++++++++++++++++++++++++++++++++++ src/benchmark_runner.py | 15 ++++ src/main.py | 34 ++++++--- src/utils/metrics.py | 13 ++++ 4 files changed, 186 insertions(+), 9 deletions(-) diff --git a/scripts/aggregate_metrics.py b/scripts/aggregate_metrics.py index 5a9412d..5142b27 100644 --- a/scripts/aggregate_metrics.py +++ b/scripts/aggregate_metrics.py @@ -4,6 +4,7 @@ from __future__ import annotations import argparse +import csv import json from pathlib import Path from typing import Any, Iterable @@ -31,12 +32,18 @@ def _record_from_event(event: dict[str, Any], source: Path) -> dict[str, Any]: providers = data.get("providers") or [] model_id = context.get("model_id") or (models[0] if models else "unknown") provider = context.get("provider") or (providers[0] if providers else None) + strategy = _extract_strategy(data) return { "model_id": model_id, "provider": provider, "prompt_id": context.get("prompt_id"), "prompt_excerpt": context.get("prompt_excerpt"), + "case_id": context.get("case_id"), + "solver": context.get("solver"), + "difficulty_tier": context.get("difficulty_tier"), + "novelty_tier": context.get("novelty_tier"), + "retrieval_strategy": strategy, "job_status": data.get("job_status"), "iteration": data.get("iteration"), "run_directory": data.get("run_directory"), @@ -59,6 +66,65 @@ def _write_jsonl(path: Path, records: list[dict[str, Any]]) -> None: handle.write("\n") +def _extract_strategy(data: dict[str, Any]) -> str | None: + stages = data.get("stages") or {} + for summary in stages.values(): + retrieval = summary.get("retrieval") if isinstance(summary, dict) else None + if not retrieval: + continue + last = retrieval.get("last", {}) if isinstance(retrieval, dict) else {} + strategy = last.get("strategy") + if strategy: + return strategy + return None + + +def _write_csv(path: Path, rows: list[dict[str, Any]], headers: list[str]) -> None: + if not rows: + return + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=headers) + writer.writeheader() + for row in rows: + writer.writerow({key: row.get(key) for key in headers}) + + +def _group_summary(records: list[dict[str, Any]], key: str, label: str) -> list[dict[str, Any]]: + grouped: dict[str, list[dict[str, Any]]] = {} + for record in records: + value = record.get(key) or "unknown" + grouped.setdefault(str(value), []).append(record) + + rows = [] + for value, items in sorted(grouped.items(), key=lambda item: item[0]): + rows.append(_summarize_items(items, label, value)) + return rows + + +def _summarize_items(items: list[dict[str, Any]], label: str, value: str) -> dict[str, Any]: + total = len(items) + success = sum(1 for item in items if item.get("job_status") == "completed") + tokens_total = _avg([item.get("tokens_total") for item in items]) + tokens_input = _avg([item.get("tokens_total_input") for item in items]) + tokens_output = _avg([item.get("tokens_total_output") for item in items]) + return { + label: value, + "total_runs": total, + "success_runs": success, + "success_rate": round(success / total, 4) if total else 0.0, + "avg_tokens_total": tokens_total, + "avg_tokens_input": tokens_input, + "avg_tokens_output": tokens_output, + } + + +def _avg(values: list[Any]) -> float: + filtered = [v for v in values if isinstance(v, (int, float))] + if not filtered: + return 0.0 + return round(sum(filtered) / len(filtered), 2) + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Aggregate metrics JSONL into raw benchmark records.") parser.add_argument("--input", required=True, help="Metrics JSONL file or directory.") @@ -85,6 +151,73 @@ def main() -> None: records.append(_record_from_event(event, path)) _write_jsonl(output_path, records) + output_dir = output_path.parent + summary_headers = [ + "group", + "total_runs", + "success_runs", + "success_rate", + "avg_tokens_total", + "avg_tokens_input", + "avg_tokens_output", + ] + summary_row = [_summarize_items(records, "group", "all")] + _write_csv(output_dir / "summary.csv", summary_row, summary_headers) + + by_model = _group_summary(records, "model_id", "model_id") + _write_csv(output_dir / "by_model.csv", by_model, [ + "model_id", + "total_runs", + "success_runs", + "success_rate", + "avg_tokens_total", + "avg_tokens_input", + "avg_tokens_output", + ]) + + by_solver = _group_summary(records, "solver", "solver") + _write_csv(output_dir / "by_solver.csv", by_solver, [ + "solver", + "total_runs", + "success_runs", + "success_rate", + "avg_tokens_total", + "avg_tokens_input", + "avg_tokens_output", + ]) + + by_strategy = _group_summary(records, "retrieval_strategy", "retrieval_strategy") + _write_csv(output_dir / "by_strategy.csv", by_strategy, [ + "retrieval_strategy", + "total_runs", + "success_runs", + "success_rate", + "avg_tokens_total", + "avg_tokens_input", + "avg_tokens_output", + ]) + + by_difficulty = _group_summary(records, "difficulty_tier", "difficulty_tier") + _write_csv(output_dir / "by_difficulty.csv", by_difficulty, [ + "difficulty_tier", + "total_runs", + "success_runs", + "success_rate", + "avg_tokens_total", + "avg_tokens_input", + "avg_tokens_output", + ]) + + by_novelty = _group_summary(records, "novelty_tier", "novelty_tier") + _write_csv(output_dir / "by_novelty.csv", by_novelty, [ + "novelty_tier", + "total_runs", + "success_runs", + "success_rate", + "avg_tokens_total", + "avg_tokens_input", + "avg_tokens_output", + ]) print(json.dumps({"output": str(output_path), "records": len(records)}, indent=2)) diff --git a/src/benchmark_runner.py b/src/benchmark_runner.py index 1805a41..d32f674 100644 --- a/src/benchmark_runner.py +++ b/src/benchmark_runner.py @@ -222,6 +222,10 @@ def _ensure_prompt_entry(entry: Any, index: int) -> dict[str, Any]: "id": prompt_id, "prompt": entry.get("prompt"), "prompt_path": entry.get("prompt_path"), + "case_id": entry.get("case_id"), + "solver": entry.get("solver"), + "difficulty_tier": entry.get("difficulty_tier"), + "novelty_tier": entry.get("novelty_tier"), } raise ValueError(f"Invalid prompt entry at index {index}: {entry!r}") @@ -373,6 +377,7 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non model_slug = _slugify(model_id) model_env = dict(env_common) model_env.update(model.get("env") or {}) + provider = (model.get("overrides") or {}).get("llm_provider") model_config_path = _build_config_for_model(model, run_dir) manifest["models"].append({ @@ -390,6 +395,16 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non prompt_dir.mkdir(parents=True, exist_ok=True) env = os.environ.copy() env.update({k: str(v) for k, v in model_env.items()}) + benchmark_env = { + "BENCHMARK_PROMPT_ID": str(prompt_id), + "BENCHMARK_CASE_ID": prompt.get("case_id"), + "BENCHMARK_SOLVER": prompt.get("solver"), + "BENCHMARK_DIFFICULTY_TIER": prompt.get("difficulty_tier"), + "BENCHMARK_NOVELTY_TIER": prompt.get("novelty_tier"), + "BENCHMARK_MODEL_ID": model_id, + "BENCHMARK_PROVIDER": provider, + } + env.update({k: str(v) for k, v in benchmark_env.items() if v}) cmd = _build_command(model_config_path, prompt, prompt_dir, run_args) diff --git a/src/main.py b/src/main.py index 792b2d8..cde5db8 100644 --- a/src/main.py +++ b/src/main.py @@ -690,7 +690,7 @@ def main(args: list[str] | None = None) -> None: # Save metrics JSONL (if enabled) try: - from src.utils.metrics import metrics_collector + from src.utils.metrics import metrics_collector, metrics_extra if getattr(config, "metrics_enabled", True) and metrics_collector.events(): summary = metrics_collector.build_workflow_summary() @@ -699,13 +699,14 @@ def main(args: list[str] | None = None) -> None: "iteration": result.get("iteration", 0), "run_directory": result.get("run_directory"), }) - metrics_collector.record_event( - "workflow_summary", - summary, - stage="workflow", - node="main", - iteration=result.get("iteration", 0), - ) + with metrics_extra(result.get("metrics_context") or None): + metrics_collector.record_event( + "workflow_summary", + summary, + stage="workflow", + node="main", + iteration=result.get("iteration", 0), + ) if 'run_directory' in result: run_dir = Path(result['run_directory']) metrics_path = run_dir / getattr(config, "metrics_filename", "metrics.jsonl") @@ -877,6 +878,17 @@ def initialize_state(user_requirement: str, config: AMReXAgentConfig) -> dict[st if not prompt_content: raise ValueError("User requirement prompt cannot be empty") + metrics_context = { + "case_id": os.getenv("BENCHMARK_CASE_ID"), + "solver": os.getenv("BENCHMARK_SOLVER"), + "difficulty_tier": os.getenv("BENCHMARK_DIFFICULTY_TIER"), + "novelty_tier": os.getenv("BENCHMARK_NOVELTY_TIER"), + "prompt_id": os.getenv("BENCHMARK_PROMPT_ID"), + "model_id": os.getenv("BENCHMARK_MODEL_ID"), + "provider": os.getenv("BENCHMARK_PROVIDER"), + } + metrics_context = {k: v for k, v in metrics_context.items() if v} + # 3. Initialize state with defaults return { # Inputs @@ -897,6 +909,7 @@ def initialize_state(user_requirement: str, config: AMReXAgentConfig) -> dict[st "modifications": [], "workflow_history": [], "history": [], # Legacy field required by visualization_node + "metrics_context": metrics_context, } @@ -1037,7 +1050,10 @@ def run_agent(user_requirement: str, config: AMReXAgentConfig) -> dict[str, Any] try: from langgraph.errors import GraphRecursionError - final_state = app.invoke(initial_state, run_config) + from src.utils.metrics import metrics_extra + + with metrics_extra(initial_state.get("metrics_context") or None): + final_state = app.invoke(initial_state, run_config) status = final_state.get("job_status", "unknown") logger.info("-" * 80) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index ab82999..0939682 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -41,6 +41,19 @@ def metrics_context( token.var.reset(token) +@contextmanager +def metrics_extra(extra: dict[str, Any] | None): + """Set metrics context extra fields without changing stage.""" + if extra is None: + yield + return + token = _extra_var.set(extra) + try: + yield + finally: + _extra_var.reset(token) + + class MetricsCollector: """Collects instrumentation events for aggregation and export.""" From 417e0534d66255528da9a6859ebf8a5ea78dce4a Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 06:51:09 -0800 Subject: [PATCH 15/36] Add difficulty and novelty aggregates to metrics output --- scripts/aggregate_metrics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/aggregate_metrics.py b/scripts/aggregate_metrics.py index 5142b27..b6f7121 100644 --- a/scripts/aggregate_metrics.py +++ b/scripts/aggregate_metrics.py @@ -32,17 +32,17 @@ def _record_from_event(event: dict[str, Any], source: Path) -> dict[str, Any]: providers = data.get("providers") or [] model_id = context.get("model_id") or (models[0] if models else "unknown") provider = context.get("provider") or (providers[0] if providers else None) - strategy = _extract_strategy(data) + strategy = context.get("strategy") or data.get("strategy") or _extract_strategy(data) return { "model_id": model_id, "provider": provider, "prompt_id": context.get("prompt_id"), "prompt_excerpt": context.get("prompt_excerpt"), - "case_id": context.get("case_id"), - "solver": context.get("solver"), - "difficulty_tier": context.get("difficulty_tier"), - "novelty_tier": context.get("novelty_tier"), + "case_id": context.get("case_id") or data.get("case_id"), + "solver": context.get("solver") or data.get("solver"), + "difficulty_tier": context.get("difficulty_tier") or data.get("difficulty_tier"), + "novelty_tier": context.get("novelty_tier") or data.get("novelty_tier"), "retrieval_strategy": strategy, "job_status": data.get("job_status"), "iteration": data.get("iteration"), From 2374c399e8a4025553ef35be8118e5ef97fa381c Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 07:01:57 -0800 Subject: [PATCH 16/36] Remove env-var metrics context and refine concept density --- benchmark/specs/case_schema.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/benchmark/specs/case_schema.yaml b/benchmark/specs/case_schema.yaml index 8a11a91..4cc54ed 100644 --- a/benchmark/specs/case_schema.yaml +++ b/benchmark/specs/case_schema.yaml @@ -66,14 +66,18 @@ properties: difficulty_tier: type: string enum: [easy, medium, hard] + description: "Derived from prompt length, concept density, and specialized knowledge." prompt_length_band: type: string enum: [short, medium, long] + description: "Short: single sentence. Medium: 2-3 sentences. Long: multi-clause or multi-paragraph." concept_density: type: string enum: [low, medium, high] + description: "Low: direct parameter tweaks. Medium: one term implies 2-3 dependent changes. High: multiple terms imply 4+ dependent changes." specialized_knowledge: type: boolean + description: "True if requires domain-specific mechanisms, EB, multiphase, or specialized turbulence." novelty_tier: type: string enum: @@ -82,6 +86,7 @@ properties: - code-touch - build-dependent - additional-runs + description: "Change distance from baseline: inputs-only → config → code → build → extra runs." scaling: type: object additionalProperties: false From 142e2414155c2e7f63ca91f6350497de18fb1d5d Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 07:19:31 -0800 Subject: [PATCH 17/36] Use benchmark context sidecar for metrics --- src/benchmark_runner.py | 26 +++++++++++++---------- src/main.py | 46 +++++++++++++++++++++++++++-------------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/benchmark_runner.py b/src/benchmark_runner.py index d32f674..2563892 100644 --- a/src/benchmark_runner.py +++ b/src/benchmark_runner.py @@ -264,6 +264,7 @@ def _build_command( prompt_entry: dict[str, Any], output_dir: Path, run_args: dict[str, Any], + benchmark_context: Path | None, ) -> list[str]: cmd = [os.environ.get("AMREX_AGENT_PYTHON", "python"), "amrex_agent.py", "--json"] if prompt_entry.get("prompt"): @@ -276,6 +277,8 @@ def _build_command( if config_path: cmd.extend(["--config", str(config_path)]) cmd.extend(["--output-dir", str(output_dir)]) + if benchmark_context: + cmd.extend(["--benchmark-context", str(benchmark_context)]) flag_map = { "run_mode": "--run-mode", @@ -395,18 +398,19 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non prompt_dir.mkdir(parents=True, exist_ok=True) env = os.environ.copy() env.update({k: str(v) for k, v in model_env.items()}) - benchmark_env = { - "BENCHMARK_PROMPT_ID": str(prompt_id), - "BENCHMARK_CASE_ID": prompt.get("case_id"), - "BENCHMARK_SOLVER": prompt.get("solver"), - "BENCHMARK_DIFFICULTY_TIER": prompt.get("difficulty_tier"), - "BENCHMARK_NOVELTY_TIER": prompt.get("novelty_tier"), - "BENCHMARK_MODEL_ID": model_id, - "BENCHMARK_PROVIDER": provider, - } - env.update({k: str(v) for k, v in benchmark_env.items() if v}) + benchmark_context = prompt_dir / "benchmark_context.json" + benchmark_context.write_text(json.dumps({ + "prompt_id": prompt_id, + "prompt_excerpt": prompt["prompt"][:160], + "case_id": prompt.get("case_id"), + "solver": prompt.get("solver"), + "difficulty_tier": prompt.get("difficulty_tier"), + "novelty_tier": prompt.get("novelty_tier"), + "model_id": model_id, + "provider": provider, + }, indent=2, default=str)) - cmd = _build_command(model_config_path, prompt, prompt_dir, run_args) + cmd = _build_command(model_config_path, prompt, prompt_dir, run_args, benchmark_context) started_at = datetime.now().isoformat() start_time = time.time() diff --git a/src/main.py b/src/main.py index cde5db8..7289873 100644 --- a/src/main.py +++ b/src/main.py @@ -526,6 +526,13 @@ def parse_arguments(args: list[str] | None = None) -> argparse.Namespace: dest='mpi_ranks', help='Local run tasks (mpirun -np)' ) + parser.add_argument( + '--benchmark-context', + type=str, + default=None, + dest='benchmark_context', + help='Optional JSON/YAML file with benchmark metadata to attach to metrics' + ) return parser.parse_args(args) @@ -597,6 +604,23 @@ def _warn_if_schema_missing(config: AMReXAgentConfig, baseline_override: str | N logging.getLogger(__name__).warning("Schema missing for %s. Run: %s", solver_name, schema_cmd) +def _load_benchmark_context(path: str | None) -> dict[str, Any] | None: + if not path: + return None + context_path = Path(path) + if not context_path.exists(): + raise FileNotFoundError(f"Benchmark context file not found: {context_path}") + suffix = context_path.suffix.lower() + if suffix in {".yaml", ".yml"}: + import yaml + data = yaml.safe_load(context_path.read_text(encoding="utf-8")) + else: + data = json.loads(context_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise ValueError("Benchmark context file must contain a JSON/YAML object.") + return data + + def main(args: list[str] | None = None) -> None: """ Run the AMReXAgent CLI workflow. @@ -686,7 +710,11 @@ def main(args: list[str] | None = None) -> None: # Run Agent logger.debug("Starting AMReXAgent workflow...") - result = run_agent(user_requirement, config) + benchmark_context = _load_benchmark_context(parsed_args.benchmark_context) + from src.utils.metrics import metrics_extra + + with metrics_extra(benchmark_context): + result = run_agent(user_requirement, config) # Save metrics JSONL (if enabled) try: @@ -699,7 +727,7 @@ def main(args: list[str] | None = None) -> None: "iteration": result.get("iteration", 0), "run_directory": result.get("run_directory"), }) - with metrics_extra(result.get("metrics_context") or None): + with metrics_extra(benchmark_context): metrics_collector.record_event( "workflow_summary", summary, @@ -865,8 +893,6 @@ def initialize_state(user_requirement: str, config: AMReXAgentConfig) -> dict[st dict Initialized graph state for the workflow. """ - import os - # 1. Load prompt (file or string) if os.path.exists(user_requirement) and os.path.isfile(user_requirement): with open(user_requirement) as f: @@ -878,17 +904,6 @@ def initialize_state(user_requirement: str, config: AMReXAgentConfig) -> dict[st if not prompt_content: raise ValueError("User requirement prompt cannot be empty") - metrics_context = { - "case_id": os.getenv("BENCHMARK_CASE_ID"), - "solver": os.getenv("BENCHMARK_SOLVER"), - "difficulty_tier": os.getenv("BENCHMARK_DIFFICULTY_TIER"), - "novelty_tier": os.getenv("BENCHMARK_NOVELTY_TIER"), - "prompt_id": os.getenv("BENCHMARK_PROMPT_ID"), - "model_id": os.getenv("BENCHMARK_MODEL_ID"), - "provider": os.getenv("BENCHMARK_PROVIDER"), - } - metrics_context = {k: v for k, v in metrics_context.items() if v} - # 3. Initialize state with defaults return { # Inputs @@ -909,7 +924,6 @@ def initialize_state(user_requirement: str, config: AMReXAgentConfig) -> dict[st "modifications": [], "workflow_history": [], "history": [], # Legacy field required by visualization_node - "metrics_context": metrics_context, } From 28cd4fe088308d4c5fb4dcd3a1cc36e893cdd094 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 09:28:38 -0800 Subject: [PATCH 18/36] Skip REST mkdir when SFAPI creds are available --- src/services/run_superfacility_tools.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/services/run_superfacility_tools.py b/src/services/run_superfacility_tools.py index b0e7d32..aff9c89 100644 --- a/src/services/run_superfacility_tools.py +++ b/src/services/run_superfacility_tools.py @@ -1008,6 +1008,9 @@ def resolve_remote_output_dir( logger = logging.getLogger(__name__) candidates: list[Path] = [] suffix: str | None = None + key_path = _resolve_sfapi_key_path() + client_id, secret = _resolve_sfapi_credentials() + sfapi_available = bool(key_path or (client_id and secret)) if preferred_output_dir: preferred_path = Path(os.path.expandvars(str(preferred_output_dir))) if not str(preferred_path).startswith("/global/cfs/cdirs/"): @@ -1035,6 +1038,10 @@ def resolve_remote_output_dir( except Exception as exc: logger.debug("Remote output dir check failed for %s: %s", candidate, exc) continue + if sfapi_available and not nersc_session: + logger.debug("Using SFAPI credentials for %s; skipping REST mkdir check", candidate) + logger.info("Using remote output dir: %s", candidate) + return candidate try: ensure_remote_directory_rest( remote_run_dir=str(candidate), From 188500cfdb1656121c3709615be1834e256b0cb6 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 13:37:54 -0800 Subject: [PATCH 19/36] Fix LLM client unwrapping and SFAPI test --- src/config.py | 14 +++++++++----- tests/unit/test_superfacility_tools.py | 8 ++++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/config.py b/src/config.py index 1a4d9fa..fd7da12 100644 --- a/src/config.py +++ b/src/config.py @@ -983,6 +983,10 @@ def _wrap_llm_client_with_retry(client, config: AMReXAgentConfig): def _wrap_llm_client_with_metrics(client, config: AMReXAgentConfig): if getattr(config, "metrics_enabled", True) is False: return client + if not getattr(client, "chat", None): + return client + if not getattr(getattr(client, "chat", None), "completions", None): + return client if isinstance(client, _LLMMetricsClient): return client return _LLMMetricsClient(client, config) @@ -1000,11 +1004,11 @@ def wrap_llm_client_with_retry(client, config: AMReXAgentConfig): def unwrap_llm_client(client): """Return the underlying client if wrapped by the LLM gate.""" - if isinstance(client, _LLMGateClient): - return client._client - if isinstance(client, _LLMRetryClient): - return client._client - return client + wrapped_types = (_LLMGateClient, _LLMRetryClient, _LLMMetricsClient) + current = client + while isinstance(current, wrapped_types): + current = current._client + return current class _LLMRetryClient: diff --git a/tests/unit/test_superfacility_tools.py b/tests/unit/test_superfacility_tools.py index 496207f..b74d51c 100644 --- a/tests/unit/test_superfacility_tools.py +++ b/tests/unit/test_superfacility_tools.py @@ -120,6 +120,14 @@ def fake_ensure_remote_directory_rest(remote_run_dir, **_kwargs): "src.services.run_superfacility_tools.ensure_remote_directory_rest", fake_ensure_remote_directory_rest, ) + monkeypatch.setattr( + "src.services.run_superfacility_tools._resolve_sfapi_key_path", + lambda: None, + ) + monkeypatch.setattr( + "src.services.run_superfacility_tools._resolve_sfapi_credentials", + lambda: (None, None), + ) result = resolve_remote_output_dir( preferred_output_dir="/global/cfs/cdirs/acct/superfacility/output", From 2d9fe1c48f202b5759bd725dd642511ccf1d822f Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 13:51:21 -0800 Subject: [PATCH 20/36] Document instructor usage map --- docs/instructor_use.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 docs/instructor_use.md diff --git a/docs/instructor_use.md b/docs/instructor_use.md new file mode 100644 index 0000000..3c0e22c --- /dev/null +++ b/docs/instructor_use.md @@ -0,0 +1,15 @@ +# Instructor Usage Map + +This table tracks where instructor-based structured LLM calls are used and +what prompt templates and fallbacks are in play. + +| Location | Purpose | Prompt Template Source | Response Mode | +| --- | --- | --- | --- | +| `src/services/architect.py` | Solver selection, schema scan, plan fallback | `ArchitectService._resolve_llm_prompt_template(...)` and solver config templates | Instructor + raw fallback | +| `src/services/cases.py` | Case selection (`find_best_match`) | Solver config templates | Instructor + raw fallback | +| `src/services/inputs_file_selector.py` | `llm_compare` inputs selection | `config_cls.get_prompt_templates().get("misc")["inputs_select"]` or inline | Instructor + raw fallback | +| `src/services/knowledge.py` | `generate_questions_from_prompt` | Solver config `knowledge.question_generator` | Instructor + raw fallback | +| `src/services/config_model_factory.py` | Remap failed modifications | Solver config `remap.template` or BaseAMReXConfig | Instructor only (exceptions caught) | +| `src/nodes/reviewer_node.py` | Retry guidance refinement | Solver config `misc.retry_guidance` | Instructor + raw JSON fallback | +| `src/nodes/analysis_node.py` | Retry guidance refinement | BaseAMReXConfig `misc.retry_guidance` | Instructor + raw JSON fallback | +| `src/services/config_service.py` | LLM connectivity check | `misc_prompts` or inline default | Raw only | From 8fcde6d81432c962073d6649fb3af4b94b8cac9f Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 14:19:55 -0800 Subject: [PATCH 21/36] Add shared LLM call helper skeleton --- src/utils/llm_calls.py | 95 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 src/utils/llm_calls.py diff --git a/src/utils/llm_calls.py b/src/utils/llm_calls.py new file mode 100644 index 0000000..49154c8 --- /dev/null +++ b/src/utils/llm_calls.py @@ -0,0 +1,95 @@ +"""Shared LLM call helper with policy hooks and instructor fallback.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from src.utils.metrics import metrics_extra + + +@dataclass(frozen=True) +class LLMCallSpec: + model: str + messages: list[dict[str, Any]] + temperature: float | None = None + max_tokens: int | None = None + response_model: Any | None = None + max_retries: int | None = None + purpose: str | None = None + template_name: str | None = None + template_source: str | None = None + + +class LLMPolicy: + """Policy interface for governed LLM calls.""" + + def check_budget(self, _spec: LLMCallSpec) -> tuple[bool, str | None]: + return True, None + + def check_rate_limit(self, _spec: LLMCallSpec) -> tuple[bool, str | None]: + return True, None + + def check_approval(self, _spec: LLMCallSpec) -> tuple[bool, str | None]: + return True, None + + def audit(self, _event: str, _spec: LLMCallSpec, _details: dict[str, Any]) -> None: + return None + + +class PolicyViolation(RuntimeError): + pass + + +def call_llm( + client: Any, + spec: LLMCallSpec, + *, + config: Any | None = None, + policy: LLMPolicy | None = None, + extra_context: dict[str, Any] | None = None, + fallback_parser: Callable[[Any], Any] | None = None, +) -> Any: + policy = policy or LLMPolicy() + ok, reason = policy.check_rate_limit(spec) + if not ok: + policy.audit("llm_denied", spec, {"reason": reason, "type": "rate_limit"}) + raise PolicyViolation(reason or "rate_limit") + ok, reason = policy.check_budget(spec) + if not ok: + policy.audit("llm_denied", spec, {"reason": reason, "type": "budget"}) + raise PolicyViolation(reason or "budget") + ok, reason = policy.check_approval(spec) + if not ok: + policy.audit("llm_denied", spec, {"reason": reason, "type": "approval"}) + raise PolicyViolation(reason or "approval") + + with metrics_extra(extra_context): + policy.audit("llm_allowed", spec, {}) + if spec.response_model: + try: + import instructor + from src.config import unwrap_llm_client, wrap_llm_client + + base_client = unwrap_llm_client(client) + instr_client = instructor.from_openai(base_client) + instr_client = wrap_llm_client(instr_client, config or {}) + return instr_client.chat.completions.create( + model=spec.model, + response_model=spec.response_model, + messages=spec.messages, + temperature=spec.temperature, + max_retries=spec.max_retries, + ) + except (ImportError, ModuleNotFoundError): + pass + + response = client.chat.completions.create( + model=spec.model, + messages=spec.messages, + temperature=spec.temperature, + max_tokens=spec.max_tokens, + ) + if fallback_parser: + return fallback_parser(response) + return response From cf0b7803e2e2895ce03590ccf0824fe8ac6b2dea Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Mon, 16 Feb 2026 14:33:10 -0800 Subject: [PATCH 22/36] Refactor LLM call sites to use helper --- src/nodes/analysis_node.py | 64 +++++++------- src/nodes/reviewer_node.py | 64 +++++++------- src/services/architect.py | 125 ++++++++++++--------------- src/services/cases.py | 31 ++++--- src/services/config_model_factory.py | 17 ++-- src/services/config_service.py | 10 ++- src/services/inputs_file_selector.py | 46 +++++----- src/services/knowledge.py | 68 +++++++-------- src/utils/llm_calls.py | 2 + 9 files changed, 204 insertions(+), 223 deletions(-) diff --git a/src/nodes/analysis_node.py b/src/nodes/analysis_node.py index c422d07..19cba38 100644 --- a/src/nodes/analysis_node.py +++ b/src/nodes/analysis_node.py @@ -304,50 +304,46 @@ def analysis_node(state: GraphState) -> dict[str, Any]: analysis_issues="\n".join(issues) if issues else "none", ) try: - import instructor + import json from pydantic import BaseModel, Field - from src.config import unwrap_llm_client, wrap_llm_client + from src.utils.llm_calls import LLMCallSpec, call_llm class RetryGuidance(BaseModel): inputs_base_action: str = Field(description="keep or switch") baseline_base_action: str = Field(description="keep or switch") rationale: str | None = None + spec = LLMCallSpec( + model=config.llm_model, + response_model=RetryGuidance, + messages=[{"role": "user", "content": filled}], + temperature=0.0, + max_retries=2, + purpose="retry_guidance", + template_name="retry_guidance", + template_source="base_config.misc", + ) with metrics_context("analysis", node="analysis", iteration=iteration): - base_client = unwrap_llm_client(llm_client) - instr_client = instructor.from_openai(base_client) - instr_client = wrap_llm_client(instr_client, config) - parsed = instr_client.chat.completions.create( - model=config.llm_model, - response_model=RetryGuidance, - messages=[{"role": "user", "content": filled}], - temperature=0.0, - max_retries=2, - ) - retry_guidance.update({ - "inputs_base_action": parsed.inputs_base_action or "keep", - "baseline_base_action": parsed.baseline_base_action or "keep", - "inputs_reason": parsed.rationale, - "baseline_reason": parsed.rationale, - }) - except (ImportError, ModuleNotFoundError): - with metrics_context("analysis", node="analysis", iteration=iteration): - response = llm_client.chat.completions.create( - model=config.llm_model, - messages=[{"role": "user", "content": filled}], - temperature=0.0, - max_tokens=200, - ) - content = response.choices[0].message.content.strip() - import json - parsed = json.loads(content) - if isinstance(parsed, dict): + result = call_llm(llm_client, spec, config=config) + if hasattr(result, "inputs_base_action"): retry_guidance.update({ - "inputs_base_action": parsed.get("inputs_base_action", "keep"), - "baseline_base_action": parsed.get("baseline_base_action", "keep"), - "inputs_reason": parsed.get("rationale"), - "baseline_reason": parsed.get("rationale"), + "inputs_base_action": result.inputs_base_action or "keep", + "baseline_base_action": result.baseline_base_action or "keep", + "inputs_reason": result.rationale, + "baseline_reason": result.rationale, }) + else: + content = result.choices[0].message.content.strip() + parsed = json.loads(content) + if isinstance(parsed, dict): + retry_guidance.update({ + "inputs_base_action": parsed.get("inputs_base_action", "keep"), + "baseline_base_action": parsed.get("baseline_base_action", "keep"), + "inputs_reason": parsed.get("rationale"), + "baseline_reason": parsed.get("rationale"), + }) + except Exception as exc: + logger.debug(f"Retry guidance LLM unavailable: {exc}") except Exception as exc: logger.debug(f"Retry guidance LLM unavailable: {exc}") diff --git a/src/nodes/reviewer_node.py b/src/nodes/reviewer_node.py index c59383b..f4403b3 100644 --- a/src/nodes/reviewer_node.py +++ b/src/nodes/reviewer_node.py @@ -643,50 +643,46 @@ def _derive_retry_guidance(violations, rejected_inputs_file, rejected_baseline_c analysis_issues="none", ) try: - import instructor + import json from pydantic import BaseModel, Field - from src.config import unwrap_llm_client, wrap_llm_client + from src.utils.llm_calls import LLMCallSpec, call_llm class RetryGuidance(BaseModel): inputs_base_action: str = Field(description="keep or switch") baseline_base_action: str = Field(description="keep or switch") rationale: str | None = None + spec = LLMCallSpec( + model=config.llm_model, + response_model=RetryGuidance, + messages=[{"role": "user", "content": filled}], + temperature=0.0, + max_retries=2, + purpose="retry_guidance", + template_name="retry_guidance", + template_source="solver_config.misc", + ) with metrics_context("reviewer", node="reviewer", iteration=iteration): - base_client = unwrap_llm_client(llm_client) - instr_client = instructor.from_openai(base_client) - instr_client = wrap_llm_client(instr_client, config) - parsed = instr_client.chat.completions.create( - model=config.llm_model, - response_model=RetryGuidance, - messages=[{"role": "user", "content": filled}], - temperature=0.0, - max_retries=2, - ) - retry_guidance.update({ - "inputs_base_action": parsed.inputs_base_action or "keep", - "baseline_base_action": parsed.baseline_base_action or "keep", - "inputs_reason": parsed.rationale, - "baseline_reason": parsed.rationale, - }) - except (ImportError, ModuleNotFoundError): - with metrics_context("reviewer", node="reviewer", iteration=iteration): - response = llm_client.chat.completions.create( - model=config.llm_model, - messages=[{"role": "user", "content": filled}], - temperature=0.0, - max_tokens=200, - ) - content = response.choices[0].message.content.strip() - import json - parsed = json.loads(content) - if isinstance(parsed, dict): + result = call_llm(llm_client, spec, config=config) + if hasattr(result, "inputs_base_action"): retry_guidance.update({ - "inputs_base_action": parsed.get("inputs_base_action", "keep"), - "baseline_base_action": parsed.get("baseline_base_action", "keep"), - "inputs_reason": parsed.get("rationale"), - "baseline_reason": parsed.get("rationale"), + "inputs_base_action": result.inputs_base_action or "keep", + "baseline_base_action": result.baseline_base_action or "keep", + "inputs_reason": result.rationale, + "baseline_reason": result.rationale, }) + else: + content = result.choices[0].message.content.strip() + parsed = json.loads(content) + if isinstance(parsed, dict): + retry_guidance.update({ + "inputs_base_action": parsed.get("inputs_base_action", "keep"), + "baseline_base_action": parsed.get("baseline_base_action", "keep"), + "inputs_reason": parsed.get("rationale"), + "baseline_reason": parsed.get("rationale"), + }) + except Exception as exc: + logger.debug(f"Retry guidance LLM unavailable: {exc}") except Exception as exc: logger.debug(f"Retry guidance LLM unavailable: {exc}") diff --git a/src/services/architect.py b/src/services/architect.py index 297dc2d..e13456f 100644 --- a/src/services/architect.py +++ b/src/services/architect.py @@ -2184,9 +2184,10 @@ def _normalize_llm_value(value: str) -> str: ) try: - import instructor - from src.config import get_llm_client, unwrap_llm_client, wrap_llm_client + import json from pydantic import BaseModel, Field + from src.config import get_llm_client + from src.utils.llm_calls import LLMCallSpec, call_llm class Modification(BaseModel): parameter: str @@ -2195,58 +2196,49 @@ class Modification(BaseModel): class ModificationExtraction(BaseModel): working: str = Field(description="Step-by-step reasoning showing parameter mapping and unit conversions") modifications: list[Modification] + if client is None: - base_client = unwrap_llm_client(get_llm_client(self.config)) - client = instructor.from_openai(base_client) - client = wrap_llm_client(client, self.config) - result = client.chat.completions.create( + client = get_llm_client(self.config) + spec = LLMCallSpec( model=self.config.llm_model, response_model=ModificationExtraction, + response_format={"type": "json_object"}, messages=[{"role": "user", "content": prompt}], - temperature=0.1 + temperature=0.1, + purpose="modification_extraction", + template_name="modification_extraction", + template_source="solver_config", ) + result = call_llm(client, spec, config=self.config) + if hasattr(result, "modifications"): + modifications = [(m.parameter, _normalize_llm_value(m.value)) for m in result.modifications] + + # Log reasoning for debugging + logger.debug(f"[LLM] Working: {result.working}") + logger.debug(f"[LLM] Extracted {len(modifications)} modifications") + if not modifications: + redactor = _redactor() + raw = result.model_dump_json() if hasattr(result, "model_dump_json") else str(result) + logger.debug("[LLM] Empty modifications (raw response): %s", redactor(raw)) - modifications = [(m.parameter, _normalize_llm_value(m.value)) for m in result.modifications] + return {"success": True, "modifications": modifications} - # Log reasoning for debugging - logger.debug(f"[LLM] Working: {result.working}") + raw_content = result.choices[0].message.content + parsed = json.loads(raw_content) + modifications = [ + (m['parameter'], _normalize_llm_value(m.get('value', ''))) + for m in parsed.get('modifications', []) + ] + + # Log reasoning from fallback mode too + logger.debug(f"[LLM] Working: {parsed.get('working', 'N/A')}") logger.debug(f"[LLM] Extracted {len(modifications)} modifications") if not modifications: redactor = _redactor() - raw = result.model_dump_json() if hasattr(result, "model_dump_json") else str(result) - logger.debug("[LLM] Empty modifications (raw response): %s", redactor(raw)) + logger.debug("[LLM] Empty modifications (raw response): %s", redactor(raw_content)) return {"success": True, "modifications": modifications} - except (ModuleNotFoundError, ImportError): - try: - import json - from src.config import get_llm_client - client = get_llm_client(self.config) - response = client.chat.completions.create( - model=self.config.llm_model, - messages=[{"role": "user", "content": prompt}], - response_format={"type": "json_object"}, - temperature=0.1 - ) - raw_content = response.choices[0].message.content - result = json.loads(raw_content) - modifications = [ - (m['parameter'], _normalize_llm_value(m.get('value', ''))) - for m in result.get('modifications', []) - ] - - # Log reasoning from fallback mode too - logger.debug(f"[LLM] Working: {result.get('working', 'N/A')}") - logger.debug(f"[LLM] Extracted {len(modifications)} modifications") - if not modifications: - redactor = _redactor() - logger.debug("[LLM] Empty modifications (raw response): %s", redactor(raw_content)) - - return {"success": True, "modifications": modifications} - except Exception as e: - return {"success": False, "error": e} - except Exception as e: return {"success": False, "error": e} @@ -2273,29 +2265,31 @@ def _call_llm_for_schema_scan( ) try: - import instructor - from src.config import get_llm_client, unwrap_llm_client, wrap_llm_client from pydantic import BaseModel, Field + from src.config import get_llm_client + from src.utils.llm_calls import LLMCallSpec, call_llm class SchemaScanResult(BaseModel): unresolved_concepts: list[str] = Field(default_factory=list) notes: str = "" if client is None: - base_client = unwrap_llm_client(get_llm_client(self.config)) - client = instructor.from_openai(base_client) - client = wrap_llm_client(client, self.config) - result = client.chat.completions.create( + client = get_llm_client(self.config) + spec = LLMCallSpec( model=self.config.llm_model, response_model=SchemaScanResult, messages=[{"role": "user", "content": prompt}], - temperature=0.1 + temperature=0.1, + purpose="schema_scan", + template_name="schema_scan", + template_source="solver_config", ) - - return { - "unresolved_concepts": result.unresolved_concepts, - "notes": result.notes, - } + result = call_llm(client, spec, config=self.config) + if hasattr(result, "unresolved_concepts"): + return { + "unresolved_concepts": result.unresolved_concepts, + "notes": result.notes, + } except Exception as e: logger.debug(f"[LLM] Schema scan failed: {e}") @@ -4018,15 +4012,8 @@ def _llm_fallback_structured(self, query: str, baseline: dict, baseline_case: st Dict compatible with plan_modifications return format, or None if failed """ try: - import instructor - from src.services.plan import SimulationPlan - from src.config import unwrap_llm_client, wrap_llm_client - - # Wrap LLM client with instructor - base_client = unwrap_llm_client(self.llm_client) - client = instructor.from_openai(base_client) - client = wrap_llm_client(client, self.config) + from src.utils.llm_calls import LLMCallSpec, call_llm # Extract solver from baseline metadata solver_name = baseline.get('metadata', {}).get('solver') @@ -4051,13 +4038,20 @@ def _llm_fallback_structured(self, query: str, baseline: dict, baseline_case: st logger.info("[LLM] Requesting complete SimulationPlan via structured output...") # Get structured SimulationPlan from LLM - plan = client.chat.completions.create( + spec = LLMCallSpec( model=self.config.llm_model, response_model=SimulationPlan, messages=[{"role": "user", "content": prompt}], temperature=0.2, - max_retries=2 + max_retries=2, + purpose="structured_fallback_planning", + template_name="structured_fallback_planning", + template_source="solver_config", ) + plan = call_llm(self.llm_client, spec, config=self.config) + if not hasattr(plan, "modifications"): + logger.warning("[LLM] Structured output unavailable - falling back to JSON parsing") + return self._llm_fallback_json(query, baseline, baseline_case) # Override fields we already know (LLM might hallucinate these) plan.selected_solver = solver_name @@ -4082,11 +4076,6 @@ def _llm_fallback_structured(self, query: str, baseline: dict, baseline_case: st "_full_plan": plan # Include full plan for debugging } - except ImportError: - logger.warning("[LLM] instructor not installed - falling back to JSON parsing") - logger.info("[LLM] Install with: pip install instructor") - return self._llm_fallback_json(query, baseline, baseline_case) - except Exception as e: logger.error(f"[LLM] Structured fallback failed: {e}") import traceback diff --git a/src/services/cases.py b/src/services/cases.py index de6b5f0..62d4e9d 100644 --- a/src/services/cases.py +++ b/src/services/cases.py @@ -402,39 +402,46 @@ def find_best_match(self, user_prompt: str, llm_client, case_match = None use_plain = False try: - import instructor from pydantic import BaseModel, Field - from src.config import unwrap_llm_client, wrap_llm_client + from src.utils.llm_calls import LLMCallSpec, call_llm class CaseSelection(BaseModel): code: str = Field(description="Code name from the available list") case: str = Field(description="Case path from the selected code") - base_client = unwrap_llm_client(llm_client) - client = instructor.from_openai(base_client) - client = wrap_llm_client(client, self.config) - result = client.chat.completions.create( + spec = LLMCallSpec( model=self.config.llm_model, response_model=CaseSelection, messages=[{"role": "user", "content": prompt}], temperature=0.1, max_retries=2, + purpose="case_selection", + template_name="case_selection", + template_source="solver_config", ) - code_match = result.code.strip() - case_match = result.case.strip() - except (ImportError, ModuleNotFoundError): - use_plain = True + result = call_llm(llm_client, spec, config=self.config) + if hasattr(result, "code") and hasattr(result, "case"): + code_match = result.code.strip() + case_match = result.case.strip() + else: + use_plain = True except Exception as exc: logger.warning("LLM structured selection failed: %s", exc) use_plain = True if use_plain: - response = llm_client.chat.completions.create( + from src.utils.llm_calls import LLMCallSpec, call_llm + + spec = LLMCallSpec( model=self.config.llm_model, messages=[{"role": "user", "content": prompt}], temperature=0.1, - max_tokens=100 + max_tokens=100, + purpose="case_selection_plain", + template_name="case_selection", + template_source="solver_config", ) + response = call_llm(llm_client, spec, config=self.config) content = response.choices[0].message.content.strip() diff --git a/src/services/config_model_factory.py b/src/services/config_model_factory.py index 9db87c3..b988e3e 100644 --- a/src/services/config_model_factory.py +++ b/src/services/config_model_factory.py @@ -792,14 +792,11 @@ def remap_failed_modifications( schema_list = "\n".join(f"- {f}" for f in schema_fields[:150]) try: - import instructor - from src.config import get_llm_client, unwrap_llm_client, wrap_llm_client + from src.config import get_llm_client + from src.utils.llm_calls import LLMCallSpec, call_llm - base_client = unwrap_llm_client(get_llm_client(config_service)) - client = instructor.from_openai(base_client) - client = wrap_llm_client(client, config_service) - - result = client.chat.completions.create( + client = get_llm_client(config_service) + spec = LLMCallSpec( model=config_service.llm_model, response_model=MappingExtraction, messages=[{"role": "user", "content": prompt_template.format( @@ -807,8 +804,12 @@ def remap_failed_modifications( schema_list=schema_list, format_hints=format_hints or "None provided" )}], - temperature=0.1 + temperature=0.1, + purpose="remap_failed_modifications", + template_name="remap", + template_source="solver_config", ) + result = call_llm(client, spec, config=config_service) logger.debug(f"[Remap] LLM response: {len(result.mappings)} mappings") diff --git a/src/services/config_service.py b/src/services/config_service.py index ade7ddd..42100dd 100644 --- a/src/services/config_service.py +++ b/src/services/config_service.py @@ -260,11 +260,17 @@ def test_llm_connection(self, config: AMReXAgentConfig) -> bool: prompt = prompt logger.debug(" Testing API connection...") - response = client.chat.completions.create( + from src.utils.llm_calls import LLMCallSpec, call_llm + + spec = LLMCallSpec( model=config.llm_model, messages=[{"role": "user", "content": prompt}], - max_tokens=5 + max_tokens=5, + purpose="llm_connectivity_test", + template_name="llm_connectivity_test", + template_source="base_config.misc", ) + response = call_llm(client, spec, config=config) if config.llm_provider == "cborg": logger.debug("[ OK ] CBORG API connection successful") else: diff --git a/src/services/inputs_file_selector.py b/src/services/inputs_file_selector.py index cd8a935..1fa51e2 100644 --- a/src/services/inputs_file_selector.py +++ b/src/services/inputs_file_selector.py @@ -283,33 +283,27 @@ def _select_with_llm( prompt = prompt.format(case_name=case_hint, candidates="\n\n".join(summaries)) try: - try: - import instructor - from pydantic import BaseModel, Field - from src.config import unwrap_llm_client, wrap_llm_client - - class InputsSelection(BaseModel): - filename: str = Field(description="Selected inputs filename") - - base_client = unwrap_llm_client(client) - instr_client = instructor.from_openai(base_client) - instr_client = wrap_llm_client(instr_client, config) - result = instr_client.chat.completions.create( - model=config.llm_model, - response_model=InputsSelection, - messages=[{"role": "user", "content": prompt}], - temperature=0.0, - max_retries=2, - ) + from pydantic import BaseModel, Field + from src.utils.llm_calls import LLMCallSpec, call_llm + + class InputsSelection(BaseModel): + filename: str = Field(description="Selected inputs filename") + + spec = LLMCallSpec( + model=config.llm_model, + response_model=InputsSelection, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_retries=2, + purpose="inputs_file_selection", + template_name="inputs_select", + template_source="solver_config.misc", + ) + result = call_llm(client, spec, config=config) + if hasattr(result, "filename"): content = result.filename.strip() - except (ImportError, ModuleNotFoundError): - response = client.chat.completions.create( - model=config.llm_model, - messages=[{"role": "user", "content": prompt}], - temperature=0.0, - max_tokens=50 - ) - content = response.choices[0].message.content.strip() + else: + content = result.choices[0].message.content.strip() except Exception as exc: logger.debug(f"LLM compare failed: {exc}") return None diff --git a/src/services/knowledge.py b/src/services/knowledge.py index 47028d4..a6d92ad 100644 --- a/src/services/knowledge.py +++ b/src/services/knowledge.py @@ -452,49 +452,39 @@ def generate_questions_from_prompt(self, user_prompt: str, llm_client) -> list[s prompt = prompt_template.format(user_prompt=user_prompt) try: - try: - import instructor - from pydantic import BaseModel, Field - from src.config import unwrap_llm_client, wrap_llm_client - - class QuestionList(BaseModel): - questions: list[str] = Field(description="List of short questions") - - base_client = unwrap_llm_client(llm_client) - client = instructor.from_openai(base_client) - client = wrap_llm_client(client, self.config) - result = client.chat.completions.create( - model=self.config.llm_model, - response_model=QuestionList, - messages=[{"role": "user", "content": prompt}], - temperature=0.3, - max_retries=2, - ) + from pydantic import BaseModel, Field + from src.utils.llm_calls import LLMCallSpec, call_llm + + class QuestionList(BaseModel): + questions: list[str] = Field(description="List of short questions") + + spec = LLMCallSpec( + model=self.config.llm_model, + response_model=QuestionList, + messages=[{"role": "user", "content": prompt}], + temperature=0.3, + max_retries=2, + purpose="knowledge_question_generation", + template_name="question_generator", + template_source="solver_config.knowledge", + ) + result = call_llm(llm_client, spec, config=self.config) + if hasattr(result, "questions"): return result.questions - except (ImportError, ModuleNotFoundError): - response = llm_client.chat.completions.create( - model=self.config.llm_model, - messages=[{"role": "user", "content": prompt}], - temperature=0.3, - max_tokens=500 - ) - - # Parse JSON response - content = response.choices[0].message.content.strip() + response = result + content = response.choices[0].message.content.strip() - # Extract JSON array (handle markdown code blocks) - if "```json" in content: - content = content.split("```json")[1].split("```")[0].strip() - elif "```" in content: - content = content.split("```")[1].split("```")[0].strip() + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() - questions = json.loads(content) + questions = json.loads(content) - if isinstance(questions, list): - return questions - else: - logger.warning(f"[WARN] LLM returned non-list: {questions}") - return [] + if isinstance(questions, list): + return questions + logger.warning(f"[WARN] LLM returned non-list: {questions}") + return [] except Exception as e: logger.error(f"[ERROR] Failed to generate questions: {e}") return self._get_knowledge_fallback_questions() diff --git a/src/utils/llm_calls.py b/src/utils/llm_calls.py index 49154c8..b84ee34 100644 --- a/src/utils/llm_calls.py +++ b/src/utils/llm_calls.py @@ -15,6 +15,7 @@ class LLMCallSpec: temperature: float | None = None max_tokens: int | None = None response_model: Any | None = None + response_format: dict[str, Any] | None = None max_retries: int | None = None purpose: str | None = None template_name: str | None = None @@ -89,6 +90,7 @@ def call_llm( messages=spec.messages, temperature=spec.temperature, max_tokens=spec.max_tokens, + response_format=spec.response_format, ) if fallback_parser: return fallback_parser(response) From 4854ad00323ea0fa4ddacacb7fab9a11587c5a3c Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 02:34:30 -0800 Subject: [PATCH 23/36] Add filesystem write policy guard --- src/services/files.py | 3 ++ src/services/input_writer.py | 3 ++ src/utils/write_policy.py | 68 +++++++++++++++++++++++++++++++++ tests/unit/test_write_policy.py | 26 +++++++++++++ 4 files changed, 100 insertions(+) create mode 100644 src/utils/write_policy.py create mode 100644 tests/unit/test_write_policy.py diff --git a/src/services/files.py b/src/services/files.py index 817d3ca..5608b7a 100644 --- a/src/services/files.py +++ b/src/services/files.py @@ -328,6 +328,9 @@ def write_full_setup(self, config_json: str, output_dir: str | Path, JSON-encoded mapping of generated files. """ output_dir = Path(output_dir) + from src.utils.write_policy import ensure_write_allowed + + ensure_write_allowed(output_dir, self.config, purpose="files.write_full_setup") output_dir.mkdir(parents=True, exist_ok=True) from src.services.file_generation import FileGenerationService diff --git a/src/services/input_writer.py b/src/services/input_writer.py index f1a47ff..36e85ec 100644 --- a/src/services/input_writer.py +++ b/src/services/input_writer.py @@ -194,6 +194,9 @@ def apply_plan(self, # Ensure output directory exists output_dir = Path(output_dir) + from src.utils.write_policy import ensure_write_allowed + + ensure_write_allowed(output_dir, self.config, purpose="input_writer.apply_plan") output_dir.mkdir(parents=True, exist_ok=True) try: diff --git a/src/utils/write_policy.py b/src/utils/write_policy.py new file mode 100644 index 0000000..2aac9d4 --- /dev/null +++ b/src/utils/write_policy.py @@ -0,0 +1,68 @@ +"""Filesystem write policy helpers.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +class WritePolicyViolation(RuntimeError): + pass + + +def _is_within(child: Path, root: Path) -> bool: + try: + return child.resolve().is_relative_to(root.resolve()) + except Exception: + return False + + +def ensure_write_allowed( + output_dir: str | Path, + config: Any | None, + *, + purpose: str | None = None, +) -> None: + """Ensure output_dir is allowed by policy. + + Policy is configured via config attributes: + - write_policy_mode: "off" | "warn" | "deny" (default: "warn") + - allow_write_paths: list[str | Path] (optional) + - require_run_dir_prefix: bool (optional) + - run_dir_prefix: str (default: "run_") + - output_dir / metrics_output_dir / remote_output_dir: allowed roots + """ + mode = getattr(config, "write_policy_mode", "warn") + if mode == "off": + return + + output_path = Path(output_dir) + allow_paths = [] + for attr in ("output_dir", "metrics_output_dir", "remote_output_dir"): + value = getattr(config, attr, None) if config is not None else None + if value: + allow_paths.append(Path(value)) + + allow_list = getattr(config, "allow_write_paths", None) if config is not None else None + if allow_list: + allow_paths.extend(Path(p) for p in allow_list) + + allowed = any(_is_within(output_path, root) for root in allow_paths) + + require_prefix = bool(getattr(config, "require_run_dir_prefix", False)) if config is not None else False + if require_prefix: + prefix = getattr(config, "run_dir_prefix", "run_") + allowed = allowed and output_path.name.startswith(prefix) + + if allowed: + return + + context = f" ({purpose})" if purpose else "" + roots = ", ".join(str(p) for p in allow_paths) if allow_paths else "none" + message = f"Write blocked{context}: {output_path} not under allowed roots ({roots})" + if mode == "deny": + raise WritePolicyViolation(message) + logger.warning(message) diff --git a/tests/unit/test_write_policy.py b/tests/unit/test_write_policy.py new file mode 100644 index 0000000..9824bb7 --- /dev/null +++ b/tests/unit/test_write_policy.py @@ -0,0 +1,26 @@ +from pathlib import Path + +import pytest + +from src.utils.write_policy import ensure_write_allowed, WritePolicyViolation + + +class DummyConfig: + def __init__(self, output_dir, mode="deny"): + self.output_dir = Path(output_dir) + self.write_policy_mode = mode + + +def test_allows_write_under_output_dir(tmp_path): + config = DummyConfig(output_dir=tmp_path, mode="deny") + run_dir = tmp_path / "run_20250101_120000" + + ensure_write_allowed(run_dir, config, purpose="unit_test") + + +def test_denies_write_outside_output_dir(tmp_path): + config = DummyConfig(output_dir=tmp_path, mode="deny") + outside = tmp_path.parent / "other_dir" + + with pytest.raises(WritePolicyViolation): + ensure_write_allowed(outside, config, purpose="unit_test") From 0cfaca2bde8331dc95ad6bfd5de48b5b75e5c73a Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 02:38:32 -0800 Subject: [PATCH 24/36] Add write policy defaults to config --- src/config.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/config.py b/src/config.py index fd7da12..1ab0526 100644 --- a/src/config.py +++ b/src/config.py @@ -656,6 +656,22 @@ class AMReXAgentConfig(BaseModel): default="metrics.jsonl", description="Metrics JSONL filename for workflow summaries." ) + write_policy_mode: str = Field( + default="warn", + description="Filesystem write policy mode: off, warn, or deny." + ) + allow_write_paths: List[Path] = Field( + default=[], + description="Optional allowlist of writable path roots." + ) + require_run_dir_prefix: bool = Field( + default=False, + description="Require run directories to start with run_dir_prefix." + ) + run_dir_prefix: str = Field( + default="run_", + description="Prefix required for run directories when require_run_dir_prefix is true." + ) save_intermediate: bool = Field( default=True, description="Save intermediate results (plans, configs, etc.)" From 226bb36a8b277361e4852dd930065155ab3d3a2d Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 02:39:19 -0800 Subject: [PATCH 25/36] Add test enforcing LLM helper usage --- tests/unit/test_llm_call_enforcement.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 tests/unit/test_llm_call_enforcement.py diff --git a/tests/unit/test_llm_call_enforcement.py b/tests/unit/test_llm_call_enforcement.py new file mode 100644 index 0000000..6e4e409 --- /dev/null +++ b/tests/unit/test_llm_call_enforcement.py @@ -0,0 +1,15 @@ +from pathlib import Path + + +def test_llm_calls_go_through_helper(): + repo_root = Path(__file__).resolve().parents[2] + src_root = repo_root / "src" + offenders = [] + for path in src_root.rglob("*.py"): + if path.name == "llm_calls.py": + continue + text = path.read_text() + if "chat.completions.create" in text: + offenders.append(path.relative_to(repo_root)) + + assert offenders == [], f"Direct chat.completions.create usage found: {offenders}" From 25f8e6579c08e95846099e864cfccd5472bc038d Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 03:25:03 -0800 Subject: [PATCH 26/36] Add privacy scrubbing modes and hooks --- src/config.py | 8 ++ src/main.py | 156 ++++++++++++++++++++++--------------- src/utils/llm_calls.py | 8 ++ src/utils/metrics.py | 9 ++- src/utils/privacy.py | 155 ++++++++++++++++++++++++++++++++++++ tests/unit/test_privacy.py | 39 ++++++++++ 6 files changed, 310 insertions(+), 65 deletions(-) create mode 100644 src/utils/privacy.py create mode 100644 tests/unit/test_privacy.py diff --git a/src/config.py b/src/config.py index 1ab0526..4fe9e15 100644 --- a/src/config.py +++ b/src/config.py @@ -656,6 +656,14 @@ class AMReXAgentConfig(BaseModel): default="metrics.jsonl", description="Metrics JSONL filename for workflow summaries." ) + privacy_mode: Literal["off", "shared", "strict"] = Field( + default="off", + description="Privacy mode for prompt/log persistence: off, shared, or strict." + ) + privacy_hash_salt: Optional[str] = Field( + default_factory=lambda: os.getenv("AMREX_PRIVACY_SALT"), + description="Optional salt for prompt hashing in privacy modes." + ) write_policy_mode: str = Field( default="warn", description="Filesystem write policy mode: off, warn, or deny." diff --git a/src/main.py b/src/main.py index 7289873..afebe0c 100644 --- a/src/main.py +++ b/src/main.py @@ -748,7 +748,7 @@ def main(args: list[str] | None = None) -> None: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"metrics_{timestamp}.jsonl" metrics_path = base_dir / filename - metrics_collector.write_jsonl(str(metrics_path)) + metrics_collector.write_jsonl(str(metrics_path), config=config) logger.info(f"Metrics saved to {metrics_path}") except Exception as e: logger.warning(f"Failed to save metrics JSONL: {e}") @@ -764,8 +764,14 @@ def main(args: list[str] | None = None) -> None: base_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") workflow_path = base_dir / f"workflow_history_{timestamp}.json" + from src.utils.privacy import sanitize_payload + + workflow_payload = sanitize_payload( + result.get('workflow_history', []), + config=config, + ) with open(workflow_path, 'w') as f: - json.dump(result.get('workflow_history', []), f, indent=2, default=str) + json.dump(workflow_payload, f, indent=2, default=str) logger.info(f"Workflow history saved to {workflow_path}") except Exception as e: logger.warning(f"Failed to save workflow history: {e}") @@ -775,67 +781,91 @@ def main(args: list[str] | None = None) -> None: run_dir = Path(result['run_directory']) transcript_path = run_dir / "agent_transcript.txt" try: - transcript_lines = ["=== Agent Transcript ===\n\n"] - transcript_lines.append(f"User Prompt:\n{user_requirement}\n\n") - transcript_lines.append("=" * 80 + "\n\n") - - for entry in result.get('workflow_history', []): - if not isinstance(entry, dict): - continue - node = entry.get('node', 'unknown') - action = entry.get('action', 'unknown') - timestamp = entry.get('timestamp', '') - details = entry.get('details', {}) - - transcript_lines.append(f"[{timestamp}] NODE: {node.upper()}\n") - transcript_lines.append(f"ACTION: {action}\n") - - if node == 'architect' and 'level0_routing' in details: - routing = details['level0_routing'] - transcript_lines.append(" Level 0 Routing Decision:\n") - transcript_lines.append(f" Selected Code: {routing.get('selected_code')}\n") - transcript_lines.append(f" Confidence: {routing.get('confidence', 0):.2f}\n") - transcript_lines.append(f" Reasoning: {routing.get('reasoning')}\n") - - if node == 'architect' and 'level2_cbr' in details: - cbr = details['level2_cbr'] - match = cbr.get('top_match', {}) - transcript_lines.append(" Level 2 Case-Based Reasoning:\n") - transcript_lines.append(f" Best Match: {match.get('case_name')}\n") - transcript_lines.append(f" Path: {match.get('repo_path')}\n") - transcript_lines.append(f" Similarity: {match.get('similarity_score', 0):.2f}\n") - transcript_lines.append(f" Reason: {match.get('match_reason')}\n") - - if node == 'architect' and 'modifications' in details: - mods = details.get('modifications', []) - if mods: - transcript_lines.append(f" Planned Modifications: {len(mods)} changes\n") - for mod in mods[:3]: # Show first 3 - if isinstance(mod, dict): - section = mod.get('section', '') - param = mod.get('parameter', '') - reason = mod.get('reason', '') - label = f"{section}.{param}".strip('.') - transcript_lines.append(f" - {label}: {reason}\n") - elif isinstance(mod, (list, tuple)) and len(mod) == 2: - param, value = mod - transcript_lines.append(f" - {param} = {value}\n") - else: - transcript_lines.append(f" - {mod}\n") - - if node == 'analysis' and 'report' in details: - report = details.get('report', {}) - transcript_lines.append(" Analysis Results:\n") - transcript_lines.append(f" Status: {report.get('status')}\n") - if report.get('performance'): - perf = report.get('performance', {}) - transcript_lines.append(f" Performance: {perf.get('avg_cells_per_sec', 0):,.0f} cells/sec\n") - - transcript_lines.append("\n") - - with open(transcript_path, 'w') as f: - f.writelines(transcript_lines) - logger.info(f"Agent transcript saved to {transcript_path}") + from src.utils.privacy import get_privacy_mode, scrub_text + + privacy_mode = get_privacy_mode(config) + if privacy_mode == "strict": + logger.info("Privacy mode strict: skipping transcript output.") + else: + transcript_lines = ["=== Agent Transcript ===\n\n"] + prompt_text = user_requirement + if privacy_mode == "shared": + prompt_text = scrub_text( + prompt_text, + mode=privacy_mode, + salt=getattr(config, "privacy_hash_salt", None), + ).text + transcript_lines.append(f"User Prompt:\n{prompt_text}\n\n") + transcript_lines.append("=" * 80 + "\n\n") + + for entry in result.get('workflow_history', []): + if not isinstance(entry, dict): + continue + node = entry.get('node', 'unknown') + action = entry.get('action', 'unknown') + timestamp = entry.get('timestamp', '') + details = entry.get('details', {}) + + transcript_lines.append(f"[{timestamp}] NODE: {node.upper()}\n") + transcript_lines.append(f"ACTION: {action}\n") + + if node == 'architect' and 'level0_routing' in details: + routing = details['level0_routing'] + transcript_lines.append(" Level 0 Routing Decision:\n") + transcript_lines.append(f" Selected Code: {routing.get('selected_code')}\n") + transcript_lines.append(f" Confidence: {routing.get('confidence', 0):.2f}\n") + transcript_lines.append(f" Reasoning: {routing.get('reasoning')}\n") + + if node == 'architect' and 'level2_cbr' in details: + cbr = details['level2_cbr'] + match = cbr.get('top_match', {}) + transcript_lines.append(" Level 2 Case-Based Reasoning:\n") + transcript_lines.append(f" Best Match: {match.get('case_name')}\n") + transcript_lines.append(f" Path: {match.get('repo_path')}\n") + transcript_lines.append(f" Similarity: {match.get('similarity_score', 0):.2f}\n") + transcript_lines.append(f" Reason: {match.get('match_reason')}\n") + + if node == 'architect' and 'modifications' in details: + mods = details.get('modifications', []) + if mods: + transcript_lines.append(f" Planned Modifications: {len(mods)} changes\n") + for mod in mods[:3]: # Show first 3 + if isinstance(mod, dict): + section = mod.get('section', '') + param = mod.get('parameter', '') + reason = mod.get('reason', '') + label = f"{section}.{param}".strip('.') + transcript_lines.append(f" - {label}: {reason}\n") + elif isinstance(mod, (list, tuple)) and len(mod) == 2: + param, value = mod + transcript_lines.append(f" - {param} = {value}\n") + else: + transcript_lines.append(f" - {mod}\n") + + if node == 'analysis' and 'report' in details: + report = details.get('report', {}) + transcript_lines.append(" Analysis Results:\n") + transcript_lines.append(f" Status: {report.get('status')}\n") + if report.get('performance'): + perf = report.get('performance', {}) + transcript_lines.append( + f" Performance: {perf.get('avg_cells_per_sec', 0):,.0f} cells/sec\n" + ) + + transcript_lines.append("\n") + + if privacy_mode == "shared": + transcript_lines = [ + scrub_text( + line, + mode=privacy_mode, + salt=getattr(config, "privacy_hash_salt", None), + ).text + for line in transcript_lines + ] + with open(transcript_path, 'w') as f: + f.writelines(transcript_lines) + logger.info(f"Agent transcript saved to {transcript_path}") except Exception as e: logger.warning(f"Failed to save transcript: {e}") diff --git a/src/utils/llm_calls.py b/src/utils/llm_calls.py index b84ee34..3106cde 100644 --- a/src/utils/llm_calls.py +++ b/src/utils/llm_calls.py @@ -51,6 +51,14 @@ def call_llm( extra_context: dict[str, Any] | None = None, fallback_parser: Callable[[Any], Any] | None = None, ) -> Any: + if config is not None: + from src.utils.privacy import enforce_strict + + for message in spec.messages: + content = message.get("content") if isinstance(message, dict) else None + if isinstance(content, str): + enforce_strict(content, config=config, purpose="llm_call") + policy = policy or LLMPolicy() ok, reason = policy.check_rate_limit(spec) if not ok: diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 0939682..18a0585 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -165,13 +165,18 @@ def build_workflow_summary(self, stages: list[str] | None = None) -> dict[str, A def events(self) -> list[dict[str, Any]]: return list(self._events) - def write_jsonl(self, path: str) -> None: + def write_jsonl(self, path: str, *, config: Any | None = None) -> None: if not self._events: return try: with open(path, "w", encoding="utf-8") as handle: for event in self._events: - handle.write(json.dumps(event, default=str)) + payload = event + if config is not None: + from src.utils.privacy import sanitize_payload + + payload = sanitize_payload(event, config=config) + handle.write(json.dumps(payload, default=str)) handle.write("\n") except Exception as exc: logger.warning("Failed to write metrics JSONL to %s: %s", path, exc) diff --git a/src/utils/privacy.py b/src/utils/privacy.py new file mode 100644 index 0000000..b74c5a5 --- /dev/null +++ b/src/utils/privacy.py @@ -0,0 +1,155 @@ +"""Privacy scrubbing utilities for prompts and persisted artifacts.""" + +from __future__ import annotations + +import hashlib +import re +from dataclasses import dataclass +from typing import Any, Iterable + + +class PrivacyViolation(RuntimeError): + pass + + +_EMAIL_RE = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}") +_PATH_RE = re.compile(r"(?:/home/[^\s/]+/|/Users/[^\s/]+/|C:\\Users\\[^\s\\]+\\)") +_JWT_RE = re.compile(r"[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+") +_PEM_RE = re.compile(r"-----BEGIN [A-Z0-9 _-]+-----.*?-----END [A-Z0-9 _-]+-----", re.DOTALL) +_API_KEY_RE = re.compile(r"(?i)(api_key|apikey|token|secret)\s*[:=]\s*[^\s]+") +_SK_RE = re.compile(r"sk-[A-Za-z0-9]{20,}") + +_SENSITIVE_KEYS = { + "prompt", + "user_prompt", + "prompt_text", + "messages", + "content", + "reasoning", + "inputs_content", + "question", + "answer", + "query", + "raw_prompt", +} + + +@dataclass(frozen=True) +class ScrubResult: + text: str + hash: str + detections: list[str] + + +def _hash_text(text: str, salt: str | None) -> str: + digest = hashlib.sha256() + if salt: + digest.update(salt.encode("utf-8")) + digest.update(text.encode("utf-8")) + return digest.hexdigest() + + +def _detect(text: str) -> list[str]: + detections: list[str] = [] + if _EMAIL_RE.search(text): + detections.append("email") + if _PATH_RE.search(text): + detections.append("path") + if _JWT_RE.search(text): + detections.append("jwt") + if _PEM_RE.search(text): + detections.append("pem") + if _API_KEY_RE.search(text) or _SK_RE.search(text): + detections.append("key") + return detections + + +def scrub_text(text: str, *, mode: str, salt: str | None = None) -> ScrubResult: + detections = _detect(text) + redacted = text + if mode == "shared": + if detections: + redacted = _PEM_RE.sub("[REDACTED:PEM]", redacted) + redacted = _SK_RE.sub("[REDACTED:KEY]", redacted) + redacted = _API_KEY_RE.sub(r"\1=[REDACTED:KEY]", redacted) + redacted = _JWT_RE.sub("[REDACTED:JWT]", redacted) + redacted = _EMAIL_RE.sub("[REDACTED:EMAIL]", redacted) + redacted = _PATH_RE.sub("[REDACTED:PATH]", redacted) + elif mode == "strict": + redacted = "[REDACTED:HASH]" + return ScrubResult(text=redacted, hash=_hash_text(text, salt), detections=detections) + + +def get_privacy_mode(config: Any | None) -> str: + return getattr(config, "privacy_mode", "off") if config is not None else "off" + + +def get_privacy_salt(config: Any | None) -> str | None: + return getattr(config, "privacy_hash_salt", None) if config is not None else None + + +def should_scrub_key(key: str) -> bool: + return key in _SENSITIVE_KEYS + + +def sanitize_payload( + payload: Any, + *, + config: Any | None, + mode: str | None = None, + sensitive_keys: Iterable[str] | None = None, +) -> Any: + if mode is None: + mode = get_privacy_mode(config) + if mode == "off": + return payload + + salt = get_privacy_salt(config) + sensitive = set(sensitive_keys) if sensitive_keys is not None else _SENSITIVE_KEYS + + if isinstance(payload, dict): + sanitized: dict[str, Any] = {} + for key, value in payload.items(): + if isinstance(value, str) and (key in sensitive or _detect(value)): + result = scrub_text(value, mode=mode, salt=salt) + if mode == "strict" and key in sensitive: + sanitized[key] = result.hash + else: + sanitized[key] = result.text + continue + sanitized[key] = sanitize_payload( + value, + config=config, + mode=mode, + sensitive_keys=sensitive, + ) + return sanitized + + if isinstance(payload, list): + return [ + sanitize_payload(item, config=config, mode=mode, sensitive_keys=sensitive) + for item in payload + ] + + if isinstance(payload, tuple): + return tuple( + sanitize_payload(item, config=config, mode=mode, sensitive_keys=sensitive) + for item in payload + ) + + if isinstance(payload, str): + if _detect(payload): + result = scrub_text(payload, mode=mode, salt=salt) + return result.hash if mode == "strict" else result.text + return payload + + return payload + + +def enforce_strict(text: str, *, config: Any | None, purpose: str) -> None: + mode = get_privacy_mode(config) + if mode != "strict": + return + detections = _detect(text) + if detections: + raise PrivacyViolation(f"Privacy mode strict blocked {purpose}: {detections}") diff --git a/tests/unit/test_privacy.py b/tests/unit/test_privacy.py new file mode 100644 index 0000000..bcfe022 --- /dev/null +++ b/tests/unit/test_privacy.py @@ -0,0 +1,39 @@ +from src.utils.privacy import PrivacyViolation, enforce_strict, sanitize_payload, scrub_text + + +class DummyConfig: + def __init__(self, mode): + self.privacy_mode = mode + self.privacy_hash_salt = "salt" + + +def test_scrub_text_shared_redacts_markers(): + text = "email me at user@example.com and use sk-abcdefghijklmnopqrstuv" + result = scrub_text(text, mode="shared", salt="salt") + assert "[REDACTED:EMAIL]" in result.text + assert "[REDACTED:KEY]" in result.text + assert result.hash + + +def test_sanitize_payload_shared_redacts_sensitive_key(): + payload = {"prompt": "reach me at user@example.com"} + result = sanitize_payload(payload, config=DummyConfig("shared")) + assert result["prompt"] != payload["prompt"] + assert "[REDACTED:EMAIL]" in result["prompt"] + + +def test_sanitize_payload_strict_hashes_sensitive_key(): + payload = {"prompt": "reach me at user@example.com"} + result = sanitize_payload(payload, config=DummyConfig("strict")) + assert result["prompt"] != payload["prompt"] + assert len(result["prompt"]) == 64 + + +def test_enforce_strict_blocks_on_detection(): + config = DummyConfig("strict") + try: + enforce_strict("path /home/user/secret", config=config, purpose="unit_test") + except PrivacyViolation as exc: + assert "unit_test" in str(exc) + else: + raise AssertionError("PrivacyViolation not raised") From 4ac9eb8cb236a67b0e1810fb5ddec849dd685905 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 03:27:15 -0800 Subject: [PATCH 27/36] Add privacy persistence tests --- tests/unit/test_privacy_persistence.py | 46 ++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/unit/test_privacy_persistence.py diff --git a/tests/unit/test_privacy_persistence.py b/tests/unit/test_privacy_persistence.py new file mode 100644 index 0000000..dd15804 --- /dev/null +++ b/tests/unit/test_privacy_persistence.py @@ -0,0 +1,46 @@ +import json + +from src.utils.metrics import MetricsCollector +from src.utils.privacy import sanitize_payload + + +class DummyConfig: + def __init__(self, mode): + self.privacy_mode = mode + self.privacy_hash_salt = "salt" + + +def test_metrics_jsonl_scrubbed_shared(tmp_path): + config = DummyConfig("shared") + collector = MetricsCollector() + collector.record_event( + "prompt_event", + {"prompt": "email me at user@example.com"}, + stage="workflow", + node="main", + ) + path = tmp_path / "metrics.jsonl" + collector.write_jsonl(str(path), config=config) + + content = path.read_text() + assert "user@example.com" not in content + assert "[REDACTED:EMAIL]" in content + + +def test_workflow_history_scrubbed_strict(tmp_path): + config = DummyConfig("strict") + history = [ + { + "node": "architect", + "details": { + "prompt": "email me at user@example.com", + }, + } + ] + payload = sanitize_payload(history, config=config) + path = tmp_path / "workflow_history.json" + path.write_text(json.dumps(payload)) + + content = path.read_text() + assert "user@example.com" not in content + assert "prompt" in content From d8835e75015811d4570eb2f3b25ed19c4697c267 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 03:38:53 -0800 Subject: [PATCH 28/36] Make privacy scrubber pluggable --- docs/standards.md | 16 ++++++++ src/config.py | 4 ++ src/main.py | 12 +++--- src/utils/privacy.py | 88 ++++++++++++++++++++++++++++++++++++-------- 4 files changed, 100 insertions(+), 20 deletions(-) diff --git a/docs/standards.md b/docs/standards.md index 10d6c91..8e83dc0 100644 --- a/docs/standards.md +++ b/docs/standards.md @@ -81,6 +81,22 @@ GitHub reference: https://github.com/agents4science/agents4science.github.io/tre Note: dynamic DAG construction means the workflow graph can be built or pruned at runtime based on context, instead of following a fixed node sequence. +## Privacy and prompt scrubbing + +Policy summary: +- Privacy modes: off, shared, strict. +- Shared mode redacts sensitive content in persisted artifacts and logs. +- Strict mode blocks LLM calls when sensitive content is detected and skips transcripts. + +Implementation: +- Config: `privacy_mode`, `privacy_scrubber`, `privacy_hash_salt` in `src/config.py`. +- Scrubber: `src/utils/privacy.py`. +- Hooks: `src/utils/llm_calls.py`, `src/utils/metrics.py`, `src/main.py`. + +Tests: +- `tests/unit/test_privacy.py` +- `tests/unit/test_privacy_persistence.py` + ## Collaboration expectations (Agents4Science) | Expectation | Status | Notes | diff --git a/src/config.py b/src/config.py index 4fe9e15..73f6cdb 100644 --- a/src/config.py +++ b/src/config.py @@ -660,6 +660,10 @@ class AMReXAgentConfig(BaseModel): default="off", description="Privacy mode for prompt/log persistence: off, shared, or strict." ) + privacy_scrubber: Literal["builtin", "scrubadub", "presidio"] = Field( + default="builtin", + description="Scrubber backend for privacy modes: builtin, scrubadub, or presidio." + ) privacy_hash_salt: Optional[str] = Field( default_factory=lambda: os.getenv("AMREX_PRIVACY_SALT"), description="Optional salt for prompt hashing in privacy modes." diff --git a/src/main.py b/src/main.py index afebe0c..a075401 100644 --- a/src/main.py +++ b/src/main.py @@ -794,6 +794,7 @@ def main(args: list[str] | None = None) -> None: prompt_text, mode=privacy_mode, salt=getattr(config, "privacy_hash_salt", None), + config=config, ).text transcript_lines.append(f"User Prompt:\n{prompt_text}\n\n") transcript_lines.append("=" * 80 + "\n\n") @@ -856,11 +857,12 @@ def main(args: list[str] | None = None) -> None: if privacy_mode == "shared": transcript_lines = [ - scrub_text( - line, - mode=privacy_mode, - salt=getattr(config, "privacy_hash_salt", None), - ).text + scrub_text( + line, + mode=privacy_mode, + salt=getattr(config, "privacy_hash_salt", None), + config=config, + ).text for line in transcript_lines ] with open(transcript_path, 'w') as f: diff --git a/src/utils/privacy.py b/src/utils/privacy.py index b74c5a5..6e53816 100644 --- a/src/utils/privacy.py +++ b/src/utils/privacy.py @@ -3,10 +3,13 @@ from __future__ import annotations import hashlib +import logging import re from dataclasses import dataclass from typing import Any, Iterable +logger = logging.getLogger(__name__) + class PrivacyViolation(RuntimeError): pass @@ -64,7 +67,23 @@ def _detect(text: str) -> list[str]: return detections -def scrub_text(text: str, *, mode: str, salt: str | None = None) -> ScrubResult: +def get_privacy_mode(config: Any | None) -> str: + return getattr(config, "privacy_mode", "off") if config is not None else "off" + + +def get_privacy_scrubber(config: Any | None) -> str: + return getattr(config, "privacy_scrubber", "builtin") if config is not None else "builtin" + + +def get_privacy_salt(config: Any | None) -> str | None: + return getattr(config, "privacy_hash_salt", None) if config is not None else None + + +def should_scrub_key(key: str) -> bool: + return key in _SENSITIVE_KEYS + + +def _scrub_with_builtin(text: str, mode: str) -> tuple[str, list[str]]: detections = _detect(text) redacted = text if mode == "shared": @@ -77,21 +96,60 @@ def scrub_text(text: str, *, mode: str, salt: str | None = None) -> ScrubResult: redacted = _PATH_RE.sub("[REDACTED:PATH]", redacted) elif mode == "strict": redacted = "[REDACTED:HASH]" + return redacted, detections + + +def _scrub_with_scrubadub(text: str) -> tuple[str, list[str]]: + try: + import scrubadub + except (ImportError, ModuleNotFoundError): + logger.warning("privacy_scrubber=scrubadub requested but scrubadub is not installed; using builtin.") + return _scrub_with_builtin(text, "shared") + scrubber = scrubadub.Scrubber() + redacted = scrubber.scrub(text) + detections = ["scrubadub"] if redacted != text else [] + return redacted, detections + + +def _scrub_with_presidio(text: str) -> tuple[str, list[str]]: + try: + from presidio_analyzer import AnalyzerEngine + from presidio_anonymizer import AnonymizerEngine + except (ImportError, ModuleNotFoundError): + logger.warning("privacy_scrubber=presidio requested but presidio is not installed; using builtin.") + return _scrub_with_builtin(text, "shared") + try: + analyzer = AnalyzerEngine() + anonymizer = AnonymizerEngine() + results = analyzer.analyze(text=text, language="en") + anonymized = anonymizer.anonymize(text=text, analyzer_results=results) + redacted = anonymized.text + detections = ["presidio"] if redacted != text else [] + return redacted, detections + except Exception: + logger.warning("privacy_scrubber=presidio failed to initialize; using builtin.") + return _scrub_with_builtin(text, "shared") + + +def scrub_text( + text: str, + *, + mode: str, + salt: str | None = None, + config: Any | None = None, +) -> ScrubResult: + scrubber = get_privacy_scrubber(config) + if mode == "shared" and scrubber != "builtin": + if scrubber == "scrubadub": + redacted, detections = _scrub_with_scrubadub(text) + else: + redacted, detections = _scrub_with_presidio(text) + return ScrubResult(text=redacted, hash=_hash_text(text, salt), detections=detections) + + redacted, detections = _scrub_with_builtin(text, mode) return ScrubResult(text=redacted, hash=_hash_text(text, salt), detections=detections) -def get_privacy_mode(config: Any | None) -> str: - return getattr(config, "privacy_mode", "off") if config is not None else "off" - - -def get_privacy_salt(config: Any | None) -> str | None: - return getattr(config, "privacy_hash_salt", None) if config is not None else None - - -def should_scrub_key(key: str) -> bool: - return key in _SENSITIVE_KEYS - - def sanitize_payload( payload: Any, *, @@ -111,7 +169,7 @@ def sanitize_payload( sanitized: dict[str, Any] = {} for key, value in payload.items(): if isinstance(value, str) and (key in sensitive or _detect(value)): - result = scrub_text(value, mode=mode, salt=salt) + result = scrub_text(value, mode=mode, salt=salt, config=config) if mode == "strict" and key in sensitive: sanitized[key] = result.hash else: @@ -139,7 +197,7 @@ def sanitize_payload( if isinstance(payload, str): if _detect(payload): - result = scrub_text(payload, mode=mode, salt=salt) + result = scrub_text(payload, mode=mode, salt=salt, config=config) return result.hash if mode == "strict" else result.text return payload From b7e99b8ab87f15962aa709deaef20390f62aa6d9 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 03:43:17 -0800 Subject: [PATCH 29/36] Add scrubber selection integration test --- .../test_privacy_scrubber_selection.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 tests/integration/test_privacy_scrubber_selection.py diff --git a/tests/integration/test_privacy_scrubber_selection.py b/tests/integration/test_privacy_scrubber_selection.py new file mode 100644 index 0000000..e0fe89b --- /dev/null +++ b/tests/integration/test_privacy_scrubber_selection.py @@ -0,0 +1,19 @@ +from src.utils.privacy import scrub_text + + +class DummyConfig: + def __init__(self): + self.privacy_mode = "shared" + self.privacy_scrubber = "scrubadub" + self.privacy_hash_salt = "salt" + + +def test_scrubber_selection_falls_back_to_builtin(): + config = DummyConfig() + result = scrub_text( + "email me at user@example.com", + mode=config.privacy_mode, + salt=config.privacy_hash_salt, + config=config, + ) + assert "[REDACTED:EMAIL]" in result.text From 57272c9233b613b43d86be781c2d79f0b7120a7c Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 03:46:32 -0800 Subject: [PATCH 30/36] Add optional scrubadub dependency --- environment.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/environment.yaml b/environment.yaml index b778357..3fbbb55 100644 --- a/environment.yaml +++ b/environment.yaml @@ -88,3 +88,7 @@ dependencies: # Additional utilities - python-json-logger>=4.0 # Structured logging - sfapi-client + # Optional privacy scrubbers (requires extra model downloads for presidio) + - scrubadub>=2.0 + # - presidio-analyzer>=2.2 + # - presidio-anonymizer>=2.2 From 6165f08353d03180a76b3af29a9a687ae2018451 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 03:50:30 -0800 Subject: [PATCH 31/36] Scrub sensitive content from logs --- src/main.py | 43 ++++++++++++++++++++++++++++++++++++++ src/utils/privacy.py | 18 ++++++++++++++++ tests/unit/test_privacy.py | 23 +++++++++++++++++++- 3 files changed, 83 insertions(+), 1 deletion(-) diff --git a/src/main.py b/src/main.py index a075401..8b808aa 100644 --- a/src/main.py +++ b/src/main.py @@ -78,6 +78,47 @@ def filter(self, record: logging.LogRecord) -> bool: return True +class PrivacyFilter(logging.Filter): + """Scrub log messages based on configured privacy mode.""" + + def __init__(self, config: AMReXAgentConfig): + super().__init__() + self._config = config + + def filter(self, record: logging.LogRecord) -> bool: + try: + from src.utils.privacy import scrub_log_message + + msg = record.getMessage() + scrubbed = scrub_log_message(msg, config=self._config) + if scrubbed != msg: + record.msg = scrubbed + record.args = () + except Exception: + pass + return True + + +_privacy_filter_installed = False + + +def apply_privacy_log_filter(config: AMReXAgentConfig) -> None: + global _privacy_filter_installed + if _privacy_filter_installed: + return + from src.utils.privacy import get_privacy_mode + + if get_privacy_mode(config) == "off": + return + privacy_filter = PrivacyFilter(config) + root_logger = logging.getLogger() + for handler in root_logger.handlers: + handler.addFilter(privacy_filter) + for name in ("httpx", "openai", "anthropic"): + logging.getLogger(name).addFilter(privacy_filter) + _privacy_filter_installed = True + + class ColorFormatter(logging.Formatter): """Optional ANSI color formatting for log levels.""" @@ -698,6 +739,8 @@ def main(args: list[str] | None = None) -> None: raise ValueError("--run-ntasks must be >= 1") config.mpi_ranks = parsed_args.mpi_ranks + apply_privacy_log_filter(config) + _warn_if_schema_missing(config, getattr(parsed_args, "baseline_override", None)) # Disable schema validator temporarily (modifications format issue) diff --git a/src/utils/privacy.py b/src/utils/privacy.py index 6e53816..ea0f36e 100644 --- a/src/utils/privacy.py +++ b/src/utils/privacy.py @@ -83,6 +83,10 @@ def should_scrub_key(key: str) -> bool: return key in _SENSITIVE_KEYS +def detect_text(text: str) -> list[str]: + return _detect(text) + + def _scrub_with_builtin(text: str, mode: str) -> tuple[str, list[str]]: detections = _detect(text) redacted = text @@ -150,6 +154,20 @@ def scrub_text( return ScrubResult(text=redacted, hash=_hash_text(text, salt), detections=detections) +def scrub_log_message(message: str, *, config: Any | None) -> str: + mode = get_privacy_mode(config) + if mode == "off": + return message + detections = detect_text(message) + if not detections: + return message + salt = get_privacy_salt(config) + if mode == "shared": + return scrub_text(message, mode=mode, salt=salt, config=config).text + result = scrub_text(message, mode="strict", salt=salt, config=config) + return f"[REDACTED:HASH:{result.hash}]" + + def sanitize_payload( payload: Any, *, diff --git a/tests/unit/test_privacy.py b/tests/unit/test_privacy.py index bcfe022..7e24a12 100644 --- a/tests/unit/test_privacy.py +++ b/tests/unit/test_privacy.py @@ -1,4 +1,11 @@ -from src.utils.privacy import PrivacyViolation, enforce_strict, sanitize_payload, scrub_text +from src.utils.privacy import ( + PrivacyViolation, + detect_text, + enforce_strict, + sanitize_payload, + scrub_log_message, + scrub_text, +) class DummyConfig: @@ -37,3 +44,17 @@ def test_enforce_strict_blocks_on_detection(): assert "unit_test" in str(exc) else: raise AssertionError("PrivacyViolation not raised") + + +def test_scrub_log_message_shared_redacts(): + config = DummyConfig("shared") + message = "contact user@example.com" + assert "user@example.com" in message + scrubbed = scrub_log_message(message, config=config) + assert "user@example.com" not in scrubbed + assert "[REDACTED:EMAIL]" in scrubbed + + +def test_detect_text_reports_matches(): + detections = detect_text("token=abc123 user@example.com") + assert "email" in detections From 4d4404a9d249b4f5d8f857f7287ce62d597dbe38 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 04:02:30 -0800 Subject: [PATCH 32/36] Scrub benchmark artifacts for privacy --- src/benchmark_runner.py | 61 +++++++++++++++++++++++----- tests/unit/test_benchmark_privacy.py | 16 ++++++++ 2 files changed, 67 insertions(+), 10 deletions(-) create mode 100644 tests/unit/test_benchmark_privacy.py diff --git a/src/benchmark_runner.py b/src/benchmark_runner.py index 2563892..b31a86c 100644 --- a/src/benchmark_runner.py +++ b/src/benchmark_runner.py @@ -10,6 +10,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any +from types import SimpleNamespace import yaml from jsonschema import Draft202012Validator @@ -319,12 +320,27 @@ def _build_command( return cmd -def _write_jsonl(path: Path, payload: dict[str, Any]) -> None: +def _write_jsonl(path: Path, payload: dict[str, Any], config: Any | None = None) -> None: with path.open("a", encoding="utf-8") as handle: + if config is not None: + from src.utils.privacy import sanitize_payload + + payload = sanitize_payload(payload, config=config) handle.write(json.dumps(payload, default=str)) handle.write("\n") +def _privacy_config(run_args: dict[str, Any]) -> Any | None: + mode = run_args.get("privacy_mode") + if not mode: + return None + return SimpleNamespace( + privacy_mode=mode, + privacy_scrubber=run_args.get("privacy_scrubber", "builtin"), + privacy_hash_salt=run_args.get("privacy_hash_salt"), + ) + + def _slugify(value: str) -> str: return re.sub(r"[^A-Za-z0-9._-]+", "_", value) @@ -341,6 +357,7 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non run_args = bench_config.get("run_args") or {} env_common = bench_config.get("env") or {} + privacy_config = _privacy_config(run_args) run_name = run_name or datetime.now().strftime("bench_%Y%m%d_%H%M%S") run_dir = output_dir / run_name @@ -365,11 +382,16 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non if not prompt_text: raise ValueError(f"Prompt entry missing text/path: {normalized!r}") prompt_entries.append({**normalized, "prompt": prompt_text}) - manifest["prompts"].append({ + prompt_manifest = { "id": normalized["id"], "prompt_path": normalized.get("prompt_path"), "prompt_excerpt": prompt_text[:160], - }) + } + if privacy_config is not None: + from src.utils.privacy import sanitize_payload + + prompt_manifest = sanitize_payload(prompt_manifest, config=privacy_config) + manifest["prompts"].append(prompt_manifest) raw_metrics_path = run_dir / "benchmark_runs.jsonl" @@ -383,14 +405,19 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non provider = (model.get("overrides") or {}).get("llm_provider") model_config_path = _build_config_for_model(model, run_dir) - manifest["models"].append({ + model_manifest = { "id": model_id, "slug": model_slug, "config_path": str(model_config_path) if model_config_path else None, "config_source": model.get("config_path"), "override_keys": sorted((model.get("overrides") or {}).keys()), "env_keys": sorted((model.get("env") or {}).keys()), - }) + } + if privacy_config is not None: + from src.utils.privacy import sanitize_payload + + model_manifest = sanitize_payload(model_manifest, config=privacy_config) + manifest["models"].append(model_manifest) for prompt in prompt_entries: prompt_id = prompt["id"] @@ -399,7 +426,7 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non env = os.environ.copy() env.update({k: str(v) for k, v in model_env.items()}) benchmark_context = prompt_dir / "benchmark_context.json" - benchmark_context.write_text(json.dumps({ + context_payload = { "prompt_id": prompt_id, "prompt_excerpt": prompt["prompt"][:160], "case_id": prompt.get("case_id"), @@ -408,7 +435,12 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non "novelty_tier": prompt.get("novelty_tier"), "model_id": model_id, "provider": provider, - }, indent=2, default=str)) + } + if privacy_config is not None: + from src.utils.privacy import sanitize_payload + + context_payload = sanitize_payload(context_payload, config=privacy_config) + benchmark_context.write_text(json.dumps(context_payload, indent=2, default=str)) cmd = _build_command(model_config_path, prompt, prompt_dir, run_args, benchmark_context) @@ -481,17 +513,26 @@ def run_model_benchmark(config_path: Path, output_dir: Path, run_name: str | Non "error": error, "stderr_excerpt": stderr[:2000] if stderr else None, } - _write_jsonl(raw_metrics_path, record) + _write_jsonl(raw_metrics_path, record, config=privacy_config) per_run = prompt_dir / "result.json" - per_run.write_text(json.dumps({ + per_run_payload = { "command": cmd, "exit_code": exit_code, "stdout": stdout, "stderr": stderr, "result": result_data, "record": record, - }, indent=2, default=str)) + } + if privacy_config is not None: + from src.utils.privacy import sanitize_payload + + per_run_payload = sanitize_payload(per_run_payload, config=privacy_config) + per_run.write_text(json.dumps(per_run_payload, indent=2, default=str)) + + if privacy_config is not None: + from src.utils.privacy import sanitize_payload + manifest = sanitize_payload(manifest, config=privacy_config) (run_dir / "manifest.json").write_text(json.dumps(manifest, indent=2)) return {"run_dir": str(run_dir), "metrics": str(raw_metrics_path)} diff --git a/tests/unit/test_benchmark_privacy.py b/tests/unit/test_benchmark_privacy.py new file mode 100644 index 0000000..5d03841 --- /dev/null +++ b/tests/unit/test_benchmark_privacy.py @@ -0,0 +1,16 @@ +from types import SimpleNamespace + +from src.benchmark_runner import _write_jsonl + + +def test_benchmark_jsonl_scrubs_prompt(tmp_path): + config = SimpleNamespace( + privacy_mode="shared", + privacy_scrubber="builtin", + privacy_hash_salt="salt", + ) + path = tmp_path / "bench.jsonl" + _write_jsonl(path, {"prompt_excerpt": "email user@example.com"}, config=config) + content = path.read_text() + assert "user@example.com" not in content + assert "[REDACTED:EMAIL]" in content From 15b5ede17bda0f59c566ee6b23401fbc852f01d8 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 04:15:39 -0800 Subject: [PATCH 33/36] Document benchmark privacy run_args --- docs/standards.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/standards.md b/docs/standards.md index 8e83dc0..9c70640 100644 --- a/docs/standards.md +++ b/docs/standards.md @@ -90,6 +90,7 @@ Policy summary: Implementation: - Config: `privacy_mode`, `privacy_scrubber`, `privacy_hash_salt` in `src/config.py`. +- Benchmark runner: `run_args.privacy_mode`, `run_args.privacy_scrubber`, `run_args.privacy_hash_salt` in `src/benchmark_runner.py`. - Scrubber: `src/utils/privacy.py`. - Hooks: `src/utils/llm_calls.py`, `src/utils/metrics.py`, `src/main.py`. From bced396b4964c4df02c07451f049bf38a59c24ff Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 05:02:01 -0800 Subject: [PATCH 34/36] Create remote run dirs via REST mkdir --- src/services/run_superfacility_tools.py | 74 +++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/src/services/run_superfacility_tools.py b/src/services/run_superfacility_tools.py index aff9c89..8de55a8 100644 --- a/src/services/run_superfacility_tools.py +++ b/src/services/run_superfacility_tools.py @@ -549,6 +549,7 @@ def stage_run_directory_sfapi_client( secret: str | None = None, key_path: str | Path | None = None, exclude_names: list[str] | None = None, + upload_host: str = "perlmutter", ) -> None: """ Stage a local run directory to a remote filesystem via sfapi_client. @@ -592,10 +593,17 @@ def stage_run_directory_sfapi_client( target_dir = None if target_dir is None: - raise FileNotFoundError( - f"Remote run directory not found: {remote_run_dir}. " - "Create it on Perlmutter or via the SFAPI client." - ) + try: + ensure_remote_directory_rest( + remote_run_dir=remote_run_dir, + upload_host=upload_host, + ) + [target_dir] = perlmutter.ls(remote_run_dir, directory=True) + except Exception as exc: + raise FileNotFoundError( + f"Remote run directory not found: {remote_run_dir}. " + "Create it on Perlmutter or via the SFAPI client." + ) from exc for item in sorted(local_run_dir.iterdir()): if not item.is_file(): @@ -675,7 +683,10 @@ def ensure_remote_directory_rest( Ensure remote directory exists by uploading a placeholder file via REST. """ import io + import json import requests + import shlex + import time if nersc_session is None: clients = find_nersc_clients() @@ -685,9 +696,64 @@ def ensure_remote_directory_rest( if nersc_session: break + if not nersc_session: + key_path = _resolve_sfapi_key_path() + if key_path: + try: + from authlib.integrations.requests_client import OAuth2Session + from authlib.oauth2.rfc7523 import PrivateKeyJWT + + key_path = Path(key_path) + key_lines = key_path.read_text().splitlines() + if key_lines: + client_id = key_lines[0].strip() + private_key = "\n".join(key_lines[1:]).strip() or None + if client_id and private_key: + token_url = "https://oidc.nersc.gov/c2id/token" + session = OAuth2Session( + client_id, + private_key, + PrivateKeyJWT(token_url), + grant_type="client_credentials", + token_endpoint=token_url, + ) + session.fetch_token() + nersc_session = {"type": "oauth", "session": session} + except Exception: + pass + if not nersc_session: raise RuntimeError("No NERSC session for REST upload staging") + if nersc_session["type"] == "oauth": + try: + command_url = f"https://api.nersc.gov/api/v1.2/utilities/command/{upload_host}" + executable = f"bash -lc {shlex.quote(f'mkdir -p {remote_run_dir}')}" + response = nersc_session["session"].post( + command_url, + data={"executable": executable}, + ) + response.raise_for_status() + payload = response.json() + task_id = payload.get("task_id") + if task_id: + tasks_url = f"https://api.nersc.gov/api/v1.2/tasks/{task_id}" + for _ in range(15): + task_resp = nersc_session["session"].get(tasks_url) + task_resp.raise_for_status() + task = task_resp.json() + if task.get("status") == "completed": + result = task.get("result") + if isinstance(result, str): + result = json.loads(result) + if isinstance(result, dict) and result.get("status") == "ok": + return + raise RuntimeError(f"REST mkdir task failed: {result}") + time.sleep(1) + raise RuntimeError(f"REST mkdir task did not complete: {task_id}") + except Exception: + pass + upload_url = f"https://api.nersc.gov/api/v1.2/utilities/upload/{upload_host}" remote_path = str(Path(remote_run_dir) / ".keep") payload = io.BytesIO(b"") From e85d598acc3a57da9401361857e9a8b3bb9510f1 Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 05:08:43 -0800 Subject: [PATCH 35/36] Allow mkdir fallback when remote output missing --- src/services/run_superfacility_tools.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/services/run_superfacility_tools.py b/src/services/run_superfacility_tools.py index 8de55a8..dbb1ed7 100644 --- a/src/services/run_superfacility_tools.py +++ b/src/services/run_superfacility_tools.py @@ -1098,12 +1098,26 @@ def resolve_remote_output_dir( if candidate not in candidates: candidates.append(candidate) + create_budget = 2 for candidate in candidates: try: list_remote_entries(str(candidate), nersc_session=nersc_session, system=system) except Exception as exc: - logger.debug("Remote output dir check failed for %s: %s", candidate, exc) - continue + if create_budget > 0: + try: + ensure_remote_directory_rest( + remote_run_dir=str(candidate), + nersc_session=nersc_session, + upload_host=system, + ) + create_budget -= 1 + list_remote_entries(str(candidate), nersc_session=nersc_session, system=system) + except Exception as mkdir_exc: + logger.debug("Remote output dir check failed for %s: %s", candidate, mkdir_exc) + continue + else: + logger.debug("Remote output dir check failed for %s: %s", candidate, exc) + continue if sfapi_available and not nersc_session: logger.debug("Using SFAPI credentials for %s; skipping REST mkdir check", candidate) logger.info("Using remote output dir: %s", candidate) From 0b7a5804845280da0392e5281ffbccd53a2443cc Mon Sep 17 00:00:00 2001 From: "Jean M. Sexton" Date: Tue, 17 Feb 2026 06:12:21 -0800 Subject: [PATCH 36/36] Ensure run_benchmark.py sets PYTHONPATH --- scripts/run_benchmark.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/run_benchmark.py b/scripts/run_benchmark.py index 1634e7e..6357dd1 100644 --- a/scripts/run_benchmark.py +++ b/scripts/run_benchmark.py @@ -4,9 +4,12 @@ from __future__ import annotations import argparse +import sys from datetime import datetime, timezone from pathlib import Path +sys.path.append(str(Path(__file__).resolve().parents[1])) + from src.benchmark_runner import ( collect_cases, expand_case_runs,