Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xarray.DataArray imagej metadata and improved axis/scale logic #247

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
728619c
Add Linear and Enumerated axis classes to jc
elevans Jan 27, 2023
53d23ea
Add _create_image_metadata function
elevans Jan 30, 2023
d5375dc
Update _assign_axes to use "imagej" metadata
elevans Jan 30, 2023
9adf026
Add metadata for all CalibratedAxis types
elevans Jan 31, 2023
ed8f794
Refactor dims._assign_axes()
elevans Jan 31, 2023
719fe25
Remove deprecated Linear and Enumerated axis func
elevans Jan 31, 2023
1c0f385
Check for "Hello" key in attributes
elevans Jan 31, 2023
ab2238c
Add case for singleton coordinates array
elevans Feb 4, 2023
8cf19e5
Remove unnecessary to_java() call on coords
elevans Feb 4, 2023
8ed5fa4
Fix typo in jimport of VariableAxis
elevans Feb 10, 2023
d58d255
Use jc.CalibratedAxis for type hint for axes seqs
elevans Feb 10, 2023
fac7f33
Use if clause with ValueError instead of assert
elevans Mar 28, 2023
0f175dd
Add metadata submodule
elevans Apr 2, 2023
61a814c
Use one-liner to fetch dict key
elevans Apr 4, 2023
56ca272
Refactor metadata module to be more pythonic
elevans Apr 14, 2023
8af2bdc
Refactor create imagej metadata function
elevans Apr 17, 2023
6cd7eb8
Refactor _assign_axes to use new metadata layout
elevans Apr 25, 2023
1956351
Improve docstring for axis functions
elevans Apr 25, 2023
615ee25
Add type hint to create_imagej_metadata function
elevans Apr 25, 2023
88f5040
Remove axis submodule and axis to str methods
elevans May 18, 2023
bcfa002
Refactor the xarray metadata creation funciton
elevans May 18, 2023
ccee1b3
Add is_rgb_merged method to check if image is RGB
elevans May 18, 2023
80ed95a
Use the metadata module for image metadata
elevans May 18, 2023
28ab749
Apply Black formatting
elevans May 18, 2023
73c5734
Add array module for xarray accessors
elevans May 18, 2023
b88e513
Add set and get methods to MetadataAccessor
elevans May 18, 2023
205cff9
Add axes property to MetadataAccessor
elevans May 18, 2023
097c907
Add tree method to MetadataAccessor
elevans May 18, 2023
94db806
Add type check for JavaMap to metadata tree func
elevans May 18, 2023
e281221
Use xarray MetadataAccessor class for metadata
elevans May 18, 2023
f2aef68
Add Flake8 bypass for imagej.array import
elevans May 18, 2023
4c5bba6
Store the metadata dict in the xarray global attr
elevans May 19, 2023
93d6ba9
Add _update method to MetadataAccessor
elevans May 22, 2023
30c4fa3
Pre-create the "image" xarr attr dict
elevans May 22, 2023
b8f3b6c
Add "axisLength" metadata update to array module
elevans May 22, 2023
58f19b1
Add logic for is_rgb method to ImgAccessor
elevans May 22, 2023
3beccd9
Fix metadata axes sorting
elevans May 22, 2023
d0b208b
Remove duplicate comment
elevans May 22, 2023
12da19e
Add check for axis labels in xarray dims
elevans May 22, 2023
5b611c4
Use MetadataAccessor to assign dataset axes
elevans May 22, 2023
0b1ab8b
Remove metadata module
elevans May 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/imagej/_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class MyJavaClasses(JavaClasses):
significantly easier and more readable.
"""

@JavaClasses.java_import
def Double(self):
return "java.lang.Double"

@JavaClasses.java_import
def Throwable(self):
return "java.lang.Throwable"
Expand All @@ -50,6 +54,62 @@ def MetadataWrapper(self):
def LabelingIOService(self):
return "io.scif.labeling.LabelingIOService"

@JavaClasses.java_import
def ChapmanRichardsAxis(self):
return "net.imagej.axis.ChapmanRichardsAxis"

@JavaClasses.java_import
def DefaultLinearAxis(self):
return "net.imagej.axis.DefaultLinearAxis"

@JavaClasses.java_import
def EnumeratedAxis(self):
return "net.imagej.axis.EnumeratedAxis"

@JavaClasses.java_import
def ExponentialAxis(self):
return "net.imagej.axis.ExponentialAxis"

@JavaClasses.java_import
def ExponentialRecoveryAxis(self):
return "net.imagej.axis.ExponentialRecoveryAxis"

@JavaClasses.java_import
def GammaVariateAxis(self):
return "net.imagej.axis.GammaVariateAxis"

@JavaClasses.java_import
def GaussianAxis(self):
return "net.imagej.axis.GaussianAxis"

@JavaClasses.java_import
def IdentityAxis(self):
return "net.imagej.axis.IdentityAxis"

@JavaClasses.java_import
def InverseRodbardAxis(self):
return "net.imagej.axis.InverseRodbardAxis"

@JavaClasses.java_import
def LogLinearAxis(self):
return "net.imagej.axis.LogLinearAxis"

@JavaClasses.java_import
def PolynomialAxis(self):
return "net.imagej.axis.PolynomialAxis"

@JavaClasses.java_import
def PowerAxis(self):
return "net.imagej.axis.PowerAxis"

@JavaClasses.java_import
def RodbardAxis(self):
return "net.imagej.axis.RodbardAxis"

@JavaClasses.java_import
def VariableAxis(self):
return "net.imagej.axis.VariableAxis"

@JavaClasses.java_import
def Dataset(self):
return "net.imagej.Dataset"
Expand Down Expand Up @@ -106,6 +166,10 @@ def ImgView(self):
def ImgLabeling(self):
return "net.imglib2.roi.labeling.ImgLabeling"

@JavaClasses.java_import
def IntegerType(self):
return "net.imglib2.type.numeric.IntegerType"

@JavaClasses.java_import
def Named(self):
return "org.scijava.Named"
Expand Down
122 changes: 122 additions & 0 deletions src/imagej/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import numpy as np
import xarray as xr
from scyjava import _convert

import imagej.dims as dims


@xr.register_dataarray_accessor("img")
class ImgAccessor:
def __init__(self, xarr):
self._data = xarr

@property
def is_rgb(self):
"""
Returns True or False if the xarray.DataArray is an RGB image.

