Skip to content

Commit

Permalink
Initial prototype using Xarray SeasonGrouper`
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Nov 14, 2024
1 parent 8d156c2 commit 2e736ca
Showing 1 changed file with 54 additions and 17 deletions.
71 changes: 54 additions & 17 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xarray.coding.cftime_offsets import get_date_type
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
from xarray.core.groupby import DataArrayGroupBy
from xarray.groupers import SeasonGrouper, UniqueGrouper

from xcdat import bounds # noqa: F401
from xcdat._logger import _setup_custom_logger
Expand Down Expand Up @@ -1091,7 +1092,10 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]]
f"Supported months include: {predefined_months}."
)

c_seasons = {"".join(months): months for months in custom_seasons}
c_seasons = {}
for season in custom_seasons:
key = "".join([month[0] for month in season])
c_seasons[key] = season

return c_seasons

Expand Down Expand Up @@ -1130,18 +1134,19 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
self._freq == "season"
and self._season_config.get("custom_seasons") is not None
):
# Get a flat list of all of the months included in the custom
# seasons to determine if the dataset needs to be subsetted
# on just those months. For example, if we define a custom season
# "NDJFM", we should subset the dataset for time coordinates
# belonging to those months.
months = self._season_config["custom_seasons"].values() # type: ignore
months = list(chain.from_iterable(months))

if len(months) != 12:
ds = self._subset_coords_for_custom_seasons(ds, months)

ds = self._shift_custom_season_years(ds)
# FIXME: This causes a bug when accessing `.groups` with
# SeasonGrouper(). Also shifting custom seasons is done for
# drop_incomplete_seasons and grouping for months that span the
# calendar year. The Xarray PR will handle both of these cases
# and this method will be removed.
# ds = self._shift_custom_season_years(ds)
pass

if self._freq == "season" and self._season_config.get("dec_mode") == "DJF":
ds = self._shift_djf_decembers(ds)
Expand All @@ -1153,11 +1158,11 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset:
):
ds = self._drop_incomplete_djf(ds)

if (
self._freq == "season"
and self._season_config["drop_incomplete_seasons"] is True
):
ds = self._drop_incomplete_seasons(ds)
# if (
# self._freq == "season"
# and self._season_config["drop_incomplete_seasons"] is True
# ):
# ds = self._drop_incomplete_seasons(ds)

return ds

Expand Down Expand Up @@ -1494,8 +1499,7 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:

# Label the time coordinates for grouping weights and the data variable
# values.
self._labeled_time = self._label_time_coords(dv[self.dim])
dv = dv.assign_coords({self.dim: self._labeled_time})
dv_grouped = self._label_time_coords_for_grouping(dv)

if self._weighted:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
Expand All @@ -1514,13 +1518,14 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
# Perform weighted average using the formula
# WA = sum(data*weights) / sum(weights). The denominator must be
# included to take into account zero weight for missing data.
weights_gb = self._label_time_coords_for_grouping(weights)
with xr.set_options(keep_attrs=True):
dv = self._group_data(dv).sum() / self._group_data(weights).sum()
dv = dv_grouped.sum() / weights_gb.sum()

# Restore the data variable's name.
dv.name = data_var
else:
dv = self._group_data(dv).mean()
dv = dv_grouped.mean()

# After grouping and aggregating, the grouped time dimension's
# attributes are removed. Xarray's `keep_attrs=True` option only keeps
Expand Down Expand Up @@ -1578,7 +1583,10 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:

time_lengths = time_lengths.astype(np.float64)

grouped_time_lengths = self._group_data(time_lengths)
grouped_time_lengths = self._label_time_coords_for_grouping(time_lengths)
# FIXME: File "/opt/miniconda3/envs/xcdat_dev_416_xr/lib/python3.12/site-packages/xarray/core/groupby.py", line 639, in _raise_if_not_single_group
# raise NotImplementedError(
# NotImplementedError: This method is not supported for grouping by multiple variables yet.
weights: xr.DataArray = grouped_time_lengths / grouped_time_lengths.sum()
weights.name = f"{self.dim}_wts"

Expand Down Expand Up @@ -1670,6 +1678,35 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray:

return time_grouped

def _label_time_coords_for_grouping(self, dv: xr.DataArray) -> DataArrayGroupBy:
# Use the TIME_GROUPS dictionary to determine which components
# are needed to form the labeled time coordinates.
dt_comps = TIME_GROUPS[self._mode][self._freq]
dt_comps_map: Dict[str, UniqueGrouper | SeasonGrouper] = {
comp: UniqueGrouper() for comp in dt_comps if comp != "season"
}

dv_new = dv.copy()
for comp in dt_comps_map.keys():
dv_new.coords[comp] = dv_new[self.dim][f"{self.dim}.{comp}"]

if self._freq == "season":
custom_seasons = self._season_config.get("custom_seasons")
# NOTE: SeasonGrouper() does not drop incomplete seasons yet.
# TODO: Add `drop_incomplete` arg once available.

if custom_seasons is not None:
season_keys = list(custom_seasons.keys())
season_grouper = SeasonGrouper(season_keys)
else:
season_keys = list(SEASON_TO_MONTH.keys())
season_grouper = SeasonGrouper(season_keys)

dt_comps_map[self.dim] = season_grouper
dv_gb = dv_new.groupby(**dt_comps_map) # type: ignore

return dv_gb

def _get_df_dt_components(
self, time_coords: xr.DataArray, drop_obsolete_cols: bool
) -> pd.DataFrame:
Expand Down

0 comments on commit 2e736ca

Please sign in to comment.