Source code for climpred.prediction

"""Prediction module: _apply_metric_at_given_lead and compute functions."""

from typing import List, Optional, Set, Tuple, Union

from dask import is_dask_collection

from .checks import is_in_list
from .comparisons import (
    COMPARISON_ALIASES,
    HINDCAST_COMPARISONS,
    PM_COMPARISONS,
    PROBABILISTIC_HINDCAST_COMPARISONS,
    PROBABILISTIC_PM_COMPARISONS,
)
from .constants import M2M_MEMBER_DIM, PM_CALENDAR_STR
from .exceptions import DimensionError
from .logging import log_hindcast_verify_inits_and_verifs
from .metrics import HINDCAST_METRICS, METRIC_ALIASES, PM_METRICS
from .reference import (
    _adapt_member_for_reference_forecast,
    climatology,
    persistence,
    uninitialized,
)
from .utils import (
    convert_time_index,
    get_comparison_class,
    get_lead_cftime_shift_args,
    get_metric_class,
    shift_cftime_singular,
)


def _apply_metric_at_given_lead(
    verif,
    verif_dates,
    lead,
    initialized=None,
    hist=None,
    inits=None,
    reference=None,
    metric=None,
    comparison=None,
    dim=None,
    **metric_kwargs,
):
    """Apply a metric between two time series at a given lead.

    Args:
        verif (xr.Dataset): Verification data.
        verif_dates (dict): Lead-dependent verification dates for alignment.
        lead (int): Given lead to score.
        initialized (xr.Dataset): Initialized hindcast. Not required in a persistence
            forecast.
        hist (xr.Dataset): Uninitialized/historical simulation. Required when
            ``reference='uninitialized'``.
        inits (dict): Lead-dependent initialization dates for alignment.
        reference (str): If not ``None``, return score for this reference forecast.
            * 'persistence'
            * 'uninitialized'
        metric (Metric): Metric class for scoring.
        comparison (Comparison): Comparison class.
        dim (str): Dimension to apply metric over.

    Returns:
        result (xr.Dataset): Metric results for the given lead for the initialized
            forecast or reference forecast.
    """
    # naming:: lforecast: forecast at lead; lverif: verification at lead
    if reference is None:
        # Use `.where()` instead of `.sel()` to account for resampled inits when
        # bootstrapping.
        lforecast = initialized.sel(lead=lead).where(
            initialized["time"].isin(inits[lead]), drop=True
        )
        lverif = verif.sel(time=verif_dates[lead])
    elif reference == "persistence":
        lforecast, lverif = persistence(verif, inits, verif_dates, lead)
    elif reference == "uninitialized":
        lforecast, lverif = uninitialized(hist, verif, verif_dates, lead)
    elif reference == "climatology":
        lforecast, lverif = climatology(verif, inits, verif_dates, lead)
    if reference is not None:
        lforecast, dim = _adapt_member_for_reference_forecast(
            lforecast, lverif, metric, comparison, dim
        )
    assert lforecast.time.size == lverif.time.size, print(
        lforecast.time.to_index(), lverif.time.to_index(), reference
    )
    lforecast["time"] = lverif[
        "time"
    ]  # a bit dangerous: what if different? more clear once implemented
    # https://github.com/pangeo-data/climpred/issues/523#issuecomment-728951645
    dim = _rename_dim(
        dim, initialized, verif
    )  # dim should be much clearer once time_valid in initialized.coords
    if metric.normalize or metric.allows_logical:
        metric_kwargs["comparison"] = comparison

    result = metric.function(lforecast, lverif, dim=dim, **metric_kwargs)

    log_hindcast_verify_inits_and_verifs(dim, lead, inits, verif_dates, reference)
    # push time (later renamed to init) back by lead
    if "time" in result.dims:
        n, freq = get_lead_cftime_shift_args(initialized.lead.attrs["units"], lead)
        result = result.assign_coords(time=shift_cftime_singular(result.time, -n, freq))
    if "valid_time" in result.coords:
        if is_dask_collection(result.coords["valid_time"]):
            result.coords["valid_time"] = result.coords["valid_time"].compute()
    return result


def _rename_dim(dim, forecast, verif):
    """Rename `dim` to `time` or `init` if forecast and verif dims requires."""
    if "init" in dim and "time" in forecast.dims and "time" in verif.dims:
        dim = dim.copy()
        dim.remove("init")
        dim = dim + ["time"]
    elif "time" in dim and "init" in forecast.dims and "init" in verif.dims:
        dim = dim.copy()
        dim.remove("time")
        dim = dim + ["init"]
    elif "init" in dim and "time" in forecast.dims and "time" in verif.dims:
        dim = dim.copy()
        dim.remove("init")
        dim = dim + ["time"]
    return dim


