diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..37f4c245 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +*.nc filter=lfs diff=lfs merge=lfs -text +*.nc4 filter=lfs diff=lfs merge=lfs -text +*.res filter=lfs diff=lfs merge=lfs -text +*.odb filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/codestyle.yml b/.github/workflows/codestyle.yml index f09c7a68..397d3d41 100644 --- a/.github/workflows/codestyle.yml +++ b/.github/workflows/codestyle.yml @@ -9,6 +9,11 @@ jobs: name: Check Python Coding Norms runs-on: ubuntu-latest steps: + - name: Install ubuntu dependencies + run: | + sudo apt-get install libproj-dev proj-data proj-bin + sudo apt-get install libgeos-dev musl-dev libc-dev + sudo ln -s /usr/lib/x86_64-linux-musl/libc.so /lib/libc.musl-x86_64.so.1 - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: diff --git a/README.md b/README.md index 162d858a..feb6f9cc 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ + # Evaluation and Verification of the Analysis (EVA) ### Continuous integration: @@ -6,6 +7,14 @@ | --------- | --------| | Python coding norms | ![Status](https://github.com/danholdaway/eva/actions/workflows/codestyle.yml/badge.svg) | +### Licence: + +(C) Copyright 2021-2022 United States Government as represented by the Administrator of the National +Aeronautics and Space Administration. All Rights Reserved. + +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + + ### Installation The eva package can be installed using pip: @@ -14,30 +23,28 @@ The eva package can be installed using pip: ### Usage -eva uses a strict dictionary only API. This ensures flexible use for different applications. The most straightforward use is achieved by choosing the Class containing the diagnostic and passing a yaml configuration file: - - eva ObsCorrelationScatter obs_correlation.yaml - -Alternatively you can invoke eva passing the class name within the yaml configuration: +eva uses a strict dictionary/configuration only API. This ensures flexible use for different applications. The most straightforward use of eva is achieved by passing it a YAML configuration file on the command line: eva obs_correlation.yaml -Where the yaml must contain a list of the diagnostics to be used: +Where the YAML must contain a list of the diagnostics to be used in the following format: ``` -applications: - - application name: ObsCorrelationScatter +diagnostics: + - diagnostic name: ObsCorrelationScatter ... - - application name: ObsMapScatter + - diagnostic name: ObsMapScatter ... ``` -eva can also be invoked from another Python module that passes a dictionary, rather than a Yaml file that is later translated into a dictionary. This is achieved as follows: +eva can also be invoked from another Python module that passes it a dictionary, rather than a YAML file that is later translated into a dictionary. This is achieved as follows: ``` -from eva.base import create_and_run +from eva.eva_base import eva -create_and_run("ObsCorrelationScatter", eva_dict) +eva(eva_dict) ``` -Note that this also allows for use of eva within Jupyter notebooks. +The dictionary must take the same hierarchy as shown above in the YAML file, i.e. with a list of diagnostics to be run. Note that the calling routine can still pass a string with a path to the YAML file if so desired. + +Note that eva can also be invoked within Jupyter notebooks using the above API. diff --git a/requirements.txt b/requirements.txt index 746950df..4e48ca34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ pyyaml>=5.4 pycodestyle>=2.8.0 netCDF4>=1.5.7 matplotlib>=3.4.3 +cartopy==0.19.0.post1 +scikit-learn>=1.0.2 diff --git a/setup.py b/setup.py index 9bf2560b..4a355219 100644 --- a/setup.py +++ b/setup.py @@ -36,10 +36,18 @@ 'pycodestyle>=2.8.0', 'netCDF4>=1.5.7', 'matplotlib>=3.4.3', + 'cartopy==0.19.0.post1', + 'scikit-learn>=1.0.2', ], + package_data={ + '': [ + 'tests/config/*', + 'tests/data/*', + ], + }, entry_points={ 'console_scripts': [ - 'eva = eva.base:main', + 'eva = eva.eva_base:main', ], }, ) diff --git a/src/eva/base.py b/src/eva/base.py deleted file mode 100644 index 079087e8..00000000 --- a/src/eva/base.py +++ /dev/null @@ -1,179 +0,0 @@ -#!/usr/bin/env python - -# (C) Copyright 2021-2022 United States Government as represented by the Administrator of the -# National Aeronautics and Space Administration. All Rights Reserved. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. - -# imports -from abc import ABC, abstractmethod -import argparse -import importlib -import os -import sys -import yaml - -# local imports -from eva.utilities.logger import Logger -from eva.utilities.utils import camelcase_to_underscore - - -# -------------------------------------------------------------------------------------------------- - - -class Config(dict): - - def __init__(self, dict_or_yaml): - - # Program can recieve a dictionary or a yaml file - if type(dict_or_yaml) is dict: - config = dict_or_yaml - else: - with open(dict_or_yaml, 'r') as ymlfile: - config = yaml.safe_load(ymlfile) - - # Initialize the parent class with the config - super().__init__(config) - - -# -------------------------------------------------------------------------------------------------- - - -class Base(ABC): - - # Base class constructor - def __init__(self, eva_class_name, config, logger): - - print("\nInitializing eva with the following parameters:") - print(" Diagnostic: ", eva_class_name) - print(" Configuration: ", config) - - # Create message logger - # --------------------- - if logger is None: - self.logger = Logger(eva_class_name) - else: - self.logger = logger - - # Create a configuration object - # ----------------------------- - self.config = Config(config) - - @abstractmethod - def execute(self): - ''' - Each class must implement this method and it is where it will do all of its work. - ''' - pass - - -# -------------------------------------------------------------------------------------------------- - - -class Factory(): - - def create_object(self, eva_class_name, config, logger): - - # Convert capitilized string to one with underscores - eva_module_name = camelcase_to_underscore(eva_class_name) - - # Import class based on user selected task - eva_class = getattr(importlib.import_module("eva."+eva_module_name), eva_class_name) - - # Return implementation of the class (calls base class constructor that is above) - return eva_class(eva_class_name, config, logger) - - -# -------------------------------------------------------------------------------------------------- - - -def create_and_run(eva_class_name, config, logger=None): - - ''' - Given a class name and a config this method will create an object of the class name and execute - the diagnostic defined therein. The config will determine how the diagnostic behaves. The - config can be passed in using a path to the Yaml file or an already parsed dictionary. - - Args: - eva_class_name : (str) Name of the class to be instantiated - config : (str or dictionary) configuation that will guide the diagnostic - ''' - - # Create the diagnostic object - creator = Factory() - eva_object = creator.create_object(eva_class_name, config, logger) - - # Execute the diagnostic - eva_object.execute() - - -# -------------------------------------------------------------------------------------------------- - - -def loop_and_create_and_run(config): - - # Create dictionary from the input file - with open(config, 'r') as ymlfile: - app_dict = yaml.safe_load(ymlfile) - - # Get the list of applications - try: - apps = app_dict['applications'] - except Exception: - print('ABORT: When running standalone the input config must contain \'applications\' as ' + - 'a list') - sys.exit("ABORT") - - # Loop over the applications and run - for app in apps: - app_name = app['application name'] - create_and_run(app_name, app) - - -# -------------------------------------------------------------------------------------------------- - - -def main(): - - # Arguments - # --------- - parser = argparse.ArgumentParser() - parser.add_argument('args', nargs='+', type=str, help='Application name [optional] followed ' + - 'by the configuration file [madatory]. E.g. eva ObsCorrelationScatter ' + - 'conf.yaml') - - args = parser.parse_args() - args_list = args.args - - # Make sure only 1 or 2 arguments are present - assert len(args_list) <= 2, "The maximum number of arguments is two." - - # Check the file exists - # --------------------- - if len(args_list) == 2: - application = args_list[0] - config_in = args_list[1] - else: - application = None - config_in = args_list[0] - - assert os.path.exists(config_in), "File " + config_in + "not found" - - # Run application or determine application(s) to run from config. - if application is not None: - # User specifies e.g. eva ObsCorrelationScatter ObsCorrelationScatterDriver.yaml - create_and_run(application, config_in) - else: - # User specifies e.g. eva ObsCorrelationScatter ObsCorrelationScatterDriver.yaml - loop_and_create_and_run(config_in) - - -# -------------------------------------------------------------------------------------------------- - - -if __name__ == "__main__": - main() - - -# -------------------------------------------------------------------------------------------------- diff --git a/src/eva/diagnostics/__init__.py b/src/eva/diagnostics/__init__.py new file mode 100644 index 00000000..ac1c0bc4 --- /dev/null +++ b/src/eva/diagnostics/__init__.py @@ -0,0 +1,9 @@ +# (C) Copyright 2021-2022 United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + +import os + +repo_directory = os.path.dirname(__file__) diff --git a/src/eva/obs_correlation_scatter.py b/src/eva/diagnostics/obs_correlation_scatter.py similarity index 81% rename from src/eva/obs_correlation_scatter.py rename to src/eva/diagnostics/obs_correlation_scatter.py index b6d650d2..4af16d45 100644 --- a/src/eva/obs_correlation_scatter.py +++ b/src/eva/diagnostics/obs_correlation_scatter.py @@ -1,3 +1,5 @@ +# (C) Copyright 2021-2022 NOAA/NWS/EMC +# # (C) Copyright 2021-2022 United States Government as represented by the Administrator of the # National Aeronautics and Space Administration. All Rights Reserved. # @@ -8,10 +10,11 @@ # -------------------------------------------------------------------------------------------------- -from eva.base import Base +from eva.eva_base import EvaBase from eva.utilities import ioda_definitions from eva.utilities import ioda_netcdf_api -from eva.plot_tools.scatter_correlation import scatter_correlation_plot +from eva.plot_tools.figure import CreatePlot, CreateFigure +from eva.plot_tools.plots import Scatter, LinePlot import netCDF4 import numpy as np @@ -23,7 +26,7 @@ # TODO: needs to be ioda-erized and r2d2-erized -class ObsCorrelationScatter(Base): +class ObsCorrelationScatter(EvaBase): def execute(self): @@ -179,9 +182,30 @@ def execute(self): plot_title = platform_long_name + ' | ' + variable_name_no_ # Create the plot - scatter_correlation_plot(data_ref, data_exp, ref_metric_long_name, - exp_metric_long_name, plot_title, output_file, - marker_size=marker_size) + # set up the scatter layer + scatter = Scatter(data_ref, data_exp) + scatter.markersize = marker_size + scatter.color = 'blue' + data_min = min(min(data_ref), min(data_exp)) + data_max = max(max(data_ref), max(data_exp)) + data_diff = data_max - data_min + plotmin = data_min - (0.1 * data_diff) + plotmax = data_max + (0.1 * data_diff) + # add a 1:1 line layer + line = LinePlot([plotmin, plotmax], [plotmin, plotmax]) + line.color = 'black' + # set up the plot + plot = CreatePlot(plot_layers=[line, scatter]) + plot.add_title(plot_title) + plot.set_xlim([plotmin, plotmax]) + plot.set_ylim([plotmin, plotmax]) + plot.add_xlabel(ref_metric_long_name) + plot.add_ylabel(exp_metric_long_name) + # create the figure + fig = CreateFigure(figsize=(8, 8)) + fig.plot_list = [plot] + fig.create_figure() + fig.save_figure(output_file) # Close files fh_exp.close() diff --git a/src/eva/eva_base.py b/src/eva/eva_base.py new file mode 100644 index 00000000..d645f7ec --- /dev/null +++ b/src/eva/eva_base.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python + +# (C) Copyright 2021-2022 NOAA/NWS/EMC +# +# (C) Copyright 2021-2022 United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + +# imports +from abc import ABC, abstractmethod +import argparse +import importlib +import os +import sys +import yaml + +# local imports +from eva.eva_path import return_eva_path +from eva.utilities.logger import Logger +from eva.utilities.utils import camelcase_to_underscore + + +# -------------------------------------------------------------------------------------------------- +def load_yaml_file(eva_config, logger): + # utility function to help load a yaml file into a dict. + + if logger is None: + logger = Logger('EvaSetup') + + try: + with open(eva_config, 'r') as eva_config_opened: + eva_dict = yaml.safe_load(eva_config_opened) + except Exception as e: + logger.abort('Eva diagnostics is expecting a valid yaml file, but it encountered ' + + f'errors when attempting to load: {eva_config}, error: {e}') + + return eva_dict + + +class Config(dict): + + def __init__(self, dict_or_yaml): + + logger = Logger('EvaSetup') + + # Program can recieve a dictionary or a yaml file + if isinstance(dict_or_yaml, dict): + config = dict_or_yaml + else: + config = load_yaml_file(dict_or_yaml, logger) + + # Initialize the parent class with the config + super().__init__(config) + + +# -------------------------------------------------------------------------------------------------- + + +class EvaBase(ABC): + + # Base class constructor + def __init__(self, eva_class_name, config, eva_logger): + + # Replace logger + # -------------- + if eva_logger is None: + self.logger = Logger(eva_class_name) + else: + self.logger = eva_logger + + self.logger.info(" Initializing eva with the following parameters:") + self.logger.info(" Diagnostic: " + eva_class_name) + self.logger.info(" ") + + # Create a configuration object + # ----------------------------- + self.config = Config(config) + + @abstractmethod + def execute(self): + ''' + Each class must implement this method and it is where it will do all of its work. + ''' + pass + + +# -------------------------------------------------------------------------------------------------- + + +class EvaFactory(): + + def create_eva_object(self, eva_class_name, config, eva_logger): + + # Create temporary logger + logger = Logger('EvaFactory') + + # Convert capitilized string to one with underscores + # -------------------------------------------------- + eva_module_name = camelcase_to_underscore(eva_class_name) + + # Check user provided class name against valid tasks + # -------------------------------------------------- + # List of diagnostics in directory + valid_diagnostics = os.listdir(os.path.join(return_eva_path(), 'diagnostics')) + # Remove files like __* + valid_diagnostics = [vd for vd in valid_diagnostics if '__' not in vd] + # Remove trailing .py + valid_diagnostics = [vd.replace(".py", "") for vd in valid_diagnostics] + # Abort if not found + if (eva_module_name not in valid_diagnostics): + logger.abort('Expecting to find a class called in ' + eva_class_name + ' in a file ' + + 'called ' + os.path.join(return_eva_path(), 'diagnostics', eva_module_name) + + '.py but no such file was found.') + + # Import class based on user selected task + # ---------------------------------------- + try: + eva_class = getattr(importlib.import_module("eva.diagnostics."+eva_module_name), + eva_class_name) + except Exception as e: + logger.abort('Expecting to find a class called in ' + eva_class_name + ' in a file ' + + 'called ' + os.path.join(return_eva_path(), 'diagnostics', eva_module_name) + + '.py but no such class was found or an error occurred.') + + # Return implementation of the class (calls base class constructor that is above) + # ------------------------------------------------------------------------------- + return eva_class(eva_class_name, config, eva_logger) + + +# -------------------------------------------------------------------------------------------------- + + +def eva(eva_config, eva_logger=None): + + # Create temporary logger + logger = Logger('EvaSetup') + + # Convert incoming config (either dictionary or file) to dictionary + if isinstance(eva_config, dict): + eva_dict = eva_config + else: + # Create dictionary from the input file + eva_dict = load_yaml_file(eva_config, logger) + + # Get the list of applications + try: + diagnostic_configs = eva_dict['diagnostics'] + except KeyError: + logger.abort('eva configuration must contain \'diagnostics\' and it should provide a ' + + 'list of diagnostics to be run.') + + # Loop over the applications and run + for diagnostic_config in diagnostic_configs: + + # Extract name for this diagnostic + eva_class_name = diagnostic_config['diagnostic name'] + + # Create the diagnostic object + creator = EvaFactory() + eva_object = creator.create_eva_object(eva_class_name, diagnostic_config, eva_logger) + + # Run the diagnostic + eva_object.execute() + + +# -------------------------------------------------------------------------------------------------- + + +def main(): + + # Arguments + # --------- + parser = argparse.ArgumentParser() + parser.add_argument('config_file', type=str, help='Configuration YAML file for driving ' + + 'the diagnostic. See documentation/examples for how to configure the YAML.') + + # Get the configuation file + args = parser.parse_args() + config_file = args.config_file + + assert os.path.exists(config_file), "File " + config_file + " not found" + + # Run the diagnostic(s) + eva(config_file) + + +# -------------------------------------------------------------------------------------------------- + + +if __name__ == "__main__": + main() + + +# -------------------------------------------------------------------------------------------------- diff --git a/src/eva/eva_path.py b/src/eva/eva_path.py new file mode 100644 index 00000000..dfc8f1f2 --- /dev/null +++ b/src/eva/eva_path.py @@ -0,0 +1,21 @@ +# (C) Copyright 2021-2022 United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + + +# -------------------------------------------------------------------------------------------------- + + +import os + + +# -------------------------------------------------------------------------------------------------- + + +def return_eva_path(): + return os.path.split(__file__)[0] + + +# -------------------------------------------------------------------------------------------------- diff --git a/src/eva/plot_tools/figure.py b/src/eva/plot_tools/figure.py new file mode 100644 index 00000000..bceced2a --- /dev/null +++ b/src/eva/plot_tools/figure.py @@ -0,0 +1,660 @@ +# This work developed by NOAA/NWS/EMC under the Apache 2.0 license. +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import cartopy.crs as ccrs +import cartopy.feature as cfeature +from scipy.interpolate import interpn +from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter +from eva.plot_tools.maps import Domain, MapProjection +from eva.utilities.stats import get_linear_regression + +__all__ = ['CreateFigure', 'CreatePlot'] + + +class CreateFigure: + + def __init__(self, nrows=1, ncols=1, figsize=(8, 6), + sharex=False, sharey=False): + + self.nrows = nrows + self.ncols = ncols + self.figsize = figsize + self.sharex = sharex + self.sharey = sharey + self.plot_list = [] + + def save_figure(self, filepath, **kwargs): + """ + Method to save figure to file + """ + self.fig.savefig(filepath, **kwargs) + + def create_figure(self): + """ + Driver method to create figure and subplots. + """ + + # Check to make sure plot_list == nrows*ncols + if len(self.plot_list) != self.nrows*self.ncols: + raise ValueError( + 'Number of plots does not match the number inputted rows' + 'and columns.') + + plot_dict = { + 'scatter': self._scatter, + 'histogram': self._histogram, + 'line_plot': self._lineplot, + 'vertical_line': self._verticalline, + 'horizontal_line': self._horizontalline, + 'bar_plot': self._barplot, + 'horizontal_bar': self._hbar, + 'map_scatter': self._map_scatter, + 'map_gridded': self._map_gridded, + 'map_contour': self._map_contour + } + + gs = gridspec.GridSpec(self.nrows, self.ncols) + self.fig = plt.figure(figsize=self.figsize) + + for i, plot_obj in enumerate(self.plot_list): + + # check if object has projection and domain attributes to determine ax + if hasattr(plot_obj, 'projection') and hasattr(plot_obj, 'domain'): + self.domain = Domain(plot_obj.domain) + self.projection = MapProjection(plot_obj.projection) + + # Set up axis specific things + ax = plt.subplot(gs[i], projection=self.projection.projection) + if str(self.projection) not in ['npstere', 'spstere']: + ax.set_extent(self.domain.extent) + if str(self.projection) not in ['lamconf']: + ax.set_xticks(self.domain.xticks, crs=ccrs.PlateCarree()) + ax.set_yticks(self.domain.yticks, crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=False) + lat_formatter = LatitudeFormatter() + ax.xaxis.set_major_formatter(lon_formatter) + ax.yaxis.set_major_formatter(lat_formatter) + + else: + ax = plt.subplot(gs[i]) + + # Loop through plot layers + for layer in plot_obj.plot_layers: + plot_dict[layer.plottype](layer, ax) + + # loop through all keys in an object and then call approriate + # method to plot the feature on the axis + for feat in vars(plot_obj).keys(): + self._plot_features(plot_obj, feat, ax) + + if self.sharex: + self._sharex(ax) + if self.sharey: + self._sharey(ax) + + gs.tight_layout(self.fig) + + def add_suptitle(self, text, **kwargs): + """ + Add super title to figure. Useful for subplots. + """ + if hasattr(self, 'fig'): + self.fig.suptitle(text, **kwargs) + + def _plot_features(self, plot_obj, feature, ax): + + feature_dict = { + 'title': self._plot_title, + 'xlabel': self._plot_xlabel, + 'ylabel': self._plot_ylabel, + 'colorbar': self._plot_colorbar, + 'stats': self._plot_stats, + 'legend': self._plot_legend, + 'text': self._plot_text, + 'grid': self._plot_grid, + 'xlim': self._set_xlim, + 'ylim': self._set_ylim, + 'xticks': self._set_xticks, + 'yticks': self._set_yticks, + 'xticklabels': self._set_xticklabels, + 'yticklabels': self._set_yticklabels, + 'invert_xaxis': self._invert_xaxis, + 'invert_yaxis': self._invert_yaxis, + 'yscale': self._set_yscale, + 'map_features': self._add_map_features + } + + if feature in feature_dict: + feature_dict[feature](ax, vars(plot_obj)[feature]) + + def _map_scatter(self, plotobj, ax): + + if plotobj.data is None: + skipvars = ['plottype', 'longitude', 'latitude', + 'markersize'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + cs = ax.scatter(plotobj.longitude, plotobj.latitude, + s=plotobj.markersize, **inputs, + transform=self.projection.projection) + else: + skipvars = ['plottype', 'longitude', 'latitude', + 'data', 'markersize'] + inputs = self._get_inputs_dict(skipvars, plotobj) + cs = ax.scatter(plotobj.longitude, plotobj.latitude, + c=plotobj.data, s=plotobj.markersize, + **inputs, transform=self.projection.projection) + if plotobj.colorbar: + self.cs = cs + + def _map_gridded(self, plotobj, ax): + + skipvars = ['plottype', 'longitude', 'latitude', + 'markersize'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + cs = ax.pcolormesh(plotobj.latitude, plotobj.longitude, + plotobj.data, **inputs, + transform=self.projection.projection) + + if plotobj.colorbar: + self.cs = cs + + def _map_contour(self, plotobj, ax): + + skipvars = ['plottype', 'longitude', 'latitude', + 'markersize'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + cs = ax.contour(plotobj.longitude, plotobj.latitude, + plot.data, **inputs, + transform=self.projection.projection) + + if plotobj.clabel: + plt.clabel(cs, levels=plotobj.levels, use_clabeltext=True) + + if plotobj.colorbar: + self.cs = cs + + def _density_scatter(self, plotobj, ax): + """ + Uses Scatter Object to plot density scatter colored by + 2d histogram. + """ + _idx = np.logical_and(~np.isnan(plotobj.x), ~np.isnan(plotobj.y)) + data, x_e, y_e = np.histogram2d(plotobj.x[_idx], plotobj.y[_idx], + bins=plotobj.density['bins'], + density=not plotobj.density['nsamples']) + if plotobj.density['nsamples']: + # compute percentage of total for each bin + data = data / np.count_nonzero(_idx) * 100. + z = interpn((0.5*(x_e[1:] + x_e[:-1]), 0.5*(y_e[1:]+y_e[:-1])), + data, np.vstack([plotobj.x, plotobj.y]).T, + method=plotobj.density['interp'], bounds_error=False) + # To be sure to plot all data + z[np.where(np.isnan(z))] = 0.0 + # Sort the points by density, so that the densest + # points are plotted last + if plotobj.density['sort']: + idx = z.argsort() + x, y, z = plotobj.x[idx], plotobj.y[idx], z[idx] + cs = ax.scatter(x, y, c=z, + s=plotobj.markersize, + cmap=plotobj.density['cmap'], + label=plotobj.label) + # below doing nothing? fix/remove in subsequent PR? + # norm = Normalize(vmin=np.min(z), vmax=np.max(z)) + + if plotobj.density['colorbar']: + self.cs = cs + + def _scatter(self, plotobj, ax): + """ + Uses Scatter object to plot on axis. + """ + # checks to see if density attribute is True + if hasattr(plotobj, 'density'): + self._density_scatter(plotobj, ax) + else: + skipvars = ['plottype', 'plot_ax', 'x', 'y', + 'markersize', 'linear_regression', + 'density'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + s = ax.scatter(plotobj.x, plotobj.y, s=plotobj.markersize, + **inputs) + + # checks to see if linear regression attribute + if hasattr(plotobj, 'linear_regression'): + y_pred, r_sq, intercept, slope = get_linear_regression(plotobj.x, + plotobj.y) + label = f"y = {slope:.4f}x + {intercept:.4f}\nR\u00b2 : {r_sq:.4f}" + ax.plot(plotobj.x, y_pred, **plotobj.linear_regression) + + def _lineplot(self, plotobj, ax): + """ + Uses LinePlot object to plot on axis. + """ + skipvars = ['plottype', 'plot_ax', 'x', 'y'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + ax.plot(plotobj.x, plotobj.y, **inputs) + + def _histogram(self, plotobj, ax): + """ + Uses Histogram object to plot on axis. + """ + skipvars = ['plottype', 'plot_ax', 'data'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + ax.hist(plotobj.data, **inputs) + + def _verticalline(self, plotobj, ax): + """ + Uses VerticalLine object to plot on axis. + """ + skipvars = ['plottype', 'plot_ax', 'x'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + ax.axvline(plotobj.x, **inputs) + + def _horizontalline(self, plotobj, ax): + """ + Uses HorizontalLine object to plot on axis. + """ + skipvars = ['plottype', 'plot_ax', 'y'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + ax.axhline(plotobj.y, **inputs) + + def _barplot(self, plotobj, ax): + """ + Uses BarPlot object to plot on axis. + """ + skipvars = ['plottype', 'plot_ax', 'x', 'height'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + ax.bar(plotobj.x, plotobj.height, **inputs) + + def _hbar(self, plotobj, ax): + """ + Uses HorizontalBar object to plot on axis. + """ + skipvars = ['plottype', 'plot_ax', 'y', 'width'] + inputs = self._get_inputs_dict(skipvars, plotobj) + + ax.barh(plotobj.y, plotobj.width, **inputs) + + def _get_inputs_dict(self, skipvars, plotobj): + """ + Creates dictionary for plot inputs. Skips variables + in 'skipvars' list. + """ + inputs = {} + for v in [v for v in vars(plotobj) if v not in skipvars]: + inputs[v] = vars(plotobj)[v] + + return inputs + + def _plot_title(self, ax, title): + """ + Add title on specified ax. + """ + ax.set_title(**title) + + def _plot_xlabel(self, ax, xlabel): + """ + Add xlabel on specified ax. + """ + ax.set_xlabel(**xlabel) + + def _plot_ylabel(self, ax, ylabel): + """ + Add ylabel on specified ax. + """ + ax.set_ylabel(**ylabel) + + def _plot_colorbar(self, ax, colorbar): + """ + Add colorbar on specified ax or for total figure. + """ + + if hasattr(self, 'cs'): + if colorbar['single_cbar']: + # IMPORTANT NOTICE #### + # If using single colorbar option, this method grabs the color + # series from the subplot that is in last row and column. It + # is important to note that if comparing multiple subplots with + # the same colorbar, the vmin and vmax should all be the same to + # avoid comparison errors. + if ax.is_last_row() and ax.is_last_col(): + cbar_ax = self.fig.add_axes(colorbar['cbar_loc']) + cb = self.fig.colorbar(self.cs, cax=cbar_ax, **colorbar['kwargs']) + + else: + cb = self.fig.colorbar(self.cs, ax=ax, + **colorbar['kwargs']) + # Add labels + cb.set_label(colorbar['label'], fontsize=colorbar['fontsize']) + + def _plot_stats(self, ax, stats): + """ + Add annotated stats on specified ax. + """ + # loop through the dictionary and create the sting to annotate + outstr = '' + for key, value in stats['stats'].items(): + outstr = outstr + f'{key}: {value} ' + + ax.annotate(outstr, xy=(stats['xloc'], stats['yloc']), + xycoords='axes fraction', ha=stats['ha'], + **stats['kwargs']) + + def _plot_legend(self, ax, legend): + """ + Add legend on specified ax. + """ + leg = ax.legend(**legend) + + for i, key in enumerate(leg.legendHandles): + leg.legendHandles[i]._sizes = [20] + + def _plot_text(self, ax, text): + """ + Add text on specified ax. + """ + ax.text(text['xloc'], text['yloc'], + text['text'], **text['kwargs']) + + def _plot_grid(self, ax, grid): + """ + Add grid on specified ax. + """ + try: + ax.gridlines(crs=ccrs.PlateCarree(), **grid) + except AttributeError: + ax.grid(**grid) + + def _set_xlim(self, ax, xlim): + """ + Set x-limits on specified ax. + """ + ax.set_xlim(**xlim) + + def _set_ylim(self, ax, ylim): + """ + Set y-limits on specified ax. + """ + ax.set_ylim(**ylim) + + def _set_xticks(self, ax, xticks): + """ + Set x-ticks on specified ax. + """ + try: + ax.set_xticks(**xticks, crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=True) + lat_formatter = LatitudeFormatter() + ax.xaxis.set_major_formatter(lon_formatter) + ax.yaxis.set_major_formatter(lat_formatter) + except AttributeError: + ax.set_xticks(**xticks) + + def _set_yticks(self, ax, yticks): + """ + Set y-ticks on specified ax. + """ + try: + ax.set_yticks(**yticks, crs=ccrs.PlateCarree()) + except AttributeError: + ax.set_yticks(**yticks) + + def _set_xticklabels(self, ax, xticklabels): + """ + Set x-tick labels on specified ax. + """ + if len(xticklabels['labels']) == len(ax.get_xticks()): + ax.set_xticklabels(xticklabels['labels'], + **xticklabels['kwargs']) + + else: + raise ValueError('Len of xtick labels does not equal ' + + 'len of xticks. Set xticks appropriately ' + + 'or change labels to be len of xticks.') + + def _set_yticklabels(self, ax, yticklabels): + """ + Set y-tick labels on specified ax. + """ + if len(yticklabels['labels']) == len(ax.get_yticks()): + ax.set_yticklabels(yticklabels['labels'], + **yticklabels['kwargs']) + + else: + raise ValueError('Len of ytick labels does not equal ' + + 'len of yticks. Set yticks appropriately ' + + 'or change labels to be len of yticks.') + + def _invert_xaxis(self, ax, invert_xaxis): + """ + Invert x-axis on specified ax. + """ + if invert_xaxis: + ax.invert_xaxis() + + def _invert_yaxis(self, ax, invert_yaxis): + """ + Invert y-axis on specified ax. + """ + if invert_yaxis: + ax.invert_yaxis() + + def _set_yscale(self, ax, yscale): + """ + Set y-scale on specified ax. + """ + ax.set_yscale(yscale) + + def _sharex(self, ax): + """ + If sharex axis is True, will find where to hide xticklabels. + """ + if not ax.is_last_row(): + plt.setp(ax.get_xticklabels(), visible=False) + + def _sharey(self, ax): + """ + If sharey axis is True, will find where to hide yticklabels. + """ + if not ax.is_first_col(): + plt.setp(ax.get_yticklabels(), visible=False) + + def _add_map_features(self, ax, map_features): + """ + Factory to add map features. + """ + feature_dict = { + 'coastline': cfeature.COASTLINE, + 'borders': cfeature.BORDERS, + 'states': cfeature.STATES, + 'lakes': cfeature.LAKES, + 'rivers': cfeature.RIVERS, + 'land': cfeature.LAND, + 'ocean': cfeature.OCEAN + } + + for feat in map_features: + try: + ax.add_feature(feature_dict[feat]) + except KeyError: + raise TypeError(f'{feat} is not a valid map feature.' + + 'Current map features supported are:\n' + + f'{" | ".join(feature_dict.keys())}"') + + +class CreatePlot(): + """ + Creates a figure to plot data as a scatter plot, + histogram, or line plot. + """ + def __init__(self, plot_layers=[], projection=None, + domain=None): + + self.plot_layers = plot_layers + + ############################################### + # Need a better way of doing this + if projection is not None and domain is not None: + self.projection = projection + self.domain = domain + ############################################### + + def add_title(self, label, loc='center', + pad=None, **kwargs): + + self.title = { + 'label': label, + 'loc': loc, + 'pad': pad, + **kwargs + } + + def add_xlabel(self, xlabel, labelpad=None, + loc='center', **kwargs): + + self.xlabel = { + 'xlabel': xlabel, + 'labelpad': labelpad, + 'loc': loc, + **kwargs + } + + def add_ylabel(self, ylabel, labelpad=None, + loc='center', **kwargs): + + self.ylabel = { + 'ylabel': ylabel, + 'labelpad': labelpad, + 'loc': loc, + **kwargs + } + + def add_colorbar(self, label=None, fontsize=12, single_cbar=False, + cbar_location=None, **kwargs): + + kwargs.setdefault('orientation', 'horizontal') + + pad = 0.15 if kwargs['orientation'] == 'horizontal' else 0.1 + fraction = 0.065 if kwargs['orientation'] == 'horizontal' else 0.085 + + kwargs.setdefault('pad', pad) + kwargs.setdefault('fraction', fraction) + + if not cbar_location: + h_loc = [0.14, -0.1, 0.8, 0.04] + v_loc = [1.02, 0.12, 0.04, 0.8] + cbar_location = h_loc if kwargs['orientation'] == 'horizontal' else v_loc + + self.colorbar = { + 'label': label, + 'fontsize': fontsize, + 'single_cbar': single_cbar, + 'cbar_loc': cbar_location, + 'kwargs': kwargs + } + + def add_stats_dict(self, stats_dict={}, xloc=0.5, + yloc=-0.1, ha='center', **kwargs): + + self.stats = { + 'stats': stats_dict, + 'xloc': xloc, + 'yloc': yloc, + 'ha': ha, + 'kwargs': kwargs + } + + def add_legend(self, **kwargs): + + self.legend = { + **kwargs + } + + def add_text(self, xloc, yloc, text, **kwargs): + + self.text = { + 'xloc': xloc, + 'yloc': yloc, + 'text': text, + 'kwargs': kwargs + } + + def add_grid(self, **kwargs): + + self.grid = { + **kwargs + } + + def add_map_features(self, feature_list=['coastline']): + + self.map_features = feature_list + + def set_xlim(self, left=None, right=None): + + self.xlim = { + 'left': left, + 'right': right + } + + def set_ylim(self, bottom=None, top=None): + + self.ylim = { + 'bottom': bottom, + 'top': top + } + + def set_xticks(self, ticks=list(), minor=False): + + self.xticks = { + 'ticks': ticks, + 'minor': minor + } + + def set_yticks(self, ticks=list(), minor=False): + + self.yticks = { + 'ticks': ticks, + 'minor': minor + } + + def set_xticklabels(self, labels=list(), **kwargs): + + self.xticklabels = { + 'labels': labels, + 'kwargs': kwargs + } + + def set_yticklabels(self, labels=list(), **kwargs): + + self.yticklabels = { + 'labels': labels, + 'kwargs': kwargs + } + + def invert_xaxis(self): + + self.invert_xaxis = True + + def invert_yaxis(self): + + self.invert_yaxis = True + + def set_yscale(self, scale): + + valid_scales = ['log', 'linear', 'symlog', 'logit'] + if scale not in valid_scales: + raise ValueError(f'requested scale {scale} is invalid. Valid ' + f'choices are: {" | ".join(valid_scales)}') + + self.yscale = scale diff --git a/src/eva/plot_tools/maps.py b/src/eva/plot_tools/maps.py new file mode 100644 index 00000000..934e727d --- /dev/null +++ b/src/eva/plot_tools/maps.py @@ -0,0 +1,376 @@ +# This work developed by NOAA/NWS/EMC under the Apache 2.0 license. +import cartopy.crs as ccrs + + +class Domain: + + def __init__(self, domain='global', dd=dict()): + """ + Class constructor that stores extent, xticks, and + yticks for the domain given. + Args: + domain : (str; default='global') domain name to grab info + dd : (dict) dictionary to add custom xticks, yticks + """ + domain = domain.lower() + + map_domains = { + "global": self._global, + "north america": self._north_america, + "europe": self._europe, + "conus": self._conus, + "northeast": self._northeast, + "mid atlantic": self._mid_atlantic, + "southeast": self._southeast, + "ohio valley": self._ohio_valley, + "upper midwest": self._upper_midwest, + "north central": self._north_central, + "central": self._central, + "south central": self._south_central, + "northwest": self._northwest, + "colorado": self._colorado, + "boston nyc": self._boston_nyc, + "sf bay area": self._sf_bay_area, + "la vegas": self._la_vegas, + "custom": self._custom + } + + try: + map_domains[domain](dd=dd) + except KeyError: + raise TypeError(f'{domain} is not a valid domain.' + + 'Current domains supported are:\n' + + f'{" | ".join(map_domains.keys())}"') + + def _global(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a global domain. + """ + self.extent = (-180, 180, -90, 90) + self.xticks = dd.get('xticks', (-180, -120, -60, + 0, 60, 120, 180)) + self.yticks = dd.get('yticks', (-90, -60, -30, 0, + 30, 60, 90)) + + def _north_america(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a north american domain. + """ + self.extent = (-170, -50, 7.5, 75) + self.xticks = dd.get('xticks', (-170, -150, -130, -110, + -90, -70, -50)) + self.yticks = dd.get('yticks', (10, 30, 50, 70)) + + self.cenlon = dd.get('cenlon', -100) + self.cenlat = dd.get('cenlat', 41.25) + + def _conus(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a contiguous United States domain. + """ + self.extent = (-125.5, -63.5, 20, 51) + self.xticks = dd.get('xticks', (-125.5, -110, -94.5, + -79, -63.5)) + self.yticks = dd.get('yticks', (20, 27.5, 35, 42.5, 50)) + + self.cenlon = dd.get('cenlon', -94.5) + self.cenlat = dd.get('cenlat', 35.5) + + def _northeast(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Northeast region of U.S. + """ + self.extent = (-80, -66.5, 40, 48) + self.xticks = dd.get('xticks', (-80, -75.5, -71, -66.5)) + self.yticks = dd.get('yticks', (40, 42, 44, 46, 48)) + + self.cenlon = dd.get('cenlon', -76) + self.cenlat = dd.get('cenlat', 44) + + def _mid_atlantic(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Mid Atlantic region of U.S. + """ + self.extent = (-82, -73, 36.5, 42.5) + self.xticks = dd.get('xticks', (-82, -79, -76, -73)) + self.yticks = dd.get('yticks', (36.5, 38.5, 40.5, 42.5)) + + self.cenlon = dd.get('cenlon', -79) + self.cenlat = dd.get('cenlat', 36.5) + + def _southeast(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Southeast region of U.S. + """ + self.extent = (-92, -75, 24, 37) + self.xticks = dd.get('xticks', (-92, -87.75, -83.5, -79.25, -75)) + self.yticks = dd.get('yticks', (24, 27.25, 30.5, 33.75, 37)) + + self.cenlon = dd.get('cenlon', -89) + self.cenlat = dd.get('cenlat', 30.5) + + def _ohio_valley(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for an Ohio Valley region of U.S. + """ + self.extent = (-91.5, -80, 34.5, 43) + self.xticks = dd.get('xticks', (-91.5, -85.75, -80)) + self.yticks = dd.get('yticks', (34.5, 38.75, 43)) + + self.cenlon = dd.get('cenlon', -88) + self.cenlat = dd.get('cenlat', 38.75) + + def _upper_midwest(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for an Upper Midwest region of U.S. + """ + self.extent = (-97.5, -82, 40, 49.5) + self.xticks = dd.get('xticks', (-97.5, -89.75, -82)) + self.yticks = dd.get('yticks', (40, 44.75, 49.5)) + + self.cenlon = dd.get('cenlon', -92) + self.cenlat = dd.get('cenlat', 44.75) + + def _north_central(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a North Central region of U.S. + """ + self.extent = (-111.5, -94, 39, 49.5) + self.xticks = dd.get('xticks', (-111.5, -102.75, -94)) + self.yticks = dd.get('yticks', (39, 44.25, 49.5)) + + self.cenlon = dd.get('cenlon', -103) + self.cenlat = dd.get('cenlat', 44.25) + + def _central(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Central region of U.S. + """ + self.extent = (-103.5, -89, 32, 42) + self.xticks = dd.get('xticks', (-103.5, -96.25, -89)) + self.yticks = dd.get('yticks', (32, 37, 42)) + + self.cenlon = dd.get('cenlon', -99) + self.cenlat = dd.get('cenlat', 37) + + def _south_central(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a South Central region of U.S. + """ + self.extent = (-109, -88.5, 25, 37.5) + self.xticks = dd.get('xticks', (-109, -98.75, -88.5)) + self.yticks = dd.get('yticks', (25, 31.25, 37.5)) + + self.cenlon = dd.get('cenlon', -101) + self.cenlat = dd.get('cenlat', 31.25) + + def _northwest(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Northwest region of U.S. + """ + self.extent = (-125, -110, 40, 50) + self.xticks = dd.get('xticks', (-125, -117.5, -110)) + self.yticks = dd.get('yticks', (40, 45, 50)) + + self.cenlon = dd.get('cenlon', -116) + self.cenlat = dd.get('cenlat', 45) + + def _southwest(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Southwest region of U.S. + """ + self.extent = (-125, -108.5, 31, 42.5) + self.xticks = dd.get('xticks', (-125, -116.75, -108.5)) + self.yticks = dd.get('yticks', (31, 37.5, 42.5)) + + self.cenlon = dd.get('cenlon', -116) + self.cenlat = dd.get('cenlat', 36.75) + + def _colorado(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Colorado region of U.S. + """ + self.extent = (-110, -101, 35, 42) + self.xticks = dd.get('xticks', (-110, -105.5, -101)) + self.yticks = dd.get('yticks', (35, 38.5, 42)) + + self.cenlon = dd.get('cenlon', -106) + self.cenlat = dd.get('cenlat', 38.5) + + def _boston_nyc(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Boston-NYC region. + """ + self.extent = (-75.5, -69.5, 40, 43) + self.xticks = dd.get('xticks', (-75.5, -73.5, -71.5, -69.5)) + self.yticks = dd.get('yticks', (40, 41, 42, 43)) + + self.cenlon = dd.get('cenlon', -76) + self.cenlat = dd.get('cenlat', 41.5) + + def _seattle_portland(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Seattle-Portland region of U.S. + """ + self.extent = (-125, -119, 44.5, 49.5) + self.xticks = dd.get('xticks', (-125, -122, -119)) + self.yticks = dd.get('yticks', (44.5, 47, 49.5)) + + self.cenlon = dd.get('cenlon', -121) + self.cenlat = dd.get('cenlat', 47) + + def _sf_bay_area(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a San Francisco Bay area region of U.S. + """ + self.extent = (-123.5, -121, 37.25, 38.5) + self.xticks = dd.get('xticks', (-123.5, -122.25, -121)) + self.yticks = dd.get('yticks', (37.5, 38, 38.5)) + + self.cenlon = dd.get('cenlon', -121) + self.cenlat = dd.get('cenlat', 48.25) + + def _la_vegas(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Los Angeles and Las Vegas region of U.S. + """ + self.extent = (-121, -114, 32, 37) + self.xticks = dd.get('xticks', (-121, -117.5, -114)) + self.yticks = dd.get('yticks', (32, 34.5, 37)) + + self.cenlon = dd.get('cenlon', -114) + self.cenlat = dd.get('cenlat', 34.5) + + def _europe(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a European domain. + """ + self.extent = (-12.5, 40, 30, 70) + self.xticks = dd.get('xticks', (-10, 0, 10, 20, 30, 40)) + self.yticks = dd.get('yticks', (30, 40, 50, 60, 70)) + + self.cenlon = dd.get('cenlon', 25) + self.cenlat = dd.get('cenlat', 50) + + def _custom(self, dd=dict()): + """ + Sets extent, longitude xticks, and latitude yticks + for a Custom domain. + """ + try: + self.extent = dd.extent + self.xticks = dd.xticks + self.yticks = dd.yticks + + self.cenlon = dd.cenlon + self.cenlat = dd.cenlat + except AttributeError: + raise TypeError("Custom domain requires input dictionary " + + "with keys: 'extent', 'xticks', 'yticks', " + + "as tuples and 'cenlon' and 'cenlat' as floats.") + + +class MapProjection: + + def __init__(self, projection='plcarr', + cenlon=None, + cenlat=None, + globe=None): + """ + Class constructor that stores projection cartopy object + for the projection given. + Args: + projection : (str; default='plcarr') projection name to grab info + cenlon : (int, float; default=None) central longitude + cenlat : (int, float; default=None) central latitude + globe : (default=None) if ommited, creates a globe for map + """ + self.str_projection = projection + + self.cenlon = cenlon + self.cenlat = cenlat + self.globe = globe + + map_projections = { + "plcarr": self._platecarree, + "mill": self._miller, + "lambert": self._lambertconformal, + "npstere": self._npstereo, + "spstere": self._spstereo + } + + try: + map_projections[projection]() + except KeyError: + raise TypeError(f'{projection} is not a valid projection.' + + 'Current projections supported are:\n' + + f'{" | ".join(map_projections.keys())}"') + + def __str__(self): + return self.str_projection + + def _platecarree(self): + """Creates projection using PlateCarree from Cartopy.""" + self.cenlon = 0 if self.cenlon is None else self.cenlon + + self.projection = ccrs.PlateCarree(central_longitude=self.cenlon, + globe=self.globe) + + def _miller(self): + """Creates projection using Miller from Cartopy.""" + self.cenlon = 0 if self.cenlon is None else self.cenlon + + self.projection = ccrs.Miller(central_longitude=self.cenlon, + globe=self.globe) + + def _lambertconformal(self): + """Creates projection using Lambert Conformal from Cartopy.""" + + if self.cenlon is None or self.cenlat is None: + raise TypeError("Need 'cenlon' and cenlat to plot Lambert " + "Conformal projection. This projection also " + "does not work for a global domain.") + + self.projection = ccrs.LambertConformal(central_longitude=self.cenlon, + central_latitude=self.cenlat) + + def _npstereo(self): + """ + Creates projection using Orthographic from Cartopy and + orients it from central latitude 90 degrees. + """ + self.cenlon = -90 if self.cenlon is None else self.cenlon + + self.projection = ccrs.Orthographic(central_longitude=self.cenlon, + central_latitude=90, + globe=self.globe) + + def _spstereo(self): + """ + Creates projection using Orthographic from Cartopy and + orients it from central latitude -90 degrees. + """ + self.cenlon = 0 if self.cenlon is None else self.cenlon + + self.projection = ccrs.Orthographic(central_longitude=self.cenlon, + central_latitude=-90, + globe=self.globe) diff --git a/src/eva/plot_tools/plots.py b/src/eva/plot_tools/plots.py new file mode 100644 index 00000000..0ffdd625 --- /dev/null +++ b/src/eva/plot_tools/plots.py @@ -0,0 +1,294 @@ +# This work developed by NOAA/NWS/EMC under the Apache 2.0 license. +import numpy as np + +__all__ = ['Scatter', 'Histogram', 'LinePlot' 'VerticalLine', + 'HorizontalLine', 'BarPlot' 'HorizontalBar', + 'MapScatter', 'MapGridded', 'MapContour'] + + +class Scatter(): + + def __init__(self, x, y): + """ + Constructor for Scatter. + Args: + x : (array type) + y : (array type) + """ + + super().__init__() + self.plottype = 'scatter' + + self.x = x + self.y = y + + self.markersize = 5 + self.color = 'darkgray' + self.marker = 'o' + self.vmin = None + self.vmax = None + self.alpha = None + self.linewidths = 1.5 + self.edgecolors = None + self.label = f'n={np.count_nonzero(~np.isnan(x))}' + + def add_linear_regression(self): + """ + Include linear regression line info as attributes. + """ + self.linear_regression = { + 'color': 'black', + 'linewidth': 1, + 'linestyle': '-' + } + + def density_scatter(self): + """ + Include density scatter plot info as attributes. + """ + self.density = { + 'sort': True, + 'cmap': 'nipy_spectral_r', + 'colorbar': True, + 'bins': [100, 100], + 'interp': 'linear', + 'nsamples': True + } + + +class Histogram(): + + def __init__(self, data): + """ + Constructor for Histogram. + Args: + data : (array type) + """ + + super().__init__() + self.plottype = 'histogram' + + self.data = data + + self.bins = 10 + self.range = None + self.density = False + self.weights = None + self.cumulative = False + self.bottom = None + self.histtype = 'bar' + self.align = 'mid' + self.orientation = 'vertical' + self.rwidth = None + self.log = False + self.color = 'tab:blue' + self.label = f'n={np.count_nonzero(~np.isnan(data))}' + self.stacked = False + self.alpha = None + + +class LinePlot(): + + def __init__(self, x, y): + """ + Constructor for LinePlot. + Args: + x : (array type) + y : (array type) + """ + super().__init__() + self.plottype = 'line_plot' + + self.x = x + self.y = y + + self.color = 'tab:blue' + self.linestyle = '-' + self.linewidth = 1.5 + self.marker = None + self.markersize = None + self.alpha = None + self.label = None + + +class VerticalLine(): + + def __init__(self, x): + """ + Constructor for VerticalLine + Args: + x : (int/float) x-value where vertical line + is to be plotted + """ + + super().__init__() + self.plottype = 'vertical_line' + + self.x = x + + self.color = 'black' + self.linestyle = '-' + self.linewidth = 1.5 + self.label = None + + +class HorizontalLine(): + + def __init__(self, y): + """ + Constructor for HorizontalLine + Args: + y : (int/float) y-value where horizontal + line is to be plotted + """ + + super().__init__() + self.plottype = 'horizontal_line' + + self.y = y + + self.color = 'black' + self.linestyle = '-' + self.linewidth = 1.5 + self.label = None + + +class BarPlot(): + + def __init__(self, x, height): + """ + Constructor for BarPlot. + Args: + x : (array type) x coordinate of bars + height : (array type) the height(s) of the bars + """ + + super().__init__() + self.plottype = 'bar_plot' + + self.x = x + self.height = height + + self.width = 0.8 + self.bottom = 0 + self.align = 'center' + self.color = 'tab:blue' + self.edgecolor = None + self.linewidth = 0 + self.tick_label = None + self.xerr = None + self.yerr = None + self.ecolor = 'black' + self.capsize = 0 + self.error_kw = {} + self.log = False + + +class HorizontalBar(): + + def __init__(self, y, width): + """ + Constructor to create a horizontal bar plot. + Args: + y : (array type) y coordinate of bars + width : (array type) the width(s) of the bars + """ + + super().__init__() + self.plottype = 'horizontal_bar' + + self.y = y + self.width = width + + self.height = 0.8 + self.left = 0 + self.align = 'center' + self.color = 'tab:blue' + self.edgecolor = None + self.linewidth = 0 + self.tick_label = None + self.xerr = None + self.yerr = None + self.ecolor = 'black' + self.capsize = 0 + self.error_kw = {} + self.log = False + + +class MapScatter: + + def __init__(self, latitude, longitude, data=None): + """ + Constructor for MapScatter. + Args: + latitude : (array type) Latitude data + longitude : (array type) Longitude data + data : (array type; default=None) data to be plotted + """ + self.plottype = 'map_scatter' + + self.latitude = latitude + self.longitude = longitude + self.data = data + + self.marker = 'o' + self.markersize = 5 + if data is None: + self.color = 'tab:blue' + else: + self.cmap = 'viridis' + self.linewidths = 1.5 + self.edgecolors = None + self.alpha = None + self.vmin = None + self.vmax = None + self.label = None + + +class MapGridded: + + def __init__(self, latitude, longitude, data): + """ + Constructor for MapGridded. + Args: + latitude : (array type) Latitude data + longitude : (array type) Longitude data + data : (array type) data to be plotted + """ + self.plottype = 'map_gridded' + + self.latitude = latitude + self.longitude = longitude + self.data = data + + self.cmap = 'viridis' + self.vmin = None + self.vmax = None + self.alpha = None + + +class MapContour: + + def __init__(self, latitude, longitude, data): + """ + Constructor for MapScatter. + Args: + latitude : (array type) Latitude data + longitude : (array type) Longitude data + data : (array type) data to be plotted + """ + self.plottype = 'map_contour' + + self.latitude = latitude + self.longitude = longitude + self.data = data + + self.levels = None + self.clabel = False + self.colors = 'black' + self.linewidths = 1.5 + self.linestyles = '-' + self.cmap = None + self.vmin = None + self.vmax = None + self.alpha = None diff --git a/src/eva/plot_tools/scatter_correlation.py b/src/eva/plot_tools/scatter_correlation.py deleted file mode 100644 index 9f1e465e..00000000 --- a/src/eva/plot_tools/scatter_correlation.py +++ /dev/null @@ -1,51 +0,0 @@ -# (C) Copyright 2021-2022 United States Government as represented by the Administrator of the -# National Aeronautics and Space Administration. All Rights Reserved. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. - - -# -------------------------------------------------------------------------------------------------- - - -import numpy as np -import matplotlib.pyplot as plt - - -# -------------------------------------------------------------------------------------------------- - - -def scatter_correlation_plot(data_a, data_b, data_a_label, data_b_label, plot_title, output_name, - marker_size=2): - - # Compute limits for the figure - # ----------------------------- - data_min = min(min(data_a), min(data_b)) - data_max = max(max(data_a), max(data_b)) - data_dif = data_max - data_min - - # Compute correlation between two datasets - # ---------------------------------------- - correlation = np.corrcoef(data_a, data_b)[0, 1] - # TODO add this to the plot somewhere - - # Create figure - # ------------- - fig = plt.figure() - ax = fig.add_subplot(111) - plt.scatter(data_a, data_b, s=marker_size) - plt.title(plot_title) - - # Figure labeling - # --------------- - plt.xlabel(data_a_label) - plt.ylabel(data_b_label) - ax.set_aspect('equal', adjustable='box') - plt.xlim(data_min - 0.1*data_dif, data_max + 0.1*data_dif) - plt.ylim(data_min - 0.1*data_dif, data_max + 0.1*data_dif) - plt.axline((0, 0), slope=1.0, color='k') - - # Save figure - # ----------- - plt.savefig(output_name) - plt.close('all') diff --git a/src/eva/tests/__init__.py b/src/eva/tests/__init__.py new file mode 100644 index 00000000..ac1c0bc4 --- /dev/null +++ b/src/eva/tests/__init__.py @@ -0,0 +1,9 @@ +# (C) Copyright 2021-2022 United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + +import os + +repo_directory = os.path.dirname(__file__) diff --git a/src/eva/tests/config/ObsCorrelationScatterDriver.yaml b/src/eva/tests/config/ObsCorrelationScatterDriver.yaml new file mode 100644 index 00000000..98aa0aa9 --- /dev/null +++ b/src/eva/tests/config/ObsCorrelationScatterDriver.yaml @@ -0,0 +1,35 @@ +diagnostics: + - diagnostic name: ObsCorrelationScatter + comparisons: + - - hofx + - GsiHofXBc + - - hofx + - ObsValue + - - GsiHofXBc + - ObsValue + - - omb + - GsiombBc + figure file type: png + ioda experiment files: !ENVVAR ${EVA_TESTS_DIR}/data/amsua_n19.hofx.2020-12-14T21:00:00Z.nc4 + ioda reference files: !ENVVAR ${EVA_TESTS_DIR}/data/amsua_n19.hofx.2020-12-14T21:00:00Z.nc4 + marker size: 2 + output path: ./ + platforms: + - amsua_n19 + - diagnostic name: ObsCorrelationScatter + comparisons: + - - hofx + - GsiHofXBc + - - hofx + - ObsValue + - - GsiHofXBc + - ObsValue + - - omb + - GsiombBc + figure file type: png + ioda experiment files: !ENVVAR ${EVA_TESTS_DIR}/data/aircraft.hofx.2020-12-14T21:00:00Z.nc4 + ioda reference files: !ENVVAR ${EVA_TESTS_DIR}/data/aircraft.hofx.2020-12-14T21:00:00Z.nc4 + marker size: 2 + output path: ./ + platforms: + - aircraft diff --git a/src/eva/tests/data/aircraft.hofx.2020-12-14T21:00:00Z.nc4 b/src/eva/tests/data/aircraft.hofx.2020-12-14T21:00:00Z.nc4 new file mode 100644 index 00000000..1960a00d --- /dev/null +++ b/src/eva/tests/data/aircraft.hofx.2020-12-14T21:00:00Z.nc4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ee9dbbc82cd14208aff6db438215520fa578cb06fd12677a9e5f1aeb383ab0a +size 66304 diff --git a/src/eva/tests/data/amsua_n19.hofx.2020-12-14T21:00:00Z.nc4 b/src/eva/tests/data/amsua_n19.hofx.2020-12-14T21:00:00Z.nc4 new file mode 100644 index 00000000..f23c98ef --- /dev/null +++ b/src/eva/tests/data/amsua_n19.hofx.2020-12-14T21:00:00Z.nc4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a267fd5f54d166e25826593a2e58d12575e844899227d6a5e1f0b00ef68acc0 +size 162205 diff --git a/src/eva/tests/helpers.py b/src/eva/tests/helpers.py new file mode 100644 index 00000000..482e04ee --- /dev/null +++ b/src/eva/tests/helpers.py @@ -0,0 +1,41 @@ +import os +import re +import yaml + +# local imports +from eva.utilities.logger import Logger + + +def envvar_constructor(loader, node): + # method to help substitute parent directory for a yaml env var + return os.path.expandvars(node.value) + + +def load_yaml_file(yaml_file, logger): + # this yaml load function will allow a user to specify an environment + # variable to substitute in a yaml file if the tag '!ENVVAR' exists + # this will help developers create absolute system paths that are related + # to the install path of the eva package. + + if logger is None: + logger = Logger('EvaSetup') + + try: + loader = yaml.SafeLoader + loader.add_implicit_resolver( + '!ENVVAR', + re.compile(r'.*\$\{([^}^{]+)\}.*'), + None + ) + + loader.add_constructor('!ENVVAR', envvar_constructor) + yaml_dict = None + with open(yaml_file, 'r') as ymlfile: + yaml_dict = yaml.load(ymlfile, Loader=loader) + + except Exception as e: + logger.trace('Eva diagnostics is expecting a valid yaml file, but it encountered ' + + f'errors when attempting to load: {yaml_file}, error: {e}') + raise TypeError(f'Errors encountered loading yaml file: {yaml_file}, error: {e}') + + return yaml_dict diff --git a/src/eva/tests/test_obs_correlation_scatter.py b/src/eva/tests/test_obs_correlation_scatter.py new file mode 100644 index 00000000..927b3c96 --- /dev/null +++ b/src/eva/tests/test_obs_correlation_scatter.py @@ -0,0 +1,33 @@ +import os +import pathlib +import pytest +from eva import eva_base +from unittest.mock import patch + +# local imports +from eva.utilities.logger import Logger +from eva.tests import helpers + + +EVA_TESTS_DIR = pathlib.Path(__file__).parent.resolve() +PYTEST_ENVVARS = { + 'EVA_TESTS_DIR': str(EVA_TESTS_DIR) +} + +OBS_CORRELATION_SCATTER_YAML = os.path.join( + EVA_TESTS_DIR, 'config/ObsCorrelationScatterDriver.yaml') + +# redirect load_yaml_file from eva_base to helpers +eva_base.load_yaml_file = helpers.load_yaml_file + + +def test_obs_correlation_scatter(): + # unit test meant to assure that the loop_and_create_and_run + # method returns without errors when running the + # ObsCorrelationScatter plot method with test data defined in + # src/eva/tests/config/ObsCorrelationScatterDriver.yaml + try: + with patch.dict(os.environ, PYTEST_ENVVARS): + eva_base.eva(OBS_CORRELATION_SCATTER_YAML) + except Exception as e: + raise ValueError(f'Unexpected error encountered: {e}') diff --git a/src/eva/utilities/logger.py b/src/eva/utilities/logger.py index 5de03906..a82764cc 100644 --- a/src/eva/utilities/logger.py +++ b/src/eva/utilities/logger.py @@ -8,6 +8,7 @@ import os +import sys # -------------------------------------------------------------------------------------------------- @@ -42,8 +43,8 @@ def __init__(self, task_name): def send_message(self, level, message): - if self.loggerdict[level]: - print(level+" "+self.task_name+": "+message) + if level.upper() == 'ABORT' or self.loggerdict[level]: + print(level+' '+self.task_name+': '+message) # ---------------------------------------------------------------------------------------------- diff --git a/src/eva/utilities/stats.py b/src/eva/utilities/stats.py new file mode 100644 index 00000000..4257e839 --- /dev/null +++ b/src/eva/utilities/stats.py @@ -0,0 +1,198 @@ +# This work developed by NOAA/NWS/EMC under the Apache 2.0 license. +''' +stats.py contains statistics utility functions +''' + +__all__ = ['lregress', 'ttest', 'get_weights', 'get_weighted_mean', + 'get_linear_regression', 'bootstrap'] + +import numpy as _np +from scipy.stats import t as _t +from sklearn.linear_model import LinearRegression + + +def lregress(x, y, ci=95.0): + ''' + Function that computes the linear regression between two variables and + returns the regression coefficient and statistical significance + for a t-value at a desired confidence interval. + Args: + x : (array like) independent variable + y : (array like) dependent variable + ci : (float, optional, default=95) confidence interval percentage + Returns: + The linear regression coefficient (float), + the standard error on the linear regression coefficient (float), + and the statistical signficance of the linear regression + coefficient (bool). + ''' + + # make sure the two samples are of the same size + if (len(x) != len(y)): + raise ValueError('samples x and y are not of the same size') + else: + nsamp = len(x) + + pval = 1.0 - (1.0 - ci / 100.0) / 2.0 + tcrit = _t.ppf(pval, 2 * len(x) - 2) + + covmat = _np.cov(x, y=y, ddof=1) + cov_xx = covmat[0, 0] + cov_yy = covmat[1, 1] + cov_xy = covmat[0, 1] + + # regression coefficient (rc) + rc = cov_xy / cov_xx + # total standard error squared (se) + se = (cov_yy - (rc**2) * cov_xx) * (nsamp - 1) / (nsamp - 2) + # standard error on rc (sb) + sb = _np.sqrt(se / (cov_xx * (nsamp - 1))) + # error bar on rc + eb = tcrit * sb + + ssig = True if (_np.abs(rc) - _np.abs(eb)) > 0.0 else False + + return rc, sb, ssig + + +def ttest(x, y=None, ci=95.0, paired=True, scale=False): + ''' + Given two samples, perform the Student's t-test and return the errorbar. + The test assumes the sample size be the same between x and y. + Args: + x: (numpy array) control + y: (numpy array, optional, default=x )experiment + ci: (float, optional, default=95) confidence interval percentage + paired: (bool, optional, default=True) paired t-test + scale: (bool, optional, default=False) normalize with mean(x) and + return as a percentage + Returns: + The (normalized) difference in the sample means and + the (normalized) errorbar with respect to control. + To mask out statistically significant values:\n + `diffmask = numpy.ma.masked_where(numpy.abs(diffmean) + <=errorbar,diffmean).mask` + ''' + + nsamp = x.shape[0] + + if y is None: + y = x.copy() + + pval = 1.0 - (1.0 - ci / 100.0) / 2.0 + tcrit = _t.ppf(pval, 2*(nsamp-1)) + + xmean = _np.nanmean(x, axis=0) + ymean = _np.nanmean(y, axis=0) + + diffmean = ymean - xmean + + if paired: + # paired t-test + std_err = _np.sqrt(_np.nanvar(y-x, axis=0, ddof=1) / nsamp) + else: + # unpaired t-test + std_err = _np.sqrt((_np.nanvar(x, axis=0, ddof=1) + + _np.nanvar(y, axis=0, ddof=1)) / (nsamp-1.)) + + errorbar = tcrit * std_err + + # normalize (rescale) the diffmean and errorbar + if scale: + scale_fac = 100.0 / xmean + diffmean = diffmean * scale_fac + errorbar = errorbar * scale_fac + + return diffmean, errorbar + + +def get_weights(lats): + ''' + Get weights for latitudes to do weighted mean + Args: + lats: (array like) Latitudes + Returns: + An array of weights for latitudes + ''' + return _np.cos((_np.pi / 180.0) * lats) + + +def get_weighted_mean(data, weights, axis=None): + ''' + Given the weights for latitudes, compute weighted mean + of data in that direction + Note, `data` and `weights` must be same dimension + Uses `numpy.average` + Args: + data: (numpy array) input data array + weights: (numpy array) input weights + axis: (int) direction to compute weighted average + Returns: + An array of data weighted mean by weights + ''' + assert data.shape == weights.shape, ( + 'data and weights mis-match array size') + + return _np.average(data, weights=weights, axis=axis) + + +def get_linear_regression(x, y): + """ + Calculate linear regression between two sets of data. + Fits a linear model with coefficients to minumize the + residual sum of squares between the observed targets + in the dataset, and the targets predicted by the linear + approximation. + Args: + y, x : (array like) Data to calculate linear regression + Returns: + The predicted y values from calculation, + the R squared value, the intercept of the line, and the + slope of the line from the equation for the predicted + y values. + """ + x = x.reshape((-1, 1)) + model = LinearRegression().fit(x, y) + r_sq = model.score(x, y) + intercept = model.intercept_ + slope = model.coef_[0] + # This is the same as if you calculated y_pred + # by y_pred = slope * x + intercept + y_pred = model.predict(x) + return y_pred, r_sq, intercept, slope + + +def bootstrap(insample, level=.95, estimator='mean', nrepl=10000): + """ + Generate emprical bootstrap confidence intervals. + See https://ocw.mit.edu/courses/mathematics/ + 18-05-introduction-to-probability-and-statistics-spring-2014/ + readings/MIT18_05S14_Reading24.pdf for more information. + Args: + insample: (array like) is the array from which the estimator (u) + was derived (x_1, x_2,....x_n). + level: (float, default=0.95) desired confidence level for CI bounds + estimator: (char, default='mean') type of statistic obtained from + the sample (mean or median) + nrepl: (integer, default=1000) number of replicates + Returns: + Lower and upper bounds of confidence intervals + """ + if any(_np.isnan(insample)): + print('bootstrap_ci.py: NaN detected. Dropping NaN(s) input prior to bootstrap...') + sample = insample[~_np.isnan(insample)] + else: + sample = insample + + boot_dist = [_np.random.choice(sample, _np.size(sample)) for x in _np.arange(nrepl)] + if estimator.lower() == 'mean': + deltas = _np.sort(_np.mean(boot_dist, axis=1) - _np.mean(sample)) + elif estimator.lower() == 'median': + deltas = _np.sort(_np.median(boot_dist, axis=1) - _np.median(sample)) + + lower_pctile = 100*((1. - level)/2.) + upper_pctile = 100. - lower_pctile + ci_lower = _np.percentile(deltas, lower_pctile) + ci_upper = _np.percentile(deltas, upper_pctile) + + return ci_lower, ci_upper