Skip to content

Commit

Permalink
perf: speed up initial import
Browse files Browse the repository at this point in the history
Accelerates initial import time by refactoring heavy imports to reside inside functions.

Importing the plotting module no longer automatically imports the colormap packages `cmocean`, `cmasher`, and `colorcet`. The user must add manual import statements.
  • Loading branch information
kmnhan committed Nov 19, 2024
1 parent 3593d41 commit d7f3b3c
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 31 deletions.
17 changes: 14 additions & 3 deletions src/erlab/accessors/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import copy
import itertools
from collections.abc import Collection, Hashable, Iterable, Mapping, Sequence
from typing import Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast

import joblib
import lmfit
import numpy as np
import tqdm.auto
import xarray as xr
Expand All @@ -25,6 +23,10 @@
from erlab.utils.misc import emit_user_level_warning
from erlab.utils.parallel import joblib_progress

if TYPE_CHECKING:
# Avoid importing until runtime for initial import performance
import lmfit


def _nested_dict_vals(d):
for v in d.values():
Expand Down Expand Up @@ -55,6 +57,8 @@ def _concat_along_keys(d: dict[str, xr.DataArray], dim_name: str) -> xr.DataArra
def _parse_params(
d: dict[str, Any] | lmfit.Parameters, dask: bool
) -> xr.DataArray | _ParametersWrapper:
import lmfit

