Skip to content

Commit

Permalink
perf(io): implement lazy loading for h5netcdf and nexusformat imports
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed Nov 19, 2024
1 parent aa6f5d2 commit 3f219ae
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 47 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def setup(app):
"NXgroup": "`NXgroup <nexusformat.nexus.tree.NXgroup>`",
"NXlink": "`NXlink <nexusformat.nexus.tree.NXlink>`",
"NXdata": "`NXdata <nexusformat.nexus.tree.NXdata>`",
"NXentry": "`NXentry <nexusformat.nexus.tree.NXentry>`",
"lmfit.Parameters": "`lmfit.Parameters <lmfit.parameter.Parameters>`",
"lmfit.Model": "`lmfit.Model <lmfit.model.Model>`",
}
Expand Down
88 changes: 51 additions & 37 deletions src/erlab/io/nexusutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,27 @@

import os
from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any

import h5py
import numpy as np
import xarray as xr
from nexusformat.nexus import nxload
from nexusformat.nexus.tree import NXattr, NXentry, NXfield, NXgroup, NXlink

if TYPE_CHECKING:
import h5py
from nexusformat import nexus

else:
from erlab.utils.misc import LazyImport

nexus = LazyImport("nexusformat.nexus")
h5py = LazyImport("h5py")


def _parse_value(value):
"""Convert numpy scalars or bytes to native Python types and regular strings."""
if isinstance(value, bytes | bytearray):
return value.decode("utf-8")
if isinstance(value, NXattr):
if isinstance(value, nexus.NXattr):
return _parse_value(value.nxdata)
if isinstance(value, np.ndarray) and value.size == 1:
return _parse_value(np.atleast_1d(value)[0])
Expand All @@ -32,7 +39,7 @@ def _parse_value(value):

def _parse_field(field):
"""Convert a NeXus field to a native Python type."""
if not isinstance(field, NXfield):
if not isinstance(field, nexus.NXfield):
return field

if field.size == 1:
Expand All @@ -51,7 +58,7 @@ def _remove_axis_attrs(attrs: Mapping[str, Any]) -> dict[str, Any]:
return out


def _parse_h5py(obj: h5py.Group | h5py.Dataset, out: dict) -> None:
def _parse_h5py(obj: "h5py.Group | h5py.Dataset", out: dict) -> None:
if isinstance(obj, h5py.Group):
for v in obj.values():
_parse_h5py(v, out)
Expand All @@ -60,7 +67,10 @@ def _parse_h5py(obj: h5py.Group | h5py.Dataset, out: dict) -> None:


