From ce5e357d4fa6e547684dbd2e591e0d1ebb0fa4fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:54:05 -0500 Subject: [PATCH 1/2] Use inference_mode for evaluation methods Import torch.autograd.inference_mode and apply @inference_mode() to run_train_validation, run_validation, and run_test. This disables autograd during evaluation/metric gathering, reducing memory usage and overhead and improving runtime performance. The import was added near the other imports at the top of lightning_pipeline.py. --- src/matchcake_opt/tr_pipeline/lightning_pipeline.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/matchcake_opt/tr_pipeline/lightning_pipeline.py b/src/matchcake_opt/tr_pipeline/lightning_pipeline.py index 2541214..0497349 100644 --- a/src/matchcake_opt/tr_pipeline/lightning_pipeline.py +++ b/src/matchcake_opt/tr_pipeline/lightning_pipeline.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Type, Union import torch +from torch.autograd import inference_mode from lightning import Trainer from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.loggers import TensorBoardLogger @@ -124,6 +125,7 @@ def run(self) -> Dict[str, Any]: val_metrics: Dict[str, Any] = self.run_validation() return {**train_metrics, **val_metrics} + @inference_mode() def run_train_validation(self, **additional_metrics) -> Dict[str, Any]: start_time = time.perf_counter() try: @@ -150,6 +152,7 @@ def run_train_validation(self, **additional_metrics) -> Dict[str, Any]: self.save_metrics_to_checkpoint_folder(metrics_0, name="train_metrics") return metrics_0 + @inference_mode() def run_validation(self) -> Dict[str, Any]: start_time = time.perf_counter() try: @@ -174,6 +177,7 @@ def run_validation(self) -> Dict[str, Any]: self.save_metrics_to_checkpoint_folder(metrics_0, name="validation_metrics") return metrics_0 + @inference_mode() def run_test(self, ckpt_path="best") -> Dict[str, Any]: start_time = time.perf_counter() metrics: Dict[str, Any] = self.trainer.test( # type: ignore From 4f9df9f371f01d5f17c90a050ec6736d7a1cd067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gince?= <50332514+JeremieGince@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:57:18 -0500 Subject: [PATCH 2/2] Reorder imports in lightning_pipeline.py Move the `from torch.autograd import inference_mode` import below the Lightning imports to fix import ordering and reduce the chance of import-time issues or linter complaints. No functional logic changed. --- src/matchcake_opt/tr_pipeline/lightning_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matchcake_opt/tr_pipeline/lightning_pipeline.py b/src/matchcake_opt/tr_pipeline/lightning_pipeline.py index 0497349..0be9d7e 100644 --- a/src/matchcake_opt/tr_pipeline/lightning_pipeline.py +++ b/src/matchcake_opt/tr_pipeline/lightning_pipeline.py @@ -7,11 +7,11 @@ from typing import Any, Dict, List, Optional, Type, Union import torch -from torch.autograd import inference_mode from lightning import Trainer from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.profilers import SimpleProfiler +from torch.autograd import inference_mode from ..datamodules.datamodule import DataModule from ..modules.base_model import BaseModel