if isinstance(d, lmfit.Parameters):
# Input to apply_ufunc cannot be a Mapping, so wrap in a class
return _ParametersWrapper(d)
Expand All @@ -70,6 +74,8 @@ def _parse_params(


def _parse_multiple_params(d: dict[str, Any], as_str: bool) -> xr.DataArray:
import lmfit

for k in d:
if isinstance(d[k], int | float | complex | xr.DataArray):
d[k] = {"value": d[k]}
Expand Down Expand Up @@ -252,6 +258,8 @@ def __call__(
scipy.optimize.curve_fit
"""
import lmfit

# Implementation analogous to xarray.Dataset.curve_fit

if params is None:
Expand Down Expand Up @@ -539,6 +547,9 @@ def _output_wrapper(name, da, out=None) -> dict:
}

if parallel:
# Avoid importing until runtime for initial import performance
import joblib

if is_dask:
emit_user_level_warning(
"The input Dataset is chunked. Parallel fitting will not offer any "
Expand Down
5 changes: 3 additions & 2 deletions src/erlab/accessors/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from collections.abc import Hashable, Mapping
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

Expand Down Expand Up @@ -54,14 +53,16 @@ def __call__(self, *args, **kwargs):
Keyword arguments to be passed to the plotting function.
"""
import matplotlib.pyplot

from erlab.plotting.erplot import fancy_labels, plot_array

if len(self._obj.dims) == 2:
return plot_array(self._obj, *args, **kwargs)

ax = kwargs.pop("ax", None)
if ax is None:
ax = plt.gca()
ax = matplotlib.pyplot.gca()
kwargs["ax"] = ax

out = self._obj.plot(*args, **kwargs)
Expand Down
16 changes: 8 additions & 8 deletions src/erlab/io/exampledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import erlab.analysis.image
import erlab.analysis.kspace
from erlab.constants import kb_eV
import erlab.constants


def _func(kvec, a):
Expand Down Expand Up @@ -161,8 +161,8 @@ def generate_data_angles(
Erange: tuple[float, float] = (-0.45, 0.12),
hv: float = 50.0,
configuration: (
erlab.analysis.kspace.AxesConfiguration | int
) = erlab.analysis.kspace.AxesConfiguration.Type1,
erlab.constants.AxesConfiguration | int
) = erlab.constants.AxesConfiguration.Type1,
temp: float = 20.0,
a: float = 6.97,
t: float = 0.43,
Expand Down Expand Up @@ -238,8 +238,8 @@ def generate_data_angles(
alpha = np.linspace(-angrange, angrange, shape[0])
beta = np.linspace(-angrange, angrange, shape[1])

if not isinstance(configuration, erlab.analysis.kspace.AxesConfiguration):
configuration = erlab.analysis.kspace.AxesConfiguration(configuration)
if not isinstance(configuration, erlab.constants.AxesConfiguration):
configuration = erlab.constants.AxesConfiguration(configuration)

eV = np.linspace(*Erange, shape[2])

Expand Down Expand Up @@ -286,8 +286,8 @@ def generate_data_angles(

match configuration:
case (
erlab.analysis.kspace.AxesConfiguration.Type1DA
| erlab.analysis.kspace.AxesConfiguration.Type2DA
erlab.constants.AxesConfiguration.Type1DA
| erlab.constants.AxesConfiguration.Type2DA
):
out = out.assign_coords(chi=0.0)

Expand Down Expand Up @@ -369,7 +369,7 @@ def generate_gold_edge(
center = np.polynomial.polynomial.polyval(alpha, edge_coeffs)

data = (b - c + a * eV) / (
1 + np.exp((1.0 * eV - center) / max(1e-15, temp * kb_eV))
1 + np.exp((1.0 * eV - center) / max(1e-15, temp * erlab.constants.kb_eV))
) + c

background = np.polynomial.polynomial.polyval(alpha, background_coeffs).clip(min=0)
Expand Down
9 changes: 0 additions & 9 deletions src/erlab/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""

import importlib
import io
import os
import pkgutil
Expand All @@ -30,14 +29,6 @@
import matplotlib.style
import numpy as np

# Import colormaps if available
if importlib.util.find_spec("cmasher"):
importlib.import_module("cmasher")
if importlib.util.find_spec("cmocean"):
importlib.import_module("cmocean")
if importlib.util.find_spec("colorcet"):
importlib.import_module("colorcet")


def load_igor_ct(
file: str | os.PathLike | io.BytesIO, name: str, register_reversed: bool = True
Expand Down
11 changes: 7 additions & 4 deletions src/erlab/plotting/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from typing import TYPE_CHECKING, Any, Literal, cast

import matplotlib
import matplotlib.backends.backend_pdf
import matplotlib.backends.backend_svg
import matplotlib.figure
import matplotlib.font_manager
import matplotlib.mathtext
Expand Down Expand Up @@ -286,9 +284,14 @@ def copy_mathtext(
fig.text(0, depth / height, s, fontproperties=fontproperties)

if svg:
matplotlib.backends.backend_svg.FigureCanvasSVG(fig)
from matplotlib.backends.backend_svg import FigureCanvasSVG

FigureCanvasSVG(fig)

else:
matplotlib.backends.backend_pdf.FigureCanvasPdf(fig)
from matplotlib.backends.backend_pdf import FigureCanvasPdf

Check warning on line 292 in src/erlab/plotting/annotations.py

View check run for this annotation

Codecov / codecov/patch

src/erlab/plotting/annotations.py#L292

Added line #L292 was not covered by tests

FigureCanvasPdf(fig)

Check warning on line 294 in src/erlab/plotting/annotations.py

View check run for this annotation

Codecov / codecov/patch

src/erlab/plotting/annotations.py#L294

Added line #L294 was not covered by tests

for k, v in mathtext_rc.items():
if k in ["bf", "cal", "it", "rm", "sf", "tt"] and isinstance(
Expand Down
3 changes: 2 additions & 1 deletion src/erlab/plotting/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import scipy.spatial
from matplotlib.patches import RegularPolygon

from erlab.plotting.colors import axes_textcolor
Expand Down Expand Up @@ -70,6 +69,8 @@ def get_bz_edge(
# Get index of origin
zero_ind = np.where((points == 0).all(axis=1))[0][0]

import scipy.spatial

vor = scipy.spatial.Voronoi(points)

lines = []
Expand Down
3 changes: 1 addition & 2 deletions src/erlab/plotting/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
<https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_ colormaps, `cmasher
<https://cmasher.readthedocs.io>`_, `cmocean <https://matplotlib.org/cmocean/>`_, and
`colorcet <https://colorcet.holoviz.org>`_ packages can be installed to extend the
available colormaps. If these packages are installed, they will be automatically
imported upon importing `erlab.plotting`.
available colormaps.
Colormap Normalization
----------------------
Expand Down
5 changes: 3 additions & 2 deletions src/erlab/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import contextlib
import sys

import joblib
import joblib._parallel_backends
import tqdm.auto


@contextlib.contextmanager
def joblib_progress(file=None, **kwargs):
"""Patches joblib to report into a tqdm progress bar."""
import joblib

if file is None:
file = sys.stdout

Expand All @@ -39,6 +39,7 @@ def joblib_progress_qt(signal):
The number of completed tasks are emitted by the given signal.
"""
import joblib

def qt_print_progress(self) -> None:
signal.emit(self.n_completed_tasks)
Expand Down

0 comments on commit d7f3b3c

Please sign in to comment.