:return: Boolean
"""
ch_labels = ["c", "ch", "Channel"]
# check if array is signed
if self._data.min() < 0:
return False
# check if array is integer dtype
if not np.issubdtype(self._data.data.dtype, np.integer):
return False
# check bitsperpixel
if self._data.dtype.itemsize * 8 != 8:
return False
# check if "channel" present
if not any(dim in self._data.dims for dim in ch_labels):
return False
# check channel length = 3 exactly
for dim in self._data.dims:
if dim in ch_labels:
loc = self._data.dims.index(dim)
if self._data.shape[loc] != 3:
return False

return True


@xr.register_dataarray_accessor("metadata")
class MetadataAccessor:
def __init__(self, xarr):
self._data = xarr
self._update()

@property
def axes(self):
"""
Returns a tuple of the ImageJ axes.

:return: A Python tuple of the ImageJ axes.
"""
return (
tuple(self._data.attrs["imagej"].get("scifio.metadata.image").get("axes"))
if "scifio.metadata.image" in self._data.attrs["imagej"]
else None
)

def set(self, metadata: dict):
"""
Set the metadata of the parent xarray.DataArray.

:param metadata: A Python dict representing the image metadata.
"""
self._data.attrs["imagej"] = metadata

def get(self):
"""
Get the metadata dict of the the parent xarray.DataArray.

