diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3f477a7..6182bec 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -8,11 +8,16 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest"] - python-version: ["3.7", "3.8"] + os: ["macos-latest", "ubuntu-latest", "windows-latest"] + python-version: ["3.7", "3.8", "3.9"] steps: - - uses: actions/checkout@v2 - - name: Cache conda + - name: Checkout source + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Cache Linux/macOS (x86) Conda environment + if: ${{ runner.os != "windows-latest" }} uses: actions/cache@v1 env: # Increase this value to reset cache if ci/environment.yml has not changed @@ -20,7 +25,20 @@ jobs: with: path: ~/conda_pkgs_dir key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/environment-py${{ matrix.python-version }}.yml') }} - - uses: conda-incubator/setup-miniconda@v2 + + - name: Cache Windows Conda environment + if: ${{ runner.os == "windows-latest" }} + uses: actions/cache@v1 + env: + # Increase this value to reset cache if ci/environment.yml has not changed + CACHE_NUMBER: 0 + with: + path: ~/conda_pkgs_dir + key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/environment-py${{ matrix.python-version }}-win.yml') }} + + - name: Build and activate Linux/macOS Conda environment + if: ${{ runnor.os != "windows-latest" }} + uses: conda-incubator/setup-miniconda@v2 with: mamba-version: "*" # activate this to build with mamba. channels: conda-forge, defaults # These need to be specified to use mamba @@ -29,14 +47,29 @@ jobs: activate-environment: test_env_extract_model use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! - - name: Set up conda environment + + - name: Build and activate Linux/macOS Conda environment + if: ${{ runnor.os == "windows-latest" }} + uses: conda-incubator/setup-miniconda@v2 + with: + mamba-version: "*" # activate this to build with mamba. + channels: conda-forge, defaults # These need to be specified to use mamba + channel-priority: true + environment-file: ci/environment-py${{ matrix.python-version }}-win.yml + + activate-environment: test_env_extract_model + use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! + + - name: Install package in environment shell: bash -l {0} run: | python -m pip install -e . --no-deps --force-reinstall - - name: Run Tests + + - name: Run tests shell: bash -l {0} run: | pytest --cov=./ --cov-report=xml + - name: Upload code coverage to Codecov uses: codecov/codecov-action@v1 with: diff --git a/ci/environment-py3.7-win.yml b/ci/environment-py3.7-win.yml new file mode 100644 index 0000000..b20b135 --- /dev/null +++ b/ci/environment-py3.7-win.yml @@ -0,0 +1,21 @@ +name: test-env +channels: + - conda-forge +dependencies: + - python=3.7 + - cf_xarray>=0.6 + - cmocean + - dask + - matplotlib + - netcdf4 + - numpy + - pip + - pyinterp + - requests + - xarray + - xcmocean + - pytest + - pip: + - codecov + - pytest-cov + - coverage[toml] diff --git a/ci/environment-py3.7.yml b/ci/environment-py3.7.yml index b8f5bdd..93652f2 100644 --- a/ci/environment-py3.7.yml +++ b/ci/environment-py3.7.yml @@ -1,24 +1,21 @@ -name: test_env_extract_model +name: test-env channels: - conda-forge dependencies: - python=3.7 - ############## These will have to be adjusted to your specific project - cartopy - cf_xarray - cmocean - dask - - jupyter - - jupyterlab - matplotlib - netcdf4 - numpy - pip + - pyinterp - requests - xarray - xcmocean - xesmf - ############## - pytest - pip: - codecov diff --git a/ci/environment-py3.8-win.yml b/ci/environment-py3.8-win.yml new file mode 100644 index 0000000..7227a2d --- /dev/null +++ b/ci/environment-py3.8-win.yml @@ -0,0 +1,21 @@ +name: test-env +channels: + - conda-forge +dependencies: + - python=3.8 + - cf_xarray>=0.6 + - cmocean + - dask + - matplotlib + - netcdf4 + - numpy + - pip + - pyinterp + - requests + - xarray + - xcmocean + - pytest + - pip: + - codecov + - pytest-cov + - coverage[toml] diff --git a/ci/environment-py3.8.yml b/ci/environment-py3.8.yml index 9f78521..015ca5e 100644 --- a/ci/environment-py3.8.yml +++ b/ci/environment-py3.8.yml @@ -1,24 +1,21 @@ -name: test_env_extract_model +name: test-env channels: - conda-forge dependencies: - python=3.8 - ############## These will have to be adjusted to your specific project - cartopy - cf_xarray - cmocean - dask - - jupyter - - jupyterlab - matplotlib - netcdf4 - numpy - pip + - pyinterp - requests - xarray - xcmocean - xesmf - ############## - pytest - pip: - codecov diff --git a/ci/environment-py3.9-win.yml b/ci/environment-py3.9-win.yml new file mode 100644 index 0000000..5cb1c40 --- /dev/null +++ b/ci/environment-py3.9-win.yml @@ -0,0 +1,21 @@ +name: test-env +channels: + - conda-forge +dependencies: + - python=3.9 + - cf_xarray>=0.6 + - cmocean + - dask + - matplotlib + - netcdf4 + - numpy + - pip + - pyinterp + - requests + - xarray + - xcmocean + - pytest + - pip: + - codecov + - pytest-cov + - coverage[toml] diff --git a/ci/environment-py3.9.yml b/ci/environment-py3.9.yml new file mode 100644 index 0000000..00cedd9 --- /dev/null +++ b/ci/environment-py3.9.yml @@ -0,0 +1,22 @@ +name: test-env +channels: + - conda-forge +dependencies: + - python=3.9 + - cf_xarray>=0.6 + - cmocean + - dask + - matplotlib + - netcdf4 + - numpy + - pip + - pyinterp + - requests + - xarray + - xcmocean + - xesmf + - pytest + - pip: + - codecov + - pytest-cov + - coverage[toml] diff --git a/environment.yml b/environment.yml index e40dcb8..b5877dd 100644 --- a/environment.yml +++ b/environment.yml @@ -4,18 +4,22 @@ channels: dependencies: # Required for full project functionality (dont remove) - pytest +<<<<<<< HEAD + - cf_xarray>=0.6 +======= # Examples (remove and add as needed) - cartopy - cf_xarray +>>>>>>> origin/main - cmocean - dask - - jupyter - - jupyterlab - matplotlib - netcdf4 - numpy - pip + - pyinterp - requests + - scipy - xarray - xcmocean - xesmf diff --git a/extract_model/__init__.py b/extract_model/__init__.py index aa95262..61d50ae 100644 --- a/extract_model/__init__.py +++ b/extract_model/__init__.py @@ -6,14 +6,14 @@ import cf_xarray as cfxr # noqa: F401 import requests # noqa: F401 - from pkg_resources import DistributionNotFound, get_distribution -import extract_model.accessor - -from .extract_model import argsel2d, sel2d, select # noqa: F401 -from .utils import filter, order, preprocess, sub_bbox, sub_grid +import extract_model.accessor # noqa: F401 +from .extract_model import (argsel2d, make_output_ds, sel2d, # noqa: F401 + select) +from .pyinterp_shim import PyInterpShim # noqa: F401 +from .utils import filter, order, preprocess, sub_bbox, sub_grid # noqa: F401 try: __version__ = get_distribution("extract_model").version diff --git a/extract_model/extract_model.py b/extract_model/extract_model.py index bc13a12..101831e 100644 --- a/extract_model/extract_model.py +++ b/extract_model/extract_model.py @@ -1,12 +1,36 @@ """ Main file for this code. The main code is in `select`, and the rest is to help with variable name management. """ +import warnings +from numbers import Number +from types import ModuleType +from typing import List, Optional, Union import cartopy.geodesic import cf_xarray # noqa: F401 import numpy as np +import numpy.typing as npt import xarray as xr -import xesmf as xe + +try: + import xesmf as xe + + XESMF_AVAILABLE = True +except ImportError: + XESMF_AVAILABLE = False + warnings.warn("xESMF not found. Interpolation will be performed using pyinterp.") + +try: + from .pyinterp_shim import PyInterpShim +except ImportError: + if XESMF_AVAILABLE: + warnings.warn( + "PyInterp not found. Interpolation will be performed using xESMF." + ) + else: + raise ModuleNotFoundError( + "Neither PyInterp nor xESMF are available. Please install either package." + ) # try: @@ -18,24 +42,27 @@ def select( - da, - longitude=None, - latitude=None, - T=None, - Z=None, - iT=None, - iZ=None, - extrap=False, - extrap_val=None, - locstream=False, - regridder=None, + da: xr.DataArray, + longitude: Optional[ + Union[Number, List[Number], npt.ArrayLike, xr.DataArray] + ] = None, + latitude: Optional[Union[Number, List[Number], npt.ArrayLike, xr.DataArray]] = None, + T: Optional[Union[str, List[str]]] = None, + Z: Optional[Union[Number, List[Number]]] = None, + iT: Optional[Union[int, List[int]]] = None, + iZ: Optional[Union[int, List[int]]] = None, + extrap: bool = False, + extrap_val: Optional[Number] = None, + locstream: bool = False, + interp_lib: str = "xesmf", + regridder: Optional[ModuleType] = None, ): """Extract output from da at location(s). Parameters ---------- da: DataArray - Property to select model output from. + DataArray from which to extract data. longitude, latitude: int, float, list, array (1D or 2D), DataArray, optional longitude(s), latitude(s) at which to return model output. Package `xESMF` will be used to interpolate with "bilinear" to @@ -94,83 +121,128 @@ def select( >>> da_out = em.select(**kwargs) """ - # can't run in both Z and iZ mode, same for T/iT - assert not ((Z is not None) and (iZ is not None)) - assert not ((T is not None) and (iT is not None)) + # Must select or interpolate for depth and time. + # - i.e. One cannot run in both Z and iZ mode, same for T/iT + if (Z is not None) and (iZ is not None): + raise ValueError("Cannot specify both Z and iZ.") + if (T is not None) and (iT is not None): + raise ValueError("Cannot specify both T and iT.") if (longitude is not None) and (latitude is not None): - if (isinstance(longitude, int)) or (isinstance(longitude, float)): + # Must convert scalars to lists because 0D lat/lon arrays are not supported. + if isinstance(longitude, Number): longitude = [longitude] - if (isinstance(latitude, int)) or (isinstance(latitude, float)): + if isinstance(latitude, Number): latitude = [latitude] - latitude = np.asarray(latitude) longitude = np.asarray(longitude) + latitude = np.asarray(latitude) + output_grid = True + else: + output_grid = False + + # Horizontal interpolation + # Verify interpolated points in domain if not extrapolating. + if output_grid and not extrap: + if ( + longitude.min() < da.cf["longitude"].min() + or longitude.max() > da.cf["longitude"].max() + ): + raise ValueError( + "Longitude outside of available domain." + "Use extrap=True to extrapolate." + ) + if ( + latitude.min() < da.cf["latitude"].min() + or latitude.max() > da.cf["latitude"].max() + ): + raise ValueError( + "Latitude outside of available domain." + "Use extrap=True to extrapolate." + ) + + # Create output grid as Dataset. + if output_grid: + ds_out = make_output_ds(longitude, latitude) + else: + ds_out = None + # If extrapolating, define method if extrap: extrap_method = "nearest_s2d" else: extrap_method = None - if (not extrap) and ((longitude is not None) and (latitude is not None)): - assertion = "the input longitude range is outside the model domain" - assert (longitude.min() >= da.cf["longitude"].min()) and ( - longitude.max() <= da.cf["longitude"].max() - ), assertion - assertion = "the input latitude range is outside the model domain" - assert (latitude.min() >= da.cf["latitude"].min()) and ( - latitude.max() <= da.cf["latitude"].max() - ), assertion - - # Horizontal interpolation # - # grid of lon/lat to interpolate to, with desired ending attributes - if (longitude is not None) and (latitude is not None): + # Perform interpolation + if interp_lib == "xesmf" and XESMF_AVAILABLE: + da = _xesmf_interp( + da, + ds_out, + T=T, + Z=Z, + iT=iT, + iZ=iZ, + extrap_method=extrap_method, + extrap_val=extrap_val, + locstream=locstream, + regridder=regridder + ) + elif interp_lib == "pyinterp" or not XESMF_AVAILABLE: + da = _pyinterp_interp( + da, + ds_out, + T=T, + Z=Z, + iT=iT, + iZ=iZ, + extrap_method=extrap_method, + extrap_val=extrap_val, + locstream=locstream + ) + else: + raise ValueError(f"{interp_lib} interpolation not supported") - if latitude.ndim == 1: - da_out = xr.Dataset( - { - "lat": ( - ["lat"], - latitude, - dict(axis="Y", units="degrees_north", standard_name="latitude"), - ), - "lon": ( - ["lon"], - longitude, - dict(axis="X", units="degrees_east", standard_name="longitude"), - ), - } - ) - elif latitude.ndim == 2: - da_out = xr.Dataset( - { - "lat": ( - ["Y", "X"], - latitude, - dict(units="degrees_north", standard_name="latitude"), - ), - "lon": ( - ["Y", "X"], - longitude, - dict(units="degrees_east", standard_name="longitude"), - ), - } - ) + return da + +def _xesmf_interp( + da: xr.DataArray, + ds_out: Optional[xr.Dataset] = None, + T: Optional[Union[str, List[str]]] = None, + Z: Optional[Union[Number, List[Number]]] = None, + iT: Optional[Union[int, List[int]]] = None, + iZ: Optional[Union[int, List[int]]] = None, + extrap_method: Optional[str] = None, + extrap_val: Optional[Number] = None, + locstream: bool = False, + regridder: Optional[ModuleType] = None +) -> xr.DataArray: + """Interpolate input DataArray to output DataArray using xESMF. + + Parameters + ---------- + da: xarray.DataArray + Input DataArray to interpolate. + da_out: xarray.DataArray + Output DataArray to interpolate to. + T: datetime-like string, list of datetime-like strings, optional + Z: int, float, list, optional + iT: int or list of ints, optional + iZ: int or list of ints, optional + extrap: bool, optional + extrap_val: int, float, optional + locstream: boolean, optional + + Returns + ------- + DataArray of interpolated and/or selected values from da. + """ + if ds_out is not None: if regridder is None: # set up regridder, which would work for multiple interpolations if desired regridder = xe.Regridder( - da, - da_out, - "bilinear", - extrap_method=extrap_method, - locstream_out=locstream, - ignore_degenerate=True, + da, ds_out, "bilinear", extrap_method=extrap_method, locstream_out=locstream ) - - # do regridding - da_int = regridder(da, keep_attrs=True) - else: - da_int = da + da = regridder(da, keep_attrs=True) # get z coordinates to go with interpolated output if not available if "vertical" in da.cf.coords: @@ -181,50 +253,50 @@ def select( zint = regridder(da[zkey], keep_attrs=True) # add coords - da_int = da_int.assign_coords({zkey: zint}) + da = da.assign_coords({zkey: zint}) if iT is not None: with xr.set_options(keep_attrs=True): - da_int = da_int.cf.isel(T=iT) + da = da.cf.isel(T=iT) elif T is not None: with xr.set_options(keep_attrs=True): - da_int = da_int.cf.interp(T=T) + da = da.cf.interp(T=T) # Time and depth interpolation or iselection # if iZ is not None: with xr.set_options(keep_attrs=True): - da_int = da_int.cf.isel(Z=iZ) + da = da.cf.isel(Z=iZ) # deal with interpolation in Z separately elif Z is not None: # can do interpolation in depth for any number of dimensions if the # vertical coord is 1d - if da_int.cf["vertical"].ndim == 1: - da_int = da_int.cf.interp(vertical=Z) + if da.cf["vertical"].ndim == 1: + da = da.cf.interp(vertical=Z) # if the vertical coord is greater than 1D, can only do restricted interpolation # at the moment else: - da_int = da_int.squeeze() - if len(da_int.dims) == 1 and da_int.cf["Z"].name in da_int.dims: - da_int = da_int.swap_dims( - {da_int.cf["Z"].name: da_int.cf["vertical"].name} + da = da.squeeze() + if len(da.dims) == 1 and da.cf["Z"].name in da.dims: + da = da.swap_dims( + {da.cf["Z"].name: da.cf["vertical"].name} ) - da_int = da_int.cf.interp(vertical=Z) - elif len(da_int.dims) == 2 and da_int.cf["Z"].name in da_int.dims: + da = da.cf.interp(vertical=Z) + elif len(da.dims) == 2 and da.cf["Z"].name in da.dims: # loop over other dimension - dim_var_name = list(set(da_int.dims) - set([da_int.cf["Z"].name]))[0] + dim_var_name = list(set(da.dims) - set([da.cf["Z"].name]))[0] new_da = [] - for i in range(len(da_int[dim_var_name])): + for i in range(len(da[dim_var_name])): new_da.append( - da_int.isel({dim_var_name: i}) - .swap_dims({da_int.cf["Z"].name: da_int.cf["vertical"].name}) + da.isel({dim_var_name: i}) + .swap_dims({da.cf["Z"].name: da.cf["vertical"].name}) .cf.interp(vertical=Z) ) - da_int = xr.concat(new_da, dim=dim_var_name) - elif len(da_int.dims) > 2: + da = xr.concat(new_da, dim=dim_var_name) + elif len(da.dims) > 2: # need to implement (x)isoslice here raise NotImplementedError( "Currently it is not possible to interpolate in depth with more than 1 other (time) dimension." @@ -233,9 +305,96 @@ def select( if extrap_val is not None: # returns 0 outside the domain by default. Assumes that no other values are exactly 0 # and replaces all 0's with extrap_val if chosen. - da_int = da_int.where(da_int != 0, extrap_val) + da = da.where(da != 0, extrap_val) + + return da.squeeze() + + +def _pyinterp_interp( + da: xr.DataArray, + ds_out: Optional[xr.Dataset] = None, + T: Optional[Union[str, List[str]]] = None, + Z: Optional[Union[Number, List[Number]]] = None, + iT: Optional[Union[int, List[int]]] = None, + iZ: Optional[Union[int, List[int]]] = None, + extrap_method: Optional[str] = None, + extrap_val: Optional[Number] = None, + locstream: bool = False +): + """Interpolate input DataArray to output DataArray using PyInterp. + + Parameters + ---------- + da: xarray.DataArray + Input DataArray to interpolate. + da_out: xarray.DataArray + Output DataArray to interpolate to. + T: datetime-like string, list of datetime-like strings, optional + Z: int, float, list, optional + iT: int or list of ints, optional + iZ: int or list of ints, optional + extrap: bool, optional + extrap_val: int, float, optional + locstream: boolean, optional + + Returns + ------- + DataArray of interpolated and/or selected values from da. + """ + + # Loess based extrapolation will be used if required. + if extrap_method is not None: + extrap = True + else: + extrap = False + + interpretor = PyInterpShim() + da = interpretor( + da, ds_out, T=T, Z=Z, iT=iT, iZ=iZ, extrap=extrap, locstream=locstream + ) + + return da + + +def make_output_ds(longitude: npt.ArrayLike, latitude: npt.ArrayLike) -> xr.Dataset: + """ + Given desired interpolated longitude and latitude, return points as Dataset. + """ + # Grid of lat/lon to interpolate to with desired ending attributes + if latitude.ndim == 1: + ds_out = xr.Dataset( + { + "lat": ( + ["lat"], + latitude, + dict(axis="Y", units="degrees_north", standard_name="latitude"), + ), + "lon": ( + ["lon"], + longitude, + dict(axis="X", units="degrees_east", standard_name="longitude"), + ), + } + ) + elif latitude.ndim == 2: + ds_out = xr.Dataset( + { + "lat": ( + ["Y", "X"], + latitude, + dict(units="degrees_north", standard_name="latitude"), + ), + "lon": ( + ["Y", "X"], + longitude, + dict(units="degrees_east", standard_name="longitude"), + ), + } + ) + else: + raise IndexError(f"{latitude.ndim}D latitude/longitude arrays not supported.") - return da_int.squeeze(), regridder + return ds_out def argsel2d(lons, lats, lon0, lat0): diff --git a/extract_model/pyinterp_shim.py b/extract_model/pyinterp_shim.py new file mode 100644 index 0000000..efc7507 --- /dev/null +++ b/extract_model/pyinterp_shim.py @@ -0,0 +1,527 @@ +""" +Temporary interface for using pyinterp in same manner as xESMF in this package. +""" +import warnings +from numbers import Number +from typing import List, Optional, Tuple, Union + +import numpy as np +import xarray as xr + +try: + import pyinterp + import pyinterp.backends.xarray + import pyinterp.fill +except ImportError: + warnings.warn( + "pyinterp not installed. Interpolation will be performed using xESMF." + ) + + +class PyInterpShim: + def __call__( + self, + da: xr.DataArray, + da_out: Optional[xr.DataArray] = None, + T: Optional[Union[str, List[str]]] = None, + Z: Optional[Union[Number, List[Number]]] = None, + iT: Optional[Union[int, List[int]]] = None, + iZ: Optional[Union[int, List[int]]] = None, + extrap: bool = False, + locstream: bool = False, + ): + # If extrapolating, bounds_errors will not be raised. + # Loess extrapoltion will be used. + if extrap: + bounds_error = False + else: + bounds_error = True + + # Time and depth interpolation or iselection + with xr.set_options(keep_attrs=True): + if iZ is not None: + da = da.cf.isel(Z=iZ) + elif Z is not None: + da = da.cf.interp(Z=Z) + + if iT is not None: + da = da.cf.isel(T=iT) + elif T is not None: + da = da.cf.interp(T=T) + + # Requires horizontal interpolation + if da_out is not None: + # interpolate to the output grid + # then package appropriately + subset_da, interped_array, interp_method = self._interp( + da, da_out, T, Z, iT, iZ, bounds_error + ) + if locstream: + da = self._package_locstream( + da, da_out, subset_da, interped_array, T, Z, iT, iZ, interp_method + ) + else: + da = self._package_grid( + da, da_out, subset_da, interped_array, T, Z, iT, iZ, interp_method + ) + + return da + + def _interp( + self, + da: xr.DataArray, + da_out: xr.DataArray, + interped: xr.DataArray, + T: Optional[Union[str, List[str]]] = None, + Z: Optional[Union[Number, List[Number]]] = None, + iT: Optional[Union[int, List[int]]] = None, + iZ: Optional[Union[int, List[int]]] = None, + bounds_error: bool = False, + ) -> Tuple[xr.DataArray, np.ndarray, str]: + # Prepare points for interpolation + # - Need a DataArray + if isinstance(da, xr.Dataset): + var_name = list(da.data_vars)[0] + da = da[var_name] + else: + var_name = da.name + + # Add misssing coordinates to da_out + if len(da_out.lon.shape) == 2: + xy_dataset = xr.Dataset( + data_vars={ + "X": np.arange(da_out.dims["X"]), + "Y": np.arange(da_out.dims["Y"]), + } + ) + da_out = da_out.merge(xy_dataset) + + # Identify singular dimensions for time and depth + def _is_singular_parameter(da, coordinate, vars): + # First check if extraction parameters will render singular dimensions + for v in vars: + if v is not None: + if isinstance(v, list) and len(v) == 0: + return True + elif isinstance(v, Number): + return True + + # Then check if there are singular dimensions in the data array + if coordinate in da.cf.coordinates: + coordinate_name = da.cf.coordinates[coordinate][0] + if da[coordinate_name].data.size == 1: + return True + + return False + + time_singular = _is_singular_parameter(da, "time", [T, iT]) + vertical_singular = _is_singular_parameter(da, "vertical", [Z, iZ]) + + # Perform interpolation with details depending on dimensionality of data + ndims = 0 + if "longitude" in da.cf.coordinates: + ndims += 1 + if "latitude" in da.cf.coordinates: + ndims += 1 + if "vertical" in da.cf.coordinates and not vertical_singular: + ndims += 1 + if "time" in da.cf.coordinates and not time_singular: + ndims += 1 + + lat_var = da.cf.coordinates["latitude"][0] + lon_var = da.cf.coordinates["longitude"][0] + if "time" in da.cf.coordinates: + time_var = da.cf.coordinates["time"][0] + else: + time_var = None + if "vertical" in da.cf.coordinates: + vertical_var = da.cf.coordinates["vertical"][0] + else: + vertical_var = None + regrid_method = "bilinear" + + subset_da = da + if ndims == 2: + if time_var: + if time_var in subset_da.coords and time_var in subset_da.dims: + subset_da = subset_da.isel({time_var: 0}) + + if vertical_var: + if vertical_var in subset_da.coords and vertical_var in subset_da.dims: + subset_da = subset_da.isel({vertical_var: 0}) + + # Interpolate + try: + mx, my = np.meshgrid( + da_out.lon.values, da_out.lat.values, indexing="ij" + ) + grid = pyinterp.backends.xarray.Grid2D(subset_da) + interped = grid.bivariate( + coords={lon_var: mx.ravel(), lat_var: my.ravel()}, + bounds_error=bounds_error, + ).reshape(mx.shape) + # Transpose from x,y to y,x + interped = interped.T + except ValueError: + grid = pyinterp.RTree() + grid.packing( + np.vstack( + ( + subset_da[lon_var].data.ravel(), + subset_da[lat_var].data.ravel(), + ) + ).T, + subset_da.data.ravel(), + ) + if len(da_out.lon.shape) == 2: + mx = da_out.lon.values + my = da_out.lat.values + else: + mx, my = np.meshgrid( + da_out.lon.values, da_out.lat.values, indexing="ij" + ) + idw, _ = grid.inverse_distance_weighting( + np.vstack((mx.ravel(), my.ravel())).T, + within=bounds_error, + k=5, + ) + interped = idw.reshape(mx.shape) + regrid_method = "IDW" + + elif ndims == 3: + if time_var: + time_da = subset_da[time_var] + if vertical_var: + vertical_da = subset_da[vertical_var] + + if time_singular: + if iT is not None: + subset_da = subset_da.isel({time_var: iT}) + time_da = time_da.isel({time_var: iT}) + elif T is not None: + subset_da = subset_da.sel({time_var: T}) + time_da = time_da.sel({time_var: T}) + if vertical_singular: + if iZ is not None: + subset_da = subset_da.isel({vertical_var: iZ}) + vertical_da = vertical_da.isel({time_var: iT}) + elif Z is not None: + subset_da = subset_da.sel({vertical_var: Z}) + vertical_da = vertical_da.sel({time_var: Z}) + + # Regular grid + try: + mx, my, mz = np.meshgrid( + da_out.lon.values, + da_out.lat.values, + da.cf.coords["time"].values, + indexing="ij", + ) + + # Fill NaNs using Loess + grid = pyinterp.backends.xarray.Grid3D(subset_da) + filled = pyinterp.fill.loess(grid, nx=5, ny=5) + grid = pyinterp.Grid3D(grid.x, grid.y, grid.z, filled) + interped = pyinterp.bicubic( + grid, + x=mx.ravel(), + y=my.ravel(), + z=mz.ravel(), + bounds_error=bounds_error, + ).reshape(mx.shape) + # Curviliear or unstructured + except ValueError: + # Need to manually create grid when lon, lat are 2D (curvilinear or unstructured) + trailing_dim = subset_da.shape[0] + + grid = pyinterp.RTree() + grid.packing( + np.vstack( + ( + subset_da[lon_var].data.ravel(), + subset_da[lat_var].data.ravel(), + ) + ).T, + subset_da.data.ravel().reshape(-1, trailing_dim), + ) + if len(da_out.lon.shape) == 2: + mx = da_out.lon.values + my = da_out.lat.values + else: + mx, my = np.meshgrid( + da_out.lon.values, da_out.lat.values, indexing="ij" + ) + idw, _ = grid.inverse_distance_weighting( + np.vstack((mx.ravel(), my.ravel())).T, + within=bounds_error, + k=5, + ) + interped = idw.reshape(mx.shape) + regrid_method = "IDW" + + elif ndims == 4: + mx, my, mz, mu = np.meshgrid( + da_out.lon.values, + da_out.lat.values, + da.cf.coords["time"].values, + da.cf.coords["vertical"].values, + indexing="ij", + ) + # Fill NaNs using Loess + grid = pyinterp.backends.xarray.Grid4D(da) + filled = pyinterp.fill.loess(grid, nx=3, ny=3) + grid = pyinterp.Grid4D(grid.x, grid.y, grid.z, grid.u, filled) + interped = pyinterp.bicubic( + grid, + x=mx.ravel(), + y=mx.ravel(), + z=mz.ravel(), + u=mu.ravel(), + bounds_error=bounds_error, + ).reshape(mx.shape) + else: + raise IndexError(f"{ndims}D interpolation not supported") + + return subset_da, interped, regrid_method + + def _package_locstream( + self, + da: xr.DataArray, + da_out: xr.DataArray, + subset_da: xr.DataArray, + interped: np.ndarray, + T: Optional[Union[str, List[str]]] = None, + Z: Optional[Union[Number, List[Number]]] = None, + iT: Optional[Union[int, List[int]]] = None, + iZ: Optional[Union[int, List[int]]] = None, + regrid_method: Optional[str] = None, + ): + # Prepare points for interpolation + # - Need a DataArray + if type(da) == xr.Dataset: + var_name = list(da.data_vars)[0] + da = da[var_name] + else: + var_name = da.name + + # Locstream will have dim pt for the number of points + # - Change dims from lon/lat to pts + lat_var = da_out.cf.coordinates["latitude"][0] + lon_var = da_out.cf.coordinates["longitude"][0] + da_out = da_out.rename_dims( + { + lat_var: "pts", + lon_var: "pts", + } + ) + + # Add coordinates from the original data + coords = da_out.coords + if "time" in da.cf.coordinates: + time_var = da.cf.coordinates["time"][0] + else: + time_var = None + if "vertical" in da.cf.coordinates: + vertical_var = da.cf.coordinates["vertical"][0] + else: + vertical_var = None + + if "time" in da.cf.coordinates: + coords["time"] = subset_da[time_var] + if "vertical" in da.cf.coordinates: + coords["vertical"] = subset_da[vertical_var] + + # Add interpolated data + # If a single point, reshape to len(pts, 1) + if da_out[lat_var].shape == (1,): + interped = np.squeeze(interped)[:, np.newaxis] + # Also need to swap the dims to match + dims = [dim for dim in da_out.dims][::-1] + # Else, it's probably a grid and the diagonal needs to be extracted + # - This is a workaround for a bad implementation in _interp which interpolates + # a whole grid instead of just a set of points. + else: + interped = np.diagonal(interped) + dims = da_out.dims + + return xr.DataArray( + interped, + coords=coords, + dims=dims, + attrs={**da.attrs, **{"regrid_method": regrid_method}}, + ) + + def _package_grid( + self, + da: xr.DataArray, + da_out: xr.DataArray, + subset_da: xr.DataArray, + interped: np.ndarray, + T: Optional[Union[str, List[str]]] = None, + Z: Optional[Union[Number, List[Number]]] = None, + iT: Optional[Union[int, List[int]]] = None, + iZ: Optional[Union[int, List[int]]] = None, + regrid_method: Optional[str] = None, + ): + # Prepare points for interpolation + # - Need a DataArray + if type(da) == xr.Dataset: + var_name = list(da.data_vars)[0] + da = da[var_name] + else: + var_name = da.name + + # Add misssing coordinates to da_out + if len(da_out.lon.shape) == 2: + xy_dataset = xr.Dataset( + data_vars={ + "X": np.arange(da_out.dims["X"]), + "Y": np.arange(da_out.dims["Y"]), + } + ) + da_out = da_out.merge(xy_dataset) + + # Identify singular dimensions for time and depth + def _is_singular_parameter(da, coordinate, vars): + # First check if extraction parameters will render singular dimensions + for v in vars: + if v is not None: + if isinstance(v, list) and len(v) == 0: + return True + elif isinstance(v, Number): + return True + + # Then check if there are singular dimensions in the data array + if coordinate in da.cf.coordinates: + coordinate_name = da.cf.coordinates[coordinate][0] + if da[coordinate_name].data.size == 1: + return True + + return False + + time_singular = _is_singular_parameter(da, "time", [T, iT]) + vertical_singular = _is_singular_parameter(da, "vertical", [Z, iZ]) + + # Perform interpolation with details depending on dimensionality of data + ndims = 0 + if "longitude" in da.cf.coordinates: + ndims += 1 + if "latitude" in da.cf.coordinates: + ndims += 1 + if "vertical" in da.cf.coordinates and not vertical_singular: + ndims += 1 + if "time" in da.cf.coordinates and not time_singular: + ndims += 1 + + if "time" in da.cf.coordinates: + time_var = da.cf.coordinates["time"][0] + else: + time_var = None + if "vertical" in da.cf.coordinates: + vertical_var = da.cf.coordinates["vertical"][0] + else: + vertical_var = None + if ndims == 2: + # Package as DataArray + if len(da_out.lon) == 1: + lons = da_out.lon.isel({"lon": 0}) + else: + lons = da_out.lon + if len(da_out.lat) == 1: + lats = da_out.lat.isel({"lat": 0}) + else: + lats = da_out.lat + + coords = {"lon": lons, "lat": lats} + # Handle curvilinear lon/lat coords + if len(lons.shape) == 2: + for dim in lons.dims: + coords[dim] = lons[dim] + if "time" in da.cf.coordinates: + coords["time"] = da[time_var] + if "vertical" in da.cf.coordinates: + coords["vertical"] = da[vertical_var] + + # Handle missing dims from interpolation + missing_subset_dims = [] + for subset_dim in subset_da.dims: + if subset_dim not in [ + da.cf.coordinates["longitude"][0], + da.cf.coordinates["latitude"][0], + ]: + missing_subset_dims.append(subset_dim) + + output_dims = [] + for orig_dim in da.dims: + # Handle original x, y to lon, lat + # Also, do not add lon and lat if they are scalars + if orig_dim == "xi_rho" and len(da_out.lon) > 1: + output_dims.append("X") + elif orig_dim == "xi_rho" and len(da_out.lon) == 1: + interped = np.squeeze(interped, axis=0) + continue + elif orig_dim == "eta_rho" and len(da_out.lat) > 1: + output_dims.append("Y") + elif orig_dim == "eta_rho" and len(da_out.lat) == 1: + interped = np.squeeze(interped, axis=0) + continue + elif ( + orig_dim == da.cf.coordinates["longitude"][0] + and len(da_out.lon) > 1 + ): + output_dims.append("lon") + elif ( + orig_dim == da.cf.coordinates["longitude"][0] + and len(da_out.lon) == 1 + ): + interped = np.squeeze(interped, axis=0) + continue + elif ( + orig_dim == da.cf.coordinates["latitude"][0] and len(da_out.lat) > 1 + ): + output_dims.append("lat") + elif ( + orig_dim == da.cf.coordinates["latitude"][0] + and len(da_out.lat) == 1 + ): + interped = np.squeeze(interped, axis=0) + continue + else: + output_dims.append(orig_dim) + + if orig_dim not in missing_subset_dims: + interped = interped[np.newaxis, ...] + + da = xr.DataArray( + interped, + coords=coords, + dims=output_dims, + attrs={**da.attrs, **{"regrid_method": regrid_method}}, + ) + elif ndims == 3: + coords = { + "lat": da_out.lat, + "lon": da_out.lon, + "time": da.cf.coords["time"], + } + da = xr.Dataset( + {var_name: (["lat", "lon", "time"], interped)}, + coords=coords, + attrs=da.attrs, + ) + elif ndims == 4: + coords = { + "lat": da_out.lat, + "lon": da_out.lon, + "time": da.cf.coords["time"], + "vertical": da.cf.coords["vertical"], + } + da = xr.Dataset( + {var_name: (["lat", "lon", "time", "vertical"], interped)}, + coords=coords, + attrs=da.attrs, + ) + else: + raise IndexError(f"{ndims}D interpolation not supported") + + return da diff --git a/setup.cfg b/setup.cfg index 840256f..3571cfb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,13 +35,12 @@ license_file = LICENSE.txt # For details see: https://pypi.org/classifiers/ classifiers = - Development Status :: 5 - Production/Stable + Development Status :: 3 - Alpha Topic :: Scientific/Engineering Intended Audience :: Science/Research Operating System :: OS Independent Programming Language :: Python Programming Language :: Python :: 3 - Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 @@ -59,20 +58,18 @@ install_requires = cf_xarray cmocean dask - jupyter - jupyterlab matplotlib netcdf4 numpy - requests pip + requests + scipy xarray xcmocean - xesmf setup_requires= setuptools_scm -python_requires = >=3.6 +python_requires = >=3.7 ################ Up until here zip_safe = False diff --git a/tests/model_configs.yaml b/tests/model_configs.yaml new file mode 100644 index 0000000..eba6792 --- /dev/null +++ b/tests/model_configs.yaml @@ -0,0 +1,67 @@ +MOM6: + url: Path(__file__).parent / "test_mom6.nc" + var: "uo" + i: 0 + j: 0 + iZ: ~ + Z: 0 + iT: null + T: null + lon1: -166 + lat1: 48 + lon2: -149 + lat2: 56.0 + lonslice: slice(None, 5) + latslice: slice(None, 5) + model_names: [None, "sea_water_x_velocity", None, None, None] + +HYCOM_01: + url: Path(__file__).parent / "test_hycom.nc" + var: "water_u" + i: 0 + j: 30 + iZ: null + Z: 0 + iT: null + T: null + lon1: -166 + lat1: 48 + lon2: 149.0 + lat2: -10.1 + lonslice: slice(10, 15) + latslice: slice(10, 15) + model_names: [None, "eastward_sea_water_velocity", None, None, None] + +HYCOM_02: + url: Path(__file__).parent / "test_hycom2.nc" + var: "u" + j: 30 + i: 0 + iZ: null + Z: 0 + iT: null + T: null + lon1: -166 + lat1: 48 + lon2: -91 + lat2: 29.5 + lonslice: slice(10, 15) + latslice: slice(10, 15) + model_names: [None, "eastward_sea_water_velocity", None, None, None] + +ROMS: + url: Path(__file__).parent / "test_roms.nc" + var: "zeta" + j: 50 + i: 10 + iZ: null + Z: null + iT: null + T: 0 + lon1: -166 + lat1: 48 + lon2: -91 + lat2: 29.5 + lonslice: slice(10, 15) + latslice: slice(10, 15) + model_names: ["sea_surface_elevation", None, None, None, None] \ No newline at end of file diff --git a/tests/test_em.py b/tests/test_em.py index 3681610..d7724ca 100644 --- a/tests/test_em.py +++ b/tests/test_em.py @@ -1,3 +1,4 @@ +import sys from pathlib import Path import numpy as np @@ -6,117 +7,10 @@ import extract_model as em +from .utils import read_model_configs -models = [] - -# MOM6 inputs -url = Path(__file__).parent / "test_mom6.nc" -ds = xr.open_dataset(url) -ds = ds.cf.guess_coord_axis() -da = ds["uo"] -i, j = 0, 0 -Z, T = 0, None -lon1, lat1 = -166, 48 -lon2, lat2 = -149.0, 56.0 -lonslice = slice(None, 5) -latslice = slice(None, 5) -model_names = [None, "sea_water_x_velocity", None, None, None] -mom6 = dict( - da=da, - i=i, - j=j, - Z=Z, - T=T, - lon1=lon1, - lat1=lat1, - lon2=lon2, - lat2=lat2, - lonslice=lonslice, - latslice=latslice, - model_names=model_names, -) -models += [mom6] - -# HYCOM inputs -url = Path(__file__).parent / "test_hycom.nc" -ds = xr.open_mfdataset([url], preprocess=em.preprocess) -da = ds["water_u"] -i, j = 0, 30 -Z, T = 0, None -lon1, lat1 = -166, 48 -lon2, lat2 = 149.0, -10.1 -lonslice = slice(10, 15) -latslice = slice(10, 15) -model_names = [None, "eastward_sea_water_velocity", None, None, None] -hycom = dict( - da=da, - i=i, - j=j, - Z=Z, - T=T, - lon1=lon1, - lat1=lat1, - lon2=lon2, - lat2=lat2, - lonslice=lonslice, - latslice=latslice, - model_names=model_names, -) -models += [hycom] - -# Second HYCOM example inputs, from Heather -url = Path(__file__).parent / "test_hycom2.nc" -ds = xr.open_mfdataset([url], preprocess=em.preprocess) -da = ds["u"] -j, i = 30, 0 -Z, T = 0, None -lon1, lat1 = -166, 48 -lon2, lat2 = -91, 29.5 -lonslice = slice(10, 15) -latslice = slice(10, 15) -model_names = [None, "eastward_sea_water_velocity", None, None, None] -hycom2 = dict( - da=da, - i=i, - j=j, - Z=Z, - T=T, - lon1=lon1, - lat1=lat1, - lon2=lon2, - lat2=lat2, - lonslice=lonslice, - latslice=latslice, - model_names=model_names, -) -models += [hycom2] - -# ROMS inputs -url = Path(__file__).parent / "test_roms.nc" -ds = xr.open_mfdataset([url], preprocess=em.preprocess) -da = ds["zeta"] -j, i = 50, 10 -Z1, T = None, 0 -lon1, lat1 = -166, 48 -lon2, lat2 = -91, 29.5 -lonslice = slice(10, 15) -latslice = slice(10, 15) -model_names = ["sea_surface_elevation", None, None, None, None] -roms = dict( - da=da, - i=i, - j=j, - Z=Z1, - T=T, - lon1=lon1, - lat1=lat1, - lon2=lon2, - lat2=lat2, - lonslice=lonslice, - latslice=latslice, - model_names=model_names, -) -models += [roms] +model_config_path = Path(__file__).parent / "model_configs.yaml" +models = read_model_configs(model_config_path) def test_T_interp(): @@ -124,7 +18,7 @@ def test_T_interp(): url = Path(__file__).parent / "test_roms.nc" ds = xr.open_dataset(url) - da_out, _ = em.select(da=ds["zeta"], T=0.5) + da_out = em.select(da=ds["zeta"], T=0.5) assert np.allclose(da_out[0, 0], -0.12584045) @@ -133,12 +27,233 @@ def test_Z_interp(): url = Path(__file__).parent / "test_hycom.nc" ds = xr.open_dataset(url) - da_out, _ = em.select(da=ds["water_u"], Z=1.0) + da_out = em.select(da=ds["water_u"], Z=1.0) assert np.allclose(da_out[-1, -1], -0.1365) -@pytest.mark.parametrize("model", models) +interp_libs = ['pyinterp', 'xesmf'] if 'xesmf' in sys.modules else ['pyinterp'] +test_args = [] +for model in models: + for lib in interp_libs: + test_args.append((model, lib)) + + +@pytest.mark.parametrize("model, interp_lib", test_args) class TestModel: + def test_grid_point_isel_Z(self, model, interp_lib): + """Select and return a grid point.""" + + da = model["da"] + i, j = model["i"], model["j"] + Z, T = model["Z"], model["T"] + + if da.cf["longitude"].ndim == 1: + longitude = float(da.cf["X"][i]) + latitude = float(da.cf["Y"][j]) + sel = dict(longitude=longitude, latitude=latitude) + + # isel + isel = dict(Z=Z) + + # check + da_check = da.cf.sel(sel).cf.isel(isel) + elif da.cf["longitude"].ndim == 2: + longitude = float(da.cf["longitude"][j, i]) + latitude = float(da.cf["latitude"][j, i]) + + isel = dict(T=T, X=i, Y=j) + + # check + da_check = da.cf.isel(isel) + + try: + kwargs = dict(da=da, longitude=longitude, latitude=latitude, iZ=Z, iT=T, interp_lib=interp_lib) + da_out = em.select(**kwargs) + assert np.allclose(da_out, da_check) + except AttributeError: + if interp_lib == 'xesmf': + pass + + def test_extrap_False(self, model, interp_lib): + """Search for point outside domain, which should raise an assertion.""" + + da = model["da"] + lon1, lat1 = model["lon1"], model["lat1"] + Z, T = model["Z"], model["T"] + + # sel + longitude = lon1 + latitude = lat1 + + kwargs = dict( + da=da, + longitude=longitude, + latitude=latitude, + iT=T, + iZ=Z, + extrap=False, + interp_lib=interp_lib + ) + + with pytest.raises(ValueError): + em.select(**kwargs) + + def test_extrap_True(self, model, interp_lib): + """Check that a point right outside domain has + extrapolated value of neighbor point.""" + + da = model["da"] + # varname = model["varname"] + i, j = model["i"], model["j"] + Z, T = model["Z"], model["T"] + + if da.cf["longitude"].ndim == 1: + longitude_check = float(da.cf["X"][i]) + longitude = longitude_check - 0.1 + latitude = float(da.cf["Y"][j]) + sel = dict(longitude=longitude_check, latitude=latitude) + + # isel + isel = dict(Z=Z) + + # check + da_check = da.cf.sel(sel).cf.isel(isel) + elif da.cf["longitude"].ndim == 2: + longitude = float(da.cf["longitude"][j, i]) + latitude = float(da.cf["latitude"][j, i]) + + isel = dict(T=T, X=i, Y=j) + + # check + da_check = da.cf.isel(isel) + + kwargs = dict( + da=da, + longitude=longitude, + latitude=latitude, + iZ=Z, + iT=T, + extrap=True, + interp_lib=interp_lib + ) + + try: + da_out = em.select(**kwargs) + assert np.allclose(da_out, da_check) + # Should throw TypeError because extrap is not supported using PyInterp. + except (ValueError, AssertionError, TypeError): + if interp_lib == 'pyinterp': + pass + except AttributeError: + if interp_lib == 'xesmf': + pass + + def test_extrap_False_extrap_val_nan(self, model, interp_lib): + """Check that land point returns np.nan for extrap=False + and extrap_val=np.nan.""" + + da = model["da"] + lon2, lat2 = model["lon2"], model["lat2"] + Z, T = model["Z"], model["T"] + + # sel + longitude = lon2 + latitude = lat2 + + kwargs = dict( + da=da, + longitude=longitude, + latitude=latitude, + iZ=Z, + iT=T, + extrap=False, + extrap_val=np.nan, + interp_lib=interp_lib + ) + + try: + da_out = em.select(**kwargs) + assert da_out.isnull() + except AttributeError: + if interp_lib == 'xesmf': + pass + + def test_locstream(self, model, interp_lib): + + da = model["da"] + lonslice, latslice = model["lonslice"], model["latslice"] + Z, T = model["Z"], model["T"] + + if da.cf["longitude"].ndim == 1: + longitude = da.cf["X"][lonslice].values + latitude = da.cf["Y"][latslice].values + sel = dict( + longitude=xr.DataArray(longitude, dims="pts"), + latitude=xr.DataArray(latitude, dims="pts"), + ) + isel = dict(Z=Z) + + elif da.cf["longitude"].ndim == 2: + longitude = da.cf["longitude"].cf.isel(Y=50, X=lonslice) + latitude = da.cf["latitude"].cf.isel(Y=50, X=lonslice) + isel = dict(T=T) + sel = dict(X=longitude.cf["X"], Y=longitude.cf["Y"]) + + kwargs = dict( + da=da, + longitude=longitude, + latitude=latitude, + iZ=Z, + iT=T, + locstream=True, + interp_lib=interp_lib + ) + + try: + da_out = em.select(**kwargs) + da_check = da.cf.sel(sel).cf.isel(isel) + assert np.allclose(da_out, da_check, equal_nan=True) + except AttributeError: + if interp_lib == 'xesmf': + pass + + def test_grid(self, model, interp_lib): + + da = model["da"] + lonslice, latslice = model["lonslice"], model["latslice"] + Z, T = model["Z"], model["T"] + + if da.cf["longitude"].ndim == 1: + longitude = da.cf["X"][lonslice] + latitude = da.cf["Y"][latslice] + sel = dict(longitude=longitude, latitude=latitude) + + isel = dict(Z=Z) + + # check + da_check = da.cf.sel(sel).cf.isel(isel) + + elif da.cf["longitude"].ndim == 2: + longitude = da.cf["longitude"][latslice, lonslice].values + latitude = da.cf["latitude"][latslice, lonslice].values + + isel = dict(T=T, X=lonslice, Y=latslice) + + # check + da_check = da.cf.isel(isel) + + kwargs = dict(da=da, longitude=longitude, latitude=latitude, iZ=Z, iT=T, interp_lib=interp_lib) + + try: + da_out = em.select(**kwargs) + assert np.allclose(da_out, da_check) + except AttributeError: + if interp_lib == 'xesmf': + pass + + +@pytest.mark.parametrize("model", models) +class TestSelectModels: def test_grid_point_isel_Z(self, model): """Select and return a grid point.""" @@ -294,7 +409,8 @@ def test_grid(self, model): assert np.allclose(da_out, da_check) - def test_preprocess(self, model): + + def test_preprocess(self, model, interp_lib): """Test preprocessing on output.""" da = model["da"] diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..cb872bd --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,23 @@ +from pathlib import Path # noqa E401 + +import xarray as xr +import yaml + + +def read_model_configs(model_configs_file): + """Read model configs from file and return as a list of dicts.""" + + with open(model_configs_file) as f: + configs = yaml.safe_load(f) + + for _, config in configs.items(): + path = eval(config["url"]) + with xr.open_dataset(path) as ds: + ds = ds.cf.guess_coord_axis() + da = ds[config['var']] + config["da"] = da + + config["lonslice"] = eval(config["lonslice"]) + config["latslice"] = eval(config["latslice"]) + + return [config for _, config in configs.items()]