Skip to content

Commit

Permalink
use standalone metadata object
Browse files Browse the repository at this point in the history
  • Loading branch information
cfkanesan committed Nov 26, 2024
1 parent 94c973b commit ab3ec0e
Show file tree
Hide file tree
Showing 24 changed files with 1,632 additions and 1,080 deletions.
2,504 changes: 1,572 additions & 932 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ authors = [
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
click = "^8.1.7"
earthkit-data = "^0.5.6"
earthkit-data = ">=0.5.6,<1"
eccodes = "^1.5.0"
numpy = "^1.26.4"
polytope-client = "^0.7.4"
Expand Down
12 changes: 6 additions & 6 deletions src/meteodatalab/grib_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def load(self, field: GribField) -> None:
self.values[key] = field.to_numpy(dtype=np.float32)

if not self.metadata:
md = field.metadata()
md = field.metadata().override()
self.metadata = {
"message": md,
"metadata": md,
**metadata.extract(md),
}

Expand Down Expand Up @@ -360,7 +360,7 @@ def load(
if extract_pv not in requests:
msg = f"{extract_pv=} was not a key of the given requests."
raise RuntimeError(msg)
return result | metadata.extract_pv(result[extract_pv].message)
return result | metadata.extract_pv(result[extract_pv].metadata)
return result

def load_fieldnames(
Expand Down Expand Up @@ -394,11 +394,11 @@ def save(
If the field does not have a message attribute.
"""
if not hasattr(field, "message"):
msg = "The message attribute is required to write to the GRIB format."
if not hasattr(field, "metadata"):
msg = "The metadata attribute is required to write to the GRIB format."
raise ValueError(msg)

md = field.message
md = field.metadata

idx = {
dim: field.coords[key]
Expand Down
83 changes: 19 additions & 64 deletions src/meteodatalab/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

# Standard library
import dataclasses as dc
import io
import typing

# Third-party
import earthkit.data as ekd # type: ignore
import numpy as np
import xarray as xr
from earthkit.data.core.metadata import Metadata
Expand All @@ -21,7 +19,7 @@
}


def extract(metadata):
def extract(metadata: Metadata):
if metadata.get("gridType") == "unstructured_grid":
vref_flag = False
else:
Expand All @@ -41,37 +39,37 @@ def extract(metadata):
}


def override(message: Metadata, **kwargs: typing.Any) -> dict[str, typing.Any]:
"""Override GRIB metadata contained in message.
def override(metadata: Metadata, **kwargs: typing.Any) -> dict[str, typing.Any]:
"""Override GRIB metadata.
Note that no special consideration is made for maintaining consistency when
overriding template definition keys such as productDefinitionTemplateNumber.
Note that the origin components in x and y are left untouched.
Parameters
----------
message : bytes
Byte string of the input GRIB message
metadata : Metadata
Metadata of the input GRIB metadata
kwargs : Any
Keyword arguments forwarded to earthkit-data GribMetadata override method
Returns
-------
dict[str, Any]
Updated message byte string along with the geography and parameter namespaces
Updated metadata along with the geography and parameter namespaces
"""

if message["editionNumber"] == 1:
if metadata["editionNumber"] == 1:
return {
"message": message,
**extract(message),
"metadata": metadata,
**extract(metadata),
}

md = message.override(**kwargs)
md = metadata.override(**kwargs)

return {
"message": md,
"metadata": md,
**extract(md),
}

Expand All @@ -98,7 +96,7 @@ def load_grid_reference(metadata) -> Grid:
Parameters
----------
metadata : bytes
metadata : Metadata
GRIB metadata defining the reference grid.
Returns
Expand Down Expand Up @@ -162,29 +160,26 @@ def set_origin_xy(ds: dict[str, xr.DataArray], ref_param: str) -> None:
if ref_param not in ds:
raise KeyError(f"ref_param {ref_param} not present in dataset.")

ref_grid = load_grid_reference(ds[ref_param].message)
ref_grid = load_grid_reference(ds[ref_param].metadata)
for field in ds.values():
field.attrs |= compute_origin(ref_grid, field)


def extract_pv(message: bytes) -> dict[str, xr.DataArray]:
def extract_pv(metadata: Metadata) -> dict[str, xr.DataArray]:
"""Extract hybrid level coefficients.
Parameters
----------
message : bytes
GRIB message containing the pv metadata.
message : Metadata
GRIB metadata containing the pv metadata.
Returns
-------
dict[str, xarray.DataArray]
Hybrid level coefficients.
"""
stream = io.BytesIO(message)
[grib_field] = ekd.from_source("stream", stream)

pv = grib_field.metadata("pv")
pv = metadata.get("pv")

if pv is None:
return {}
Expand All @@ -196,12 +191,12 @@ def extract_pv(message: bytes) -> dict[str, xr.DataArray]:
}


def extract_hcoords(metadata) -> dict[str, xr.DataArray]:
def extract_hcoords(metadata: Metadata) -> dict[str, xr.DataArray]:
"""Extract horizontal coordinates.
Parameters
----------
metadata : bytes
metadata : Metadata
GRIB metadata containing the grid definition.
Returns
Expand All @@ -217,43 +212,3 @@ def extract_hcoords(metadata) -> dict[str, xr.DataArray]:
dims=("y", "x"), data=geo.longitudes().reshape(geo.shape())
),
}


def extract_keys(message: bytes, keys: typing.Any, single: bool = True) -> typing.Any:
"""Extract keys from the GRIB message.
Parameters
----------
message : bytes
The GRIB message.
keys : Any
Keys for which to extract values from the message.
single : bool, optional
Whether a single GRIB message should be expected.
Raises
------
ValueError
if keys is None because the resulting metadata would point
to an eccodes handle that no longer exists resulting in a
possible segmentation fault
Returns
-------
Any
Single value if keys is a single value, tuple of values if
keys is a tuple, list of values if keys is a list. The type of
the value depends on the default type for the given key in eccodes.
If single is false, the above is returned within a list.
"""
if keys is None:
raise ValueError("keys must be specified.")
stream = io.BytesIO(message)
source = ekd.from_source("stream", stream)

if single:
[grib_field] = source
return grib_field.metadata(keys)

return [grib_field.metadata(keys) for grib_field in source]
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/brn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def fbrn(

return xr.DataArray(
data=brn,
attrs=metadata.override(p.message, shortName="BRN"),
attrs=metadata.override(p.metadata, shortName="BRN"),
)
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def crop(field: xr.DataArray, bounds: Bounds) -> xr.DataArray:
return xr.DataArray(
field.isel(x=slice(xmin, xmax + 1), y=slice(ymin, ymax + 1)),
attrs=metadata.override(
field.message,
field.metadata,
longitudeOfFirstGridPoint=lon_min,
longitudeOfLastGridPoint=lon_max,
Ni=ni,
Expand Down
8 changes: 4 additions & 4 deletions src/meteodatalab/operators/destagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _update_grid(field: xr.DataArray, dim: Literal["x", "y"]) -> dict[str, Any]:
lon_max = np.round(geo["longitudeOfLastGridPointInDegrees"] * 1e6)
dx = np.round(geo["iDirectionIncrementInDegrees"] * 1e6)
return metadata.override(
field.message,
field.metadata,
longitudeOfFirstGridPoint=lon_min - dx / 2,
longitudeOfLastGridPoint=lon_max - dx / 2,
)
Expand All @@ -90,7 +90,7 @@ def _update_grid(field: xr.DataArray, dim: Literal["x", "y"]) -> dict[str, Any]:
lat_max = np.round(geo["latitudeOfLastGridPointInDegrees"] * 1e6)
dy = np.round(geo["jDirectionIncrementInDegrees"] * 1e6)
return metadata.override(
field.message,
field.metadata,
latitudeOfFirstGridPoint=lat_min - dy / 2,
latitudeOfLastGridPoint=lat_max - dy / 2,
)
Expand All @@ -100,7 +100,7 @@ def _update_vertical(field) -> dict[str, Any]:
if field.vcoord_type != "model_level":
raise ValueError("typeOfLevel must equal generalVertical")
return metadata.override(
field.message,
field.metadata,
typeOfLevel="generalVerticalLayer",
)

Expand Down Expand Up @@ -151,7 +151,7 @@ def destagger(
)
.transpose(*dims)
.assign_attrs({f"origin_{dim}": 0.0}, **attrs)
.assign_coords(metadata.extract_hcoords(attrs["message"]))
.assign_coords(metadata.extract_hcoords(attrs["metadata"]))
)
elif dim == "z":
if field.origin_z != -0.5:
Expand Down
4 changes: 2 additions & 2 deletions src/meteodatalab/operators/gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,13 @@ def vref_rot2geolatlon(
xr.DataArray(
u_g,
attrs=metadata.override(
u.message, resolutionAndComponentFlags=resolution_components_flags
u.metadata, resolutionAndComponentFlags=resolution_components_flags
),
),
xr.DataArray(
v_g,
attrs=metadata.override(
v.message, resolutionAndComponentFlags=resolution_components_flags
v.metadata, resolutionAndComponentFlags=resolution_components_flags
),
),
)
Expand Down
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/hzerocl.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ def fhzerocl(

return xr.DataArray(
data=hzerocl.where(hzerocl > 0),
attrs=metadata.override(t.message, shortName="HZEROCL"),
attrs=metadata.override(t.metadata, shortName="HZEROCL"),
)
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/omega_slope.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def omega_slope(
data=res,
attrs=metadata.override(
# Eta-coordinate vertical velocity
etadot.message,
etadot.metadata,
discipline=0,
parameterCategory=2,
parameterNumber=32,
Expand Down
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/pot_vortic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ def compute_pot_vortic(
dt_dlam * curl1 + dt_dphi * (curl2 + cor2) - dt_dzeta * (curl3 + cor3)
) / rho_tot

out.attrs = metadata.override(theta.message, shortName="POT_VORTIC")
out.attrs = metadata.override(theta.metadata, shortName="POT_VORTIC")

return out
4 changes: 2 additions & 2 deletions src/meteodatalab/operators/radiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def compute_athd_s(athb_s: xr.DataArray, tsurf: xr.DataArray) -> xr.DataArray:
"""
return xr.DataArray(
data=athb_s / pc.emissivity_surface + pc.boltzman_cst * tsurf**4,
attrs=metadata.override(athb_s.message, shortName="ATHD_S"),
attrs=metadata.override(athb_s.metadata, shortName="ATHD_S"),
)


Expand All @@ -48,5 +48,5 @@ def compute_swdown(diffuse: xr.DataArray, direct: xr.DataArray) -> xr.DataArray:
"""
return xr.DataArray(
data=(diffuse + direct).clip(min=0),
attrs=metadata.override(diffuse.message, shortName="ASOD_S"),
attrs=metadata.override(diffuse.metadata, shortName="ASOD_S"),
)
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def reproject_layer(field):

attrs = field.attrs
if md := _get_metadata(dst):
attrs = attrs | metadata.override(field.message, **md)
attrs = attrs | metadata.override(field.metadata, **md)

return xr.DataArray(data, attrs=attrs)

Expand Down
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/relhum.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ def relhum(
return xr.DataArray(
data=result,
attrs=metadata.override(
t.message, shortName=phase_conditions[phase]["shortName"]
t.metadata, shortName=phase_conditions[phase]["shortName"]
),
)
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/rho.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ def compute_rho_tot(

return xr.DataArray(
data=p / (pc.r_d * t * (1.0 + pc.rvd_o * qv - q)),
attrs=metadata.override(p.message, shortName="DEN"),
attrs=metadata.override(p.metadata, shortName="DEN"),
)
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/support_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def init_field_with_vcoord(
# in the interface
# attrs
attrs = parent.attrs | metadata.override(
parent.message, typeOfLevel=vcoord.type_of_level
parent.metadata, typeOfLevel=vcoord.type_of_level
)
# dims
sizes = dict(parent.sizes.items()) | {"z": vcoord.size}
Expand Down
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def compute_theta(p: xr.DataArray, t: xr.DataArray) -> xr.DataArray:

return xr.DataArray(
data=(p0 / p) ** pc.rdocp * t,
attrs=metadata.override(p.message, shortName="PT"),
attrs=metadata.override(p.metadata, shortName="PT"),
)
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/thetav.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def fthetav(p: xr.DataArray, t: xr.DataArray, qv: xr.DataArray) -> xr.DataArray:

return xr.DataArray(
data=(p0 / p) ** pc.rdocp * t * (1.0 + (pc.rvd_o * qv / (1.0 - qv))),
attrs=metadata.override(t.message, shortName="THETA_V"),
attrs=metadata.override(t.metadata, shortName="THETA_V"),
)
Loading

0 comments on commit ab3ec0e

Please sign in to comment.