From ce6eb7771a8b1e4eb915224861c4cf2a92e7149f Mon Sep 17 00:00:00 2001 From: erasdna Date: Fri, 13 Mar 2026 10:44:57 +0100 Subject: [PATCH 1/6] add segmentation class, rework statistics dataframe --- src/mritk/data/base.py | 18 ++ src/mritk/data/io.py | 62 +++--- src/mritk/segmentation/__init__.py | 4 +- src/mritk/segmentation/lookup_table.py | 32 ++- src/mritk/segmentation/segmentation.py | 43 ++++ src/mritk/statistics/__init__.py | 4 +- src/mritk/statistics/cli.py | 133 ++++++------ src/mritk/statistics/compute_stats.py | 155 ++++---------- src/mritk/statistics/metadata.py | 45 ++++ src/mritk/statistics/stat_functions.py | 53 +++++ src/mritk/statistics/utils.py | 5 +- test/conftest.py | 28 +++ test/test_metadata.py | 27 +++ test/test_mri_io.py | 45 +++- test/test_mri_stats.py | 271 ++++++++++++++----------- 15 files changed, 575 insertions(+), 350 deletions(-) create mode 100644 src/mritk/segmentation/segmentation.py create mode 100644 src/mritk/statistics/metadata.py create mode 100644 src/mritk/statistics/stat_functions.py create mode 100644 test/test_metadata.py diff --git a/src/mritk/data/base.py b/src/mritk/data/base.py index 7880e7d..08e8b53 100644 --- a/src/mritk/data/base.py +++ b/src/mritk/data/base.py @@ -6,12 +6,23 @@ import numpy as np +from pathlib import Path +from .io import load_mri_data, save_mri_data class MRIData: def __init__(self, data: np.ndarray, affine: np.ndarray): self.data = data self.affine = affine + self.dtype = data.dtype + + @classmethod + def from_file(cls, filepath: Path): + data, affine = load_mri_data(filepath, np.float64) + return cls(data=data, affine=affine) + + def save_mri_data(self, save_path: Path, intent_code: int | None = None): + save_mri_data(self.data, self.affine, save_path, self.dtype, intent_code) def get_data(self): return self.data @@ -22,3 +33,10 @@ def get_metadata(self): @property def shape(self) -> tuple[int, ...]: return self.data.shape + + @property + def voxel_ml_volume(self) -> float: + # Calculate the volume of a single voxel in milliliters + voxel_volume_mm3 = abs(np.linalg.det(self.affine[:3, :3])) + voxel_volume_ml = voxel_volume_mm3 / 1000.0 # Convert from mm^3 to ml + return voxel_volume_ml diff --git a/src/mritk/data/io.py b/src/mritk/data/io.py index ba3547c..601ac42 100644 --- a/src/mritk/data/io.py +++ b/src/mritk/data/io.py @@ -9,53 +9,47 @@ import nibabel import numpy as np import numpy.typing as npt -import re from typing import Optional -from .base import MRIData -from .orientation import data_reorientation - - -def load_mri_data( - path: Path | str, - dtype: type = np.float64, - orient: bool = True, -) -> MRIData: - suffix_regex = re.compile(r".+(?P(\.nii(\.gz|)|\.mg(z|h)))") - m = suffix_regex.match(Path(path).name) - if (m is not None) and (m.groupdict()["suffix"] in (".nii", ".nii.gz")): - mri = nibabel.nifti1.load(path) - elif (m is not None) and (m.groupdict()["suffix"] in (".mgz", ".mgh")): - mri = nibabel.freesurfer.mghformat.load(path) + +def check_suffix(filepath: Path): + suffix = filepath.suffix + if suffix == ".gz": + suffixes = filepath.suffixes + if len(suffixes) >= 2 and suffixes[-2] == ".nii": + return ".nii.gz" + return suffix + + +def load_mri_data(filepath: Path, dtype: type = np.float64) -> tuple[np.ndarray, np.ndarray]: + suffix = check_suffix(filepath) + if suffix in (".nii", ".nii.gz"): + mri = nibabel.nifti1.load(filepath) + elif suffix in (".mgz", ".mgh"): + mri = nibabel.freesurfer.mghformat.load(filepath) else: - raise ValueError(f"Invalid suffix {path}, should be either '.nii', or '.mgz'") + raise ValueError(f"Invalid suffix {filepath}, should be either '.nii', or '.mgz'") affine = mri.affine if affine is None: - raise RuntimeError("MRI do not contain affine") + raise RuntimeError("MRI does not contain an affine") data = np.asarray(mri.get_fdata("unchanged"), dtype=dtype) - mri = MRIData(data=data, affine=affine) - if orient: - return data_reorientation(mri) - else: - return mri + return data, affine -def save_mri_data(mri: MRIData, path: Path, dtype: npt.DTypeLike, intent_code: Optional[int] = None): - # TODO : Choose other way to check extension than regex ? - suffix_regex = re.compile(r".+(?P(\.nii(\.gz|)|\.mg(z|h)))") - m = suffix_regex.match(Path(path).name) - if (m is not None) and (m.groupdict()["suffix"] in (".nii", ".nii.gz")): - nii = nibabel.nifti1.Nifti1Image(mri.data.astype(dtype), mri.affine) +def save_mri_data(data: np.ndarray, affine: np.ndarray, save_path: Path, dtype: npt.DTypeLike, intent_code: Optional[int] = None): + suffix = check_suffix(save_path) + if suffix in (".nii", ".nii.gz"): + nii = nibabel.nifti1.Nifti1Image(data.astype(dtype), affine) if intent_code is not None: nii.header.set_intent(intent_code) - nibabel.nifti1.save(nii, path) - elif (m is not None) and (m.groupdict()["suffix"] in (".mgz", ".mgh")): - mgh = nibabel.freesurfer.mghformat.MGHImage(mri.data.astype(dtype), mri.affine) + nibabel.nifti1.save(nii, save_path) + elif suffix in (".mgz", ".mgh"): + mgh = nibabel.freesurfer.mghformat.MGHImage(data.astype(dtype), affine) if intent_code is not None: mgh.header.set_intent(intent_code) - nibabel.freesurfer.mghformat.save(mgh, path) + nibabel.freesurfer.mghformat.save(mgh, save_path) else: - raise ValueError(f"Invalid suffix {path}, should be either '.nii', or '.mgz'") + raise ValueError(f"Invalid suffix {save_path}, should be either '.nii', or '.mgz'") diff --git a/src/mritk/segmentation/__init__.py b/src/mritk/segmentation/__init__.py index 67400a7..689410d 100644 --- a/src/mritk/segmentation/__init__.py +++ b/src/mritk/segmentation/__init__.py @@ -3,6 +3,6 @@ # Copyright (C) 2026 Simula Research Laboratory -from . import groups, lookup_table +from . import groups, lookup_table, segmentation -__all__ = ["groups", "lookup_table"] +__all__ = ["groups", "lookup_table", "segmentation"] diff --git a/src/mritk/segmentation/lookup_table.py b/src/mritk/segmentation/lookup_table.py index 9d55348..dabe313 100644 --- a/src/mritk/segmentation/lookup_table.py +++ b/src/mritk/segmentation/lookup_table.py @@ -5,7 +5,6 @@ # Copyright (C) 2026 Simula Research Laboratory -import re import os from pathlib import Path import pandas as pd @@ -23,12 +22,21 @@ def read_lut(filename: Path | str | None) -> pd.DataFrame: if not filename.exists(): url = "https://github.com/freesurfer/freesurfer/raw/dev/distribution/FreeSurferColorLUT.txt" urlretrieve(url, filename) - lut_regex = re.compile( - r"^(?P