Skip to content

Commit

Permalink
Start implementing nifti file wrapper (#79)
Browse files Browse the repository at this point in the history
Implement nifti file wrapper, update MRC error handling
  • Loading branch information
constantinpape authored Dec 4, 2023
1 parent c0f5f61 commit a47d46d
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .github/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ dependencies:
- imageio
- intern
- mrcfile
- nibabel
- nifty >=1.1
- numba
- pandas
- pip
- python
- pytest
- scikit-image
- skan
- tqdm
Expand Down
8 changes: 8 additions & 0 deletions elf/io/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .image_stack_wrapper import ImageStackFile, ImageStackDataset
from .knossos_wrapper import KnossosFile, KnossosDataset
from .mrc_wrapper import MRCFile, MRCDataset
from .nifti_wrapper import NiftiFile, NiftiDataset
from .intern_wrapper import InternFile, InternDataset


Expand Down Expand Up @@ -80,6 +81,13 @@ def register_filetype(constructor, extensions=(), groups=(), datasets=()):
except ImportError:
intern = None

# add nifti extensions if we have nibabel
try:
import nibabel
register_filetype(NiftiFile, [".nii.gz", ".nii"], NiftiFile, NiftiDataset)
except ImportError:
nibabel = None


def identity(arg):
return arg
Expand Down
17 changes: 15 additions & 2 deletions elf/io/files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path

from .extensions import (
FILE_CONSTRUCTORS, GROUP_LIKE, DATASET_LIKE,
h5py, z5py, pyn5, zarr,
Expand Down Expand Up @@ -27,12 +28,23 @@ def open_file(path, mode="a", ext=None, **kwargs):
ext [str] - file extension. This can be used to force an extension
if it cannot be inferred from the filename. (default: None)
"""

# Before checking the extension suffix, check for "protocol-style"
# cloud provider prefixes.
if "://" in path:
ext = path.split("://")[0] + "://"

ext = os.path.splitext(path.rstrip("/"))[1] if ext is None else ext
elif ext is None:
path_ = Path(path.rstrip("/"))
suffixes = path_.suffixes
# We need to treat .nii.gz differently
if len(suffixes) == 2 and "".join(suffixes) == ".nii.gz":
ext = ".nii.gz"
elif len(suffixes) == 0:
ext = ""
else:
ext = suffixes[-1]

try:
constructor = FILE_CONSTRUCTORS[ext.lower()]
except KeyError:
Expand All @@ -42,6 +54,7 @@ def open_file(path, mode="a", ext=None, **kwargs):
f"{' '.join(supported_extensions())}. "
f"You may need to install additional dependencies (h5py, z5py, zarr, intern)."
)

return constructor(path, mode=mode, **kwargs)


Expand Down
19 changes: 11 additions & 8 deletions elf/io/mrc_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,37 +49,40 @@ class MRCFile(Mapping):
""" Wrapper for an mrc file
"""

def __init__(self, path, mode='r'):
def __init__(self, path, mode="r"):
self.path = path
self.mode = mode
if mrcfile is None:
raise AttributeError("mrcfile is not available")
raise AttributeError("mrcfile is required to read mrc or rec files, but is not installed")
try:
self._f = mrcfile.mmap(self.path, self.mode)
except ValueError as e:

# check if error comes from old version of SerialEM used for acquisition
if "Unrecognised machine stamp: 0x44 0x00 0x00 0x00" in str(e):
if (
"Unrecognised machine stamp: 0x44 0x00 0x00 0x00" in str(e) or
"Unrecognised machine stamp: 0x00 0x00 0x00 0x00" in str(e)
):
try:
self._f = mrcfile.mmap(self.path, self.mode, permissive='True')
self._f = mrcfile.mmap(self.path, self.mode, permissive="True")
except ValueError:
self._f = mrcfile.open(self.path, self.mode, permissive='True')
self._f = mrcfile.open(self.path, self.mode, permissive="True")
else:
self._f = mrcfile.open(self.path, self.mode)

def __getitem__(self, key):
if key != 'data':
if key != "data":
raise KeyError(f"Could not find key {key}")
return MRCDataset(self._f.data)

def __iter__(self):
yield 'data'
yield "data"

def __len__(self):
return 1

def __contains__(self, name):
return name == 'data'
return name == "data"

def __enter__(self):
return self
Expand Down
80 changes: 80 additions & 0 deletions elf/io/nifti_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections.abc import Mapping
from ..util import normalize_index, squeeze_singletons

import numpy as np
try:
import nibabel
except ImportError:
nibabel = None


class NiftiFile(Mapping):
def __init__(self, path, mode="r"):
if nibabel is None:
raise AttributeError("nibabel is required for nifti images, but is not installed.")
self.path = path
self.mode = mode
self.nifti = nibabel.load(self.path)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass

# dummy attrs to be compatible with h5py/z5py/zarr API
# alternatively we could also map the header to attributes
@property
def attrs(self):
return {}

def __getitem__(self, key):
if key != "data":
raise KeyError(f"Could not find key {key}")
return NiftiDataset(self.nifti)

def __iter__(self):
yield "data"

def __len__(self):
return 1

def __contains__(self, name):
return name == "data"


class NiftiDataset:
def __init__(self, data):
self._data = data

@property
def dtype(self):
return self.data.get_data_dtype()

@property
def ndim(self):
return self._data.ndim

@property
def chunks(self):
return None

@property
def shape(self):
return self._data.shape[::-1]

def __getitem__(self, key):
key, to_squeeze = normalize_index(key, self.shape)
transposed_key = key[::-1]
data = self._data.dataobj[transposed_key].T
return squeeze_singletons(data, to_squeeze)

@property
def size(self):
return np.prod(self._data.shape)

# dummy attrs to be compatible with h5py/z5py/zarr API
# alternatively we could also map the header to attributes
@property
def attrs(self):
return {}
61 changes: 61 additions & 0 deletions test/io_tests/test_nifti_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import unittest
from glob import glob

import numpy as np

try:
import nibabel
except ImportError:
nibabel = None


@unittest.skipIf(nibabel is None, "Needs nibabel")
class TestNiftiWrapper(unittest.TestCase):

def _check_data(self, expected_data, f):
dset = f["data"]

self.assertEqual(expected_data.shape, dset.shape)
shape = dset.shape

# bounding boxes for testing sub-sampling
bbs = [0, np.s_[:]]
for i in range(dset.ndim):
bbs.extend([
tuple(slice(0, shape[i] // 2) if d == i else slice(None) for d in range(dset.ndim)),
tuple(slice(shape[i] // 2, None) if d == i else slice(None) for d in range(dset.ndim))
])
bbs.append(
tuple(slice(shape[i] // 4, 3 * shape[i] // 4) for i in range(dset.ndim))
)

for bb in bbs:
self.assertTrue(np.allclose(dset[bb], expected_data[bb]))

def test_read_nifti(self):
from elf.io import open_file
from nibabel.testing import data_path

paths = glob(os.path.join(data_path, "*.nii"))
for path in paths:
expected_data = np.asarray(nibabel.load(path).dataobj).T
# the resampled image causes errors
if os.path.basename(path).startswith("resampled"):
continue
with open_file(path, "r") as f:
self._check_data(expected_data, f)

def test_read_nifti_compressed(self):
from elf.io import open_file
from nibabel.testing import data_path

paths = glob(os.path.join(data_path, "*.nii.gz"))
for path in paths:
expected_data = np.asarray(nibabel.load(path).dataobj).T
with open_file(path, "r") as f:
self._check_data(expected_data, f)


if __name__ == "__main__":
unittest.main()

0 comments on commit a47d46d

Please sign in to comment.