diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index ef8ffc5..25914c3 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -98,6 +98,8 @@ class Segmentation: mri (MRIData): The MRIData object containing the segmentation volume and affine. lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels to their descriptions. If None, a default numerical mapping is generated. Defaults to None. + Assumes that entries are indexed by the "label" column. If there is no "label" column + the current index is renamed to "label" """ mri: MRIData @@ -111,16 +113,27 @@ def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): self.rois = np.unique(self.mri.data[self.mri.data > 0]) if lut is None: - self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois) - else: - self.lut = lut + lut = pd.DataFrame( + { + "label": self.rois.astype(int), + "description": self.rois.astype(int).astype(str), + } + ).set_index("label") + + self.set_lut(lut, label_column="label" if "label" in lut.columns else None) + self._preprocess_lut() - # Identify the primary label column dynamically - self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] + def _preprocess_lut(self) -> pd.DataFrame: + # dummy function for subclasses to override if they need to preprocess the LUT after loading + pass @classmethod def from_file( - cls, seg_path: Path, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None + cls, + seg_path: Path, + dtype: npt.DTypeLike | None = None, + orient: bool = True, + lut_path: Path | None = None, ) -> "Segmentation": """Loads a Segmentation from a NIfTI file. @@ -136,19 +149,29 @@ def from_file( logger.info(f"Loading segmentation from {seg_path}.") mri = MRIData.from_file(seg_path, dtype=dtype, orient=orient) - if lut_path is None and seg_path.with_suffix(".json").exists(): - lut_path = seg_path.with_suffix(".json") + if lut_path is None: + if seg_path.with_suffix(".csv").exists(): + lut_path = seg_path.with_suffix(".csv") + lut = pd.read_csv(lut_path) + elif seg_path.with_suffix(".json").exists(): + lut_path = seg_path.with_suffix(".json") + lut = pd.read_json(lut_path) if lut_path is not None: logger.info(f"Loading LUT from {lut_path}.") - lut = pd.read_json(lut_path) else: - rois = np.unique(mri.data[mri.data > 0]) - lut = pd.DataFrame({"Label": rois}, index=rois) + lut = None return cls(mri=mri, lut=lut) - def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_code: int = 1006, lut_path: Path | None = None): + def save( + self, + output_path: Path, + dtype: npt.DTypeLike | None = None, + intent_code: int = 1006, + lut_path: Path | None = None, + lut_suffix=".csv", + ): """Saves the Segmentation to a NIfTI file. Args: @@ -157,25 +180,35 @@ def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_cod intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL). """ self.mri.save(output_path, dtype=dtype, intent_code=intent_code) + if lut_path is not None: - self.lut.to_json(lut_path, orient="index") + write_lut(lut_path, self.lut) else: - self.lut.to_json(output_path.with_suffix(".json"), orient="index") + filename = output_path.name.removesuffix("".join(output_path.suffixes)) + write_lut(output_path.parent.joinpath(filename).with_suffix(lut_suffix), self.lut) - def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"): + def set_lut(self, lut: pd.DataFrame, label_column: str | None = None): """Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs. Args: lut (pd.DataFrame): A pandas DataFrame mapping numerical labels to their descriptions. If None, a default numerical mapping is generated. Defaults to None. label_column (str, optional): The name of the column in the LUT that contains the label - descriptions. Defaults to "Label". + descriptions which will be used as the index. If None, use the current index. Defaults to None. + If the index is not already named, it is renamed to "label". """ self.lut = lut - self.label_name = label_column - if self.label_name not in self.lut.columns: - raise ValueError(f"Specified label column '{self.label_name}' not found in LUT.") + + if label_column is not None: + self.lut = lut.set_index(label_column) + self.label_name = label_column + else: + if lut.index.name is not None: # If lut index already is named, use it + self.label_name = lut.index.name + else: # Use label as default name for axis + self.label_name = "label" + self.lut = lut.rename_axis(self.label_name) @property def num_rois(self) -> int: @@ -209,8 +242,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr if not np.isin(rois, self.rois).all(): raise ValueError("Some of the provided ROIs are not present in the segmentation.") - - return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index() + return self.lut.loc[rois.astype(self.lut.index.dtype)] def resample_to_reference(self, reference_mri: MRIData) -> "Segmentation": """ @@ -292,7 +324,11 @@ class FreeSurferSegmentation(Segmentation): @classmethod def from_file( - cls, filepath: Path | str, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None + cls, + filepath: Path | str, + dtype: npt.DTypeLike | None = None, + orient: bool = True, + lut_path: Path | None = None, ) -> "FreeSurferSegmentation": """ Load a FreeSurfer segmentation from a NIfTI file, automatically resolving the LUT. @@ -309,13 +345,13 @@ def from_file( """ resolved_lut_path = resolve_freesurfer_lut_path(lut_path) lut = read_freesurfer_lut(resolved_lut_path) - - # FreeSurfer LUTs index by the "label" column - lut = lut.set_index("label") if "label" in lut.columns else lut - mri = MRIData.from_file(filepath, dtype=dtype, orient=orient) return cls(mri=mri, lut=lut) + def _preprocess_lut(self) -> pd.DataFrame: + # FreeSurfer LUTs index by the "label" column + self.lut = self.lut.query("label < 10000") # Most used FreeSurfer labels + class ExtendedFreeSurferSegmentation(FreeSurferSegmentation): """ @@ -326,6 +362,22 @@ class ExtendedFreeSurferSegmentation(FreeSurferSegmentation): the base FreeSurfer anatomical label (modulus 10000). """ + def _preprocess_lut(self) -> pd.DataFrame: + super()._preprocess_lut() + + # Add CSF and dura tags + base_lut = self.lut.copy() + for i, tissue_type in enumerate(["CSF", "Dura"]): + tissue_lut = base_lut.copy() + tissue_lut.index += 10000 if tissue_type == "CSF" else 20000 + tissue_lut["description"] = tissue_lut["description"] + f"-{tissue_type}" + if np.all(np.isin(["R", "G", "B"], base_lut.columns)): + for col in ["R", "G", "B"]: + tissue_lut[col] = np.clip( + tissue_lut[col] * (0.5 + 0.5 * i), 0, 1 + ) # Shift colors towards blue for CSF and red for Dura + self.lut = pd.concat([self.lut, tissue_lut]) + def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame: """ Retrieves descriptive mappings including the augmented tissue type classifications. @@ -338,21 +390,12 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr pd.DataFrame: A DataFrame mapping the requested ROIs to their base descriptions and their computed 'tissue_type'. """ - rois = self.rois if rois is None else rois - # Use modulus 10000 to extract the base anatomical label from the superclass LUT - freesurfer_labels = super().get_roi_labels(rois % 10000).rename(columns={"ROI": "FreeSurfer_ROI"}) + roi_labels = super().get_roi_labels(rois) - # Get the broad tissue categories based on the numerical offsets + # Add column specifying tissue_type: tissue_type = self.get_tissue_type(rois) - - # Merge the base anatomical names with the tissue types - return freesurfer_labels.merge( - tissue_type, - left_on="FreeSurfer_ROI", - right_on="FreeSurfer_ROI", - how="outer", - ).drop(columns=["FreeSurfer_ROI"])[["ROI", self.label_name, "tissue_type"]] + return pd.merge(roi_labels, tissue_type, on="label") def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame: """ @@ -372,15 +415,14 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF """ rois = self.rois if rois is None else rois - tissue_types = pd.Series( - data=np.where(rois < 10000, "Parenchyma", np.where(rois < 20000, "CSF", "Dura")), - index=rois, - name="tissue_type", - ) + tissue_types = pd.DataFrame( + { + self.label_name: rois, + "tissue_type": np.where(rois < 10000, "Parenchyma", np.where(rois < 20000, "CSF", "Dura")), + } + ).set_index(self.label_name) - ret = pd.DataFrame(tissue_types, columns=["tissue_type"]).rename_axis("ROI").reset_index() - ret["FreeSurfer_ROI"] = ret["ROI"] % 10000 - return ret + return tissue_types @dataclass @@ -574,15 +616,63 @@ def write_lut(filename: Path, table: pd.DataFrame): """ newtable = table.copy() - # Re-scale RGB values to [0, 255] - for col in ["R", "G", "B"]: - newtable[col] = (newtable[col] * 255).astype(int) + if np.all(np.isin(["R", "G", "B"], table.columns)): + # Re-scale RGB values to [0, 255] + for col in ["R", "G", "B"]: + newtable[col] = (newtable[col] * 255).astype(int) - # Reverse Alpha inversion and scale to [0, 255] - newtable["A"] = 255 - (newtable["A"] * 255).astype(int) + # Reverse Alpha inversion and scale to [0, 255] + newtable["A"] = 255 - (newtable["A"] * 255).astype(int) # Save as tab-separated values without headers or indices - newtable.to_csv(filename, sep="\t", index=False, header=False) + if filename.suffix == ".csv": + newtable.to_csv(filename, sep="\t", index=True, header=False) + elif filename.suffix == ".json": + newtable.to_json(filename, index=True, header=False) + else: + newtable.to_csv(filename, sep="\t", index=True, header=False) + + +def procedural_freesurfer_lut(labels: list, descriptions: list, cmap: str | None = None) -> pd.DataFrame: + """ + Generates a FreeSurfer compatible lut with colors for each label in a procedural manner + + Args: + labels (list): list of labels to include in the lut + descriptions (list): list of descriptions associated to each label + cmap (str, optional): Colormap for label regions. Defaults to "hsv". + + Returns: + pd.DataFrame: DataFrame indexed by the label, with RGBA columns + """ + N = len(labels) + if not N == len(descriptions): + raise ValueError("Label and descriptions lists must have same length") + + if cmap is not None: # If a colormap is specified, use cmap from matplotlib + import matplotlib.pyplot as plt + + # Get evenly spaced values between 0 and 1 based on the number of labels + color_indices = np.linspace(0, 0.95, N) + # Sample a colormap + rgb_float = plt.get_cmap(cmap)(color_indices) + else: + rgb_float = [] + import colorsys + + for i in range(N): + h = i / N + rgb = list(colorsys.hsv_to_rgb(h, 1.0, 1.0)) + rgb.append(1.0) # Add transparency + rgb_float.append(rgb) + rgb_float = np.array(rgb_float) + + # Create the DataFrame + df_colors = pd.DataFrame(rgb_float, columns=["R", "G", "B", "A"], index=labels) + df_colors.index.name = "label" + df_colors["description"] = descriptions + lut = df_colors[["description", "R", "G", "B", "A"]] + return lut def add_arguments( @@ -592,7 +682,9 @@ def add_arguments( subparser = parser.add_subparsers(dest="seg-command", help="Commands for segmentation processing") resample_parser = subparser.add_parser( - "resample", help="Resample a segmentation to match the space of a reference MRI", formatter_class=parser.formatter_class + "resample", + help="Resample a segmentation to match the space of a reference MRI", + formatter_class=parser.formatter_class, ) resample_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") resample_parser.add_argument( @@ -602,19 +694,43 @@ def add_arguments( help="Path to the reference MRI \ - usually a registered T1 weighted anatomical scan", ) - resample_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resampled segmentation") + resample_parser.add_argument( + "-o", + "--output", + type=Path, + help="Desired output path for the resampled segmentation", + ) smooth_parser = subparser.add_parser( "smooth", help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map", formatter_class=parser.formatter_class, ) - smooth_parser.add_argument("-i", "--input", type=Path, help="Path to the input (refined) segmentation NIfTI file") - smooth_parser.add_argument("-s", "--sigma", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") smooth_parser.add_argument( - "-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)" + "-i", + "--input", + type=Path, + help="Path to the input (refined) segmentation NIfTI file", + ) + smooth_parser.add_argument( + "-s", + "--sigma", + type=float, + help="Standard deviation for the Gaussian kernel used in smoothing", + ) + smooth_parser.add_argument( + "-c", + "--cutoff", + type=float, + default=0.5, + help="Cutoff score to remove low-confidence voxels (default: 0.5)", + ) + smooth_parser.add_argument( + "-o", + "--output", + type=Path, + help="Desired output path for the smoothed segmentation", ) - smooth_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the smoothed segmentation") refine_parser = subparser.add_parser( "refine", @@ -629,8 +745,18 @@ def add_arguments( help="Path to the reference MRI \ - usually a registered T1 weighted anatomical scan", ) - refine_parser.add_argument("-s", "--smooth", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") - refine_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the refined segmentation") + refine_parser.add_argument( + "-s", + "--smooth", + type=float, + help="Standard deviation for the Gaussian kernel used in smoothing", + ) + refine_parser.add_argument( + "-o", + "--output", + type=Path, + help="Desired output path for the refined segmentation", + ) if extra_args_cb is not None: extra_args_cb(resample_parser) diff --git a/src/mritk/statistics/compute_stats.py b/src/mritk/statistics/compute_stats.py index b5ccc10..cdd4a26 100644 --- a/src/mritk/statistics/compute_stats.py +++ b/src/mritk/statistics/compute_stats.py @@ -12,7 +12,11 @@ import tqdm.rich from ..data import MRIData -from ..segmentation import Segmentation, default_segmentation_groups, read_freesurfer_lut +from ..segmentation import ( + Segmentation, + default_segmentation_groups, + read_freesurfer_lut, +) from ..testing import assert_same_space from .stat_functions import Mean, Median, Statistic, Std from .utils import find_timestamp, prepend_info, voxel_count_to_ml_scale @@ -221,37 +225,28 @@ def generate_stats_dataframe_rois( # Verify that segmentation and MRI are in the same space assert_same_space(seg.mri, mri) - qoi_records = [] # Collects records related to qois - roi_records = [] # Collects records related to ROIs, - # Mask infinite values finite_mask = np.isfinite(mri.data) - for roi in tqdm.rich.tqdm(seg.roi_labels, total=len(seg.roi_labels)): - # Identify rois in segmentation - region_mask = (seg.mri.data == roi) * finite_mask - # print(region_mask.shape) - region_data = mri.data[region_mask] - nb_nans = np.isnan(region_data).sum() - - voxelcount = len(region_data) - roi_records.append( - { - "ROI": roi, - "voxel_count": voxelcount, - "volume_ml": seg.mri.voxel_ml_volume * voxelcount, - "num_nan_values": nb_nans, - } - ) - # Iterate qoi functions - for qoi in qois: - qoi_value = qoi(region_data) - # Store the qoi value in a dataframe, along with the roi label and description - qoi_records.append({"ROI": roi, "statistic": qoi.name, "value": qoi_value}) - - df = pd.DataFrame.from_records(qoi_records) - df_roi = pd.DataFrame.from_records(roi_records) - df = df.merge(df_roi, on="ROI", how="left") + stats_df = pd.DataFrame( + { + seg.label_name: seg.mri.data.ravel(), + "value": (mri.data * finite_mask).ravel(), + } + ) + stats_df = stats_df[stats_df[seg.label_name] > 0] # Remove background + + agg_funcs = [(qoi.name, qoi.func) for qoi in qois] + [ + ("voxel_count", "count"), + ("num_nan_values", lambda x: np.isnan(x).sum()), + ] + stats = stats_df.groupby(seg.label_name)["value"].agg(agg_funcs).reset_index() + stats["volume_ml"] = stats["voxel_count"] * seg.mri.voxel_ml_volume + df = stats.melt( + id_vars=[seg.label_name, "voxel_count", "volume_ml", "num_nan_values"], + var_name="statistic", + value_name="value", + ) # Add some metadata to each row if metadata is not None: diff --git a/tests/test_mri_stats.py b/tests/test_mri_stats.py index d76fbf4..1a8a1a9 100644 --- a/tests/test_mri_stats.py +++ b/tests/test_mri_stats.py @@ -33,7 +33,7 @@ def test_compute_stats_default(example_segmentation: Segmentation, example_value "timestamp": "timestamp", }, ) - print(dataframe.columns) + assert set(dataframe["statistic"]) == {"mean", "std", "median"} def test_compute_stats_default_gonzo(mri_data_dir: Path): diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 899cfd4..7ceeacf 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -28,7 +28,7 @@ def test_segmentation_initialization(example_segmentation: Segmentation): assert example_segmentation.num_rois == 3 assert set(example_segmentation.roi_labels) == {1, 2, 3} assert example_segmentation.lut.shape == (3, 1) - assert set(example_segmentation.lut.columns) == {"Label"} + assert example_segmentation.lut.index.name == "label" def test_freesurfer_segmentation_labels(mri_data_dir: Path): @@ -43,7 +43,7 @@ def test_freesurfer_segmentation_labels(mri_data_dir: Path): labels = fs_seg.get_roi_labels() assert not labels.empty - assert set(labels["ROI"]) == set(fs_seg.roi_labels) + assert set(labels.index.values) == set(fs_seg.roi_labels) def test_extended_freesurfer_segmentation_labels(example_segmentation: Segmentation, mri_data_dir: Path): @@ -53,12 +53,11 @@ def test_extended_freesurfer_segmentation_labels(example_segmentation: Segmentat ext_fs_seg = ExtendedFreeSurferSegmentation(MRIData(data=data, affine=np.eye(4))) labels = ext_fs_seg.get_roi_labels() - - assert set(labels["ROI"]) == set(ext_fs_seg.roi_labels) - assert labels.loc[labels["ROI"] == 10001, "tissue_type"].iloc[0] == "CSF" - assert labels.loc[labels["ROI"] == 20001, "tissue_type"].iloc[0] == "Dura" - assert labels.loc[labels["ROI"] == 10001, "Label"].iloc[0] == labels.loc[labels["ROI"] == 1, "Label"].iloc[0] - assert labels.loc[labels["ROI"] == 20001, "Label"].iloc[0] == labels.loc[labels["ROI"] == 1, "Label"].iloc[0] + assert set(labels.index.values) == set(ext_fs_seg.roi_labels) + assert labels.loc[10001, "tissue_type"] == "CSF" + assert labels.loc[20001, "tissue_type"] == "Dura" + assert labels.loc[10001, "description"] == "1-CSF" + assert labels.loc[20001, "description"] == "1-Dura" def test_default_segmentation_groups(): @@ -169,10 +168,24 @@ def test_write_lut_file_io(tmp_path): # Mock DataFrame matching the parsed structure (normalized floats) data = [ - {"label": 4, "description": "Left-Lateral-Ventricle", "R": 120 / 255.0, "G": 18 / 255.0, "B": 134 / 255.0, "A": 1.0}, - {"label": 5, "description": "Left-Inf-Lat-Vent", "R": 198 / 255.0, "G": 51 / 255.0, "B": 122 / 255.0, "A": 1.0}, + { + "label": 4, + "description": "Left-Lateral-Ventricle", + "R": 120 / 255.0, + "G": 18 / 255.0, + "B": 134 / 255.0, + "A": 1.0, + }, + { + "label": 5, + "description": "Left-Inf-Lat-Vent", + "R": 198 / 255.0, + "G": 51 / 255.0, + "B": 122 / 255.0, + "A": 1.0, + }, ] - df = pd.DataFrame(data) + df = pd.DataFrame(data).set_index("label") write_lut(dummy_lut_file, df) @@ -211,6 +224,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty smoothing = 1 piece_fs_seg = Segmentation(mri=piece_fs_seg_data) result = piece_fs_seg.resample_to_reference(piece_ref_mri_data) + smoothed = result.smooth(sigma=smoothing) result.mri.data = smoothed.mri.data result.save(test_output, dtype=np.int32) @@ -256,7 +270,18 @@ def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type): def test_dispatch_resample(mock_seg, mock_mri_data): """Test that dispatch correctly routes to segmentation resample.""" - mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"]) + mritk.cli.main( + [ + "seg", + "resample", + "-i", + "mock_in.nii.gz", + "-r", + "mock_ref.nii.gz", + "-o", + "mock_out.nii.gz", + ] + ) mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"))