From a47d46ddb443962c640081052ac936f6d7922e61 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 4 Dec 2023 14:10:21 +0100 Subject: [PATCH] Start implementing nifti file wrapper (#79) Implement nifti file wrapper, update MRC error handling --- .github/environment.yaml | 2 + elf/io/extensions.py | 8 +++ elf/io/files.py | 17 +++++- elf/io/mrc_wrapper.py | 19 ++++--- elf/io/nifti_wrapper.py | 80 +++++++++++++++++++++++++++++ test/io_tests/test_nifti_wrapper.py | 61 ++++++++++++++++++++++ 6 files changed, 177 insertions(+), 10 deletions(-) create mode 100644 elf/io/nifti_wrapper.py create mode 100644 test/io_tests/test_nifti_wrapper.py diff --git a/.github/environment.yaml b/.github/environment.yaml index 23fff27..a4848de 100644 --- a/.github/environment.yaml +++ b/.github/environment.yaml @@ -8,11 +8,13 @@ dependencies: - imageio - intern - mrcfile + - nibabel - nifty >=1.1 - numba - pandas - pip - python + - pytest - scikit-image - skan - tqdm diff --git a/elf/io/extensions.py b/elf/io/extensions.py index 24dba0a..1746456 100644 --- a/elf/io/extensions.py +++ b/elf/io/extensions.py @@ -5,6 +5,7 @@ from .image_stack_wrapper import ImageStackFile, ImageStackDataset from .knossos_wrapper import KnossosFile, KnossosDataset from .mrc_wrapper import MRCFile, MRCDataset +from .nifti_wrapper import NiftiFile, NiftiDataset from .intern_wrapper import InternFile, InternDataset @@ -80,6 +81,13 @@ def register_filetype(constructor, extensions=(), groups=(), datasets=()): except ImportError: intern = None +# add nifti extensions if we have nibabel +try: + import nibabel + register_filetype(NiftiFile, [".nii.gz", ".nii"], NiftiFile, NiftiDataset) +except ImportError: + nibabel = None + def identity(arg): return arg diff --git a/elf/io/files.py b/elf/io/files.py index 6cf9493..f7eb29d 100644 --- a/elf/io/files.py +++ b/elf/io/files.py @@ -1,4 +1,5 @@ -import os +from pathlib import Path + from .extensions import ( FILE_CONSTRUCTORS, GROUP_LIKE, DATASET_LIKE, h5py, z5py, pyn5, zarr, @@ -27,12 +28,23 @@ def open_file(path, mode="a", ext=None, **kwargs): ext [str] - file extension. This can be used to force an extension if it cannot be inferred from the filename. (default: None) """ + # Before checking the extension suffix, check for "protocol-style" # cloud provider prefixes. if "://" in path: ext = path.split("://")[0] + "://" - ext = os.path.splitext(path.rstrip("/"))[1] if ext is None else ext + elif ext is None: + path_ = Path(path.rstrip("/")) + suffixes = path_.suffixes + # We need to treat .nii.gz differently + if len(suffixes) == 2 and "".join(suffixes) == ".nii.gz": + ext = ".nii.gz" + elif len(suffixes) == 0: + ext = "" + else: + ext = suffixes[-1] + try: constructor = FILE_CONSTRUCTORS[ext.lower()] except KeyError: @@ -42,6 +54,7 @@ def open_file(path, mode="a", ext=None, **kwargs): f"{' '.join(supported_extensions())}. " f"You may need to install additional dependencies (h5py, z5py, zarr, intern)." ) + return constructor(path, mode=mode, **kwargs) diff --git a/elf/io/mrc_wrapper.py b/elf/io/mrc_wrapper.py index 7fbe357..0a060e1 100755 --- a/elf/io/mrc_wrapper.py +++ b/elf/io/mrc_wrapper.py @@ -49,37 +49,40 @@ class MRCFile(Mapping): """ Wrapper for an mrc file """ - def __init__(self, path, mode='r'): + def __init__(self, path, mode="r"): self.path = path self.mode = mode if mrcfile is None: - raise AttributeError("mrcfile is not available") + raise AttributeError("mrcfile is required to read mrc or rec files, but is not installed") try: self._f = mrcfile.mmap(self.path, self.mode) except ValueError as e: # check if error comes from old version of SerialEM used for acquisition - if "Unrecognised machine stamp: 0x44 0x00 0x00 0x00" in str(e): + if ( + "Unrecognised machine stamp: 0x44 0x00 0x00 0x00" in str(e) or + "Unrecognised machine stamp: 0x00 0x00 0x00 0x00" in str(e) + ): try: - self._f = mrcfile.mmap(self.path, self.mode, permissive='True') + self._f = mrcfile.mmap(self.path, self.mode, permissive="True") except ValueError: - self._f = mrcfile.open(self.path, self.mode, permissive='True') + self._f = mrcfile.open(self.path, self.mode, permissive="True") else: self._f = mrcfile.open(self.path, self.mode) def __getitem__(self, key): - if key != 'data': + if key != "data": raise KeyError(f"Could not find key {key}") return MRCDataset(self._f.data) def __iter__(self): - yield 'data' + yield "data" def __len__(self): return 1 def __contains__(self, name): - return name == 'data' + return name == "data" def __enter__(self): return self diff --git a/elf/io/nifti_wrapper.py b/elf/io/nifti_wrapper.py new file mode 100644 index 0000000..1df63f7 --- /dev/null +++ b/elf/io/nifti_wrapper.py @@ -0,0 +1,80 @@ +from collections.abc import Mapping +from ..util import normalize_index, squeeze_singletons + +import numpy as np +try: + import nibabel +except ImportError: + nibabel = None + + +class NiftiFile(Mapping): + def __init__(self, path, mode="r"): + if nibabel is None: + raise AttributeError("nibabel is required for nifti images, but is not installed.") + self.path = path + self.mode = mode + self.nifti = nibabel.load(self.path) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + # dummy attrs to be compatible with h5py/z5py/zarr API + # alternatively we could also map the header to attributes + @property + def attrs(self): + return {} + + def __getitem__(self, key): + if key != "data": + raise KeyError(f"Could not find key {key}") + return NiftiDataset(self.nifti) + + def __iter__(self): + yield "data" + + def __len__(self): + return 1 + + def __contains__(self, name): + return name == "data" + + +class NiftiDataset: + def __init__(self, data): + self._data = data + + @property + def dtype(self): + return self.data.get_data_dtype() + + @property + def ndim(self): + return self._data.ndim + + @property + def chunks(self): + return None + + @property + def shape(self): + return self._data.shape[::-1] + + def __getitem__(self, key): + key, to_squeeze = normalize_index(key, self.shape) + transposed_key = key[::-1] + data = self._data.dataobj[transposed_key].T + return squeeze_singletons(data, to_squeeze) + + @property + def size(self): + return np.prod(self._data.shape) + + # dummy attrs to be compatible with h5py/z5py/zarr API + # alternatively we could also map the header to attributes + @property + def attrs(self): + return {} diff --git a/test/io_tests/test_nifti_wrapper.py b/test/io_tests/test_nifti_wrapper.py new file mode 100644 index 0000000..f963d5b --- /dev/null +++ b/test/io_tests/test_nifti_wrapper.py @@ -0,0 +1,61 @@ +import os +import unittest +from glob import glob + +import numpy as np + +try: + import nibabel +except ImportError: + nibabel = None + + +@unittest.skipIf(nibabel is None, "Needs nibabel") +class TestNiftiWrapper(unittest.TestCase): + + def _check_data(self, expected_data, f): + dset = f["data"] + + self.assertEqual(expected_data.shape, dset.shape) + shape = dset.shape + + # bounding boxes for testing sub-sampling + bbs = [0, np.s_[:]] + for i in range(dset.ndim): + bbs.extend([ + tuple(slice(0, shape[i] // 2) if d == i else slice(None) for d in range(dset.ndim)), + tuple(slice(shape[i] // 2, None) if d == i else slice(None) for d in range(dset.ndim)) + ]) + bbs.append( + tuple(slice(shape[i] // 4, 3 * shape[i] // 4) for i in range(dset.ndim)) + ) + + for bb in bbs: + self.assertTrue(np.allclose(dset[bb], expected_data[bb])) + + def test_read_nifti(self): + from elf.io import open_file + from nibabel.testing import data_path + + paths = glob(os.path.join(data_path, "*.nii")) + for path in paths: + expected_data = np.asarray(nibabel.load(path).dataobj).T + # the resampled image causes errors + if os.path.basename(path).startswith("resampled"): + continue + with open_file(path, "r") as f: + self._check_data(expected_data, f) + + def test_read_nifti_compressed(self): + from elf.io import open_file + from nibabel.testing import data_path + + paths = glob(os.path.join(data_path, "*.nii.gz")) + for path in paths: + expected_data = np.asarray(nibabel.load(path).dataobj).T + with open_file(path, "r") as f: + self._check_data(expected_data, f) + + +if __name__ == "__main__": + unittest.main()