diff --git a/configs/gridfm_graphkit_hpo.yaml b/configs/gridfm_graphkit_hpo.yaml index d578e19..94b9a28 100644 --- a/configs/gridfm_graphkit_hpo.yaml +++ b/configs/gridfm_graphkit_hpo.yaml @@ -46,7 +46,7 @@ hpo: choices: case118: config: ./examples/config/HGNS_PF_datakit_case118.yaml - data_path: /u/rkie/ + data_path: /u/rki/ static: run_name: run1 diff --git a/docs/iterate2.md b/docs/iterate2.md index c21359e..0a84e26 100644 --- a/docs/iterate2.md +++ b/docs/iterate2.md @@ -6,19 +6,27 @@ Key capabilities: - **Multi-objective optimisation** — extract and optimise several metrics simultaneously (Pareto front) - **Five HPO parameter types** — `float`, `int`, `categorical`, `flag` (store-true), `group` (bundled arg sets) -- **Dynamic GPU count per trial** — `gpu_num` in the HPO space controls the WLM resource request per trial +- **Dynamic GPU count per trial** — `gpu_num` in the HPO space is passed to the WLM plugin via `ITERATE_WLM_GPU_COUNT` - **Null-omission** — `null` in a `categorical` choice causes the flag to be completely absent from the command line -- **Workload manager backends** — LSF, Slurm, or direct local execution +- **WLM plugin system** — any executable (bash, Python, …) can be used as a workload-manager backend; reference implementations for LSF and Vela/OpenShift are in `examples/wlm_plugins/` ## Quick start ```sh iterate2 \ --script train.py \ - --wlm lsf \ - --gpu-count 1 \ - --cpu-count 20 \ - --mem-gb 512 \ + --wlm-plugin examples/wlm_plugins/lsf_plugin.sh \ + --optuna-study-name my_study \ + --optuna-db-path sqlite:///hpo.db \ + --optuna-n-trials 50 \ + --hpo-yaml hpo_space.yaml # wlm: section sets gpu-count, cpu-count, … +``` + +For local execution (no cluster) simply omit `--wlm-plugin`: + +```sh +iterate2 \ + --script train.py \ --optuna-study-name my_study \ --optuna-db-path sqlite:///hpo.db \ --optuna-n-trials 50 \ @@ -33,29 +41,12 @@ iterate2 \ |---|---|---| | `--script` | *(required)* | Training script to execute | | `--root-dir` | `.` | Working directory; derived from `--script` if omitted | -| `--venv` | `.venv` | Virtual-environment directory to activate. Set to empty string to disable | +| `--venv` | *(none)* | Virtual-environment directory to activate. Omit to skip venv activation entirely | | `--interpreter` | `python` | Python interpreter to invoke | | `--param-setter` | `None` | Use setter-style argument passing (see [Setter-style arguments](#setter-style-arguments)) | -| `--wlm` | `none` | Workload manager: `lsf`, `slurm`, `vela`, or `none` | -| `--gpu-count` | `1` | Number of GPUs per trial | -| `--cpu-count` | `4` | Number of CPUs per trial | -| `--mem-gb` | `128` | Memory (GB) per trial | -| `--lsf-gpu-config-string` | `None` | Optional verbatim LSF `-gpu` option string (see [GPU configuration](#gpu-configuration-on-lsf)) | +| `--wlm-plugin` | *(local)* | Path to an executable WLM plugin script. When omitted, trials run locally in the current process | | `--parallelism` | `1` | Number of trials to run in parallel (see [Parallel execution](#parallel-execution)) | -### Vela (OpenShift) options - -Required when `--wlm vela`. - -| Option | Default | Description | -|---|---|---| -| `--vela-job-template` | *(required)* | Path to the Vela job YAML template. `{{HPO_COMMAND}}` in `setupCommands` is replaced per trial | -| `--vela-chart-path` | *(required)* | Path to the `pytorchjob-generator` helm chart directory | -| `--vela-namespace` | *(current context)* | OpenShift/Kubernetes namespace | -| `--vela-cmd-placeholder` | `{{HPO_COMMAND}}` | String in `setupCommands` that is replaced with the HPO-parametrised CLI call | -| `--vela-pod-ready-timeout` | `600` | Seconds to wait for the trial pod to reach Running state | -| `--vela-job-timeout` | `86400` | Seconds to wait (streaming logs) for the job to complete | - ### Optuna options | Option | Default | Description | @@ -189,7 +180,10 @@ Optuna tracks the choice as a single categorical (`dataset = "case2000"`), but t ##### `gpu_num` — dynamic GPU count -The special key `gpu_num` (as `categorical` or `int`) overrides `--gpu-count` for the **WLM resource request** of each individual trial. It is consumed by `iterate2` and never forwarded to the wrapped script. +The special key `gpu_num` (as `categorical` or `int`) is automatically extracted +from the sampled parameters and forwarded to the WLM plugin as +`ITERATE_WLM_GPU_COUNT`. It does **not** appear in the wrapped script's command +line. The WLM plugin uses it to set the cluster resource request for the trial. ```yaml gpu_num: @@ -197,6 +191,9 @@ gpu_num: choices: [1, 2, 4] ``` +Alternatively, set a fixed `gpu-count` in the `wlm:` section of the HPO YAML +when all trials use the same number of GPUs. + ### Static arguments Arguments passed unchanged to every trial. Can be supplied inline or via file: @@ -270,46 +267,59 @@ iterate2 --param-setter set ... --- -## GPU configuration on LSF +## WLM plugin system + +iteate2 has no built-in knowledge of any workload manager. Instead it calls a +user-supplied **plugin script** once per trial. The plugin can be any +executable (bash, Python, …). -When `--wlm lsf` is selected, `iterate2` constructs a `bsub` command for each trial. +### Plugin interface -### Default behaviour +iterate2 calls the plugin with no positional arguments. All information is +delivered through environment variables: -| `--gpu-count` | Generated fragment | +| Variable | Description | |---|---| -| `> 0` (default `1`) | `-gpu num=` | -| `0` | *(no `-gpu` flag, CPU-only job)* | +| `ITERATE_TRIAL_NUMBER` | Integer trial ID | +| `ITERATE_TRIAL_CMD` | Full shell command (with `cd`, `source venv`) – suited for HPC WLMs | +| `ITERATE_TRIAL_CONTAINER_CMD` | Bare CLI invocation (no `cd`/`source`) – suited for container-based systems | +| `ITERATE_OUT_FILE` | File where **stdout** must be written | +| `ITERATE_ERR_FILE` | File where **stderr** must be written | +| `ITERATE_WLM_` | Every key from the YAML `wlm:` section (uppercased, hyphens → underscores) | -### `--lsf-gpu-config-string` +The plugin must exit **0** on success; any other exit code marks the trial as +failed in Optuna. -For advanced LSF GPU scheduling you can supply the full value of the `-gpu` option as a string. When set, it **completely replaces** the auto-generated `-gpu num=` fragment. +### WLM configuration in the HPO YAML -```sh -iterate2 \ - --wlm lsf \ - --lsf-gpu-config-string "num=1:mode=exclusive_process:mps=yes:gmodel=NVIDIAA100_SXM4_80GB" \ - --cpu-count 20 \ - --mem-gb 512 \ - ... -``` +All WLM-specific parameters (GPU count, memory, queue, job template path, …) +live in an optional `wlm:` section of the HPO YAML: -This produces a `bsub` submission resembling: +```yaml +hpo: + lr: { type: float, low: 1e-5, high: 1e-2, log: true } -```sh -bsub -n 20 -R "span[hosts=1]" \ - -gpu "num=1:mode=exclusive_process:mps=yes:gmodel=NVIDIAA100_SXM4_80GB" \ - -M 512G -J hpo_trial_0 \ - "cd /my/root && source .venv/bin/activate && python train.py ..." +static: + epochs: 50 + +# WLM config – forwarded as ITERATE_WLM_* env vars to the plugin +wlm: + gpu-count: 1 + cpu-count: 8 + mem-gb: 32 + lsf-gpu-config: "num=1:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB" ``` -!!! note - `--gpu-count` is still used for the `rusage` memory/CPU reservation string even when `--lsf-gpu-config-string` is set. Set it to match the `num=` value in your GPU string. +### Reference plugins -!!! tip - Use exclusive process mode (`mode=exclusive_process`) together with MPS (`mps=yes`) to share a single A100 across multiple MPS clients while still pinning the job to one physical GPU. +See `examples/wlm_plugins/` for fully documented reference implementations: ---- +| Plugin | WLM | +|---|---| +| `lsf_plugin.sh` | IBM Spectrum LSF (`bsub -K`) | +| `vela_plugin.py` | OpenShift / MLBatch PyTorchJob (`helm template \| oc create`) | + +Writing a SLURM plugin follows the same pattern as `lsf_plugin.sh`. --- @@ -320,7 +330,7 @@ By default `iterate2` runs one trial at a time. Pass `--parallelism N` to run up ```sh iterate2 \ --parallelism 4 \ - --wlm lsf \ + --wlm-plugin examples/wlm_plugins/lsf_plugin.sh \ ... ``` @@ -344,12 +354,12 @@ Output from concurrent trials is prefixed so you can follow individual workers: ### Output files -| WLM | stdout | stderr | -|---|---|---| -| `none` | `trial_N.out` (written by iterate2) | `trial_N.err` (written by iterate2) | -| `lsf` / `slurm` | `trial_N.out` (written by WLM on cluster) | `trial_N.err` (written by WLM on cluster) | +iteate2 tells the plugin where to write output via `ITERATE_OUT_FILE` / +`ITERATE_ERR_FILE`. The plugin is responsible for directing its job's +stdout/stderr to those files. iterate2 extracts metrics from them after the +plugin exits. -For WLM backends the local WLM tool output (bsub/srun status messages) is written to `trial_N_wlm.out` / `trial_N_wlm.err` so the cluster-managed files are never overwritten. +For local execution (no plugin) iterate2 writes them directly: ### SQLite and parallelism diff --git a/examples/bumpy_function.py b/examples/bumpy_function.py index e098d76..5661fdb 100644 --- a/examples/bumpy_function.py +++ b/examples/bumpy_function.py @@ -1,6 +1,23 @@ #!/usr/bin/env python3 -import argparse +""" +Bumpy 3-D multimodal function — called by iterate2 as a trial script. + +iterate2 sets the following environment variables before calling this script: + ITERATE_TRIAL_NUMBER – integer trial index + ITERATE_OUT_FILE – path where metrics must be written + ITERATE_ERR_FILE – path for error logging + ITERATE_PARAM_X – HPO parameter x + ITERATE_PARAM_Y – HPO parameter y + ITERATE_PARAM_Z – HPO parameter z + ITERATE_PARAM_GLOBAL_MU – static parameter (three space-separated floats) + +All output that iterate2 uses to extract metrics must be written to +ITERATE_OUT_FILE (not stdout), one metric per line in "name: value" format. +""" + import math +import os +import sys def bumpy_function_3d( @@ -9,11 +26,11 @@ def bumpy_function_3d( mu_rest, sigma_rest, amps_rest, ): """ - 3D smooth multimodal function with: - - one global optimum = 1 at global_mu = (mx,my,mz) - - multiple local optima < 1 + 3D smooth multimodal function. + - one global optimum = 1 at global_mu = (mx, my, mz) + - multiple local optima < 1 - f(p) = 1 - Π_k (1 - a_k * exp(-||p - mu_k||^2 / (2 sigma_k^2))) + f(p) = 1 - prod_k (1 - a_k * exp(-||p - mu_k||^2 / (2 sigma_k^2))) """ def sqdist(p, q): @@ -21,74 +38,49 @@ def sqdist(p, q): p = (x, y, z) - # Global peak (amplitude = 1) - val = 1.0 - math.exp( - -sqdist(p, global_mu) / (2.0 * global_sigma**2) - ) + val = 1.0 - math.exp(-sqdist(p, global_mu) / (2.0 * global_sigma**2)) - # Local peaks for mu_k, sig_k, a_k in zip(mu_rest, sigma_rest, amps_rest): - term = 1.0 - a_k * math.exp( - -sqdist(p, mu_k) / (2.0 * sig_k**2) - ) - val *= term + val *= 1.0 - a_k * math.exp(-sqdist(p, mu_k) / (2.0 * sig_k**2)) return 1.0 - val if __name__ == "__main__": - parser = argparse.ArgumentParser("Evaluate the 3D bumpy multimodal function.") - - parser.add_argument("--x", type=float, required=True) - parser.add_argument("--y", type=float, required=True) - parser.add_argument("--z", type=float, required=True) - parser.add_argument("--trial-number", type=int, default=0) - - parser.add_argument( - "--global-mu", - type=float, - nargs=3, - default=[0.0, 0.0, 0.0], - metavar=("MX", "MY", "MZ"), - ) - parser.add_argument("--global-sigma", type=float, default=0.7) - - parser.add_argument( - "--mu-rest", - type=float, - nargs="*", - default=[-2.0, 0.0, 0.0, 2.0, 0.0, 0.0], - help="Flat list of (x y z) triplets", - ) - parser.add_argument( - "--sigma-rest", - type=float, - nargs="*", - default=[0.6, 0.6], - ) - parser.add_argument( - "--amps-rest", - type=float, - nargs="*", - default=[0.5, 0.8], - ) - - args = parser.parse_args() - - mu_rest = [ - tuple(args.mu_rest[i:i+3]) - for i in range(0, len(args.mu_rest), 3) - ] - + # --- read parameters from environment ---------------------------------- # + try: + x = float(os.environ["ITERATE_PARAM_X"]) + y = float(os.environ["ITERATE_PARAM_Y"]) + z = float(os.environ["ITERATE_PARAM_Z"]) + global_mu = tuple(map(float, os.environ["ITERATE_PARAM_GLOBAL_MU"].split())) + out_file = os.environ["ITERATE_OUT_FILE"] + trial_num = os.environ.get("ITERATE_TRIAL_NUMBER", "?") + except KeyError as exc: + print(f"ERROR: missing required environment variable {exc}", file=sys.stderr) + sys.exit(1) + + if len(global_mu) != 3: + print("ERROR: ITERATE_PARAM_GLOBAL_MU must contain exactly three floats", file=sys.stderr) + sys.exit(1) + + # Fixed defaults for the local-optima configuration + mu_rest = [(-2.0, 0.0, 0.0), (2.0, 0.0, 0.0)] + sigma_rest = [0.6, 0.6] + amps_rest = [0.5, 0.8] + global_sigma = 0.7 + + # --- evaluate ---------------------------------------------------------- # yval = bumpy_function_3d( - x=args.x, - y=args.y, - z=args.z, - global_mu=tuple(args.global_mu), - global_sigma=args.global_sigma, + x=x, y=y, z=z, + global_mu=global_mu, + global_sigma=global_sigma, mu_rest=mu_rest, - sigma_rest=args.sigma_rest, - amps_rest=args.amps_rest, + sigma_rest=sigma_rest, + amps_rest=amps_rest, ) - print(f'yval: {yval}, trial_number: {args.trial_number}') + # --- write metrics to ITERATE_OUT_FILE --------------------------------- # + with open(out_file, "w") as fh: + fh.write(f"yval: {yval}\n") + + print(f"[trial-{trial_num}] yval={yval:.6f}") diff --git a/examples/bumpy_hpo.yaml b/examples/bumpy_hpo.yaml index 70a16bf..bd0d109 100644 --- a/examples/bumpy_hpo.yaml +++ b/examples/bumpy_hpo.yaml @@ -1,13 +1,15 @@ -# ======================= -# Static parameters - passed to the underlying training script as is -# ======================= - +# HPO search space for the bumpy 3-D multimodal function. +# +# Only three sections are recognised by iterate2: +# metrics: – names to extract from the trial script output +# static: – fixed parameters passed to every trial +# hpo: – parameters Optuna will optimise + +metrics: + - yval + static: - global-mu: 23 42 66 - -# ======================== -# Training hyperparameters - evaluated by optuna and passed to the underlying training script -# ======================== + global-mu: "23 42 66" hpo: x: diff --git a/examples/run_ccc_gridfm_example.sh b/examples/run_ccc_gridfm_example.sh new file mode 100755 index 0000000..af5a8a7 --- /dev/null +++ b/examples/run_ccc_gridfm_example.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# ============================================================================= +# Run gridfm-graphkit HPO on CCC (IBM Spectrum LSF cluster) +# +# iterate2 orchestrates Optuna trials. For each trial it calls +# examples/wlm_plugins/ccc_plugin.sh, which owns all LSF concerns: +# venv activation, CUDA setup, and bsub submission. +# +# Prerequisites +# * bsub / bjobs available on PATH +# * gridfm-graphkit installed in GRIDFM_VENV +# * configs/gridfm_graphkit_hpo.yaml present +# * psycopg2-binary: pip install 'terratorch-iterate[postgresql]' +# * POSTGRES_URL exported +# +# export POSTGRES_URL="postgresql://user:password@host:5432/optuna_studies" +# ============================================================================= + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +: "${POSTGRES_URL:?Please export POSTGRES_URL=postgresql://user:password@host:port/dbname}" + +# Override via env vars if your paths differ from the plugin defaults +export GRIDFM_ROOT="${GRIDFM_ROOT:-/dccstor/terratorch/users/rkie/gitco/gridfm-graphkit}" +export GRIDFM_VENV="${GRIDFM_VENV:-/u/rkie/venvs/venv_gridfm-graphkit}" +export CUDA_BASE="${CUDA_BASE:-/opt/share/cuda-12.8.1}" + +iterate \ + --script "${SCRIPT_DIR}/wlm_plugins/ccc_plugin.sh" \ + --optuna-study-name gridfm_ccc_hpo \ + --optuna-db-path "${POSTGRES_URL}" \ + --parallelism 4 \ + --optuna-n-trials 20 \ + --hpo-yaml "${REPO_ROOT}/configs/gridfm_graphkit_hpo.yaml" diff --git a/examples/run_lsf_gridfm_example_postgres.sh b/examples/run_lsf_gridfm_example_postgres.sh index 197998e..ed28a78 100755 --- a/examples/run_lsf_gridfm_example_postgres.sh +++ b/examples/run_lsf_gridfm_example_postgres.sh @@ -1,42 +1,23 @@ #!/usr/bin/env bash # ============================================================================= -# Example: iterate --wlm lsf with PostgreSQL coordinator for gridfm-graphkit HPO +# Example: iterate2 with LSF job submission and PostgreSQL coordinator # -# Each Optuna trial is submitted as an LSF job that looks like: +# iterate2 is a pure Optuna orchestrator - it knows nothing about LSF. +# For every trial it: +# 1. Samples hyperparameters via Optuna +# 2. Calls --script with all params exposed as ITERATE_PARAM_ env vars +# 3. Reads metrics from ITERATE_OUT_FILE after the script exits # -# bsub -gpu "num=1:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB" \ -# -K -o trial.out -e trial.err \ -# -R "rusage[ngpus=1, cpu=16, mem=32GB]" \ -# -J hpo_trial_ \ -# "export PATH='/opt/share/cuda-12.8.1/bin:$PATH' && \ -# export CUDA_HOME='/opt/share/cuda-12.8.1/' && \ -# export LD_LIBRARY_PATH='/opt/share/cuda-12.8.1/lib64:$LD_LIBRARY_PATH' && \ -# cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && \ -# source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && \ -# gridfm_graphkit train " +# The trial script (examples/wlm_plugins/lsf_plugin.sh) owns ALL +# cluster concerns: venv activation, CUDA setup, bsub submission, etc. # # Prerequisites -# ------------- -# * LSF bsub/bjobs available on PATH -# * gridfm-graphkit installed in the venv below +# * LSF bsub available on PATH # * configs/gridfm_graphkit_hpo.yaml present -# * psycopg2-binary installed: pip install 'terratorch-iterate[postgresql]' -# * POSTGRES_URL set (or hard-code it in --optuna-db-path below) -# -# PostgreSQL coordinator -# ---------------------- -# Using PostgreSQL instead of SQLite / JournalFS is the recommended backend for -# high-parallelism HPO on a cluster: multiple bsub jobs can safely write trial -# results concurrently without lock contention. -# -# Set the connection URL as an env-var to avoid embedding credentials in scripts -# that may end up in version control: +# * psycopg2-binary: pip install 'terratorch-iterate[postgresql]' +# * POSTGRES_URL set # # export POSTGRES_URL="postgresql://user:password@host:5432/optuna_studies" -# -# or pass it inline: -# -# POSTGRES_URL="postgresql://..." bash run_lsf_gridfm_example_postgres.sh # ============================================================================= set -euo pipefail @@ -44,66 +25,12 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" -# --------------------------------------------------------------------------- -# Required: PostgreSQL connection URL -# --------------------------------------------------------------------------- : "${POSTGRES_URL:?Please set POSTGRES_URL=postgresql://user:password@host:port/dbname}" -# --------------------------------------------------------------------------- -# Customisable paths – override via environment variables -# --------------------------------------------------------------------------- -GRIDFM_ROOT="${GRIDFM_ROOT:-/dccstor/terratorch/users/rkie/gitco/gridfm-graphkit}" -GRIDFM_VENV="${GRIDFM_VENV:-/u/rkie/venvs/venv_gridfm-graphkit}" -CUDA_BASE="${CUDA_BASE:-/opt/share/cuda-12.8.1}" -DATA_PATH="${DATA_PATH:-/u/rkie/}" -LOG_DIR="${LOG_DIR:-logs}" - -# --------------------------------------------------------------------------- -# LSF GPU resource string -# Adjust gmodel to the GPU type available on your cluster. -# --------------------------------------------------------------------------- -LSF_GPU_CONFIG="${LSF_GPU_CONFIG:-num=1:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB}" - -# --------------------------------------------------------------------------- -# Pre-run commands executed inside every bsub job before the training script. -# Order matters: -# 1. Export CUDA paths so the GPU driver / toolkit is visible. -# 2. cd into the project root so relative config paths resolve correctly. -# 3. Activate the project venv. -# --------------------------------------------------------------------------- -PRE_RUN="\ -export PATH='${CUDA_BASE}/bin:\$PATH' && \ -export CUDA_HOME='${CUDA_BASE}' && \ -export LD_LIBRARY_PATH='${CUDA_BASE}/lib64:\$LD_LIBRARY_PATH' && \ -cd '${GRIDFM_ROOT}' && \ -source '${GRIDFM_VENV}/bin/activate'" - -# --------------------------------------------------------------------------- -# Static training arguments (not part of the HPO search space). -# These are appended verbatim after the sampled hyperparameters. -# --------------------------------------------------------------------------- -STATIC_ARGS_JSON='{ - "log_dir": "'"${LOG_DIR}"'", - "report-performance": true -}' - -# --------------------------------------------------------------------------- -# Launch iterate -# --------------------------------------------------------------------------- iterate \ - --script "gridfm_graphkit train" \ - --interpreter "" \ - --root-dir "${GRIDFM_ROOT}" \ - --wlm lsf \ - --pre-run-commands "${PRE_RUN}" \ - --no-underscore-to-hyphen \ - --gpu-count 1 \ - --cpu-count 16 \ - --mem-gb 32 \ - #--lsf-gpu-config-string "${LSF_GPU_CONFIG}" \ - --optuna-study-name gridfm_lsf_postgres_hpo \ - --optuna-db-path "${POSTGRES_URL}" \ - --parallelism 4 \ - --optuna-n-trials 20 \ - --hpo-yaml "${REPO_ROOT}/configs/gridfm_graphkit_hpo.yaml" \ - --static-args-json "${STATIC_ARGS_JSON}" + --script "${SCRIPT_DIR}/wlm_plugins/lsf_plugin.sh" \ + --optuna-study-name gridfm_lsf_postgres_hpo \ + --optuna-db-path "${POSTGRES_URL}" \ + --parallelism 4 \ + --optuna-n-trials 20 \ + --hpo-yaml "${REPO_ROOT}/configs/gridfm_graphkit_hpo.yaml" diff --git a/examples/run_setter_example.sh b/examples/run_setter_example.sh index cb62e01..b728dd4 100755 --- a/examples/run_setter_example.sh +++ b/examples/run_setter_example.sh @@ -1,35 +1,24 @@ #!/usr/bin/env bash # ============================================================================= -# Example: iterate2 with --param-setter +# Example: run a local trial script (no cluster, no WLM) # -# Some scripts (e.g. those using Hydra, MMCV, or custom key-value CLIs) do not -# accept traditional named flags: +# iterate2 calls examples/bumpy_function.py directly for each trial. +# All hyperparameters are supplied via ITERATE_PARAM_ environment +# variables. The script is responsible for: +# - reading those variables +# - running the computation +# - writing "metric_name: value" lines to ITERATE_OUT_FILE # -# python script.py --learning-rate 0.001 --batch-size 32 -# -# Instead they expect a setter-style interface: -# -# python script.py --set learning_rate 0.001 --set batch_size 32 -# -# Pass --param-setter to iterate2 to switch to this style. -# Every HPO and static parameter will be forwarded as: -# -- key value -# -# This example uses examples/bumpy_setter.py which accepts --set key value. +# The metrics section in the HPO YAML tells iterate2 which names to look for. # ============================================================================= set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -iterate2 \ - --script "${SCRIPT_DIR}/bumpy_setter.py" \ - --root-dir "${SCRIPT_DIR}" \ - --venv "" \ - --param-setter set \ - --wlm none \ - --optuna-study-name bumpy_setter_study \ - --optuna-db-path "sqlite:///bumpy_setter_hpo.db" \ - --optuna-n-trials 20 \ - --hpo-yaml "${SCRIPT_DIR}/bumpy_setter_hpo.yaml" \ - --metric "yval" +iterate \ + --script "${SCRIPT_DIR}/bumpy_function.py" \ + --optuna-study-name bumpy_local_study \ + --optuna-db-path "sqlite:///bumpy_local_hpo.db" \ + --optuna-n-trials 20 \ + --hpo-yaml "${SCRIPT_DIR}/bumpy_hpo.yaml" diff --git a/examples/run_vela_example.sh b/examples/run_vela_example.sh index bb5cafc..aa5461e 100755 --- a/examples/run_vela_example.sh +++ b/examples/run_vela_example.sh @@ -1,29 +1,22 @@ #!/usr/bin/env bash # ============================================================================= -# Example: iterate2 with --wlm vela (OpenShift / MLBatch PyTorchJob) +# Example: iterate2 with Vela/OpenShift job submission (MLBatch PyTorchJob) +# +# iterate2 is a pure Optuna orchestrator – it knows nothing about Vela/OpenShift. +# For every trial it: +# 1. Samples hyperparameters via Optuna +# 2. Calls --script with all params exposed as ITERATE_PARAM_ env vars +# 3. Reads metrics from ITERATE_OUT_FILE after the script exits +# +# The trial script (examples/wlm_plugins/vela_plugin.py) owns ALL +# cluster concerns: job template rendering, helm/oc submission, waiting, etc. # # Prerequisites # ------------- # * helm CLI installed and on PATH # * oc CLI logged in to the target cluster # * mlbatch/tools/pytorchjob-generator/chart checked out locally -# * The gridfm HPO YAML (configs/gridfm_graphkit_hpo.yaml) present -# -# How it works -# ------------ -# 1. For each Optuna trial iterate2: -# a. Samples hyperparameters from gridfm_graphkit_hpo.yaml -# b. Builds the gridfm_graphkit CLI invocation from static + sampled params -# c. Patches vela_gridfm_template.yaml: -# - appends "-trial-" to jobName (unique resource per trial) -# - sets numGpusPerPod = gpu_num (from the HPO space) -# - replaces {{HPO_COMMAND}} (the actual CLI call) -# d. Runs: helm template -f | oc create -f- -# e. Polls until -master-0 pod is Running -# f. Streams: oc logs -f -master-0 -# (blocks until container exits; output captured for metric extraction) -# g. Checks pod exit code; deletes the PyTorchJob resource -# 2. Metrics are extracted from the captured log and returned to Optuna. +# * configs/gridfm_graphkit_hpo.yaml present # ============================================================================= set -euo pipefail @@ -31,27 +24,10 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" -# Path to the mlbatch pytorchjob-generator helm chart. -# Clone mlbatch first: git clone https://github.com/project-codeflare/mlbatch -CHART_PATH="${MLBATCH_CHART_PATH:-${HOME}/tmp/mlbatch/tools/pytorchjob-generator/chart}" - -NAMESPACE_ARG=() -[[ -n "${OC_NAMESPACE:-}" ]] && NAMESPACE_ARG=(--vela-namespace "${OC_NAMESPACE}") - iterate \ - --script "gridfm_graphkit train" \ - --interpreter "" \ - --wlm vela \ - --vela-job-template "${SCRIPT_DIR}/vela_gridfm_template.yaml" \ - --vela-chart-path "${CHART_PATH}" \ - "${NAMESPACE_ARG[@]}" \ - --vela-cmd-placeholder "{{HPO_COMMAND}}" \ - --vela-pod-ready-timeout 600 \ - --vela-job-timeout 86400 \ - --no-underscore-to-hyphen \ - --gpu-count 1 \ - --optuna-study-name gridfm_vela_hpo \ - --optuna-db-path "js:///gridfm_vela_hpo.journal" \ - --parallelism 16 \ - --optuna-n-trials 20 \ + --script "${SCRIPT_DIR}/wlm_plugins/vela_plugin.py" \ + --optuna-study-name gridfm_vela_hpo \ + --optuna-db-path "js:///gridfm_vela_hpo.journal" \ + --parallelism 16 \ + --optuna-n-trials 20 \ --hpo-yaml "${REPO_ROOT}/configs/gridfm_graphkit_hpo.yaml" diff --git a/examples/run_zuvela_gridfm_example.sh b/examples/run_zuvela_gridfm_example.sh new file mode 100755 index 0000000..7dd2888 --- /dev/null +++ b/examples/run_zuvela_gridfm_example.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# ============================================================================= +# Run gridfm-graphkit HPO on ZuVela (IBM Spectrum LSF cluster) +# +# iterate2 orchestrates Optuna trials. For each trial it calls +# examples/wlm_plugins/zuvela_plugin.sh, which owns all LSF concerns: +# micromamba environment activation and bsub submission. +# +# Prerequisites +# * bsub / bjobs available on PATH +# * gridfm-graphkit installed in the "gridfm" micromamba env +# * configs/gridfm_graphkit_hpo.yaml present +# * psycopg2-binary: pip install 'terratorch-iterate[postgresql]' +# * POSTGRES_URL set: export POSTGRES_URL="postgresql://user:password@host:5432/optuna_studies" +# ============================================================================= + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +# Override via env vars if your paths differ from the plugin defaults +export GRIDFM_ROOT="${GRIDFM_ROOT:-${HOME}/gitco/gridfm-graphkit}" +export MICROMAMBA_ENV="${MICROMAMBA_ENV:-gridfm}" + +# Require a PostgreSQL URL – set via environment before calling this script: +# export POSTGRES_URL="postgresql://user:password@host:5432/optuna_studies" +: "${POSTGRES_URL:?Please set POSTGRES_URL=postgresql://user:password@host:port/dbname}" +STUDY_DB="${POSTGRES_URL}" + +iterate \ + --script "${SCRIPT_DIR}/wlm_plugins/zuvela_plugin.sh" \ + --optuna-study-name gridfm_zuvela_hpo \ + --optuna-db-path "${STUDY_DB}" \ + --parallelism 4 \ + --optuna-n-trials 20 \ + --hpo-yaml "${REPO_ROOT}/configs/gridfm_graphkit_hpo.yaml" diff --git a/examples/wlm_plugins/README.md b/examples/wlm_plugins/README.md new file mode 100644 index 0000000..1dcf41c --- /dev/null +++ b/examples/wlm_plugins/README.md @@ -0,0 +1,85 @@ +# WLM Plugins + +This directory contains reference implementations of **iterate2 WLM plugins** – +executable scripts that submit, wait for, and validate individual Optuna trials +on different workload managers. + +## How the plugin system works + +When you pass `--wlm-plugin /path/to/plugin` to `iterate2`, it is invoked +**once per trial** with a set of environment variables that describe the work +to be done. The plugin is responsible for: + +1. Submitting the trial to the cluster / WLM. +2. Waiting until the job completes. +3. Ensuring trial stdout is written to `$ITERATE_OUT_FILE` and stderr to + `$ITERATE_ERR_FILE` (iterate2 extracts metrics from these files). +4. Exiting **0** on success, non-zero on failure. + +iterate2 marks the Optuna trial as **FAILED** on a non-zero exit code. + +### Environment variables provided by iterate2 + +| Variable | Description | +|---|---| +| `ITERATE_TRIAL_NUMBER` | Integer trial ID | +| `ITERATE_TRIAL_CMD` | Full shell command (with `cd`, `source venv`, etc.) – use for SSH / HPC WLMs | +| `ITERATE_TRIAL_CONTAINER_CMD` | Bare CLI invocation (no `cd`/`source`) – use for container-based WLMs (Vela/k8s) | +| `ITERATE_OUT_FILE` | Path where **stdout** must be written | +| `ITERATE_ERR_FILE` | Path where **stderr** must be written | +| `ITERATE_WLM_` | Every key from the `wlm:` YAML section, uppercased with hyphens→underscores | + +### WLM configuration in the HPO YAML + +All WLM-specific settings (GPU count, queue, job template path, …) live in the +`wlm:` section of the HPO YAML. This keeps the launch script clean: + +```yaml +# my_hpo.yaml +hpo: + lr: + type: float + low: 1e-5 + high: 1e-2 + log: true + +static: + epochs: 50 + +wlm: # keys forwarded as ITERATE_WLM_* env vars + gpu-count: 1 + cpu-count: 8 + mem-gb: 32 +``` + +The corresponding launch script only needs: + +```bash +iterate2 \ + --script train.py \ + --wlm-plugin wlm_plugins/lsf_plugin.sh \ + --hpo-yaml my_hpo.yaml \ + ... +``` + +## Provided plugins + +| Plugin | WLM | Notes | +|---|---|---| +| [`lsf_plugin.sh`](lsf_plugin.sh) | IBM Spectrum LSF | Uses `bsub -K`; reads `gpu-count`, `cpu-count`, `mem-gb`, `lsf-gpu-config`, `queue` from the `wlm:` section | +| [`vela_plugin.py`](vela_plugin.py) | OpenShift / MLBatch PyTorchJob | Uses `helm template \| oc create`; reads `job-template`, `chart-path`, `namespace`, `cmd-placeholder`, `pod-ready-timeout`, `job-timeout` from the `wlm:` section | + +## Writing your own plugin + +Any executable (shell script, Python script, compiled binary) works. Minimal +example that runs the trial locally: + +```bash +#!/usr/bin/env bash +# trivial_plugin.sh – run trial locally, redirect output to the log files +bash -c "${ITERATE_TRIAL_CMD}" \ + > "${ITERATE_OUT_FILE}" 2> "${ITERATE_ERR_FILE}" +``` + +For a SLURM example you can follow the same pattern as `lsf_plugin.sh`, +replacing `bsub` with `srun` / `sbatch`. diff --git a/examples/wlm_plugins/ccc_plugin.sh b/examples/wlm_plugins/ccc_plugin.sh new file mode 100755 index 0000000..14f1d5f --- /dev/null +++ b/examples/wlm_plugins/ccc_plugin.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash +# ============================================================================= +# Trial script for iterate2 - CCC (IBM Spectrum LSF) backend +# +# Called once per Optuna trial. Activates the venv, builds the training +# command from ITERATE_PARAM_* env vars, and submits it via bsub -K. +# +# Environment variables provided by iterate2 +# ------------------------------------------ +# ITERATE_TRIAL_NUMBER integer trial ID +# ITERATE_OUT_FILE metric lines are read from here by iterate2 +# ITERATE_ERR_FILE path for error output +# ITERATE_PARAM_ one variable per HPO + static parameter +# (key uppercased, hyphens -> underscores) +# +# Path overrides (set in the run script or your environment) +# ---------------------------------------------------------- +# GRIDFM_ROOT repo root (default below) +# GRIDFM_VENV Python venv path (default below) +# CUDA_BASE CUDA install root (default below) +# ============================================================================= + +set -euo pipefail + +# -- iterate2 standard vars --------------------------------------------------- +TRIAL_NUMBER="${ITERATE_TRIAL_NUMBER:?ITERATE_TRIAL_NUMBER not set}" +OUT_FILE="${ITERATE_OUT_FILE:?ITERATE_OUT_FILE not set}" +ERR_FILE="${ITERATE_ERR_FILE:?ITERATE_ERR_FILE not set}" + +# -- Paths (override via env) ------------------------------------------------- +GRIDFM_ROOT="${GRIDFM_ROOT:-/dccstor/terratorch/users/rkie/gitco/gridfm-graphkit}" +GRIDFM_VENV="${GRIDFM_VENV:-/u/rkie/venvs/venv_gridfm-graphkit}" +CUDA_BASE="${CUDA_BASE:-/opt/share/cuda-12.8.1}" + +# -- LSF resources ------------------------------------------------------------ +# GPU_COUNT comes from the HPO group param so LSF allocates the right number. +GPU_COUNT="${ITERATE_PARAM_GPU_NUM:-1}" +CPU_COUNT=16 +MEM_GB=32 +MEM_MB=$(( MEM_GB * 1024 )) +GPU_STRING="num=${GPU_COUNT}:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB" +# QUEUE="normal" # uncomment to target a specific queue + +# -- Build training command from ITERATE_PARAM_* vars ------------------------- +# --gpu_num is NOT a gridfm_graphkit flag; GPU count is controlled via bsub -gpu. +TRAIN_CMD="gridfm_graphkit train" +TRAIN_CMD+=" --batch_size ${ITERATE_PARAM_BATCH_SIZE}" +TRAIN_CMD+=" --num_workers ${ITERATE_PARAM_NUM_WORKERS}" +TRAIN_CMD+=" --config ${ITERATE_PARAM_CONFIG}" +TRAIN_CMD+=" --data_path ${ITERATE_PARAM_DATA_PATH}" +# [[ -n "${ITERATE_PARAM_COMPILE:-}" ]] && TRAIN_CMD+=" --compile ${ITERATE_PARAM_COMPILE}" + +# -- Compose full job shell command ------------------------------------------- +JOB_CMD="\ +export PATH='${CUDA_BASE}/bin:\$PATH' && \ +export CUDA_HOME='${CUDA_BASE}' && \ +export LD_LIBRARY_PATH='${CUDA_BASE}/lib64:\$LD_LIBRARY_PATH' && \ +cd '${GRIDFM_ROOT}' && \ +source '${GRIDFM_VENV}/bin/activate' && \ +${TRAIN_CMD}" + +# -- Submit via bsub ---------------------------------------------------------- +# -K : blocks until the job completes (iterate2 runs each trial in a thread) +# -n : CPU slots +# -o/-e: redirect job stdout/stderr to paths iterate2 will read metrics from +bsub \ + -K \ + -gpu "${GPU_STRING}" \ + -n "${CPU_COUNT}" \ + -R "rusage[mem=${MEM_MB}]" \ + -o "${OUT_FILE}" \ + -e "${ERR_FILE}" \ + -J "hpo_trial_${TRIAL_NUMBER}" \ + "${JOB_CMD}" + +echo "[ccc_plugin] trial ${TRIAL_NUMBER} finished" diff --git a/examples/wlm_plugins/lsf_plugin.sh b/examples/wlm_plugins/lsf_plugin.sh new file mode 100755 index 0000000..47069c6 --- /dev/null +++ b/examples/wlm_plugins/lsf_plugin.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +# ============================================================================= +# Trial script for iterate2 - IBM Spectrum LSF backend +# +# iterate2 calls this script once per trial. It owns ALL cluster concerns: +# activating the venv, composing the training command from env vars, and +# submitting the job via bsub. +# +# Environment variables provided by iterate2 +# ------------------------------------------ +# ITERATE_TRIAL_NUMBER integer trial ID +# ITERATE_OUT_FILE path where metric lines must be written +# ITERATE_ERR_FILE path for error output +# ITERATE_PARAM_ one variable per HPO + static parameter +# (key uppercased, hyphens -> underscores) +# +# Customise the sections marked CONFIGURE below. +# +# Exit code +# --------- +# Exit 0 on success. iterate2 marks the Optuna trial FAILED on non-zero exit. +# ============================================================================= + +set -euo pipefail + +# -- Read iterate2 standard vars ----------------------------------------------- +TRIAL_NUMBER="${ITERATE_TRIAL_NUMBER:?ITERATE_TRIAL_NUMBER not set}" +OUT_FILE="${ITERATE_OUT_FILE:?ITERATE_OUT_FILE not set}" +ERR_FILE="${ITERATE_ERR_FILE:?ITERATE_ERR_FILE not set}" + +# -- CONFIGURE: paths ---------------------------------------------------------- +GRIDFM_ROOT="${GRIDFM_ROOT:-/path/to/gridfm-graphkit}" +GRIDFM_VENV="${GRIDFM_VENV:-/path/to/venv}" +CUDA_BASE="${CUDA_BASE:-/opt/share/cuda-12.8.1}" + +# -- CONFIGURE: LSF resources -------------------------------------------------- +GPU_COUNT=1 +CPU_COUNT=16 +MEM_GB=32 +GPU_STRING="num=${GPU_COUNT}:mode=exclusive_process:mps=no" +# QUEUE="normal" # uncomment to target a specific queue + +# -- Build the training command from ITERATE_PARAM_* vars ---------------------- +# Each HPO / static parameter is available as ITERATE_PARAM_. +# Translate them into the CLI flags your training script expects. +TRAIN_CMD="gridfm_graphkit train" +TRAIN_CMD+=" --gpu_num ${ITERATE_PARAM_GPU_NUM}" +TRAIN_CMD+=" --batch_size ${ITERATE_PARAM_BATCH_SIZE}" +TRAIN_CMD+=" --num_workers ${ITERATE_PARAM_NUM_WORKERS}" +TRAIN_CMD+=" --config ${ITERATE_PARAM_CONFIG}" +TRAIN_CMD+=" --data_path ${ITERATE_PARAM_DATA_PATH}" +# Add further params as needed: +# [[ -n "${ITERATE_PARAM_COMPILE:-}" ]] && TRAIN_CMD+=" --compile ${ITERATE_PARAM_COMPILE}" + +# -- Compose the full job command ---------------------------------------------- +JOB_CMD="\ +export PATH='${CUDA_BASE}/bin:\$PATH' && \ +export CUDA_HOME='${CUDA_BASE}' && \ +export LD_LIBRARY_PATH='${CUDA_BASE}/lib64:\$LD_LIBRARY_PATH' && \ +cd '${GRIDFM_ROOT}' && \ +source '${GRIDFM_VENV}/bin/activate' && \ +${TRAIN_CMD}" + +# -- Submit via bsub ----------------------------------------------------------- +# -K blocks until the job finishes (iterate2 runs each trial in a thread). +# -o/-e redirect LSF job stdout/stderr to the files iterate2 will read. +MEM_MB=$(( MEM_GB * 1024 )) + +bsub \ + -K \ + -gpu "${GPU_STRING}" \ + -n "${CPU_COUNT}" \ + -R "rusage[mem=${MEM_MB}]" \ + -o "${OUT_FILE}" \ + -e "${ERR_FILE}" \ + -J "hpo_trial_${TRIAL_NUMBER}" \ + "${JOB_CMD}" + +echo "[lsf_plugin] trial ${TRIAL_NUMBER} finished" diff --git a/examples/wlm_plugins/vela_plugin.py b/examples/wlm_plugins/vela_plugin.py new file mode 100755 index 0000000..67bb1ff --- /dev/null +++ b/examples/wlm_plugins/vela_plugin.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" +iterate2 WLM plugin – Vela / OpenShift PyTorchJob (MLBatch) + +Submits each HPO trial as a PyTorchJob on an OpenShift cluster via +``helm template | oc create``, streams pod logs, checks the exit code, +and cleans up the job resource. + +Environment variables provided by iterate2 +------------------------------------------ + ITERATE_TRIAL_NUMBER integer trial ID + ITERATE_TRIAL_CONTAINER_CMD bare CLI invocation for inside the container + (no ``cd``, no ``source venv`` – use this + one, not ITERATE_TRIAL_CMD) + ITERATE_OUT_FILE file to write trial stdout + ITERATE_ERR_FILE file to write trial stderr + +WLM configuration (from the ``wlm:`` section in the HPO YAML) +------------------------------------------------------------- +All keys from ``wlm:`` are available as ``ITERATE_WLM_`` +(hyphens → underscores). Recognised keys: + + job-template (ITERATE_WLM_JOB_TEMPLATE) REQUIRED path to PyTorchJob helm values YAML + chart-path (ITERATE_WLM_CHART_PATH) REQUIRED path to pytorchjob-generator helm chart + namespace (ITERATE_WLM_NAMESPACE) optional; uses current oc context if omitted + cmd-placeholder (ITERATE_WLM_CMD_PLACEHOLDER) default: {{HPO_COMMAND}} + gpu-count (ITERATE_WLM_GPU_COUNT) default: 1 + pod-ready-timeout (ITERATE_WLM_POD_READY_TIMEOUT) seconds; default: 600 + job-timeout (ITERATE_WLM_JOB_TIMEOUT) seconds; default: 86400 + +Usage in HPO YAML +----------------- + wlm: + job-template: examples/vela_gridfm_template.yaml + chart-path: ~/tmp/mlbatch/tools/pytorchjob-generator/chart + namespace: my-project + cmd-placeholder: "{{HPO_COMMAND}}" + gpu-count: 1 + pod-ready-timeout: 600 + job-timeout: 86400 + +Usage in the launch script +-------------------------- + iterate2 \\ + --wlm-plugin "$(dirname "$0")/wlm_plugins/vela_plugin.py" \\ + --hpo-yaml my_hpo.yaml \\ + --no-underscore-to-hyphen \\ + ... + +Exit code +--------- +Exits 0 on success, 1 on failure. iterate2 marks the Optuna trial as +FAILED on any non-zero exit. +""" + +import os +import re +import subprocess +import sys +import tempfile +import threading +import time +from pathlib import Path +from typing import Optional + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def env(key: str, default: Optional[str] = None, required: bool = False) -> str: + val = os.environ.get(key, default) + if required and not val: + sys.exit(f"[vela_plugin] ERROR: required env var '{key}' is not set") + return val or "" + + +def patch_job_yaml(template_path: str, trial_id: int, gpu_count: int, + container_cmd: str, placeholder: str) -> tuple[str, str]: + """Patch the helm values YAML; return (patched_text, job_name).""" + with open(template_path) as fh: + text = fh.read() + + # jobName → append -trial- + m = re.search(r'^(jobName\s*:\s*["\']?)([^"\'#\n]+)(["\']?)', text, re.MULTILINE) + if not m: + sys.exit(f"[vela_plugin] ERROR: 'jobName' key not found in '{template_path}'") + raw_name = m.group(2).strip() + job_name = f"{raw_name}-trial-{trial_id}" + text = text[:m.start(2)] + m.group(2).replace(raw_name, job_name) + text[m.end(2):] + + # numGpusPerPod → overwrite + text = re.sub( + r'^(numGpusPerPod\s*:\s*)\S+', + lambda m2: f"{m2.group(1)}{gpu_count}", + text, flags=re.MULTILINE, + ) + + # placeholder → container_cmd + if placeholder in text: + text = text.replace(placeholder, container_cmd) + else: + print(f"[vela_plugin] WARNING: placeholder '{placeholder}' not found in template – appending") + text += f"\n - {container_cmd}\n" + + return text, job_name + + +def stream_pipe(pipe, dest_file: str, prefix: str, dest_stream): + with open(dest_file, "w", encoding="utf-8", errors="replace") as fh: + for raw in pipe: + line = raw.decode("utf-8", errors="replace") + fh.write(line) + fh.flush() + dest_stream.write(f"{prefix} {line}") + dest_stream.flush() + + +# ── main ────────────────────────────────────────────────────────────────────── + +def main(): + trial_id = int(env("ITERATE_TRIAL_NUMBER", required=True)) + cmd = env("ITERATE_TRIAL_CONTAINER_CMD", required=True) + out_file = env("ITERATE_OUT_FILE", required=True) + err_file = env("ITERATE_ERR_FILE", required=True) + + template = env("ITERATE_WLM_JOB_TEMPLATE", required=True) + chart = env("ITERATE_WLM_CHART_PATH", required=True) + namespace = env("ITERATE_WLM_NAMESPACE", "") + placeholder = env("ITERATE_WLM_CMD_PLACEHOLDER", "{{HPO_COMMAND}}") + gpu_count = int(env("ITERATE_WLM_GPU_COUNT", "1")) + pod_timeout = int(env("ITERATE_WLM_POD_READY_TIMEOUT", "600")) + job_timeout = int(env("ITERATE_WLM_JOB_TIMEOUT", "86400")) + + ns_args = ["-n", namespace] if namespace else [] + prefix = f"[trial-{trial_id}]" + + # Resolve ~ in paths + template = str(Path(template).expanduser()) + chart = str(Path(chart).expanduser()) + + print(f"{prefix} Patching template {template}") + job_yaml, job_name = patch_job_yaml(template, trial_id, gpu_count, cmd, placeholder) + print(f"{prefix} Job name: {job_name}") + + # Write patched values to a temp file + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", prefix=f"vela_trial_{trial_id}_", delete=False + ) as fh: + fh.write(job_yaml) + tmp_yaml = fh.name + + try: + # ── Submit ──────────────────────────────────────────────────────────── + ns_flag = f"-n {namespace}" if namespace else "" + create_cmd = f"helm template -f {tmp_yaml} {chart} | oc create {ns_flag} -f-" + print(f"{prefix} Submitting: {create_cmd}") + result = subprocess.run(create_cmd, shell=True, capture_output=True, text=True) + sys.stdout.write(result.stdout) + if result.returncode != 0: + sys.stderr.write(result.stderr) + sys.exit(f"{prefix} ERROR: oc create failed (rc={result.returncode})") + + master_pod = f"{job_name}-master-0" + + # ── Wait for pod to appear ──────────────────────────────────────────── + deadline = time.monotonic() + pod_timeout + print(f"{prefix} Waiting for pod {master_pod} …") + while time.monotonic() < deadline: + r = subprocess.run( + ["oc", "get", "pod", master_pod, "--ignore-not-found"] + ns_args, + capture_output=True, text=True, + ) + if master_pod in r.stdout: + break + time.sleep(5) + else: + sys.exit(f"{prefix} ERROR: pod '{master_pod}' did not appear within {pod_timeout}s") + + # ── Wait for pod Ready (best-effort) ───────────────────────────────── + subprocess.run( + ["oc", "wait", f"pod/{master_pod}", "--for=condition=Ready", + f"--timeout={pod_timeout}s"] + ns_args, + capture_output=True, text=True, + ) + + # ── Stream logs ─────────────────────────────────────────────────────── + print(f"{prefix} Streaming logs from {master_pod}") + log_proc = subprocess.Popen( + ["oc", "logs", "-f", master_pod] + ns_args, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + ) + t_out = threading.Thread( + target=stream_pipe, + args=(log_proc.stdout, out_file, prefix, sys.stdout), daemon=True, + ) + t_err = threading.Thread( + target=stream_pipe, + args=(log_proc.stderr, err_file, prefix, sys.stderr), daemon=True, + ) + t_out.start(); t_err.start() + try: + log_proc.wait(timeout=job_timeout) + except subprocess.TimeoutExpired: + log_proc.kill() + t_out.join(); t_err.join() + + # Catch-up read in case of early EOF disconnect + if log_proc.returncode != 0: + catchup = subprocess.run( + ["oc", "logs", "--tail=-1", master_pod] + ns_args, + capture_output=True, text=True, + ) + if catchup.stdout: + with open(out_file, "a") as fh: + fh.write(catchup.stdout) + if catchup.stderr: + with open(err_file, "a") as fh: + fh.write(catchup.stderr) + + # ── Check exit code ─────────────────────────────────────────────────── + exit_code_str = "" + for _ in range(30): + ec = subprocess.run( + ["oc", "get", "pod", master_pod, "-o", + "jsonpath={.status.containerStatuses[0].state.terminated.exitCode}"] + + ns_args, + capture_output=True, text=True, + ) + exit_code_str = ec.stdout.strip() + if exit_code_str.lstrip("-").isdigit(): + break + time.sleep(5) + + exit_code = int(exit_code_str) if exit_code_str.lstrip("-").isdigit() else 0 + print(f"{prefix} Pod exit code: {exit_code}") + if exit_code != 0: + sys.exit(f"{prefix} Trial FAILED: pod exited with code {exit_code}") + + finally: + # ── Cleanup ─────────────────────────────────────────────────────────── + subprocess.run( + ["oc", "delete", "pytorchjob", job_name, "--ignore-not-found"] + ns_args, + capture_output=True, + ) + try: + os.unlink(tmp_yaml) + except OSError: + pass + + +if __name__ == "__main__": + main() diff --git a/examples/wlm_plugins/zuvela_plugin.sh b/examples/wlm_plugins/zuvela_plugin.sh new file mode 100755 index 0000000..fd9bf4b --- /dev/null +++ b/examples/wlm_plugins/zuvela_plugin.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +# ============================================================================= +# Trial script for iterate2 - ZuVela (IBM Spectrum LSF) backend +# +# Called once per Optuna trial. Activates the micromamba environment, builds +# the training command from ITERATE_PARAM_* env vars, and submits via bsub -K. +# +# bsub pattern used (non-interactive, blocking): +# bsub -gpu num= -K \ +# -R "rusage[ngpus=, cpu=, mem=GB]" \ +# -J gridfm_ \ +# 'cd ~/gitco/gridfm-graphkit && source ~/.bashrc && micromamba activate gridfm && ' +# +# Environment variables provided by iterate2 +# ------------------------------------------ +# ITERATE_TRIAL_NUMBER integer trial ID +# ITERATE_OUT_FILE metric lines are read from here by iterate2 +# ITERATE_ERR_FILE path for error output +# ITERATE_PARAM_ one variable per HPO + static parameter +# (key uppercased, hyphens -> underscores) +# +# Path overrides (set in the run script or your environment) +# ---------------------------------------------------------- +# GRIDFM_ROOT repo root (default: ~/gitco/gridfm-graphkit) +# MICROMAMBA_ENV micromamba env name (default: gridfm) +# ============================================================================= + +set -euo pipefail + +# -- iterate2 standard vars --------------------------------------------------- +TRIAL_NUMBER="${ITERATE_TRIAL_NUMBER:?ITERATE_TRIAL_NUMBER not set}" +OUT_FILE="${ITERATE_OUT_FILE:?ITERATE_OUT_FILE not set}" +ERR_FILE="${ITERATE_ERR_FILE:?ITERATE_ERR_FILE not set}" + +# -- Paths (override via env) ------------------------------------------------- +GRIDFM_ROOT="${GRIDFM_ROOT:-${HOME}/gitco/gridfm-graphkit}" +MICROMAMBA_ENV="${MICROMAMBA_ENV:-gridfm}" + +# -- LSF resources ------------------------------------------------------------ +# GPU_COUNT comes from the HPO group param so LSF allocates the right number. +GPU_COUNT="${ITERATE_PARAM_GPU_NUM:-1}" +CPU_COUNT=16 +MEM_GB=32 + +# -- Build training command from ITERATE_PARAM_* vars ------------------------- +# --gpu_num is NOT a gridfm_graphkit flag; GPU count is controlled via bsub -gpu. +TRAIN_CMD="gridfm_graphkit train" +TRAIN_CMD+=" --batch_size ${ITERATE_PARAM_BATCH_SIZE}" +TRAIN_CMD+=" --num_workers ${ITERATE_PARAM_NUM_WORKERS}" +TRAIN_CMD+=" --config ${ITERATE_PARAM_CONFIG}" +TRAIN_CMD+=" --data_path ${ITERATE_PARAM_DATA_PATH}" +# [[ -n "${ITERATE_PARAM_COMPILE:-}" ]] && TRAIN_CMD+=" --compile ${ITERATE_PARAM_COMPILE}" +[[ -n "${ITERATE_PARAM_RUN_NAME:-}" ]] && TRAIN_CMD+=" --run_name ${ITERATE_PARAM_RUN_NAME}" +[[ -n "${ITERATE_PARAM_LOG_DIR:-}" ]] && TRAIN_CMD+=" --log_dir ${ITERATE_PARAM_LOG_DIR}" +[[ "${ITERATE_PARAM_REPORT_PERFORMANCE:-}" == "True" ]] && TRAIN_CMD+=" --report-performance" + +# -- Compose full job shell command ------------------------------------------- +# source ~/.bashrc to initialise micromamba shell hooks +JOB_CMD="\ +cd '${GRIDFM_ROOT}' && \ +source ~/.bashrc && \ +micromamba activate '${MICROMAMBA_ENV}' && \ +${TRAIN_CMD}" + +# -- Submit via bsub ---------------------------------------------------------- +# -K : blocks until the job completes (iterate2 runs each trial in a thread) +# -o/-e: write output to paths iterate2 will scan for metrics +bsub \ + -K \ + -gpu "num=${GPU_COUNT}" \ + -R "rusage[ngpus=${GPU_COUNT}, cpu=${CPU_COUNT}, mem=${MEM_GB}GB]" \ + -o "${OUT_FILE}" \ + -e "${ERR_FILE}" \ + -J "gridfm_trial_${TRIAL_NUMBER}" \ + "${JOB_CMD}" + +echo "[zuvela_plugin] trial ${TRIAL_NUMBER} finished" diff --git a/pyproject.toml b/pyproject.toml index d53e340..3237f3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ include = ["terratorch_iterate*"] [project] name = "terratorch-iterate" -version = "0.3" +version = "0.4" requires-python = ">= 3.11" description = "A terratorch's plugin for benchmarking and hyperparameter optimization" authors = [ diff --git a/terratorch_iterate/iterate2/_iterate2.py b/terratorch_iterate/iterate2/_iterate2.py index e34e3ed..8d2e8cf 100644 --- a/terratorch_iterate/iterate2/_iterate2.py +++ b/terratorch_iterate/iterate2/_iterate2.py @@ -1,517 +1,175 @@ #!/usr/bin/env python3 +""" +iterate2 – minimal Optuna HPO launcher. + +iterate2 does exactly three things: + 1. Load the HPO search space and static parameters from a YAML file. + 2. For every Optuna trial, sample parameters and call a user-provided + script with those parameters exposed as environment variables. + 3. After the script exits, extract one or more metrics from the log file + the script wrote and return them to Optuna. + +The user script is fully in charge of *how* the trial runs – activating a +virtualenv, submitting a bsub/sbatch job, running locally, launching a +container, etc. iterate2 has no opinion on any of that. + +Environment variables passed to the script +------------------------------------------ + ITERATE_TRIAL_NUMBER integer trial ID (0-based) + ITERATE_OUT_FILE path the script must write its stdout to + ITERATE_ERR_FILE path the script must write its stderr to + ITERATE_PARAM_ one variable per sampled + static parameter + (key uppercased, hyphens and spaces → underscores) + +HPO YAML format +--------------- + metrics: # list of metric names to extract from ITERATE_OUT_FILE + - val_loss + - accuracy + + static: # fixed parameters, forwarded as-is every trial + epochs: 50 + dataset: /data/my_dataset + + hpo: # parameters Optuna will optimise + learning_rate: + type: float + low: 1e-5 + high: 1e-2 + log: true + batch_size: + type: categorical + choices: [16, 32, 64] +""" import argparse -import json import logging import os +import re import subprocess import sys -import re -import tempfile import threading -import time from pathlib import Path -from typing import Dict, Any, Optional, Literal, List +from typing import List, Optional import optuna import yaml from terratorch_iterate.iterate2.plugin.coordinator import load_builtin_plugins, resolve_storage -# Load built-in coordinator plugins (sqlite, journalfs, postgresql) load_builtin_plugins() logger = logging.getLogger("iterate2") -# ============================================================ -# CLI -# ============================================================ +# ─── CLI ───────────────────────────────────────────────────────────────────── def parse_args(): - parser = argparse.ArgumentParser( - description="Generic Optuna HPO launcher with Multi-Metric support" + p = argparse.ArgumentParser( + description="Minimal Optuna HPO launcher – calls a user script per trial" ) + p.add_argument("--script", required=True, + help="Executable to call for each trial") + p.add_argument("--hpo-yaml", required=True, + help="YAML file with 'hpo:', 'static:', and 'metrics:' sections") + p.add_argument("--optuna-study-name", required=True) + p.add_argument("--optuna-db-path", required=True, + help="Optuna storage URL (sqlite:///hpo.db, js:///journal.log, postgresql://…)") + p.add_argument("--optuna-n-trials", type=int, default=100) + p.add_argument("--parallelism", type=int, default=1, + help="Parallel trials (threads). Use PostgreSQL/JournalStorage for >4.") + p.add_argument("--log-level", default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"]) + return p.parse_args() + + +# ─── YAML LOADING ──────────────────────────────────────────────────────────── + +def load_yaml(path: str) -> dict: + with open(path) as f: + return yaml.safe_load(f) or {} + +def load_hpo_space(data: dict) -> dict: + space = data.get("hpo", {}) + logger.info("HPO space: %d param(s): %s", len(space), list(space.keys())) + return space - # ------------------------ - # Execution config - # ------------------------ - parser.add_argument("--script", required=True, help="Training script to execute") - parser.add_argument("--root-dir", default=None, help="Root dir (derived if omitted)") - parser.add_argument("--venv", default=".venv", help="Virtualenv dir (shortcut for source /bin/activate)") - parser.add_argument( - "--pre-run-commands", - default=None, - help=( - "Shell commands to run before the training script, joined with ' && '. " - "Useful for sourcing bashrc, activating conda/mamba envs, loading modules, etc. " - "Example: 'source ~/.bashrc && micromamba activate gridfm'. " - "When set, --venv is ignored." - ), - ) - parser.add_argument("--interpreter", default="python", help="Interpreter to use") - parser.add_argument("--param-setter", type=str, default=None) - parser.add_argument("--wlm", choices=["lsf", "slurm", "openshift", "vela", "none"], default="none") - parser.add_argument("--gpu-count", type=int, default=1) - parser.add_argument("--cpu-count", type=int, default=4) - parser.add_argument("--mem-gb", type=int, default=128) - parser.add_argument("--lsf-gpu-config-string", type=str, default=None) - - # ------------------------ - # Vela / OpenShift options - # ------------------------ - parser.add_argument( - "--vela-job-template", - type=str, - default=None, - help="Path to the Vela job YAML template (required when --wlm vela)", - ) - parser.add_argument( - "--vela-chart-path", - type=str, - default=None, - help="Path to the helm chart directory (required when --wlm vela)", - ) - parser.add_argument( - "--vela-namespace", - type=str, - default=None, - help="OpenShift/Kubernetes namespace (uses current context if omitted)", - ) - parser.add_argument( - "--vela-cmd-placeholder", - type=str, - default="{{HPO_COMMAND}}", - help="String in the job template's setupCommands that is replaced with the HPO command (default: '{{HPO_COMMAND}}')", - ) - parser.add_argument( - "--vela-pod-ready-timeout", - type=int, - default=600, - help="Seconds to wait for the trial pod to reach Running state (default: 600)", - ) - parser.add_argument( - "--vela-job-timeout", - type=int, - default=86400, - help="Seconds to wait for the trial job to complete (default: 86400 = 24 h)", - ) - parser.add_argument( - "--parallelism", - type=int, - default=1, - help="Number of trials to run in parallel (default: 1 = sequential). " - "Each parallel trial runs in its own thread. " - "For SQLite storage, values >4 may cause locking contention; " - "consider PostgreSQL for high parallelism.", - ) - parser.add_argument( - "--no-underscore-to-hyphen", - dest="underscore_to_hyphen", - action="store_false", - default=True, - help="Do not convert underscores to hyphens in arg names (default: convert)", - ) +def load_static(data: dict) -> dict: + static = data.get("static", {}) + logger.info("Static params: %d key(s): %s", len(static), list(static.keys())) + return static - # ------------------------ - # Optuna config - # ------------------------ - parser.add_argument("--optuna-study-name", required=True) - parser.add_argument("--optuna-db-path", required=True) - parser.add_argument("--optuna-n-trials", type=int, default=100) - - # ------------------------ - # HPO space - # ------------------------ - parser.add_argument("--hpo-json", type=str, default=None) - parser.add_argument("--hpo-yaml", type=str, default=None) - parser.add_argument("--static-args-json", type=str, default=None) - parser.add_argument("--static-args-yaml", type=str, default=None) - - # ------------------------ - # Metric extraction (Supports comma-separated list) - # ------------------------ - parser.add_argument( - "--metrics", - default="score_combined", - help="Comma-separated metric names to extract (e.g. score_linear_acc,score_modality_leak,score_combined)", - ) +def load_metrics(data: dict, fallback: str = "score") -> List[str]: + raw = data.get("metrics", None) + if raw is None: + logger.warning("No 'metrics:' key in YAML – defaulting to '%s'", fallback) + return [fallback] + if isinstance(raw, list): + return [str(m).strip() for m in raw] + return [m.strip() for m in str(raw).split(",")] - # ------------------------ - # Logging - # ------------------------ - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR"], - help="Logging verbosity (default: INFO)", - ) - return parser.parse_args() - - -# ============================================================ -# HELPERS & COMMAND BUILDERS -# ============================================================ - -def resolve_paths(script: str, root_dir: Optional[str]): - if root_dir is None: root_dir = '.' - resolved = Path(root_dir).resolve() - logger.debug("Resolved root_dir '%s' → '%s'", root_dir, resolved) - return script, resolved - -def build_launcher_command(wlm, cmd, trial_id, out_file, err_file, gpu_count, cpu_count, mem_gb, lsf_gpu_config_string): - logger.debug("Building launcher command: wlm=%s gpu_count=%d cpu_count=%d mem_gb=%d", wlm, gpu_count, cpu_count, mem_gb) - if wlm == "lsf": - gpu_fragment = f"-gpu \"{lsf_gpu_config_string}\"" if lsf_gpu_config_string else (f"-gpu num={gpu_count}" if gpu_count > 0 else "") - launcher = f"bsub {gpu_fragment} -K -o {out_file} -e {err_file} -R \"rusage[ngpus={gpu_count}, cpu={cpu_count}, mem={mem_gb}GB]\" -J hpo_trial_{trial_id} \"{cmd}\"" - elif wlm == "slurm": - launcher = f"srun --gres=gpu:{gpu_count} --cpus-per-task={cpu_count} --mem={mem_gb}G --job-name=hpo_trial_{trial_id} --output={out_file} --error={err_file} bash -c \"{cmd}\"" - elif wlm == "none": - # No embedded redirect: run_and_stream() captures stdout/stderr via PIPE - # and writes to out_file/err_file itself. - launcher = f'bash -c "{cmd}"' - elif wlm in ("vela",): - # Vela uses a separate submission flow; this function is not called for it. - raise ValueError("build_launcher_command must not be called for wlm='vela'; use build_vela_job_yaml + run_vela_trial instead.") - else: - raise ValueError(f"Unknown WLM: {wlm}") - logger.debug("Launcher command: %s", launcher) - return launcher - -def build_shell_command(interpreter, root_dir, script_path, venv, script_args, param_setter, underscore_to_hyphen=True, pre_run_commands=None): - parts = [f"cd {root_dir}"] - if pre_run_commands: - parts.append(pre_run_commands) - logger.debug("Pre-run commands: %s", pre_run_commands) - elif venv: - parts.append(f"source {venv}/bin/activate") - logger.debug("Activating venv: %s", venv) - arg_list = [f"{interpreter} {script_path}"] - for key, value in script_args.items(): - arg_name = key.replace("_", "-") if underscore_to_hyphen else key - if value is None: - logger.debug("Skipping arg '%s': value is None (flag omitted)", key) - continue - if param_setter: - if isinstance(value, bool): - if value: - arg_list.append(f"--{param_setter} {key}") - logger.debug("Setter flag: --%s %s (store_true)", param_setter, key) - else: - logger.debug("Skipping flag '%s': False → omitted", key) - else: - arg_list.append(f"--{param_setter} {key} {value}") - logger.debug("Setter arg: --%s %s %s", param_setter, key, value) - else: - if isinstance(value, bool): - if value: - arg_list.append(f"--{arg_name}") - logger.debug("Flag present: --%s", arg_name) - else: - logger.debug("Skipping flag '--%s': False → omitted", arg_name) - else: - arg_list.append(f"--{arg_name} {value}") - logger.debug("Arg: --%s %s", arg_name, value) - cmd = " && ".join(parts + [" ".join(arg_list)]) - logger.debug("Shell command: %s", cmd) - return cmd - - -def build_container_command(interpreter: str, script_path: str, script_args: dict, param_setter: Optional[str], underscore_to_hyphen: bool = True) -> str: - """Build a bare CLI invocation suitable for running inside a container. - - Unlike :func:`build_shell_command` this function does **not** prepend - ``cd`` or ``source venv`` – those are not needed (or available) inside an - already-running container image. - """ - prefix = f"{interpreter} " if interpreter else "" - arg_list = [f"{prefix}{script_path}".strip()] - for key, value in script_args.items(): - arg_name = key.replace("_", "-") if underscore_to_hyphen else key - if value is None: - logger.debug("Container cmd: skipping '%s' (None)", key) - continue - if param_setter: - if isinstance(value, bool): - if value: - arg_list.append(f"--{param_setter} {key}") - else: - pass # omit - else: - arg_list.append(f"--{param_setter} {key} {value}") - else: - if isinstance(value, bool): - if value: - arg_list.append(f"--{arg_name}") - # else omit - else: - arg_list.append(f"--{arg_name} {value}") - cmd = " ".join(arg_list) - logger.debug("Container command: %s", cmd) - return cmd - - -def build_vela_job_yaml( - template_path: str, - trial_id: int, - gpu_count: int, - container_cmd: str, - placeholder: str, -) -> tuple[str, str]: - """Load *template_path* as raw text, inject HPO parameters, return ``(yaml_str, job_name)``. - - All modifications are done via targeted regex/string substitutions on the raw - YAML text so that multi-line block scalars (e.g. awk pipelines), single-quoted - strings, and other constructs that PyYAML would mangle on a load→dump round-trip - are preserved exactly as written in the template. - - Changes applied: - * ``jobName`` gets a ``-trial-{trial_id}`` suffix (unique Kubernetes resource). - * ``numGpusPerPod`` is overwritten with *gpu_count*. - * The *placeholder* string inside ``setupCommands`` is replaced with - *container_cmd* in-place, preserving any surrounding wrapper (e.g. awk pipeline). - """ - with open(template_path, "r") as fh: - text = fh.read() - - # ── jobName ────────────────────────────────────────────────────────────── - job_name_match = re.search(r'^(jobName\s*:\s*["\']?)([^"\'#\n]+)(["\']?)', text, re.MULTILINE) - if not job_name_match: - raise ValueError(f"'jobName' key not found in template '{template_path}'") - raw_name = job_name_match.group(2).strip() - job_name = f"{raw_name}-trial-{trial_id}" - text = ( - text[:job_name_match.start(2)] - + job_name_match.group(2).replace(raw_name, job_name) - + text[job_name_match.end(2):] - ) - logger.debug("Vela trial %d: jobName → %s", trial_id, job_name) - - # ── numGpusPerPod ──────────────────────────────────────────────────────── - text = re.sub( - r'^(numGpusPerPod\s*:\s*)\S+', - lambda m: f"{m.group(1)}{gpu_count}", - text, - flags=re.MULTILINE, - ) - logger.debug("Vela trial %d: numGpusPerPod → %d", trial_id, gpu_count) - - # ── placeholder substitution ───────────────────────────────────────────── - if placeholder in text: - text = text.replace(placeholder, container_cmd) - logger.debug("Vela trial %d: substituted placeholder '%s'", trial_id, placeholder) - else: - logger.warning( - "Vela trial %d: placeholder '%s' not found in template '%s' – appending command", - trial_id, placeholder, template_path, - ) - text += f"\n - {container_cmd}\n" - - return text, job_name - - -def _oc(*args, namespace: Optional[str] = None, check: bool = True, capture: bool = False): - """Run an ``oc`` sub-command, optionally capturing output.""" - cmd = ["oc"] + list(args) - if namespace: - cmd += ["-n", namespace] - logger.debug("oc command: %s", " ".join(cmd)) - if capture: - return subprocess.run(cmd, check=check, capture_output=True, text=True) - return subprocess.run(cmd, check=check) - - -def run_vela_trial( - trial_id: int, - job_yaml: str, - chart_path: str, - job_name: str, - namespace: Optional[str], - out_file: str, - err_file: str, - pod_ready_timeout: int, - job_timeout: int, -) -> None: - """Submit a Vela/OpenShift PyTorchJob, stream its logs, and wait for completion. - - Steps - ----- - 1. Write *job_yaml* to a temp file. - 2. ``helm template -f | oc create [-n ] -f-`` - 3. Poll until the master pod (``-master-0``) appears. - 4. ``oc logs -f `` – streams every line to stdout **and** *out_file*. - 5. After streaming ends, check the pod's terminated exit-code. - Non-zero → raise :class:`subprocess.CalledProcessError`. - 6. Cleanup: delete the PyTorchJob resource. - """ - ns_args = ["-n", namespace] if namespace else [] - prefix = f"[trial-{trial_id}]" +# ─── OPTUNA PARAM SAMPLING ─────────────────────────────────────────────────── - # Write temp YAML - with tempfile.NamedTemporaryFile( - mode="w", - suffix=".yaml", - prefix=f"vela_trial_{trial_id}_", - delete=False, - ) as fh: - fh.write(job_yaml) - tmp_yaml = fh.name - logger.debug("Vela trial %d: temp YAML written to %s", trial_id, tmp_yaml) - - try: - # ── 1. Submit ────────────────────────────────────────────────────────── - ns_flag = f"-n {namespace}" if namespace else "" - create_cmd = ( - f"helm template -f {tmp_yaml} {chart_path}" - f" | oc create {ns_flag} -f-" - ) - logger.info("Trial %d: submitting Vela job → %s", trial_id, create_cmd) - result = subprocess.run(create_cmd, shell=True, capture_output=True, text=True) - with _print_lock: - sys.stdout.write(f"{prefix} {result.stdout}") - sys.stdout.flush() - if result.returncode != 0: - raise RuntimeError( - f"Vela trial {trial_id}: oc create failed (rc={result.returncode}):\n" - f"{result.stderr}" - ) - logger.info("Trial %d: job '%s' created", trial_id, job_name) - - # ── 2. Wait for master pod to appear ────────────────────────────────── - master_pod = f"{job_name}-master-0" - deadline = time.monotonic() + pod_ready_timeout - logger.info("Trial %d: waiting for pod '%s' to appear (timeout %ds)…", trial_id, master_pod, pod_ready_timeout) - while time.monotonic() < deadline: - r = subprocess.run( - ["oc", "get", "pod", master_pod, "--ignore-not-found"] + ns_args, - capture_output=True, text=True, - ) - if master_pod in r.stdout: - logger.debug("Trial %d: pod '%s' found", trial_id, master_pod) - break - time.sleep(5) - else: - raise TimeoutError( - f"Vela trial {trial_id}: pod '{master_pod}' did not appear within {pod_ready_timeout}s" - ) - - # ── 3. Wait for pod to be Running/Succeeded ─────────────────────────── - logger.info("Trial %d: waiting for pod '%s' to be Running…", trial_id, master_pod) - wait_cmd = ( - ["oc", "wait", f"pod/{master_pod}", - "--for=condition=Ready", - f"--timeout={pod_ready_timeout}s"] - + ns_args - ) - wr = subprocess.run(wait_cmd, capture_output=True, text=True) - # oc wait returns non-zero if the pod is already Completed (no Ready condition); - # that's fine – the logs are still accessible. - logger.debug("Trial %d: oc wait rc=%d stderr=%s", trial_id, wr.returncode, wr.stderr.strip()) - - # ── 4. Stream logs ──────────────────────────────────────────────────── - log_cmd = ["oc", "logs", "-f", master_pod] + ns_args - logger.info("Trial %d: streaming logs from '%s'", trial_id, master_pod) - log_proc = subprocess.Popen( - log_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - t_out = threading.Thread( - target=_stream_pipe, - args=(log_proc.stdout, out_file, trial_id, "stdout", sys.stdout), - daemon=True, - ) - t_err = threading.Thread( - target=_stream_pipe, - args=(log_proc.stderr, err_file, trial_id, "stderr", sys.stderr), - daemon=True, - ) - t_out.start() - t_err.start() - log_proc.wait(timeout=job_timeout) - t_out.join() - t_err.join() - logger.debug("Trial %d: log stream ended (rc=%d)", trial_id, log_proc.returncode) - - # ── 4b. If oc logs exited early (e.g. "unexpected EOF"), the pod may - # still be running. Re-attach the log stream and wait for it to - # finish so we capture the full output and don't delete a live job. - if log_proc.returncode != 0: - logger.warning( - "Trial %d: oc logs exited with rc=%d (possible EOF disconnect) – " - "waiting for pod to terminate before reading exit code", - trial_id, log_proc.returncode, - ) - # Wait for pod phase Succeeded or Failed (container terminated). - oc_wait_phase = subprocess.run( - ["oc", "wait", f"pod/{master_pod}", - "--for=jsonpath={.status.phase}=Succeeded", - f"--timeout={job_timeout}s"] - + ns_args, - capture_output=True, text=True, - ) - if oc_wait_phase.returncode != 0: - # Pod may have Failed; try that phase too. - subprocess.run( - ["oc", "wait", f"pod/{master_pod}", - "--for=jsonpath={.status.phase}=Failed", - f"--timeout=30s"] - + ns_args, - capture_output=True, text=True, - ) - # Re-stream any log lines written after the disconnect into the same files. - catchup = subprocess.run( - ["oc", "logs", "--tail=-1", master_pod] + ns_args, - capture_output=True, text=True, - ) - if catchup.stdout: - with open(out_file, "a", encoding="utf-8", errors="replace") as fh: - fh.write(catchup.stdout) - if catchup.stderr: - with open(err_file, "a", encoding="utf-8", errors="replace") as fh: - fh.write(catchup.stderr) - - # ── 5. Check pod exit code ──────────────────────────────────────────── - # Poll until the pod has a terminated exit code (handles the race - # between oc-logs EOF and pod termination being recorded in the API). - exit_code_str = "" - for _attempt in range(30): - ec_result = subprocess.run( - ["oc", "get", "pod", master_pod, "-o", - "jsonpath={.status.containerStatuses[0].state.terminated.exitCode}"] - + ns_args, - capture_output=True, text=True, - ) - exit_code_str = ec_result.stdout.strip() - if exit_code_str.lstrip("-").isdigit(): - break - logger.debug("Trial %d: exit code not yet available, retrying in 5 s…", trial_id) - time.sleep(5) - exit_code = int(exit_code_str) if exit_code_str.lstrip("-").isdigit() else 0 - logger.info("Trial %d: pod exit code = %s", trial_id, exit_code) - if exit_code != 0: - logger.warning("Trial %d: pod exited with code %d – marking trial as pruned", trial_id, exit_code) - raise optuna.exceptions.TrialPruned(f"pod exited with code {exit_code}") - - finally: - # ── 6. Cleanup – delete the job ─────────────────────────────────────── - logger.debug("Trial %d: deleting PyTorchJob '%s'", trial_id, job_name) - subprocess.run( - ["oc", "delete", "pytorchjob", job_name, "--ignore-not-found"] + ns_args, - capture_output=True, - ) +def suggest(trial: optuna.Trial, name: str, spec: dict): + t = spec["type"] + if t == "float": + return trial.suggest_float(name, float(spec["low"]), float(spec["high"]), + log=spec.get("log", False)) + if t == "int": + return trial.suggest_int(name, int(spec["low"]), int(spec["high"]), + log=spec.get("log", False)) + if t == "categorical": + return trial.suggest_categorical(name, spec["choices"]) + if t == "flag": + return trial.suggest_categorical(name, [True, False]) + if t == "group": + return trial.suggest_categorical(name, list(spec["choices"].keys())) + raise ValueError(f"Unknown param type '{t}' for '{name}'") + + +# ─── METRIC EXTRACTION ─────────────────────────────────────────────────────── + +def extract_metrics(out_file: str, err_file: str, metric_names: List[str]) -> List[float]: + """Read both output files and extract the last occurrence of each metric.""" + text = "" + for path in (out_file, err_file): try: - os.unlink(tmp_yaml) - except OSError: + text += Path(path).read_text(encoding="utf-8", errors="ignore") + "\n" + except FileNotFoundError: pass -# ============================================================ -# PARALLEL STREAMING RUNNER -# ============================================================ + results = [] + for metric in metric_names: + # Support name#N for Nth-occurrence selection (0-based) + occurrence: Optional[int] = None + bare = metric + m = re.fullmatch(r'(.+)#(\d+)', metric) + if m: + bare, occurrence = m.group(1), int(m.group(2)) + + pattern = re.compile( + rf"(?:\[\w+\]\s*)?{re.escape(bare)}\s*[:=│]\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)" + ) + matches = pattern.findall(text) + if not matches: + logger.warning("Metric '%s' not found – defaulting to 0.0", metric) + results.append(0.0) + elif occurrence is not None: + if occurrence >= len(matches): + logger.warning("Metric '%s' occurrence #%d not found – defaulting to 0.0", + metric, occurrence) + results.append(0.0) + else: + results.append(float(matches[occurrence])) + else: + results.append(float(matches[-1])) + return results + + +# ─── SCRIPT RUNNER ─────────────────────────────────────────────────────────── _print_lock = threading.Lock() -def _stream_pipe(pipe, dest_file, trial_id: int, stream_name: str, dest_stream): - """Read lines from *pipe*, write to *dest_file* and print prefixed to *dest_stream*.""" +def _stream(pipe, dest_file: str, trial_id: int, dest_stream): prefix = f"[trial-{trial_id}]" with open(dest_file, "w", encoding="utf-8", errors="replace") as fh: for raw in pipe: @@ -522,196 +180,24 @@ def _stream_pipe(pipe, dest_file, trial_id: int, stream_name: str, dest_stream): dest_stream.write(f"{prefix} {line}") dest_stream.flush() -def run_and_stream(launcher_cmd: str, trial_id: int, out_file: str, err_file: str, wlm: str): - """ - Run *launcher_cmd* in a shell. - - For ``wlm='none'``: captures stdout and stderr via PIPE, streams every line - to the main process stdout/stderr (prefixed with ``[trial-N]``), and also - writes them to *out_file* / *err_file* for later metric extraction. - - For WLM backends (lsf, slurm, …): the WLM tool itself manages the output - files on the cluster. The local subprocess output (WLM status messages, - errors) is still streamed with the same prefix so parallel workers are - distinguishable. - """ - logger.debug("Trial %d: run_and_stream wlm=%s cmd=%s", trial_id, wlm, launcher_cmd) +def run_script(script: str, env: dict, trial_id: int, out_file: str, err_file: str): + """Run *script* with *env*, stream output, raise on non-zero exit.""" + logger.info("Trial %d: calling %s", trial_id, script) proc = subprocess.Popen( - launcher_cmd, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + [script], env=env, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - - if wlm == "none": - # Full capture: write to files AND stream to console - t_out = threading.Thread( - target=_stream_pipe, - args=(proc.stdout, out_file, trial_id, "stdout", sys.stdout), - daemon=True, - ) - t_err = threading.Thread( - target=_stream_pipe, - args=(proc.stderr, err_file, trial_id, "stderr", sys.stderr), - daemon=True, - ) - else: - # WLM manages the cluster output files (out_file/err_file) itself. - # Stream only the local WLM tool output (bsub/srun status messages) - # to console; write it to separate local files to avoid clobbering the - # cluster-managed trial output files. - wlm_out = out_file.replace(".out", "_wlm.out") - wlm_err = err_file.replace(".err", "_wlm.err") - t_out = threading.Thread( - target=_stream_pipe, - args=(proc.stdout, wlm_out, trial_id, "wlm-stdout", sys.stdout), - daemon=True, - ) - t_err = threading.Thread( - target=_stream_pipe, - args=(proc.stderr, wlm_err, trial_id, "wlm-stderr", sys.stderr), - daemon=True, - ) - - t_out.start() - t_err.start() + import threading as _t + t_out = _t.Thread(target=_stream, args=(proc.stdout, out_file, trial_id, sys.stdout), daemon=True) + t_err = _t.Thread(target=_stream, args=(proc.stderr, err_file, trial_id, sys.stderr), daemon=True) + t_out.start(); t_err.start() proc.wait() - t_out.join() - t_err.join() - + t_out.join(); t_err.join() if proc.returncode != 0: - raise subprocess.CalledProcessError(proc.returncode, launcher_cmd) - -# ============================================================ -# MULTI-METRIC EXTRACTION -# ============================================================ - -def extract_metrics_from_log(path: str, metric_names: List[str], err_path: Optional[str] = None) -> List[float]: - """Extract metric values from a log file. - - Each entry in *metric_names* is either a plain name (uses the **last** - match) or ``name#N`` to select the **N-th occurrence** (0-based). This - lets you disambiguate scripts that print the same metric key multiple - times, e.g.:: - - metrics: - - "Samples/sec#0" # DataLoader throughput (first occurrence) - - "Samples/sec#1" # Training throughput (second occurrence) - - "Samples/sec#2" # Inference throughput (third occurrence) - - GFLOPS - """ - logger.debug("Extracting metrics %s from '%s'", metric_names, path) - results = [] - with open(path, "r", encoding="utf-8", errors="ignore") as f: - text = f.read() - logger.debug("Log file '%s': %d characters read", path, len(text)) - # Also read stderr — Lightning/rich writes test result tables there - if err_path: - try: - with open(err_path, "r", encoding="utf-8", errors="ignore") as f: - err_text = f.read() - logger.debug("Err file '%s': %d characters read", err_path, len(err_text)) - text = text + "\n" + err_text - except FileNotFoundError: - logger.debug("Err file '%s' not found, skipping", err_path) + raise subprocess.CalledProcessError(proc.returncode, script) - for metric in metric_names: - # Support name#N syntax for Nth-occurrence selection (0-based) - occurrence: Optional[int] = None - bare_metric = metric - idx_match = re.fullmatch(r'(.+)#(\d+)', metric) - if idx_match: - bare_metric = idx_match.group(1) - occurrence = int(idx_match.group(2)) - # Matches: key: value | key=value | [performance] key : value | Lightning table │ key │ value │ - pattern = re.compile( - rf"(?:\[\w+\]\s*)?{re.escape(bare_metric)}\s*(?:[:=│])\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)" - ) - matches = pattern.findall(text) - if not matches: - logger.warning("Metric '%s' not found in '%s' — defaulting to 0.0", metric, path) - results.append(0.0) - elif occurrence is not None: - if occurrence >= len(matches): - logger.warning( - "Metric '%s' occurrence #%d requested but only %d match(es) found — defaulting to 0.0", - metric, occurrence, len(matches), - ) - results.append(0.0) - else: - value = float(matches[occurrence]) - logger.debug("Metric '%s': using occurrence #%d = %s", metric, occurrence, value) - results.append(value) - else: - value = float(matches[-1]) - logger.debug("Metric '%s': found %d match(es), using last value %s", metric, len(matches), value) - results.append(value) - return results - -# ============================================================ -# MAIN -# ============================================================ - -def load_hpo_space(args): - data = {} - if args.hpo_json: - logger.debug("Loading HPO space from JSON string") - data = json.loads(args.hpo_json) - elif args.hpo_yaml: - logger.debug("Loading HPO space from YAML file: %s", args.hpo_yaml) - with open(args.hpo_yaml, "r") as f: data = yaml.safe_load(f) - space = data.get("hpo", {}) - logger.info("HPO space loaded: %d parameter(s): %s", len(space), list(space.keys())) - return space - -def load_metrics_from_yaml(args): - """Return metrics list from YAML 'metrics:' key, or None if not present.""" - data = {} - if args.hpo_json: - data = json.loads(args.hpo_json) - elif args.hpo_yaml: - with open(args.hpo_yaml, "r") as f: data = yaml.safe_load(f) - elif args.static_args_yaml: - with open(args.static_args_yaml, "r") as f: data = yaml.safe_load(f) - metrics = data.get("metrics", None) - if metrics is None: - return None - if isinstance(metrics, list): - return [m.strip() for m in metrics] - return [m.strip() for m in str(metrics).split(",")] - -def load_static_args(args): - data = {} - if args.static_args_json: - logger.debug("Loading static args from JSON string") - data = json.loads(args.static_args_json) - elif args.static_args_yaml: - logger.debug("Loading static args from YAML file: %s", args.static_args_yaml) - with open(args.static_args_yaml, "r") as f: data = yaml.safe_load(f) - elif args.hpo_yaml: - logger.debug("Loading static args from HPO YAML file: %s", args.hpo_yaml) - with open(args.hpo_yaml, "r") as f: data = yaml.safe_load(f) - static = data.get("static", data if data else {}) - logger.info("Static args loaded: %d key(s): %s", len(static), list(static.keys())) - return static - -def suggest_from_spec(trial, name, spec): - t = spec["type"] - if t == "float": - val = trial.suggest_float(name, float(spec["low"]), float(spec["high"]), log=spec.get("log", False)) - elif t == "int": - val = trial.suggest_int(name, int(spec["low"]), int(spec["high"]), log=spec.get("log", False)) - elif t == "categorical": - val = trial.suggest_categorical(name, spec["choices"]) - elif t == "flag": - val = trial.suggest_categorical(name, [True, False]) - elif t == "group": - val = trial.suggest_categorical(name, list(spec["choices"].keys())) - else: - raise ValueError(f"Unknown param type: {t}") - logger.debug("Suggested '%s' (%s) = %r", name, t, val) - return val +# ─── MAIN ──────────────────────────────────────────────────────────────────── def main(): args = parse_args() @@ -721,140 +207,91 @@ def main(): format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) - # Suppress noisy optuna INFO logs unless user asked for DEBUG logging.getLogger("optuna").setLevel( logging.WARNING if args.log_level == "INFO" else getattr(logging, args.log_level) ) - logger.info("iterate2 starting") - logger.info("Log level: %s", args.log_level) - logger.info("WLM: %s | interpreter: %s | script: %s", args.wlm, args.interpreter, args.script) - logger.info("Optuna study: '%s' | db: %s | n_trials: %d", args.optuna_study_name, args.optuna_db_path, args.optuna_n_trials) - - hpo_space = load_hpo_space(args) - static_args = load_static_args(args) - yaml_metrics = load_metrics_from_yaml(args) - metric_list = yaml_metrics if yaml_metrics is not None else [m.strip() for m in args.metrics.split(",")] - logger.info("Optimising metrics: %s (source: %s)", metric_list, "yaml" if yaml_metrics else "cli") + logger.info("iterate2 starting script=%s yaml=%s", args.script, args.hpo_yaml) + logger.info("study=%s db=%s n_trials=%d parallelism=%d", + args.optuna_study_name, args.optuna_db_path, + args.optuna_n_trials, args.parallelism) - script_path, root_dir = resolve_paths(args.script, args.root_dir) - - def objective(trial): - script_args = static_args.copy() - for name, spec in hpo_space.items(): - val = suggest_from_spec(trial, name, spec) - if spec["type"] == "group": - # Expand the chosen group's key→value pairs directly into script_args - script_args.update(spec["choices"][val]) - else: - script_args[name] = val - - # gpu_num in hpo/static overrides the CLI --gpu-count for this trial's launcher - gpu_count = int(script_args.pop("gpu_num", args.gpu_count)) - logger.debug("Trial %d: effective gpu_count=%d", trial.number, gpu_count) - logger.info("Trial %d: sampled parameters: %s", trial.number, script_args) - - out_file = f"trial_{trial.number}.out" - err_file = f"trial_{trial.number}.err" - logger.debug("Trial %d: stdout → %s | stderr → %s", trial.number, out_file, err_file) - - if args.wlm == "vela": - # ── Vela / OpenShift path ────────────────────────────────────── - if not args.vela_job_template: - raise ValueError("--vela-job-template is required when --wlm vela") - if not args.vela_chart_path: - raise ValueError("--vela-chart-path is required when --wlm vela") - container_cmd = build_container_command( - args.interpreter, script_path, script_args, - args.param_setter, args.underscore_to_hyphen, - ) - logger.info("Trial %d: container command → %s", trial.number, container_cmd) - job_yaml, job_name = build_vela_job_yaml( - args.vela_job_template, - trial.number, - gpu_count, - container_cmd, - args.vela_cmd_placeholder, - ) - logger.debug("Trial %d: job YAML (first 400 chars):\n%s", trial.number, job_yaml[:400]) - run_vela_trial( - trial_id=trial.number, - job_yaml=job_yaml, - chart_path=args.vela_chart_path, - job_name=job_name, - namespace=args.vela_namespace, - out_file=out_file, - err_file=err_file, - pod_ready_timeout=args.vela_pod_ready_timeout, - job_timeout=args.vela_job_timeout, - ) - else: - # ── Standard WLM path (lsf / slurm / none) ──────────────────── - shell_cmd = build_shell_command(args.interpreter, root_dir, script_path, args.venv, script_args, args.param_setter, args.underscore_to_hyphen, pre_run_commands=args.pre_run_commands) - launcher_cmd = build_launcher_command(args.wlm, shell_cmd, trial.number, out_file, err_file, gpu_count, args.cpu_count, args.mem_gb, args.lsf_gpu_config_string) - logger.info("Trial %d: submitting → %s", trial.number, launcher_cmd) - run_and_stream(launcher_cmd, trial.number, out_file, err_file, args.wlm) - - logger.info("Trial %d: job finished", trial.number) - - values = extract_metrics_from_log(out_file, metric_list, err_path=err_file) - logger.info("Trial %d: results %s", trial.number, dict(zip(metric_list, values))) - - return tuple(values) - - # Multi-objective direction - directions = ["maximize"] * len(metric_list) - logger.info("Creating Optuna study (directions: %s)", directions) + data = load_yaml(args.hpo_yaml) + hpo_space = load_hpo_space(data) + static = load_static(data) + metrics = load_metrics(data) + directions = ["maximize"] * len(metrics) + logger.info("Metrics: %s", metrics) storage = resolve_storage(args.optuna_db_path) - logger.debug("Optuna storage: %s", storage) - study = optuna.create_study( study_name=args.optuna_study_name, storage=storage, directions=directions, load_if_exists=True, ) - logger.info("Study '%s' ready (existing trials: %d)", args.optuna_study_name, len(study.trials)) + logger.info("Study '%s' ready (existing trials: %d)", + args.optuna_study_name, len(study.trials)) - # ── Re-queue failed trials (25 % retry / 75 % new) ──────────────────── - failed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL] + # ── Re-queue failed trials: 25 % retry, 75 % new ───────────────────── + failed = [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL] n_total = args.optuna_n_trials - if failed_trials: - n_retry = max(1, round(0.25 * n_total)) - n_retry = min(n_retry, len(failed_trials)) # can't retry more than we have + if failed: + n_retry = min(max(1, round(0.25 * n_total)), len(failed)) n_new = n_total - n_retry - # enqueue the most-recent failed trials first - trials_to_retry = failed_trials[-n_retry:] - logger.info( - "Found %d failed trial(s). Re-queuing %d (25%%) and running %d new (75%%).", - len(failed_trials), n_retry, n_new, - ) - for ft in trials_to_retry: - if ft.params: # skip trials that had no params at all + logger.info("Found %d failed trial(s) – re-queuing %d (25%%), %d new (75%%)", + len(failed), n_retry, n_new) + for ft in failed[-n_retry:]: + if ft.params: study.enqueue_trial(ft.params) - logger.info(" Enqueued params from failed trial %d: %s", ft.number, ft.params) + logger.info(" Enqueued failed trial %d params: %s", ft.number, ft.params) + # ───────────────────────────────────────────────────────────────────── + + def objective(trial): + # ── Sample parameters ───────────────────────────────────────────── + params = dict(static) # start with static params + for name, spec in hpo_space.items(): + val = suggest(trial, name, spec) + if spec["type"] == "group": + params.update(spec["choices"][val]) # expand group → flat keys else: - logger.info(" Skipped failed trial %d (no params recorded).", ft.number) - # adjust total so we run exactly n_new *additional* new trials on top - n_total = n_new + n_retry # enqueued slots count toward n_trials - else: - logger.info("No failed trials found – running %d fresh trials.", n_total) - # ── end retry logic ─────────────────────────────────────────────────── - - logger.info("Parallelism: %d worker(s)", args.parallelism) + params[name] = val + + logger.info("Trial %d params: %s", trial.number, params) + + out_file = f"trial_{trial.number}.out" + err_file = f"trial_{trial.number}.err" + + # ── Build env for the script ────────────────────────────────────── + env = os.environ.copy() + env["ITERATE_TRIAL_NUMBER"] = str(trial.number) + env["ITERATE_OUT_FILE"] = out_file + env["ITERATE_ERR_FILE"] = err_file + for k, v in params.items(): + env_key = "ITERATE_PARAM_" + str(k).upper().replace("-", "_").replace(" ", "_") + env[env_key] = str(v) if v is not None else "" + + # ── Call the script ─────────────────────────────────────────────── + run_script(args.script, env, trial.number, out_file, err_file) + + # ── Extract metrics ─────────────────────────────────────────────── + values = extract_metrics(out_file, err_file, metrics) + logger.info("Trial %d results: %s", trial.number, dict(zip(metrics, values))) + return tuple(values) + + logger.info("Starting optimisation (%d worker(s))", args.parallelism) study.optimize( objective, n_trials=n_total, n_jobs=args.parallelism, - catch=(Exception,), # mark trial as FAILED and continue; never crash the study + catch=(Exception,), ) logger.info("=" * 60) - logger.info("OPTIMIZATION COMPLETE") - logger.info("Pareto Front Trials: %d", len(study.best_trials)) + logger.info("OPTIMISATION COMPLETE Pareto front: %d trial(s)", len(study.best_trials)) for t in study.best_trials: - logger.info(" Trial %d: Values=%s Params=%s", t.number, t.values, t.params) + logger.info(" Trial %d: values=%s params=%s", t.number, t.values, t.params) + if __name__ == "__main__": - main() \ No newline at end of file + main()