diff --git a/src/matchcake_opt/tr_pipeline/lightning_pipeline.py b/src/matchcake_opt/tr_pipeline/lightning_pipeline.py index 2a4b0c3..2541214 100644 --- a/src/matchcake_opt/tr_pipeline/lightning_pipeline.py +++ b/src/matchcake_opt/tr_pipeline/lightning_pipeline.py @@ -118,10 +118,37 @@ def run(self) -> Dict[str, Any]: ckpt_path=(None if self.overwrite_fit else "last"), ) end_time = time.perf_counter() - metrics: Dict[str, Any] = self.run_validation() - metrics["training_time"] = end_time - start_time - self.save_metrics_to_checkpoint_folder(metrics, name="validation_metrics") - return metrics + train_metrics: Dict[str, Any] = self.run_train_validation() + train_metrics["training_time"] = end_time - start_time + self.save_metrics_to_checkpoint_folder(train_metrics, name="validation_metrics") + val_metrics: Dict[str, Any] = self.run_validation() + return {**train_metrics, **val_metrics} + + def run_train_validation(self, **additional_metrics) -> Dict[str, Any]: + start_time = time.perf_counter() + try: + metrics: List[Dict[str, Any]] = self.trainer.validate( # type: ignore + model=self.model, + dataloaders=[self.datamodule.train_dataloader()], + verbose=self.verbose, + ckpt_path="best", + ) + except ValueError: + metrics: List[Dict[str, Any]] = self.trainer.validate( # type: ignore + model=self.model, + dataloaders=[self.datamodule.train_dataloader()], + verbose=self.verbose, + ckpt_path="last", + ) + if len(metrics) == 0: + return {} + metrics_0: Dict[str, Any] = metrics[0] + end_time = time.perf_counter() + metrics_0["train_validation_time"] = end_time - start_time + metrics_0.update(additional_metrics) + metrics_0 = {k.replace("val_", "train_"): v for k, v in metrics_0.items()} + self.save_metrics_to_checkpoint_folder(metrics_0, name="train_metrics") + return metrics_0 def run_validation(self) -> Dict[str, Any]: start_time = time.perf_counter() @@ -144,6 +171,7 @@ def run_validation(self) -> Dict[str, Any]: metrics_0: Dict[str, Any] = metrics[0] end_time = time.perf_counter() metrics_0["validation_time"] = end_time - start_time + self.save_metrics_to_checkpoint_folder(metrics_0, name="validation_metrics") return metrics_0 def run_test(self, ckpt_path="best") -> Dict[str, Any]: diff --git a/tests/test_tr_pipeline/test_lightning_pipeline.py b/tests/test_tr_pipeline/test_lightning_pipeline.py index 7f3c41a..29a9b3a 100644 --- a/tests/test_tr_pipeline/test_lightning_pipeline.py +++ b/tests/test_tr_pipeline/test_lightning_pipeline.py @@ -58,5 +58,9 @@ def test_add_specific_args(self): def test_run_and_run_test(self, pipeline_instance): metrics = pipeline_instance.run() assert isinstance(metrics, dict) + assert "val_loss" in metrics + assert "train_loss" in metrics + assert "test_loss" not in metrics test_metrics = pipeline_instance.run_test() assert isinstance(test_metrics, dict) + assert "test_loss" in test_metrics