diff --git a/medsegpy/config.py b/medsegpy/config.py index 44d0e9ca..274d1dd4 100644 --- a/medsegpy/config.py +++ b/medsegpy/config.py @@ -127,6 +127,9 @@ class Config(object): VAL_DATASET = "" TEST_DATASET = "" + # Inference only, with no evaluation + INFERENCE_ONLY = False + # Cross-Validation-Parameters USE_CROSS_VALIDATION = False CV_FILE = "" @@ -550,6 +553,7 @@ def summary(self, additional_vars=None): "TRAIN_DATASET", "VAL_DATASET", "TEST_DATASET", + "INFERENCE_ONLY", "", "CATEGORIES", "", diff --git a/medsegpy/data/data_loader.py b/medsegpy/data/data_loader.py index 3208192a..fa080e3b 100644 --- a/medsegpy/data/data_loader.py +++ b/medsegpy/data/data_loader.py @@ -18,6 +18,7 @@ from medsegpy.config import Config from medsegpy.modeling import Model from medsegpy.utils import env +from pydicom import dcmread from .data_utils import add_background_labels, collect_mask, compute_patches from .transforms import apply_transform_gens, build_preprocessing @@ -242,8 +243,13 @@ def _load_input(self, dataset_dict): if self._cached_data is not None: image, mask = self._cached_data[(image_file, sem_seg_file)] else: - with h5py.File(image_file, "r") as f: - image = f["data"][:] + if image_file.endswith('.dcm'): + ds = dcmread(image_file) + image = ds.pixel_array + else: + with h5py.File(image_file, "r") as f: + image = f["data"][:] + if image.shape[-1] != 1: image = image[..., np.newaxis] @@ -323,7 +329,11 @@ def _restructure_data(self, vols: Sequence[np.ndarray]): axes = (1, 2, 0) if v.ndim > 3: axes = axes + tuple(i for i in range(3, v.ndim)) - new_vols.append(v.transpose(axes)) + # new_vols.append(v.transpose(axes)) + if v.ndim == 1: + new_vols.append(v) + else: + new_vols.append(v.transpose(axes)) vols = (np.squeeze(v) for v in new_vols) return tuple(vols) @@ -337,6 +347,9 @@ def inference(self, model: Model, **kwargs): workers = kwargs.pop("workers", self._cfg.NUM_WORKERS) use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1) + + kwargs["batch_size"] = 1 + for scan_id in scan_ids: self._dataset_dicts = scan_to_dict_mapping[scan_id] diff --git a/medsegpy/evaluation/evaluator.py b/medsegpy/evaluation/evaluator.py index 7103d3f5..6230bfe2 100644 --- a/medsegpy/evaluation/evaluator.py +++ b/medsegpy/evaluation/evaluator.py @@ -139,7 +139,10 @@ def inference_on_dataset( eval_start = time.perf_counter() logger.info("Begin evaluation...") - results = {e.__class__.__name__: e.evaluate() for e in evaluator} + if any([e._config.INFERENCE_ONLY for e in evaluator]): + results = None + else: + results = {e.__class__.__name__: e.evaluate() for e in evaluator} total_eval_time = time.perf_counter() - eval_start logger.info("Time Elapsed: {:.4f} seconds".format(total_compute_time + total_eval_time)) # An evaluator may return None when not in main process. diff --git a/medsegpy/evaluation/sem_seg_evaluation.py b/medsegpy/evaluation/sem_seg_evaluation.py index 6b8c192a..2da2b5d8 100644 --- a/medsegpy/evaluation/sem_seg_evaluation.py +++ b/medsegpy/evaluation/sem_seg_evaluation.py @@ -163,15 +163,16 @@ def eval_single_scan(self, input, output, labels, time_elapsed): metrics_kwargs = {"spacing": spacing} if spacing is not None else {} - summary = metrics_manager( - scan_id, y_true=y_true, y_pred=labels, x=x, runtime=time_elapsed, **metrics_kwargs - ) - - logger_info_str = "Scan #{:03d} (name = {}, {:0.2f}s) = {}".format( - scan_cnt, scan_id, time_elapsed, summary - ) - self._results_str = self._results_str + logger_info_str + "\n" - logger.info(logger_info_str) + if not self._config.INFERENCE_ONLY: + summary = metrics_manager( + scan_id, y_true=y_true, y_pred=labels, x=x, runtime=time_elapsed, **metrics_kwargs + ) + + logger_info_str = "Scan #{:03d} (name = {}, {:0.2f}s) = {}".format( + scan_cnt, scan_id, time_elapsed, summary + ) + self._results_str = self._results_str + logger_info_str + "\n" + logger.info(logger_info_str) if output_dir and save_raw_data: save_name = "{}/{}.pred".format(output_dir, scan_id) diff --git a/medsegpy/modeling/model.py b/medsegpy/modeling/model.py index d514653d..34c650be 100644 --- a/medsegpy/modeling/model.py +++ b/medsegpy/modeling/model.py @@ -43,6 +43,7 @@ def inference_generator( workers=1, use_multiprocessing=False, verbose=0, + batch_size=None ): return self.inference_generator_static( self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose @@ -58,6 +59,7 @@ def inference_generator_static( workers=1, use_multiprocessing=False, verbose=0, + batch_size=None ): """Generates predictions for the input samples from a data generator and returns inputs, ground truth, and predictions. @@ -116,6 +118,7 @@ def inference_generator_static( workers=workers, use_multiprocessing=use_multiprocessing, verbose=verbose, + batch_size=batch_size ) else: return model._inference_generator_tf1( @@ -295,6 +298,7 @@ def _inference_generator_tf2( batch_x, batch_y, batch_x_raw = _extract_inference_inputs(next(iterator)) # tmp_batch_outputs = predict_function(iterator) tmp_batch_outputs = model.predict(batch_x) + if data_handler.should_sync: context.async_wait() # noqa: F821 batch_outputs = tmp_batch_outputs # No error, now safe to assign.