Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions src/matchcake_opt/tr_pipeline/lightning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,37 @@ def run(self) -> Dict[str, Any]:
ckpt_path=(None if self.overwrite_fit else "last"),
)
end_time = time.perf_counter()
metrics: Dict[str, Any] = self.run_validation()
metrics["training_time"] = end_time - start_time
self.save_metrics_to_checkpoint_folder(metrics, name="validation_metrics")
return metrics
train_metrics: Dict[str, Any] = self.run_train_validation()
train_metrics["training_time"] = end_time - start_time
self.save_metrics_to_checkpoint_folder(train_metrics, name="validation_metrics")
val_metrics: Dict[str, Any] = self.run_validation()
return {**train_metrics, **val_metrics}

def run_train_validation(self, **additional_metrics) -> Dict[str, Any]:
start_time = time.perf_counter()
try:
metrics: List[Dict[str, Any]] = self.trainer.validate( # type: ignore
model=self.model,
dataloaders=[self.datamodule.train_dataloader()],
verbose=self.verbose,
ckpt_path="best",
)
except ValueError:
metrics: List[Dict[str, Any]] = self.trainer.validate( # type: ignore
model=self.model,
dataloaders=[self.datamodule.train_dataloader()],
verbose=self.verbose,
ckpt_path="last",
)
if len(metrics) == 0:
return {}
metrics_0: Dict[str, Any] = metrics[0]
end_time = time.perf_counter()
metrics_0["train_validation_time"] = end_time - start_time
metrics_0.update(additional_metrics)
metrics_0 = {k.replace("val_", "train_"): v for k, v in metrics_0.items()}
self.save_metrics_to_checkpoint_folder(metrics_0, name="train_metrics")
return metrics_0

def run_validation(self) -> Dict[str, Any]:
start_time = time.perf_counter()
Expand All @@ -144,6 +171,7 @@ def run_validation(self) -> Dict[str, Any]:
metrics_0: Dict[str, Any] = metrics[0]
end_time = time.perf_counter()
metrics_0["validation_time"] = end_time - start_time
self.save_metrics_to_checkpoint_folder(metrics_0, name="validation_metrics")
return metrics_0

def run_test(self, ckpt_path="best") -> Dict[str, Any]:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_tr_pipeline/test_lightning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,9 @@ def test_add_specific_args(self):
def test_run_and_run_test(self, pipeline_instance):
metrics = pipeline_instance.run()
assert isinstance(metrics, dict)
assert "val_loss" in metrics
assert "train_loss" in metrics
assert "test_loss" not in metrics
test_metrics = pipeline_instance.run_test()
assert isinstance(test_metrics, dict)
assert "test_loss" in test_metrics