"""Spatial/temporal smoothing implemented in PredictionEnsemble.smooth()."""
from typing import Dict, Optional
import numpy as np
import xarray as xr
try:
import xesmf as xe
except ImportError:
xe = None
[docs]
def spatial_smoothing_xesmf(
ds: xr.Dataset,
d_lon_lat_kws: Dict[str, float] = {"lon": 5, "lat": 5},
method: str = "bilinear",
periodic: bool = False,
filename: Optional[str] = None,
reuse_weights: bool = False,
tsmooth_kws: Optional[Dict[str, int]] = None,
how: Optional[str] = None,
) -> xr.Dataset:
"""Quick regridding function.
Adapted from https://github.com/JiaweiZhuang/xESMF/pull/27/files#diff-b537ef68c98c2ec11e64e4803fe4a113R105. # noqa: E501
Args:
ds: Contain input and output grid coordinates.
Look for coordinates ``lon``, ``lat``, and optionally ``lon_b``,
``lat_b`` for conservative method. Also any coordinate which is C/F
compliant, .i.e. standard_name in ["longitude", "latitude"] is allowed.
Shape can be 1D (Nlon,) and (Nlat,) for rectilinear grids,
or 2D (Ny, Nx) for general curvilinear grids.
Shape of bounds should be (N+1,) or (Ny+1, Nx+1).
d_lon_lat_kws: Longitude/Latitude step size (grid resolution); if not provided,
lon will equal 5 and lat will equal lon
method: Regridding method. Options are:
- "bilinear"
- "conservative", **requires grid corner information**
- "patch"
- "nearest_s2d"
- "nearest_d2s"
periodic: Periodic in longitude? Defaults to ``False``.
Only useful for global grids with non-conservative regridding.
Will be forced to False for conservative regridding.
filename: Name for the weight file.
The default naming scheme is "{method}_{Ny_in}x{Nx_in}_{Ny_out}x{Nx_out}.nc"
e.g. "bilinear_400x600_300x400.nc"
reuse_weights: Whether to read existing weight file to save
computing time. Defaults to ``False``.
tsmooth_kws: leads nowhere but consistent with ``temporal_smoothing``.
how: leads nowhere but consistent with ``temporal_smoothing``.
Returns:
regridded
"""
if xe is None:
raise ImportError(
"xesmf is not installed; see "
"https://pangeo-xesmf.readthedocs.io/en/latest/installation.html"
)
def _regrid_it(da, d_lon, d_lat, **kwargs):
"""Global 2D rectilinear grid centers and bounds.
Args:
da (xarray.DataArray): Contain input and output grid coords.
Look for variables ``lon``, ``lat``, ``lon_b``, ``lat_b`` for
conservative method, and ``TLAT``, ``TLON`` for CESM POP grid
Shape can be 1D (Nlon,) and (Nlat,) for rectilinear grids,
or 2D (Ny, Nx) for general curvilinear grids.
Shape of bounds should be (N+1,) or (Ny+1, Nx+1).
d_lon (float): Longitude step size, i.e. grid resolution
d_lat (float): Latitude step size, i.e. grid resolution
Returns:
da : xarray.DataArray with coordinate values
"""
if "lon" in da.coords:
lon = da.lon
else:
try:
lon = da.cf["longitude"]
except KeyError:
raise KeyError(
"Could not find `lon` as coordinate or any C/F compliant"
"`latitude` coordinate, see https://pangeo-xesmf.readthedocs.io "
"and https://cf-xarray.readthedocs.io"
)
if "lat" in da.coords:
lat = da.lat
else:
try:
lat = da.cf["latitude"]
except KeyError:
raise KeyError(
"C/F compliant or `lat` as coordinate, see "
"https://pangeo-xesmf.readthedocs.io"
)
grid_out = xr.Dataset(
{
"lat": (["lat"], np.arange(lat.min(), lat.max() + d_lat, d_lat)),
"lon": (["lon"], np.arange(lon.min(), lon.max() + d_lon, d_lon)),
}
)
regridder = xe.Regridder(da, grid_out, **kwargs)
return regridder(da, keep_attrs=True)
# check if lon or/and lat missing
if ("lon" in d_lon_lat_kws) and ("lat" in d_lon_lat_kws):
pass
elif ("lon" not in d_lon_lat_kws) and ("lat" in d_lon_lat_kws):
d_lon_lat_kws["lon"] = d_lon_lat_kws["lat"]
elif ("lat" not in d_lon_lat_kws) and ("lon" in d_lon_lat_kws):
d_lon_lat_kws["lat"] = d_lon_lat_kws["lon"]
else:
raise ValueError("please provide either `lon` or/and `lat` in d_lon_lat_kws.")
kwargs = {
"d_lon": d_lon_lat_kws["lon"],
"d_lat": d_lon_lat_kws["lat"],
"method": method,
"periodic": periodic,
"filename": filename,
"reuse_weights": reuse_weights,
}
ds = _regrid_it(ds, **kwargs)
return ds
[docs]
def temporal_smoothing(
ds: xr.Dataset,
tsmooth_kws: Optional[Dict[str, int]] = None,
how: str = "mean",
d_lon_lat_kws: Optional[Dict[str, float]] = None,
) -> xr.Dataset:
"""Apply temporal smoothing by creating rolling smooth-timestep means.
Args:
ds: input to be smoothed.
tsmooth_kws: length of smoothing of timesteps.
Defaults to ``{"time": 4}`` (see :cite:t:`Goddard2013`).
how: aggregation type for smoothing. Allowed: ``["mean", "sum"]``.
Default: ``"mean"``.
d_lon_lat_kws: leads nowhere but consistent with ``spatial_smoothing_xesmf``.
Returns:
input with ``smooth`` timesteps less and
labeling ``"1-(smooth-1)", "...", ...`` .
References:
:cite:t:`Goddard2013`
"""
# unpack dict
if not isinstance(tsmooth_kws, dict):
raise ValueError(
"Please provide `tsmooth_kws` as dict, found ", type(tsmooth_kws)
)
if not ("time" in tsmooth_kws or "lead" in tsmooth_kws):
raise ValueError(
'`tsmooth_kws` doesnt contain a `time` dimension \
(either "lead" or "time").',
tsmooth_kws,
)
smooth = list(tsmooth_kws.values())[0]
if smooth == 1:
return ds
dim = list(tsmooth_kws.keys())[0]
# fix to smooth either lead or time depending
time_dims = ["time", "lead"]
if dim not in ds.dims:
time_dims.remove(dim)
dim = time_dims[0]
tsmooth_kws = {dim: smooth}
# aggreate based on how
ds_smoothed = getattr(ds.rolling(tsmooth_kws, center=False), how)()
# remove first all-nans
ds_smoothed = ds_smoothed.isel({dim: slice(smooth - 1, None)})
ds_smoothed[dim] = ds.isel({dim: slice(None, -smooth + 1)})[dim]
return ds_smoothed
def _reset_temporal_axis(
ds_smoothed: xr.Dataset,
tsmooth_kws: Dict[str, int],
dim: str = "lead",
set_lead_center: bool = True,
) -> xr.Dataset:
"""Reduce and reset temporal axis. See temporal_smoothing.
Should be used after calculation of skill to maintain readable labels for skill
computation.
Args:
ds_smoothed: Smoothed dataset.
tsmooth_kws: Keywords smoothing is performed over.
dim: Dimension smoothing is performed over. Defaults to ``"lead"``.
set_center: Whether to set new coord `{dim}_center`. Defaults to ``True``.
Returns
Smoothed Dataset with updated labels for smoothed temporal dimension.
"""
# bugfix: actually tsmooth_kws should only dict
if tsmooth_kws is None or callable(tsmooth_kws):
return ds_smoothed
if not ("time" in tsmooth_kws.keys() or "lead" in tsmooth_kws.keys()):
raise ValueError("tsmooth_kws does not contain a time dimension.", tsmooth_kws)
for c in ["time", "lead"]:
if c in tsmooth_kws.keys():
smooth = tsmooth_kws[c]
ds_smoothed[dim] = [f"{t}-{t + smooth - 1}" for t in ds_smoothed[dim].values]
if set_lead_center:
_set_center_coord(ds_smoothed, dim)
return ds_smoothed
def _set_center_coord(ds: xr.Dataset, dim: str = "lead") -> xr.Dataset:
"""Set lead_center as a new coordinate."""
new_dim = []
old_dim = ds[dim].values
for i in old_dim:
new_dim.append(eval(i.replace("-", "+")) / 2)
ds.coords[f"{dim}_center"] = (dim, np.array(new_dim))
return ds
def smooth_goddard_2013(
ds: xr.Dataset,
tsmooth_kws: Dict[str, int] = {"lead": 4},
d_lon_lat_kws: Dict[str, float] = {"lon": 5, "lat": 5},
how: str = "mean",
**xesmf_kwargs: str,
) -> xr.Dataset:
"""Wrap to smooth as suggested by :cite:t:`Goddard2013`.
- 4-year composites
- 5x5 degree regridding
Args:
ds: input to be smoothed.
tsmooth_kws: length of smoothing of timesteps (applies to ``lead``
in forecast and ``time`` in verification data).
Default: ``{"time": 4}`` (see :cite:t:`Goddard2013`).
d_lon_lat_kws: target grid for regridding.
Default: ``{"lon":5 , "lat": 5}``.
how: aggregation type for smoothing. Allowed: ``["mean", "sum"]``.
Default: ``"mean"``.
**xesmf_kwargs: kwargs passed to ``spatial_smoothing_xesmf``.
Returns:
input with ``smooth`` timesteps less and labeling "1-(smooth-1)", "..." .
References:
:cite:t:`Goddard2013`
"""
# first temporal smoothing
ds_smoothed = temporal_smoothing(ds, tsmooth_kws=tsmooth_kws)
ds_smoothed_regridded = spatial_smoothing_xesmf(
ds_smoothed, d_lon_lat_kws=d_lon_lat_kws, **xesmf_kwargs # type: ignore
)
return ds_smoothed_regridded