diff --git a/src/matchcake_opt/tr_pipeline/lightning_pipeline.py b/src/matchcake_opt/tr_pipeline/lightning_pipeline.py index 2541214..0be9d7e 100644 --- a/src/matchcake_opt/tr_pipeline/lightning_pipeline.py +++ b/src/matchcake_opt/tr_pipeline/lightning_pipeline.py @@ -11,6 +11,7 @@ from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.profilers import SimpleProfiler +from torch.autograd import inference_mode from ..datamodules.datamodule import DataModule from ..modules.base_model import BaseModel @@ -124,6 +125,7 @@ def run(self) -> Dict[str, Any]: val_metrics: Dict[str, Any] = self.run_validation() return {**train_metrics, **val_metrics} + @inference_mode() def run_train_validation(self, **additional_metrics) -> Dict[str, Any]: start_time = time.perf_counter() try: @@ -150,6 +152,7 @@ def run_train_validation(self, **additional_metrics) -> Dict[str, Any]: self.save_metrics_to_checkpoint_folder(metrics_0, name="train_metrics") return metrics_0 + @inference_mode() def run_validation(self) -> Dict[str, Any]: start_time = time.perf_counter() try: @@ -174,6 +177,7 @@ def run_validation(self) -> Dict[str, Any]: self.save_metrics_to_checkpoint_folder(metrics_0, name="validation_metrics") return metrics_0 + @inference_mode() def run_test(self, ckpt_path="best") -> Dict[str, Any]: start_time = time.perf_counter() metrics: Dict[str, Any] = self.trainer.test( # type: ignore