From 0d81314ac07a5e09eaa93bd7e97dab98fbec97dc Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Fri, 19 Jun 2026 15:59:46 +0200 Subject: [PATCH 1/5] bump torch Signed-off-by: Romeo Kienzler --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b4e8307..3becadb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,10 +47,10 @@ dependencies = [ "pandas>=2.3.0", "plotly>=6.1.2", "pyyaml>=6.0.2", - "torch>=2.7.1,<2.9", + "torch>=2.12.1", "torch-geometric>=2.6.1", - "torchaudio>=2.7.1", - "torchvision>=0.22.1", + "torchaudio>=2.11.0", + "torchvision>=0.27.1", "lightning", "seaborn", "urllib3>=2.6.0", From 24fd3aa37196c33e2a6749bc497b725e1029a8ff Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Fri, 19 Jun 2026 18:29:17 +0200 Subject: [PATCH 2/5] Bump torch to >=2.10 Co-Authored-By: Claude Opus 4.7 --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3becadb..6a73585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,10 +47,10 @@ dependencies = [ "pandas>=2.3.0", "plotly>=6.1.2", "pyyaml>=6.0.2", - "torch>=2.12.1", + "torch>=2.10", "torch-geometric>=2.6.1", - "torchaudio>=2.11.0", - "torchvision>=0.27.1", + "torchaudio>=2.10", + "torchvision>=0.25", "lightning", "seaborn", "urllib3>=2.6.0", From 29effa3af4be2a62f1dce79d582d2d9abb7852a2 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Tue, 23 Jun 2026 09:30:52 -0400 Subject: [PATCH 3/5] Pass --mp_context spawn in integration test train calls Defensive on Linux per _warn_mp_context_on_linux. Co-Authored-By: Claude Opus 4.7 --- integrationtests/test_base_set.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py index 2dfe4d8..a8703ce 100644 --- a/integrationtests/test_base_set.py +++ b/integrationtests/test_base_set.py @@ -194,7 +194,8 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): f"--data_path data_out/ " f"--exp_name exp1 " f"--run_name run{run_i + 1} " - f"--log_dir logs", + f"--log_dir logs " + f"--mp_context spawn", ) metrics = collect_metrics_from_log("logs", pf_metric_keys) all_runs.append(metrics) @@ -214,7 +215,8 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): f"--data_path data_out/ " f"--exp_name exp1 " f"--run_name retry{attempt} " - f"--log_dir logs", + f"--log_dir logs " + f"--mp_context spawn", ) metrics = collect_metrics_from_log("logs", pf_metric_keys) else: @@ -309,7 +311,8 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): f"--data_path {opf_data_dir}/ " f"--exp_name exp_opf " f"--run_name run{run_i + 1} " - f"--log_dir logs_opf", + f"--log_dir logs_opf " + f"--mp_context spawn", ) metrics = collect_metrics_from_log("logs_opf", opf_metric_keys) all_runs.append(metrics) @@ -342,7 +345,8 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): f"--data_path {opf_data_dir}/ " f"--exp_name exp_opf " f"--run_name retry{attempt} " - f"--log_dir logs_opf", + f"--log_dir logs_opf " + f"--mp_context spawn", ) metrics = collect_metrics_from_log("logs_opf", opf_metric_keys) else: From ef1d83d3bc1b752c40c40d916e95eaae7320c671 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Wed, 24 Jun 2026 08:11:31 -0400 Subject: [PATCH 4/5] Recompute test_train_opf CI bounds for torch>=2.10 + spawn Bounds widened/shifted due to combined effect of torch>=2.10 numerics and --mp_context spawn re-seeding DataLoader workers per run. --- integrationtests/test_base_set.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py index a8703ce..6e4e5de 100644 --- a/integrationtests/test_base_set.py +++ b/integrationtests/test_base_set.py @@ -322,16 +322,16 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): return checks = { - "Avg. active res. (MW)": (0.2067, 0.4619), - "Avg. reactive res. (MVar)": (0.0825, 0.1492), - "RMSE PG generators (MW)": (2.6480, 2.8693), - "Mean optimality gap (%)": (1.1039, 1.4934), + "Avg. active res. (MW)": (0.2025, 0.6005), + "Avg. reactive res. (MVar)": (0.0854, 0.1194), + "RMSE PG generators (MW)": (2.7746, 3.4940), + "Mean optimality gap (%)": (1.0331, 2.1032), "Mean branch thermal violation from (MVA)": (0.0, 0.0), "Mean branch thermal violation to (MVA)": (0.0, 0.0), "Mean branch angle difference violation (radians)": (0.0, 0.0), - "Mean Qg violation PV buses": (0.0167, 0.1546), - "Mean Qg violation REF buses": (-0.0693, 0.4241), - "Mean Qg violation": (0.0771, 0.1322), + "Mean Qg violation PV buses": (0.0243, 0.1863), + "Mean Qg violation REF buses": (0.0303, 0.1683), + "Mean Qg violation": (0.0445, 0.1636), } MAX_RETRIES = 5 From 62277fa65ecc4a901dabf32258f7ece232591398 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Fri, 26 Jun 2026 16:47:51 -0400 Subject: [PATCH 5/5] =?UTF-8?q?Add=20--deterministic=20mode,=20tighten=20O?= =?UTF-8?q?PF=20CIs=20to=20=C2=B11%,=20bump=20to=200.8.0rc2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds --deterministic [true|warn] to the train subcommand and threads it through Lightning's Trainer(deterministic=...). The integration tests now run with --deterministic warn and CUBLAS_WORKSPACE_CONFIG=:4096:8, which produced Std=0 across 5 seed=0 calibration runs on case14_ieee. OPF CI bounds in test_train_opf retightened to ±1% around the new deterministic point estimates (replacing the wide bounds calibrated under CUDA non-determinism). --- gridfm_graphkit/__main__.py | 14 ++++++++++++++ gridfm_graphkit/cli.py | 4 ++++ integrationtests/test_base_set.py | 30 ++++++++++++++++++------------ pyproject.toml | 2 +- 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 695ba53..0fc899e 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -107,6 +107,19 @@ def main(): default=False, help="Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision('high').", ) + _deterministic_kwargs = dict( + dest="deterministic", + type=str, + nargs="?", + const="warn", + default=None, + choices=["true", "warn"], + help=( + "Enable deterministic CUDA/cuDNN algorithms via Lightning Trainer(deterministic=...). " + "Pass --deterministic (alone) for 'warn' mode, or --deterministic true for strict. " + "Requires CUBLAS_WORKSPACE_CONFIG to be set (e.g. ':4096:8') for CUDA>=10.2." + ), + ) _mp_context_kwargs = dict( dest="mp_context", type=str, @@ -174,6 +187,7 @@ def main(): help="Print the last training epoch time and a single test metric to stdout.", ) train_parser.add_argument("--mp_context", **_mp_context_kwargs) + train_parser.add_argument("--deterministic", **_deterministic_kwargs) # ---- FINETUNE SUBCOMMAND ---- finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 53a970c..e3ef1df 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -255,6 +255,10 @@ def main_cli(args): default_root_dir=args.log_dir, max_epochs=config_args.training.epochs, callbacks=training_callbacks, + deterministic=( + True if getattr(args, "deterministic", None) == "true" + else (getattr(args, "deterministic", None) or False) + ), **trainer_kwargs, profiler=profiler, ) diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py index 6e4e5de..d9b8cbb 100644 --- a/integrationtests/test_base_set.py +++ b/integrationtests/test_base_set.py @@ -13,7 +13,9 @@ def execute_and_live_output(cmd) -> None: - subprocess.run(cmd, text=True, shell=True, check=True) + env = os.environ.copy() + env.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + subprocess.run(cmd, text=True, shell=True, check=True, env=env) def collect_metrics_from_log(log_base: str, metric_keys: list) -> dict: @@ -195,7 +197,8 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): f"--exp_name exp1 " f"--run_name run{run_i + 1} " f"--log_dir logs " - f"--mp_context spawn", + f"--mp_context spawn " + f"--deterministic warn", ) metrics = collect_metrics_from_log("logs", pf_metric_keys) all_runs.append(metrics) @@ -216,7 +219,8 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): f"--exp_name exp1 " f"--run_name retry{attempt} " f"--log_dir logs " - f"--mp_context spawn", + f"--mp_context spawn " + f"--deterministic warn", ) metrics = collect_metrics_from_log("logs", pf_metric_keys) else: @@ -312,7 +316,8 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): f"--exp_name exp_opf " f"--run_name run{run_i + 1} " f"--log_dir logs_opf " - f"--mp_context spawn", + f"--mp_context spawn " + f"--deterministic warn", ) metrics = collect_metrics_from_log("logs_opf", opf_metric_keys) all_runs.append(metrics) @@ -322,16 +327,16 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): return checks = { - "Avg. active res. (MW)": (0.2025, 0.6005), - "Avg. reactive res. (MVar)": (0.0854, 0.1194), - "RMSE PG generators (MW)": (2.7746, 3.4940), - "Mean optimality gap (%)": (1.0331, 2.1032), + "Avg. active res. (MW)": (0.2559, 0.2611), + "Avg. reactive res. (MVar)": (0.1028, 0.1048), + "RMSE PG generators (MW)": (2.7297, 2.7850), + "Mean optimality gap (%)": (1.2041, 1.2285), "Mean branch thermal violation from (MVA)": (0.0, 0.0), "Mean branch thermal violation to (MVA)": (0.0, 0.0), "Mean branch angle difference violation (radians)": (0.0, 0.0), - "Mean Qg violation PV buses": (0.0243, 0.1863), - "Mean Qg violation REF buses": (0.0303, 0.1683), - "Mean Qg violation": (0.0445, 0.1636), + "Mean Qg violation PV buses": (0.0782, 0.0798), + "Mean Qg violation REF buses": (0.1251, 0.1277), + "Mean Qg violation": (0.0879, 0.0897), } MAX_RETRIES = 5 @@ -346,7 +351,8 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): f"--exp_name exp_opf " f"--run_name retry{attempt} " f"--log_dir logs_opf " - f"--mp_context spawn", + f"--mp_context spawn " + f"--deterministic warn", ) metrics = collect_metrics_from_log("logs_opf", opf_metric_keys) else: diff --git a/pyproject.toml b/pyproject.toml index 6a73585..2135c9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ namespaces = false [project] name = "gridfm-graphkit" description = "Grid Foundation Model" -version = "0.0.7" +version = "0.8.0rc2" readme = "README.md" license = "Apache-2.0" requires-python = ">=3.10,<3.13"