Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/bumpy_hpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ metrics:
static:
global-mu: "23 42 66"

wlm:
gpu-count: 0
cpu-count: 2
mem-gb: 4

hpo:
x:
type: float
Expand Down
92 changes: 42 additions & 50 deletions examples/wlm_plugins/ccc_plugin.sh
Original file line number Diff line number Diff line change
@@ -1,76 +1,68 @@
#!/usr/bin/env bash
# =============================================================================
# Trial script for iterate2 - CCC (IBM Spectrum LSF) backend
# iterate2 WLM plugin - CCC (IBM Spectrum LSF)
#
# Called once per Optuna trial. Activates the venv, builds the training
# command from ITERATE_PARAM_* env vars, and submits it via bsub -K.
# Invocation contract (set by iterate2 when --wlm-plugin is used):
#
# Environment variables provided by iterate2
# ------------------------------------------
# ITERATE_TRIAL_NUMBER integer trial ID
# ITERATE_OUT_FILE metric lines are read from here by iterate2
# ITERATE_ERR_FILE path for error output
# ITERATE_PARAM_<KEY> one variable per HPO + static parameter
# (key uppercased, hyphens -> underscores)
# ccc_plugin.sh <script> [--<wlm-key> <value>]...
#
# Path overrides (set in the run script or your environment)
# ----------------------------------------------------------
# GRIDFM_ROOT repo root (default below)
# GRIDFM_VENV Python venv path (default below)
# CUDA_BASE CUDA install root (default below)
# where <script> is the executable passed as --script, and the --<wlm-key>
# flags are the keys of the YAML 'wlm:' section, e.g.
# --gpu-count 1 --cpu-count 16 --mem-gb 32 \
# --lsf-gpu-config "num=1:mode=exclusive_process:mps=no"
#
# iterate2 also exports the per-trial env vars (ITERATE_TRIAL_NUMBER,
# ITERATE_OUT_FILE, ITERATE_ERR_FILE, ITERATE_PARAM_*) which bsub forwards
# to the job so the wrapped <script> can read them.
# =============================================================================

set -euo pipefail

# -- iterate2 standard vars ---------------------------------------------------
SCRIPT="${1:?usage: ccc_plugin.sh <script> [--key value]...}"
shift

# Defaults; overridden by --<key> <value> argv pairs below.
GPU_COUNT=1
CPU_COUNT=4
MEM_GB=16
LSF_GPU_CONFIG=""
QUEUE=""

while [[ $# -gt 0 ]]; do
case "$1" in
--gpu-count) GPU_COUNT="$2"; shift 2 ;;
--cpu-count) CPU_COUNT="$2"; shift 2 ;;
--mem-gb) MEM_GB="$2"; shift 2 ;;
--lsf-gpu-config) LSF_GPU_CONFIG="$2"; shift 2 ;;
--queue) QUEUE="$2"; shift 2 ;;
*) echo "ccc_plugin: ignoring unknown flag $1 $2" >&2; shift 2 ;;
esac
done

TRIAL_NUMBER="${ITERATE_TRIAL_NUMBER:?ITERATE_TRIAL_NUMBER not set}"
OUT_FILE="${ITERATE_OUT_FILE:?ITERATE_OUT_FILE not set}"
ERR_FILE="${ITERATE_ERR_FILE:?ITERATE_ERR_FILE not set}"

# -- Paths (override via env) -------------------------------------------------
GRIDFM_ROOT="${GRIDFM_ROOT:-/dccstor/terratorch/users/rkie/gitco/gridfm-graphkit}"
GRIDFM_VENV="${GRIDFM_VENV:-/u/rkie/venvs/venv_gridfm-graphkit}"
CUDA_BASE="${CUDA_BASE:-/opt/share/cuda-12.8.1}"

# -- LSF resources ------------------------------------------------------------
# GPU_COUNT comes from the HPO group param so LSF allocates the right number.
GPU_COUNT="${ITERATE_PARAM_GPU_NUM:-1}"
CPU_COUNT=16
MEM_GB=32
MEM_MB=$(( MEM_GB * 1024 ))
GPU_STRING="num=${GPU_COUNT}:mode=exclusive_process:mps=no:gmodel=NVIDIAA100_SXM4_80GB"
# QUEUE="normal" # uncomment to target a specific queue
GPU_STRING="${LSF_GPU_CONFIG:-num=${GPU_COUNT}:mode=exclusive_process:mps=no}"

# -- Build training command from ITERATE_PARAM_* vars -------------------------
# --gpu_num is NOT a gridfm_graphkit flag; GPU count is controlled via bsub -gpu.
TRAIN_CMD="gridfm_graphkit train"
TRAIN_CMD+=" --batch_size ${ITERATE_PARAM_BATCH_SIZE}"
TRAIN_CMD+=" --num_workers ${ITERATE_PARAM_NUM_WORKERS}"
TRAIN_CMD+=" --config ${ITERATE_PARAM_CONFIG}"
TRAIN_CMD+=" --data_path ${ITERATE_PARAM_DATA_PATH}"
# [[ -n "${ITERATE_PARAM_COMPILE:-}" ]] && TRAIN_CMD+=" --compile ${ITERATE_PARAM_COMPILE}"
QUEUE_FLAG=()
[[ -n "$QUEUE" ]] && QUEUE_FLAG=(-q "$QUEUE")