:return: A Python dict representing the image metadata.
"""
return self._data.attrs["imagej"]

def tree(self):
"""
Print a tree of the metadata of the parent xarray.DataArray.
"""
self._print_dict_tree(self._data.attrs["imagej"])

def _print_dict_tree(self, dictionary, indent="", prefix=""):
for idx, (key, value) in enumerate(dictionary.items()):
if idx == len(dictionary) - 1:
connector = "└──"
else:
connector = "├──"
print(indent + connector + prefix + " " + str(key))
if isinstance(value, (dict, _convert.JavaMap)):
if idx == len(dictionary) - 1:
self._print_dict_tree(value, indent + " ", prefix="── ")
else:
self._print_dict_tree(value, indent + "│ ", prefix="── ")

def _update(self):
if self._data.attrs.get("imagej"):
# update axes
axes = [None] * len(self._data.dims)
for i in range(len(self.axes)):
ax_label = dims._convert_dim(self.axes[i].type().getLabel(), "python")
if ax_label in self._data.dims:
axes[self._data.dims.index(ax_label)] = self.axes[i]
self._data.attrs["imagej"].get("scifio.metadata.image", {})["axes"] = axes

# update axis lengths
old_ax_len_metadata = (
self._data.attrs["imagej"]
.get("scifio.metadata.image", {})
.get("axisLengths", {})
)
new_ax_len_metadata = {}
for i in range(len(self.axes)):
ax_type = self.axes[i].type()
if ax_type in old_ax_len_metadata.keys():
ax_label = dims._convert_dim(ax_type.getLabel(), "python")
curr_ax_len = self._data.shape[self._data.dims.index(ax_label)]
new_ax_len_metadata[ax_type] = curr_ax_len
self._data.attrs["imagej"].get("scifio.metadata.image", {})[
"axisLengths"
] = new_ax_len_metadata
17 changes: 12 additions & 5 deletions src/imagej/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from jpype import JByte, JException, JFloat, JLong, JObject, JShort
from labeling import Labeling

import imagej.array # noqa:F401
import imagej.dims as dims
import imagej.images as images
from imagej._java import jc
Expand Down Expand Up @@ -166,7 +167,10 @@ def xarray_to_dataset(ij: "jc.ImageJ", xarr) -> "jc.Dataset":
axes = dims._assign_axes(xarr)
dataset.setAxes(axes)
dataset.setName(xarr.name)
_assign_dataset_metadata(dataset, xarr.attrs)
if hasattr(xarr, "metadata"):
_assign_dataset_metadata(dataset, xarr.metadata.get())
else:
_assign_dataset_metadata(dataset, xarr.attrs["imagej"])

return dataset

Expand Down Expand Up @@ -230,15 +234,18 @@ def java_to_xarray(ij: "jc.ImageJ", jobj) -> xr.DataArray:
assert hasattr(permuted_rai, "dim_axes")
xr_axes = list(permuted_rai.dim_axes)
xr_dims = list(permuted_rai.dims)
xr_attrs = sj.to_python(permuted_rai.getProperties())
xr_attrs = {sj.to_python(k): sj.to_python(v) for k, v in xr_attrs.items()}
# reverse axes and dims to match narr
xr_axes.reverse()
xr_dims.reverse()
xr_dims = dims._convert_dims(xr_dims, direction="python")
xr_coords = dims._get_axes_coords(xr_axes, xr_dims, narr.shape)
name = jobj.getName() if isinstance(jobj, jc.Named) else None
return xr.DataArray(narr, dims=xr_dims, coords=xr_coords, attrs=xr_attrs, name=name)
xr_attrs = {"imagej": {}}
xarr = xr.DataArray(narr, dims=xr_dims, coords=xr_coords, name=name, attrs=xr_attrs)
# use the MetadataAccessor to add metadata to the xarray
xarr.metadata.set(dict(sj.to_python(permuted_rai.getProperties())))
xarr.metadata._update()
return xarr


def supports_java_to_ndarray(ij: "jc.ImageJ", obj) -> bool:
Expand Down Expand Up @@ -509,7 +516,7 @@ def metadata_wrapper_to_dict(ij: "jc.ImageJ", metadata_wrapper: "jc.MetadataWrap
####################


def _assign_dataset_metadata(dataset: "jc.Dataset", attrs):
def _assign_dataset_metadata(dataset: "jc.Dataset", attrs: dict):
"""
:param dataset: ImageJ2 Dataset
:param attrs: Dictionary containing metadata
Expand Down
86 changes: 39 additions & 47 deletions src/imagej/dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
Utility functions for querying and manipulating dimensional axis metadata.
"""
import logging
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
import scyjava as sj
import xarray as xr
from jpype import JException, JObject
from jpype import JObject

from imagej._java import jc
from imagej.images import is_arraylike as _is_arraylike
Expand Down Expand Up @@ -177,49 +177,40 @@ def prioritize_rai_axes_order(
return permute_order


def _assign_axes(xarr: xr.DataArray):
def _assign_axes(
xarr: xr.DataArray,
) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]:
"""
Obtain xarray axes names, origin, and scale and convert into ImageJ Axis;
currently supports EnumeratedAxis
:param xarr: xarray that holds the units
:return: A list of ImageJ Axis with the specified origin and scale
Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both
DefaultLinearAxis and the newer EnumeratedAxis.
:param xarr: xarray that holds the data.
:return: A list of ImageJ Axis with the specified origin and scale.
"""
Double = sj.jimport("java.lang.Double")

axes = [""] * len(xarr.dims)

# try to get EnumeratedAxis, if not then default to LinearAxis in the loop
try:
EnumeratedAxis = _get_enumerated_axis()
except (JException, TypeError):
EnumeratedAxis = None

for dim in xarr.dims:
axis_str = _convert_dim(dim, direction="java")
axes = [""] * xarr.ndim
for i in range(xarr.ndim):
dim = xarr.dims[i]
axis_str = _convert_dim(dim, "java")
ax_type = jc.Axes.get(axis_str)
ax_num = _get_axis_num(xarr, dim)
scale = _get_scale(xarr.coords[dim])
coords_arr = xarr.coords[dim].to_numpy()

if scale is None:
# check if coords/scale is numeric
if _is_numeric_scale(coords_arr):
doub_coords = [jc.Double(np.double(x)) for x in xarr.coords[dim]]
else:
_logger.warning(
f"The {ax_type.label} axis is non-numeric and is translated "
"to a linear index."
)
doub_coords = [
Double(np.double(x)) for x in np.arange(len(xarr.coords[dim]))
jc.Double(np.double(x)) for x in np.arrange(len(xarr.coords[dim]))
]
else:
doub_coords = [Double(np.double(x)) for x in xarr.coords[dim]]

# EnumeratedAxis is a new axis made for xarray, so is only present in
# ImageJ versions that are released later than March 2020.
# This actually returns a LinearAxis if using an earlier version.
if EnumeratedAxis is not None:
java_axis = EnumeratedAxis(ax_type, sj.to_java(doub_coords))
# use the xarr metadata if available to assign axes
if hasattr(xarr, "metadata") and xarr.metadata.axes:
axes[ax_num] = xarr.metadata.axes[i]
else:
java_axis = _get_linear_axis(ax_type, sj.to_java(doub_coords))

axes[ax_num] = java_axis
axes[ax_num] = _get_fallback_linear_axis(ax_type, doub_coords)

return axes

Expand Down Expand Up @@ -295,27 +286,28 @@ def _get_scale(axis):
return None


def _get_enumerated_axis():
"""Get EnumeratedAxis.

EnumeratedAxis is only in releases later than March 2020. If using
an older version of ImageJ without EnumeratedAxis, use
_get_linear_axis() instead.
def _is_numeric_scale(coords_array: np.ndarray) -> bool:
"""
return sj.jimport("net.imagej.axis.EnumeratedAxis")
Checks if the coordinates array of the given axis is numeric.

:param coords_array: A 1D NumPy array.
:return: bool
"""
return np.issubdtype(coords_array.dtype, np.number)

def _get_linear_axis(axis_type: "jc.AxisType", values):
"""Get linear axis.

This is used if no EnumeratedAxis is found. If EnumeratedAxis
is available, use _get_enumerated_axis() instead.
def _get_fallback_linear_axis(axis_type: "jc.AxisType", values):
"""
Get a DefaultLinearAxis manually when all other axes
resources are unavailable.
"""
DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis")
origin = values[0]
scale = values[1] - values[0]
axis = DefaultLinearAxis(axis_type, scale, origin)
return axis
# calculate the slope using the values/coord array
if len(values) <= 1:
scale = 1
else:
scale = values[1] - values[0]
return jc.DefaultLinearAxis(axis_type, scale, origin)


def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus":
Expand Down
Loading