def _parse_group(
group: NXgroup, out: dict[str, Any], exclude: Sequence[str], parse: bool
group: "nexus.NXgroup",
out: dict[str, Any],
exclude: Sequence[str],
parse: bool,
) -> None:
"""Recursively parse a NeXus group and its items into a nested dictionary.
Expand Down Expand Up @@ -89,14 +99,14 @@ def _parse_group(
if item.nxpath in exclude:
continue

if isinstance(item, NXgroup):
if isinstance(item, nexus.NXgroup):
_parse_group(item, out, exclude, parse)

elif isinstance(item, NXlink):
elif isinstance(item, nexus.NXlink):
# Skip links
continue

elif isinstance(item, NXfield):
elif isinstance(item, nexus.NXfield):
if parse:
out[item.nxpath] = _parse_field(item)
else:
Expand All @@ -108,37 +118,39 @@ def _parse_group(
_parse_h5py(item.nxfile.get(item.nxpath), out)


def _get_primary_coords(group: NXgroup, out: list[NXfield]):
def _get_primary_coords(group: "nexus.NXgroup", out: list["nexus.NXfield"]):
for item in group.values():
if isinstance(item, NXgroup):
if isinstance(item, nexus.NXgroup):
_get_primary_coords(item, out)
elif isinstance(item, NXlink):
elif isinstance(item, nexus.NXlink):
# Skip links
continue
elif (
isinstance(item, NXfield)
isinstance(item, nexus.NXfield)
and "primary" in item.attrs
and int(item.primary) == 1
):
out.append(item)


def _get_non_primary_coords(group: NXgroup, out: list[NXfield]):
def _get_non_primary_coords(group: "nexus.NXgroup", out: list["nexus.NXfield"]):
for item in group.values():
if isinstance(item, NXgroup):
if isinstance(item, nexus.NXgroup):
_get_non_primary_coords(item, out)
elif isinstance(item, NXlink):
elif isinstance(item, nexus.NXlink):
# Skip links
continue
elif (
isinstance(item, NXfield)
isinstance(item, nexus.NXfield)
and "axis" in item.attrs
and ("primary" not in item.attrs or int(item.primary) != 1)
):
out.append(item)


def get_primary_coords(group: NXgroup) -> list[NXfield]:
def get_primary_coords(
group: "nexus.NXgroup",
) -> list["nexus.NXfield"]:
"""Get all primary coordinates in a group.
Retrieves all fields with the attribute `primary=1` in the group and its subgroups
Expand All @@ -154,12 +166,14 @@ def get_primary_coords(group: NXgroup) -> list[NXfield]:
fields_primary : list of NXfield
"""
fields_primary: list[NXfield] = []
fields_primary: list[nexus.NXfield] = []
_get_primary_coords(group, fields_primary)
return sorted(fields_primary, key=lambda field: int(field.axis))


def get_non_primary_coords(group: NXgroup) -> list[NXfield]:
def get_non_primary_coords(
group: "nexus.NXgroup",
) -> list["nexus.NXfield"]:
"""Get all non-primary coordinates in a group.
Retrieves all fields with the attribute `axis` in the group and its subgroups
Expand All @@ -175,13 +189,13 @@ def get_non_primary_coords(group: NXgroup) -> list[NXfield]:
fields_non_primary : list of NXfield
"""
fields_non_primary: list[NXfield] = []
fields_non_primary: list[nexus.NXfield] = []
_get_non_primary_coords(group, fields_non_primary)
return fields_non_primary


def get_primary_coord_dict(
fields: list[NXfield],
fields: list["nexus.NXfield"],
) -> tuple[tuple[str, ...], dict[str, xr.DataArray | tuple]]:
"""Generate a dictionary of primary coordinates from a list of NXfields.
Expand Down Expand Up @@ -212,7 +226,7 @@ def get_primary_coord_dict(

# Dict to store processed axes
# Ensure that if there are multiple fields with the same axis, only one is used
processed_axes: dict[int, NXfield] = {}
processed_axes: dict[int, nexus.NXfield] = {}

for field in fields:
if field.ndim == 1:
Expand Down Expand Up @@ -253,7 +267,7 @@ def get_primary_coord_dict(


def get_coord_dict(
group: NXgroup,
group: "nexus.NXgroup",
) -> tuple[tuple[str, ...], dict[str, xr.DataArray | tuple]]:
"""Generate a dictionary of coordinates from a NeXus group.
Expand All @@ -269,8 +283,8 @@ def get_coord_dict(
coords : dict of str to DataArray or tuple
The dictionary of all coordinates in the group.
"""
fields_primary: list[NXfield] = get_primary_coords(group)
fields_non_primary: list[NXfield] = get_non_primary_coords(group)
fields_primary: list[nexus.NXfield] = get_primary_coords(group)
fields_non_primary: list[nexus.NXfield] = get_non_primary_coords(group)

dims, coords = get_primary_coord_dict(fields_primary)

Expand All @@ -281,7 +295,7 @@ def get_coord_dict(

else:
# Coord depends on some other primary coordinate
associated_primary: NXfield = group.nxroot[dims[int(field.axis) - 1]]
associated_primary: nexus.NXfield = group.nxroot[dims[int(field.axis) - 1]]
coords[field.nxpath] = (
associated_primary.nxpath,
field.nxdata,
Expand All @@ -292,7 +306,7 @@ def get_coord_dict(


def nexus_group_to_dict(
group: NXgroup,
group: "nexus.NXgroup",
exclude: Sequence[str] | None,
relative: bool = True,
replace_slash: bool = True,
Expand Down Expand Up @@ -336,7 +350,7 @@ def nexus_group_to_dict(
return out


def nxfield_to_xarray(field: NXfield, no_dims: bool = False) -> xr.DataArray:
def nxfield_to_xarray(field: "nexus.NXfield", no_dims: bool = False) -> xr.DataArray:
"""Convert a coord-like 1D NeXus field to a single `xarray.DataArray`.
Parameters
Expand Down Expand Up @@ -366,8 +380,8 @@ def nxfield_to_xarray(field: NXfield, no_dims: bool = False) -> xr.DataArray:


def nxgroup_to_xarray(
group: NXgroup,
data: str | Callable[[NXgroup], NXfield],
group: "nexus.NXgroup",
data: str | Callable[["nexus.NXgroup"], "nexus.NXfield"],
without_values: bool = False,
) -> xr.DataArray:
"""Convert a NeXus group to an xarray DataArray.
Expand Down Expand Up @@ -397,11 +411,11 @@ def nxgroup_to_xarray(
"""
if callable(data):
values: NXfield = data(group)
values: nexus.NXfield = data(group)
else:
values = group[data]

if isinstance(values, NXlink):
if isinstance(values, nexus.NXlink):
values = values.nxlink

dims, coords = get_coord_dict(group)
Expand Down Expand Up @@ -437,7 +451,7 @@ def _make_coord_relative(t: xr.DataArray | tuple) -> xr.DataArray | tuple:
return xr.DataArray(values, dims=dims, coords=coords, attrs=attrs)


def get_entry(filename: str | os.PathLike, entry: str | None = None) -> NXentry:
def get_entry(filename: str | os.PathLike, entry: str | None = None) -> "nexus.NXentry":
"""Get an NXentry object from a NeXus file.
Parameters
Expand All @@ -454,7 +468,7 @@ def get_entry(filename: str | os.PathLike, entry: str | None = None) -> NXentry:
The NXentry object obtained from the file.
"""
root = nxload(filename)
root = nexus.nxload(filename)
if entry is None:
return next(iter(root.entries.values()))

Expand Down
6 changes: 3 additions & 3 deletions src/erlab/io/plugins/i05.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import xarray as xr

import erlab.io.nexusutils
from erlab.io.dataloader import LoaderBase
from erlab.io.nexusutils import get_entry, nxgroup_to_xarray


class I05Loader(LoaderBase):
Expand Down Expand Up @@ -41,8 +41,8 @@ def file_dialog_methods(self):
return {"Diamond I05 Raw Data (*.nxs)": (self.load, {})}

def load_single(self, file_path, without_values=False) -> xr.DataArray:
out = nxgroup_to_xarray(
get_entry(file_path), "analyser/data", without_values
out = erlab.io.nexusutils.nxgroup_to_xarray(
erlab.io.nexusutils.get_entry(file_path), "analyser/data", without_values
).squeeze()

if (
Expand Down
6 changes: 4 additions & 2 deletions src/erlab/io/plugins/lorea.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import xarray as xr

import erlab.io
import erlab.io.nexusutils
from erlab.io.dataloader import LoaderBase
from erlab.io.nexusutils import get_entry, nxgroup_to_xarray


def _get_data(group):
Expand Down Expand Up @@ -50,7 +50,9 @@ def load_single(self, file_path, without_values: bool = False) -> xr.DataArray:
if pathlib.Path(file_path).suffix == ".krx":
return self._load_krx(file_path)

return nxgroup_to_xarray(get_entry(file_path), _get_data, without_values)
return erlab.io.nexusutils.nxgroup_to_xarray(
erlab.io.nexusutils.get_entry(file_path), _get_data, without_values
)

def identify(self, num, data_dir, krax=False):
if krax:
Expand Down
10 changes: 7 additions & 3 deletions src/erlab/io/plugins/ssrl52.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import os
import re
from collections.abc import Callable
from typing import Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar

import h5netcdf
import numpy as np
import xarray as xr

import erlab.io.utils
from erlab.io.dataloader import LoaderBase
from erlab.utils.misc import emit_user_level_warning
from erlab.utils.misc import LazyImport, emit_user_level_warning

if TYPE_CHECKING:
import h5netcdf
else:
h5netcdf = LazyImport("h5netcdf")


def _format_polarization(val) -> str:
Expand Down
37 changes: 35 additions & 2 deletions src/erlab/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
import importlib
import inspect
import pathlib
import sys
import warnings

import xarray
from types import ModuleType
from typing import Any


def _find_stack_level() -> int:
Expand All @@ -19,6 +21,8 @@ def _find_stack_level() -> int:
stacklevel : int
First level in the stack that is not part of erlab or stdlib.
"""
import xarray

import erlab

xarray_dir = pathlib.Path(xarray.__file__).parent
Expand Down Expand Up @@ -55,3 +59,32 @@ def emit_user_level_warning(message, category=None) -> None:
"""Emit a warning at the user level by inspecting the stack trace."""
stacklevel = _find_stack_level()
return warnings.warn(message, category=category, stacklevel=stacklevel)


class LazyImport:
"""Lazily import a module when an attribute is accessed.
Used to delay the import of a module until it is actually needed.
Parameters
----------
module_name : str
The name of the module to be imported lazily.
Examples
--------
>>> np = LazyImport("numpy")
>>> np.array([1, 2, 3])
array([1, 2, 3])
"""

def __init__(self, module_name: str) -> None:
self._module_name = module_name

def __getattr__(self, item: str) -> Any:
return getattr(self._module, item)

@functools.cached_property
def _module(self) -> ModuleType:
return importlib.import_module(self._module_name)

0 comments on commit 3f219ae

Please sign in to comment.