# -- Compose full job shell command -------------------------------------------
JOB_CMD="\
export PATH='${CUDA_BASE}/bin:\$PATH' && \
export CUDA_HOME='${CUDA_BASE}' && \
export LD_LIBRARY_PATH='${CUDA_BASE}/lib64:\$LD_LIBRARY_PATH' && \
cd '${GRIDFM_ROOT}' && \
source '${GRIDFM_VENV}/bin/activate' && \
${TRAIN_CMD}"
GPU_FLAG=()
if [[ "$GPU_COUNT" -gt 0 || -n "$LSF_GPU_CONFIG" ]]; then
GPU_FLAG=(-gpu "$GPU_STRING")
fi

# -- Submit via bsub ----------------------------------------------------------
# -K : blocks until the job completes (iterate2 runs each trial in a thread)
# -n : CPU slots
# -o/-e: redirect job stdout/stderr to paths iterate2 will read metrics from
bsub \
-K \
-gpu "${GPU_STRING}" \
"${GPU_FLAG[@]}" \
"${QUEUE_FLAG[@]}" \
-n "${CPU_COUNT}" \
-R "rusage[mem=${MEM_MB}]" \
-o "${OUT_FILE}" \
-e "${ERR_FILE}" \
-J "hpo_trial_${TRIAL_NUMBER}" \
"${JOB_CMD}"
"${SCRIPT}"

echo "[ccc_plugin] trial ${TRIAL_NUMBER} finished"
35 changes: 28 additions & 7 deletions terratorch_iterate/iterate2/_iterate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ def parse_args():
)
p.add_argument("--script", required=True,
help="Executable to call for each trial")
p.add_argument("--wlm-plugin", default=None,
help="Optional wrapper executable. When set, iterate2 invokes "
"'<wlm-plugin> <script> [--<wlm-key> <value>]...' per trial, "
"where the wlm keys come from the YAML 'wlm:' section.")
p.add_argument("--hpo-yaml", required=True,
help="YAML file with 'hpo:', 'static:', and 'metrics:' sections")
help="YAML file with 'hpo:', 'static:', 'metrics:', and optional 'wlm:' sections")
p.add_argument("--optuna-study-name", required=True)
p.add_argument("--optuna-db-path", required=True,
help="Optuna storage URL (sqlite:///hpo.db, js:///journal.log, postgresql://…)")
Expand Down Expand Up @@ -97,6 +101,13 @@ def load_static(data: dict) -> dict:
logger.info("Static params: %d key(s): %s", len(static), list(static.keys()))
return static

def load_wlm(data: dict) -> dict:
wlm = data.get("wlm", {}) or {}
if wlm:
logger.info("WLM config: %d key(s): %s", len(wlm), list(wlm.keys()))
return wlm


def load_metrics(data: dict, fallback: str = "score") -> List[str]:
raw = data.get("metrics", None)
if raw is None:
Expand Down Expand Up @@ -180,11 +191,19 @@ def _stream(pipe, dest_file: str, trial_id: int, dest_stream):
dest_stream.write(f"{prefix} {line}")
dest_stream.flush()

def run_script(script: str, env: dict, trial_id: int, out_file: str, err_file: str):
"""Run *script* with *env*, stream output, raise on non-zero exit."""
logger.info("Trial %d: calling %s", trial_id, script)
def run_script(script: str, env: dict, trial_id: int, out_file: str, err_file: str,
wlm_plugin: Optional[str] = None, wlm_config: Optional[dict] = None):
"""Run *script* (or *wlm_plugin* wrapping it) with *env*, stream output."""
if wlm_plugin:
argv = [wlm_plugin, script]
for k, v in (wlm_config or {}).items():
argv.extend([f"--{k}", str(v)])
logger.info("Trial %d: calling plugin %s with argv %s", trial_id, wlm_plugin, argv)
else:
argv = [script]
logger.info("Trial %d: calling %s", trial_id, script)
proc = subprocess.Popen(
[script], env=env,
argv, env=env,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
)
import threading as _t
Expand Down Expand Up @@ -219,6 +238,7 @@ def main():
data = load_yaml(args.hpo_yaml)
hpo_space = load_hpo_space(data)
static = load_static(data)
wlm_config = load_wlm(data)
metrics = load_metrics(data)
directions = ["maximize"] * len(metrics)
logger.info("Metrics: %s", metrics)
Expand Down Expand Up @@ -271,8 +291,9 @@ def objective(trial):
env_key = "ITERATE_PARAM_" + str(k).upper().replace("-", "_").replace(" ", "_")
env[env_key] = str(v) if v is not None else ""

# ── Call the script ───────────────────────────────────────────────
run_script(args.script, env, trial.number, out_file, err_file)
# ── Call the script (optionally wrapped by --wlm-plugin) ──────────
run_script(args.script, env, trial.number, out_file, err_file,
wlm_plugin=args.wlm_plugin, wlm_config=wlm_config)

# ── Extract metrics ───────────────────────────────────────────────
values = extract_metrics(out_file, err_file, metrics)
Expand Down
Loading