diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 695ba53..0fc899e 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -107,6 +107,19 @@ def main(): default=False, help="Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision('high').", ) + _deterministic_kwargs = dict( + dest="deterministic", + type=str, + nargs="?", + const="warn", + default=None, + choices=["true", "warn"], + help=( + "Enable deterministic CUDA/cuDNN algorithms via Lightning Trainer(deterministic=...). " + "Pass --deterministic (alone) for 'warn' mode, or --deterministic true for strict. " + "Requires CUBLAS_WORKSPACE_CONFIG to be set (e.g. ':4096:8') for CUDA>=10.2." + ), + ) _mp_context_kwargs = dict( dest="mp_context", type=str, @@ -174,6 +187,7 @@ def main(): help="Print the last training epoch time and a single test metric to stdout.", ) train_parser.add_argument("--mp_context", **_mp_context_kwargs) + train_parser.add_argument("--deterministic", **_deterministic_kwargs) # ---- FINETUNE SUBCOMMAND ---- finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 53a970c..e3ef1df 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -255,6 +255,10 @@ def main_cli(args): default_root_dir=args.log_dir, max_epochs=config_args.training.epochs, callbacks=training_callbacks, + deterministic=( + True if getattr(args, "deterministic", None) == "true" + else (getattr(args, "deterministic", None) or False) + ), **trainer_kwargs, profiler=profiler, ) diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py index 2dfe4d8..d9b8cbb 100644 --- a/integrationtests/test_base_set.py +++ b/integrationtests/test_base_set.py @@ -13,7 +13,9 @@ def execute_and_live_output(cmd) -> None: - subprocess.run(cmd, text=True, shell=True, check=True) + env = os.environ.copy() + env.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + subprocess.run(cmd, text=True, shell=True, check=True, env=env) def collect_metrics_from_log(log_base: str, metric_keys: list) -> dict: @@ -194,7 +196,9 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): f"--data_path data_out/ " f"--exp_name exp1 " f"--run_name run{run_i + 1} " - f"--log_dir logs", + f"--log_dir logs " + f"--mp_context spawn " + f"--deterministic warn", ) metrics = collect_metrics_from_log("logs", pf_metric_keys) all_runs.append(metrics) @@ -214,7 +218,9 @@ def test_train_pf(cleanup_test_artifacts, calibrate_runs, ci_level): f"--data_path data_out/ " f"--exp_name exp1 " f"--run_name retry{attempt} " - f"--log_dir logs", + f"--log_dir logs " + f"--mp_context spawn " + f"--deterministic warn", ) metrics = collect_metrics_from_log("logs", pf_metric_keys) else: @@ -309,7 +315,9 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): f"--data_path {opf_data_dir}/ " f"--exp_name exp_opf " f"--run_name run{run_i + 1} " - f"--log_dir logs_opf", + f"--log_dir logs_opf " + f"--mp_context spawn " + f"--deterministic warn", ) metrics = collect_metrics_from_log("logs_opf", opf_metric_keys) all_runs.append(metrics) @@ -319,16 +327,16 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): return checks = { - "Avg. active res. (MW)": (0.2067, 0.4619), - "Avg. reactive res. (MVar)": (0.0825, 0.1492), - "RMSE PG generators (MW)": (2.6480, 2.8693), - "Mean optimality gap (%)": (1.1039, 1.4934), + "Avg. active res. (MW)": (0.2559, 0.2611), + "Avg. reactive res. (MVar)": (0.1028, 0.1048), + "RMSE PG generators (MW)": (2.7297, 2.7850), + "Mean optimality gap (%)": (1.2041, 1.2285), "Mean branch thermal violation from (MVA)": (0.0, 0.0), "Mean branch thermal violation to (MVA)": (0.0, 0.0), "Mean branch angle difference violation (radians)": (0.0, 0.0), - "Mean Qg violation PV buses": (0.0167, 0.1546), - "Mean Qg violation REF buses": (-0.0693, 0.4241), - "Mean Qg violation": (0.0771, 0.1322), + "Mean Qg violation PV buses": (0.0782, 0.0798), + "Mean Qg violation REF buses": (0.1251, 0.1277), + "Mean Qg violation": (0.0879, 0.0897), } MAX_RETRIES = 5 @@ -342,7 +350,9 @@ def test_train_opf(cleanup_opf_test_artifacts, calibrate_runs, ci_level): f"--data_path {opf_data_dir}/ " f"--exp_name exp_opf " f"--run_name retry{attempt} " - f"--log_dir logs_opf", + f"--log_dir logs_opf " + f"--mp_context spawn " + f"--deterministic warn", ) metrics = collect_metrics_from_log("logs_opf", opf_metric_keys) else: diff --git a/pyproject.toml b/pyproject.toml index b4e8307..2135c9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ namespaces = false [project] name = "gridfm-graphkit" description = "Grid Foundation Model" -version = "0.0.7" +version = "0.8.0rc2" readme = "README.md" license = "Apache-2.0" requires-python = ">=3.10,<3.13" @@ -47,10 +47,10 @@ dependencies = [ "pandas>=2.3.0", "plotly>=6.1.2", "pyyaml>=6.0.2", - "torch>=2.7.1,<2.9", + "torch>=2.10", "torch-geometric>=2.6.1", - "torchaudio>=2.7.1", - "torchvision>=0.22.1", + "torchaudio>=2.10", + "torchvision>=0.25", "lightning", "seaborn", "urllib3>=2.6.0",