Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions gridfm_graphkit/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ def main():
default=False,
help="Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision('high').",
)
_deterministic_kwargs = dict(
dest="deterministic",
type=str,
nargs="?",
const="warn",
default=None,
choices=["true", "warn"],
help=(
"Enable deterministic CUDA/cuDNN algorithms via Lightning Trainer(deterministic=...). "
"Pass --deterministic (alone) for 'warn' mode, or --deterministic true for strict. "
"Requires CUBLAS_WORKSPACE_CONFIG to be set (e.g. ':4096:8') for CUDA>=10.2."
),
)
_mp_context_kwargs = dict(
dest="mp_context",
type=str,
Expand Down Expand Up @@ -174,6 +187,7 @@ def main():
help="Print the last training epoch time and a single test metric to stdout.",
)
train_parser.add_argument("--mp_context", **_mp_context_kwargs)
train_parser.add_argument("--deterministic", **_deterministic_kwargs)

# ---- FINETUNE SUBCOMMAND ----
finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
Expand Down
4 changes: 4 additions & 0 deletions gridfm_graphkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ def main_cli(args):
default_root_dir=args.log_dir,
max_epochs=config_args.training.epochs,
callbacks=training_callbacks,
deterministic=(
True if getattr(args, "deterministic", None) == "true"
else (getattr(args, "deterministic", None) or False)
),
**trainer_kwargs,
profiler=profiler,
)
Expand Down
34 changes: 22 additions & 12 deletions integrationtests/test_base_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@


def execute_and_live_output(cmd) -> None:
subprocess.run(cmd, text=True, shell=True, check=True)
env = os.environ.copy()
env.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
subprocess.run(cmd, text=True, shell=True, check=True, env=env)


def collect_metrics_from_log(log_base: str, metric_keys: list) -> dict:
Expand Down Expand Up @@ -194,7 +196,9 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level):
f"--data_path data_out/ "
f"--exp_name exp1 "
f"--run_name run{run_i + 1} "
f"--log_dir logs",
f"--log_dir logs "
f"--mp_context spawn "
f"--deterministic warn",
)
metrics = collect_metrics_from_log("logs", pf_metric_keys)
all_runs.append(metrics)
Expand All @@ -214,7 +218,9 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level):
f"--data_path data_out/ "
f"--exp_name exp1 "
f"--run_name retry{attempt} "
f"--log_dir logs",
f"--log_dir logs "
f"--mp_context spawn "
f"--deterministic warn",
)
metrics = collect_metrics_from_log("logs", pf_metric_keys)
else:
Expand Down Expand Up @@ -309,7 +315,9 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level):
f"--data_path {opf_data_dir}/ "
f"--exp_name exp_opf "
f"--run_name run{run_i + 1} "
f"--log_dir logs_opf",
f"--log_dir logs_opf "
f"--mp_context spawn "
f"--deterministic warn",
)
metrics = collect_metrics_from_log("logs_opf", opf_metric_keys)
all_runs.append(metrics)
Expand All @@ -319,16 +327,16 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level):
return

checks = {
"Avg. active res. (MW)": (0.2067, 0.4619),
"Avg. reactive res. (MVar)": (0.0825, 0.1492),
"RMSE PG generators (MW)": (2.6480, 2.8693),
"Mean optimality gap (%)": (1.1039, 1.4934),
"Avg. active res. (MW)": (0.2559, 0.2611),
"Avg. reactive res. (MVar)": (0.1028, 0.1048),
"RMSE PG generators (MW)": (2.7297, 2.7850),
"Mean optimality gap (%)": (1.2041, 1.2285),
"Mean branch thermal violation from (MVA)": (0.0, 0.0),
"Mean branch thermal violation to (MVA)": (0.0, 0.0),
"Mean branch angle difference violation (radians)": (0.0, 0.0),
"Mean Qg violation PV buses": (0.0167, 0.1546),
"Mean Qg violation REF buses": (-0.0693, 0.4241),
"Mean Qg violation": (0.0771, 0.1322),
"Mean Qg violation PV buses": (0.0782, 0.0798),
"Mean Qg violation REF buses": (0.1251, 0.1277),
"Mean Qg violation": (0.0879, 0.0897),
}

MAX_RETRIES = 5
Expand All @@ -342,7 +350,9 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level):
f"--data_path {opf_data_dir}/ "
f"--exp_name exp_opf "
f"--run_name retry{attempt} "
f"--log_dir logs_opf",
f"--log_dir logs_opf "
f"--mp_context spawn "
f"--deterministic warn",
)
metrics = collect_metrics_from_log("logs_opf", opf_metric_keys)
else:
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespaces = false
[project]
name = "gridfm-graphkit"
description = "Grid Foundation Model"
version = "0.0.7"
version = "0.8.0rc2"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.10,<3.13"
Expand Down Expand Up @@ -47,10 +47,10 @@ dependencies = [
"pandas>=2.3.0",
"plotly>=6.1.2",
"pyyaml>=6.0.2",
"torch>=2.7.1,<2.9",
"torch>=2.10",
"torch-geometric>=2.6.1",
"torchaudio>=2.7.1",
"torchvision>=0.22.1",
"torchaudio>=2.10",
"torchvision>=0.25",
"lightning",
"seaborn",
"urllib3>=2.6.0",
Expand Down
Loading