#!conda create -n ML_gpu tensorflow-gpu pytorch-gpu xarray dask matplotlib nb_conda_kernels jupyterlab cudatoolkit cupy esmtools climpred -y
#!pip install git+https://github.com/jacobtomlinson/cupy-xarray.git
!conda list cupy
# packages in environment at /home/docs/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0:
#
# Name                    Version                   Build  Channel
def nvidia_smi():
    import subprocess
    p = subprocess.check_output('nvidia-smi').strip().decode('utf-8')
    return p

climpred on CPU vs GPU

import xarray as xr
import numpy as np
from climpred.tutorial import load_dataset
from climpred import PerfectModelEnsemble
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_2855/2258333576.py in <module>
      1 import xarray as xr
      2 import numpy as np
----> 3 from climpred.tutorial import load_dataset
      4 from climpred import PerfectModelEnsemble

~/checkouts/readthedocs.org/user_builds/climpred/checkouts/v2.2.0/climpred/__init__.py in <module>
      2 from pkg_resources import DistributionNotFound, get_distribution
      3 
----> 4 from . import (
      5     bias_removal,
      6     bootstrap,

~/checkouts/readthedocs.org/user_builds/climpred/checkouts/v2.2.0/climpred/bias_removal.py in <module>
     21     pass
     22 try:
---> 23     from xclim import sdba
     24 except ImportError:
     25     pass

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/__init__.py in <module>
     69 Other restrictions : ``map_blocks`` will remove any "auxiliary" coordinates before calling the wrapped function and will add them back on exit.
     70 """
---> 71 from . import detrending, processing, utils
     72 from .adjustment import *
     73 from .base import Grouper

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/detrending.py in <module>
     11 from .base import Grouper, ParametrizableWithDataset, map_groups, parse_group
     12 from .loess import loess_smoothing
---> 13 from .utils import ADDITIVE, apply_correction, invert
     14 
     15 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/utils.py in <module>
     16 
     17 from .base import Grouper, parse_group
---> 18 from .nbutils import _extrapolate_on_quantiles
     19 
     20 MULTIPLICATIVE = "*"

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/nbutils.py in <module>
    135 
    136 @njit([float32(float32[:, :]), float64(float64[:, :])], fastmath=True)
--> 137 def _autocorrelation(X):
    138     """Mean of the NxN pairwise distances of points in X of shape KxN.
    139 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/decorators.py in wrapper(func)
    224             with typeinfer.register_dispatcher(disp):
    225                 for sig in sigs:
--> 226                     disp.compile(sig)
    227                 disp.disable_compile()
    228         return disp

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in compile(self, sig)
    977                 with ev.trigger_event("numba:compile", data=ev_details):
    978                     try:
--> 979                         cres = self._compiler.compile(args, return_type)
    980                     except errors.ForceLiteralArg as e:
    981                         def folded(args, kws):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in compile(self, args, return_type)
    139 
    140     def compile(self, args, return_type):
--> 141         status, retval = self._compile_cached(args, return_type)
    142         if status:
    143             return retval

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in _compile_cached(self, args, return_type)
    153 
    154         try:
--> 155             retval = self._compile_core(args, return_type)
    156         except errors.TypingError as e:
    157             self._failed_cache[key] = e

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in _compile_core(self, args, return_type)
    166 
    167         impl = self._get_implementation(args, {})
--> 168         cres = compiler.compile_extra(self.targetdescr.typing_context,
    169                                       self.targetdescr.target_context,
    170                                       impl,

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    684     pipeline = pipeline_class(typingctx, targetctx, library,
    685                               args, return_type, flags, locals)
--> 686     return pipeline.compile_extra(func)
    687 
    688 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in compile_extra(self, func)
    426         self.state.lifted = ()
    427         self.state.lifted_from = None
--> 428         return self._compile_bytecode()
    429 
    430     def compile_ir(self, func_ir, lifted=(), lifted_from=None):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in _compile_bytecode(self)
    490         """
    491         assert self.state.func_ir is None
--> 492         return self._compile_core()
    493 
    494     def _compile_ir(self):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in _compile_core(self)
    460                 res = None
    461                 try:
--> 462                     pm.run(self.state)
    463                     if self.state.cr is not None:
    464                         break

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in run(self, state)
    332                 pass_inst = _pass_registry.get(pss).pass_inst
    333                 if isinstance(pass_inst, CompilerPass):
--> 334                     self._runPass(idx, pass_inst, state)
    335                 else:
    336                     raise BaseException("Legacy pass in use")

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     33         def _acquire_compile_lock(*args, **kwargs):
     34             with self:
---> 35                 return func(*args, **kwargs)
     36         return _acquire_compile_lock
     37 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in _runPass(self, index, pss, internal_state)
    287             mutated |= check(pss.run_initialization, internal_state)
    288         with SimpleTimer() as pass_time:
--> 289             mutated |= check(pss.run_pass, internal_state)
    290         with SimpleTimer() as finalize_time:
    291             mutated |= check(pss.run_finalizer, internal_state)

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in check(func, compiler_state)
    260 
    261         def check(func, compiler_state):
--> 262             mangled = func(compiler_state)
    263             if mangled not in (True, False):
    264                 msg = ("CompilerPass implementations should return True/False. "

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/typed_passes.py in run_pass(self, state)
    423                 # Insert native function for use by other jitted-functions.
    424                 # We also register its library to allow for inlining.
--> 425                 cfunc = targetctx.get_executable(library, fndesc, env)
    426                 targetctx.insert_user_function(cfunc, fndesc, [library])
    427                 state['cr'] = _LowerResult(fndesc, call_helper,

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/cpu.py in get_executable(self, library, fndesc, env)
    228         """
    229         # Code generation
--> 230         baseptr = library.get_pointer_to_function(fndesc.llvm_func_name)
    231         fnptr = library.get_pointer_to_function(fndesc.llvm_cpython_wrapper_name)
    232 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in get_pointer_to_function(self, name)
    986             - non-zero if the symbol is defined.
    987         """
--> 988         self._ensure_finalized()
    989         ee = self._codegen._engine
    990         if not ee.is_symbol_defined(name):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in _ensure_finalized(self)
    566     def _ensure_finalized(self):
    567         if not self._finalized:
--> 568             self.finalize()
    569 
    570     def create_ir_module(self, name):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in finalize(self)
    761         # Optimize the module after all dependences are linked in above,
    762         # to allow for inlining.
--> 763         self._optimize_final_module()
    764 
    765         self._final_module.verify()

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in _optimize_final_module(self)
    681         with self._recorded_timings.record(full_name):
    682             # The full optimisation suite is then run on the refop pruned IR
--> 683             self._codegen._mpm_full.run(self._final_module)
    684 
    685     def _get_module_for_linking(self):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/llvmlite/binding/passmanagers.py in run(self, module)
    205         Run optimization passes on the given module.
    206         """
--> 207         return ffi.lib.LLVMPY_RunPassManager(self, module)
    208 
    209 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/llvmlite/binding/ffi.py in __call__(self, *args, **kwargs)
    149     def __call__(self, *args, **kwargs):
    150         with self._lock:
--> 151             return self._cfn(*args, **kwargs)
    152 
    153 

KeyboardInterrupt: 

CPU

v = "tos"
ds3d = load_dataset("MPI-PM-DP-3D")[v]
ds3d.lead.attrs["unit"] = "years"
control3d = load_dataset("MPI-control-3D")[v]

pm_cpu = PerfectModelEnsemble(ds3d)
pm_cpu = pm_cpu.add_control(control3d)
type(ds3d.data)
numpy.ndarray
%timeit _ = pm_cpu.mean(['x','y'])
24.3 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit skill_cpu = pm_cpu.verify(metric="pearson_r", comparison="m2m", dim=["init", "member"])[v]
667 ms ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

GPU

import cupy_xarray
v = "tos"
ds3d = load_dataset("MPI-PM-DP-3D")[v].as_cupy()
ds3d.lead.attrs["unit"] = "years"
control3d = load_dataset("MPI-control-3D")[v].as_cupy()

pm_gpu = PerfectModelEnsemble(ds3d)
pm_gpu = pm_gpu.add_control(control3d)
/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/utils.py:140: UserWarning: Assuming annual resolution due to numeric inits. Change init to a datetime if it is another resolution.
  warnings.warn(
type(ds3d.data)
cupy.core.core.ndarray
%timeit _ = pm_gpu.mean(['x','y'])
423 µs ± 15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit skill_gpu = pm_gpu.verify(metric="pearson_r", comparison="m2m", dim=["init", "member"])[v]
64.1 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

🚀 GPU has the potential to increase verify by factor 10! 🚀

GPU Limitations

with PerfectModelEnsemble.bootstrap()

pm_gpu.bootstrap(metric="pearson_r", comparison="m2m", dim=["init", "member"], iterations=10)[v]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-20-45b2445bf7cd> in <module>
----> 1 pm_gpu.bootstrap(metric="pearson_r", comparison="m2m", dim=["init", "member"], iterations=10)[v]

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/classes.py in bootstrap(self, metric, comparison, dim, reference, iterations, sig, pers_sig, **metric_kwargs)
    860             "init": True,
    861         }
--> 862         return self._apply_climpred_function(
    863             bootstrap_perfect_model,
    864             input_dict=input_dict,

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/classes.py in _apply_climpred_function(self, func, input_dict, **kwargs)
    545         if control:
    546             control = control.drop_vars(ctrl_vars)
--> 547         return func(ensemble, control, **kwargs)
    548 
    549     def _vars_to_drop(self, init=True):

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/bootstrap.py in bootstrap_perfect_model(init_pm, control, metric, comparison, dim, reference, resample_dim, sig, iterations, pers_sig, reference_compute, **metric_kwargs)
   1259     )
   1260     lead_units_equal_control_time_stride(init_pm, control)
-> 1261     return bootstrap_compute(
   1262         init_pm,
   1263         control,

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/bootstrap.py in bootstrap_compute(hind, verif, hist, alignment, metric, comparison, dim, reference, resample_dim, sig, iterations, pers_sig, compute, resample_uninit, reference_compute, **metric_kwargs)
    899 
    900     # get confidence intervals CI
--> 901     init_ci = _distribution_to_ci(bootstrapped_init_skill, ci_low, ci_high)
    902     if "uninitialized" in reference:
    903         uninit_ci = _distribution_to_ci(bootstrapped_uninit_skill, ci_low, ci_high)

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/bootstrap.py in _distribution_to_ci(ds, ci_low, ci_high, dim)
    211         if np.issubdtype(ds.dtype, np.bool_):
    212             ds = ds.astype(np.float_)  # fails on py>36 if boolean dtype
--> 213     return ds.quantile(q=[ci_low, ci_high], dim=dim, skipna=False)
    214 
    215 

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/dataset.py in quantile(self, q, dim, interpolation, numeric_only, keep_attrs, skipna)
   5819                             # the former is often more efficient
   5820                             reduce_dims = None
-> 5821                         variables[name] = var.quantile(
   5822                             q,
   5823                             dim=reduce_dims,

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/variable.py in quantile(self, q, dim, interpolation, keep_attrs, skipna)
   1951 
   1952         axis = np.arange(-1, -1 * len(dim) - 1, -1)
-> 1953         result = apply_ufunc(
   1954             _wrapper,
   1955             self,

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1136     # feed Variables directly through apply_variable_ufunc
   1137     elif any(isinstance(a, Variable) for a in args):
-> 1138         return variables_vfunc(*args)
   1139     else:
   1140         # feed anything else through apply_array_ufunc

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    722             )
    723 
--> 724     result_data = func(*input_data)
    725 
    726     if signature.num_outputs == 1:

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/variable.py in _wrapper(npa, **kwargs)
   1948         def _wrapper(npa, **kwargs):
   1949             # move quantile axis to end. required for apply_ufunc
-> 1950             return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1)
   1951 
   1952         axis = np.arange(-1, -1 * len(dim) - 1, -1)

<__array_function__ internals> in quantile(*args, **kwargs)

TypeError: no implementation found for 'numpy.quantile' on types that implement __array_function__: [<class 'cupy.core.core.ndarray'>, <class 'numpy.ndarray'>]