def _sanitize_to_list(dim: Union[str, Set, Tuple, List, None]) -> Optional[List[str]]:
    """Ensure dim is List, raises ValueError if not str, set, tuple or None."""
    if isinstance(dim, str):
        dim = [dim]
    elif isinstance(dim, set):
        dim = list(dim)
    elif isinstance(dim, tuple):
        dim = list(dim)
    elif isinstance(dim, list) or dim is None:
        pass
    else:
        raise ValueError(
            f"Expected `dim` as `str`, `list` or None, found {dim} as type {type(dim)}."
        )
    return dim


def _get_metric_comparison_dim(
    initialized,
    metric,
    comparison,
    dim: Optional[Union[List[str], str]],
    kind,
):
    """Return `metric`, `comparison` and `dim` for compute functions.

    Args:
        initialized (xr.Dataset): initialized dataset
        metric (str): metric or alias string
        comparison (str): Description of parameter `comparison`.
        dim (list of str or str): dimension to apply metric to.
        kind (str): experiment type from ['hindcast', 'PM'].

    Returns:
        metric (Metric): metric class
        comparison (Comparison): comparison class.
        dim (list of str or str): corrected dimension to apply metric to.
    """
    dim = _sanitize_to_list(dim)

    # check kind allowed
    is_in_list(kind, ["hindcast", "perfect"], "kind")

    if dim is None:  # take all dimension from initialized except lead
        dim = list(initialized.dims)
        if "lead" in dim:
            dim.remove("lead")
        # adjust dim for e2c comparison when member not in forecast or verif
        if comparison in ["e2c"] and "member" in dim:
            dim.remove("member")

    # check that initialized contains all dims from dim
    if not set(dim).issubset(initialized.dims):
        raise DimensionError(
            f"`dim`={dim} is expected to be a subset of "
            f"`initialized.dims`={initialized.dims}."
        )

    # get metric and comparison strings incorporating alias
    metric = METRIC_ALIASES.get(metric, metric)
    comparison = COMPARISON_ALIASES.get(comparison, comparison)

    METRICS = HINDCAST_METRICS if kind == "hindcast" else PM_METRICS
    COMPARISONS = HINDCAST_COMPARISONS if kind == "hindcast" else PM_COMPARISONS
    metric = get_metric_class(metric, METRICS)
    comparison = get_comparison_class(comparison, COMPARISONS)

    # check whether combination of metric and comparison works
    PROBABILISTIC_COMPARISONS = (
        PROBABILISTIC_HINDCAST_COMPARISONS
        if kind == "hindcast"
        else PROBABILISTIC_PM_COMPARISONS
    )
    if metric.probabilistic:
        if not comparison.probabilistic and "member" in initialized.dims:
            raise ValueError(
                f"Probabilistic metric `{metric.name}` requires comparison "
                f"accepting multiple members e.g. `{PROBABILISTIC_COMPARISONS}`, "
                f"found `{comparison.name}`."
            )
        if "member" not in dim and "member" in initialized.dims:
            raise ValueError(
                f"Probabilistic metric {metric.name} requires to be "
                f"computed over dimension `member`, which is not found in {dim}."
            )
    else:  # determinstic metric
        if kind == "hindcast":
            # for thinking in real time as compute_hindcast renames init to time
            dim = ["time" if d == "init" else d for d in dim]
    return metric, comparison, dim


[docs] def compute_perfect_model( initialized, control=None, metric="pearson_r", comparison="m2e", dim=["member", "init"], **metric_kwargs, ): """ Compute a predictability skill score in a perfect-model framework. Args: initialized (xr.Dataset): ensemble with dims ``lead``, ``init``, ``member``. control (xr.Dataset): NOTE that this is a legacy argument from a former release. ``control`` is not used in ``compute_perfect_model`` anymore. metric (str): `metric` name, see :py:func:`climpred.utils.get_metric_class` and (see :ref:`Metrics`). comparison (str): `comparison` name defines what to take as forecast and verification (see :py:func:`climpred.utils.get_comparison_class` and :ref:`Comparisons`). dim (str or list of str): dimension to apply metric over. default: ['member', 'init'] ** metric_kwargs (dict): additional keywords to be passed to metric. (see the arguments required for a given metric in metrics.py) Returns: skill (xr.Dataset): skill score with dimensions as input `ds` without `dim`. """ # Check that init is int, cftime, or datetime; convert ints or datetime to cftime initialized = convert_time_index( initialized, "init", "initialized[init]", calendar=PM_CALENDAR_STR ) # check args compatible with each other metric, comparison, dim = _get_metric_comparison_dim( initialized, metric, comparison, dim, kind="perfect" ) forecast, verif = comparison.function(initialized, metric=metric) if metric.normalize or metric.allows_logical: metric_kwargs["comparison"] = comparison skill = metric.function(forecast, verif, dim=dim, **metric_kwargs) if comparison.name == "m2m" and M2M_MEMBER_DIM in skill.dims: skill = skill.mean(M2M_MEMBER_DIM) return skill