From 3f219aef8764859f1eaccb2e354c7780f41133bb Mon Sep 17 00:00:00 2001 From: Kimoon Han Date: Tue, 19 Nov 2024 20:46:19 +0900 Subject: [PATCH] perf(io): implement lazy loading for h5netcdf and nexusformat imports --- docs/source/conf.py | 1 + src/erlab/io/nexusutils.py | 88 ++++++++++++++++++++-------------- src/erlab/io/plugins/i05.py | 6 +-- src/erlab/io/plugins/lorea.py | 6 ++- src/erlab/io/plugins/ssrl52.py | 10 ++-- src/erlab/utils/misc.py | 37 +++++++++++++- 6 files changed, 101 insertions(+), 47 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 72f7e156..35cf1e8a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -200,6 +200,7 @@ def setup(app): "NXgroup": "`NXgroup `", "NXlink": "`NXlink `", "NXdata": "`NXdata `", + "NXentry": "`NXentry `", "lmfit.Parameters": "`lmfit.Parameters `", "lmfit.Model": "`lmfit.Model `", } diff --git a/src/erlab/io/nexusutils.py b/src/erlab/io/nexusutils.py index a519b3eb..033088f0 100644 --- a/src/erlab/io/nexusutils.py +++ b/src/erlab/io/nexusutils.py @@ -7,20 +7,27 @@ import os from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any -import h5py import numpy as np import xarray as xr -from nexusformat.nexus import nxload -from nexusformat.nexus.tree import NXattr, NXentry, NXfield, NXgroup, NXlink + +if TYPE_CHECKING: + import h5py + from nexusformat import nexus + +else: + from erlab.utils.misc import LazyImport + + nexus = LazyImport("nexusformat.nexus") + h5py = LazyImport("h5py") def _parse_value(value): """Convert numpy scalars or bytes to native Python types and regular strings.""" if isinstance(value, bytes | bytearray): return value.decode("utf-8") - if isinstance(value, NXattr): + if isinstance(value, nexus.NXattr): return _parse_value(value.nxdata) if isinstance(value, np.ndarray) and value.size == 1: return _parse_value(np.atleast_1d(value)[0]) @@ -32,7 +39,7 @@ def _parse_value(value): def _parse_field(field): """Convert a NeXus field to a native Python type.""" - if not isinstance(field, NXfield): + if not isinstance(field, nexus.NXfield): return field if field.size == 1: @@ -51,7 +58,7 @@ def _remove_axis_attrs(attrs: Mapping[str, Any]) -> dict[str, Any]: return out -def _parse_h5py(obj: h5py.Group | h5py.Dataset, out: dict) -> None: +def _parse_h5py(obj: "h5py.Group | h5py.Dataset", out: dict) -> None: if isinstance(obj, h5py.Group): for v in obj.values(): _parse_h5py(v, out) @@ -60,7 +67,10 @@ def _parse_h5py(obj: h5py.Group | h5py.Dataset, out: dict) -> None: def _parse_group( - group: NXgroup, out: dict[str, Any], exclude: Sequence[str], parse: bool + group: "nexus.NXgroup", + out: dict[str, Any], + exclude: Sequence[str], + parse: bool, ) -> None: """Recursively parse a NeXus group and its items into a nested dictionary. @@ -89,14 +99,14 @@ def _parse_group( if item.nxpath in exclude: continue - if isinstance(item, NXgroup): + if isinstance(item, nexus.NXgroup): _parse_group(item, out, exclude, parse) - elif isinstance(item, NXlink): + elif isinstance(item, nexus.NXlink): # Skip links continue - elif isinstance(item, NXfield): + elif isinstance(item, nexus.NXfield): if parse: out[item.nxpath] = _parse_field(item) else: @@ -108,37 +118,39 @@ def _parse_group( _parse_h5py(item.nxfile.get(item.nxpath), out) -def _get_primary_coords(group: NXgroup, out: list[NXfield]): +def _get_primary_coords(group: "nexus.NXgroup", out: list["nexus.NXfield"]): for item in group.values(): - if isinstance(item, NXgroup): + if isinstance(item, nexus.NXgroup): _get_primary_coords(item, out) - elif isinstance(item, NXlink): + elif isinstance(item, nexus.NXlink): # Skip links continue elif ( - isinstance(item, NXfield) + isinstance(item, nexus.NXfield) and "primary" in item.attrs and int(item.primary) == 1 ): out.append(item) -def _get_non_primary_coords(group: NXgroup, out: list[NXfield]): +def _get_non_primary_coords(group: "nexus.NXgroup", out: list["nexus.NXfield"]): for item in group.values(): - if isinstance(item, NXgroup): + if isinstance(item, nexus.NXgroup): _get_non_primary_coords(item, out) - elif isinstance(item, NXlink): + elif isinstance(item, nexus.NXlink): # Skip links continue elif ( - isinstance(item, NXfield) + isinstance(item, nexus.NXfield) and "axis" in item.attrs and ("primary" not in item.attrs or int(item.primary) != 1) ): out.append(item) -def get_primary_coords(group: NXgroup) -> list[NXfield]: +def get_primary_coords( + group: "nexus.NXgroup", +) -> list["nexus.NXfield"]: """Get all primary coordinates in a group. Retrieves all fields with the attribute `primary=1` in the group and its subgroups @@ -154,12 +166,14 @@ def get_primary_coords(group: NXgroup) -> list[NXfield]: fields_primary : list of NXfield """ - fields_primary: list[NXfield] = [] + fields_primary: list[nexus.NXfield] = [] _get_primary_coords(group, fields_primary) return sorted(fields_primary, key=lambda field: int(field.axis)) -def get_non_primary_coords(group: NXgroup) -> list[NXfield]: +def get_non_primary_coords( + group: "nexus.NXgroup", +) -> list["nexus.NXfield"]: """Get all non-primary coordinates in a group. Retrieves all fields with the attribute `axis` in the group and its subgroups @@ -175,13 +189,13 @@ def get_non_primary_coords(group: NXgroup) -> list[NXfield]: fields_non_primary : list of NXfield """ - fields_non_primary: list[NXfield] = [] + fields_non_primary: list[nexus.NXfield] = [] _get_non_primary_coords(group, fields_non_primary) return fields_non_primary def get_primary_coord_dict( - fields: list[NXfield], + fields: list["nexus.NXfield"], ) -> tuple[tuple[str, ...], dict[str, xr.DataArray | tuple]]: """Generate a dictionary of primary coordinates from a list of NXfields. @@ -212,7 +226,7 @@ def get_primary_coord_dict( # Dict to store processed axes # Ensure that if there are multiple fields with the same axis, only one is used - processed_axes: dict[int, NXfield] = {} + processed_axes: dict[int, nexus.NXfield] = {} for field in fields: if field.ndim == 1: @@ -253,7 +267,7 @@ def get_primary_coord_dict( def get_coord_dict( - group: NXgroup, + group: "nexus.NXgroup", ) -> tuple[tuple[str, ...], dict[str, xr.DataArray | tuple]]: """Generate a dictionary of coordinates from a NeXus group. @@ -269,8 +283,8 @@ def get_coord_dict( coords : dict of str to DataArray or tuple The dictionary of all coordinates in the group. """ - fields_primary: list[NXfield] = get_primary_coords(group) - fields_non_primary: list[NXfield] = get_non_primary_coords(group) + fields_primary: list[nexus.NXfield] = get_primary_coords(group) + fields_non_primary: list[nexus.NXfield] = get_non_primary_coords(group) dims, coords = get_primary_coord_dict(fields_primary) @@ -281,7 +295,7 @@ def get_coord_dict( else: # Coord depends on some other primary coordinate - associated_primary: NXfield = group.nxroot[dims[int(field.axis) - 1]] + associated_primary: nexus.NXfield = group.nxroot[dims[int(field.axis) - 1]] coords[field.nxpath] = ( associated_primary.nxpath, field.nxdata, @@ -292,7 +306,7 @@ def get_coord_dict( def nexus_group_to_dict( - group: NXgroup, + group: "nexus.NXgroup", exclude: Sequence[str] | None, relative: bool = True, replace_slash: bool = True, @@ -336,7 +350,7 @@ def nexus_group_to_dict( return out -def nxfield_to_xarray(field: NXfield, no_dims: bool = False) -> xr.DataArray: +def nxfield_to_xarray(field: "nexus.NXfield", no_dims: bool = False) -> xr.DataArray: """Convert a coord-like 1D NeXus field to a single `xarray.DataArray`. Parameters @@ -366,8 +380,8 @@ def nxfield_to_xarray(field: NXfield, no_dims: bool = False) -> xr.DataArray: def nxgroup_to_xarray( - group: NXgroup, - data: str | Callable[[NXgroup], NXfield], + group: "nexus.NXgroup", + data: str | Callable[["nexus.NXgroup"], "nexus.NXfield"], without_values: bool = False, ) -> xr.DataArray: """Convert a NeXus group to an xarray DataArray. @@ -397,11 +411,11 @@ def nxgroup_to_xarray( """ if callable(data): - values: NXfield = data(group) + values: nexus.NXfield = data(group) else: values = group[data] - if isinstance(values, NXlink): + if isinstance(values, nexus.NXlink): values = values.nxlink dims, coords = get_coord_dict(group) @@ -437,7 +451,7 @@ def _make_coord_relative(t: xr.DataArray | tuple) -> xr.DataArray | tuple: return xr.DataArray(values, dims=dims, coords=coords, attrs=attrs) -def get_entry(filename: str | os.PathLike, entry: str | None = None) -> NXentry: +def get_entry(filename: str | os.PathLike, entry: str | None = None) -> "nexus.NXentry": """Get an NXentry object from a NeXus file. Parameters @@ -454,7 +468,7 @@ def get_entry(filename: str | os.PathLike, entry: str | None = None) -> NXentry: The NXentry object obtained from the file. """ - root = nxload(filename) + root = nexus.nxload(filename) if entry is None: return next(iter(root.entries.values())) diff --git a/src/erlab/io/plugins/i05.py b/src/erlab/io/plugins/i05.py index fedaf9ea..58f806de 100644 --- a/src/erlab/io/plugins/i05.py +++ b/src/erlab/io/plugins/i05.py @@ -4,8 +4,8 @@ import xarray as xr +import erlab.io.nexusutils from erlab.io.dataloader import LoaderBase -from erlab.io.nexusutils import get_entry, nxgroup_to_xarray class I05Loader(LoaderBase): @@ -41,8 +41,8 @@ def file_dialog_methods(self): return {"Diamond I05 Raw Data (*.nxs)": (self.load, {})} def load_single(self, file_path, without_values=False) -> xr.DataArray: - out = nxgroup_to_xarray( - get_entry(file_path), "analyser/data", without_values + out = erlab.io.nexusutils.nxgroup_to_xarray( + erlab.io.nexusutils.get_entry(file_path), "analyser/data", without_values ).squeeze() if ( diff --git a/src/erlab/io/plugins/lorea.py b/src/erlab/io/plugins/lorea.py index 44182d7f..86b39cea 100644 --- a/src/erlab/io/plugins/lorea.py +++ b/src/erlab/io/plugins/lorea.py @@ -9,8 +9,8 @@ import xarray as xr import erlab.io +import erlab.io.nexusutils from erlab.io.dataloader import LoaderBase -from erlab.io.nexusutils import get_entry, nxgroup_to_xarray def _get_data(group): @@ -50,7 +50,9 @@ def load_single(self, file_path, without_values: bool = False) -> xr.DataArray: if pathlib.Path(file_path).suffix == ".krx": return self._load_krx(file_path) - return nxgroup_to_xarray(get_entry(file_path), _get_data, without_values) + return erlab.io.nexusutils.nxgroup_to_xarray( + erlab.io.nexusutils.get_entry(file_path), _get_data, without_values + ) def identify(self, num, data_dir, krax=False): if krax: diff --git a/src/erlab/io/plugins/ssrl52.py b/src/erlab/io/plugins/ssrl52.py index a74253de..4bc9c06c 100644 --- a/src/erlab/io/plugins/ssrl52.py +++ b/src/erlab/io/plugins/ssrl52.py @@ -4,15 +4,19 @@ import os import re from collections.abc import Callable -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar -import h5netcdf import numpy as np import xarray as xr import erlab.io.utils from erlab.io.dataloader import LoaderBase -from erlab.utils.misc import emit_user_level_warning +from erlab.utils.misc import LazyImport, emit_user_level_warning + +if TYPE_CHECKING: + import h5netcdf +else: + h5netcdf = LazyImport("h5netcdf") def _format_polarization(val) -> str: diff --git a/src/erlab/utils/misc.py b/src/erlab/utils/misc.py index 92074a84..e1366d65 100644 --- a/src/erlab/utils/misc.py +++ b/src/erlab/utils/misc.py @@ -1,9 +1,11 @@ +import functools +import importlib import inspect import pathlib import sys import warnings - -import xarray +from types import ModuleType +from typing import Any def _find_stack_level() -> int: @@ -19,6 +21,8 @@ def _find_stack_level() -> int: stacklevel : int First level in the stack that is not part of erlab or stdlib. """ + import xarray + import erlab xarray_dir = pathlib.Path(xarray.__file__).parent @@ -55,3 +59,32 @@ def emit_user_level_warning(message, category=None) -> None: """Emit a warning at the user level by inspecting the stack trace.""" stacklevel = _find_stack_level() return warnings.warn(message, category=category, stacklevel=stacklevel) + + +class LazyImport: + """Lazily import a module when an attribute is accessed. + + Used to delay the import of a module until it is actually needed. + + Parameters + ---------- + module_name : str + The name of the module to be imported lazily. + + Examples + -------- + >>> np = LazyImport("numpy") + >>> np.array([1, 2, 3]) + array([1, 2, 3]) + + """ + + def __init__(self, module_name: str) -> None: + self._module_name = module_name + + def __getattr__(self, item: str) -> Any: + return getattr(self._module, item) + + @functools.cached_property + def _module(self) -> ModuleType: + return importlib.import_module(self._module_name)