[1]:
#!conda create -n ML_gpu tensorflow-gpu pytorch-gpu xarray dask matplotlib nb_conda_kernels jupyterlab cudatoolkit cupy esmtools climpred -y
#!pip install git+https://github.com/jacobtomlinson/cupy-xarray.git
[2]:
!conda list cupy
# packages in environment at /work/mh0727/m300524/conda-envs/ML_gpu:
#
# Name                    Version                   Build  Channel
cupy                      8.6.0            py39h694feb1_0    conda-forge
cupy-xarray               0.1.0+2.g08c5779          pypi_0    pypi
[3]:
def nvidia_smi():
    import subprocess
    p = subprocess.check_output('nvidia-smi').strip().decode('utf-8')
    return p

climpred on CPU vs GPU

[4]:
import xarray as xr
import numpy as np
from climpred.tutorial import load_dataset
from climpred import PerfectModelEnsemble

CPU

[6]:
v = "tos"
ds3d = load_dataset("MPI-PM-DP-3D")[v]
ds3d.lead.attrs["unit"] = "years"
control3d = load_dataset("MPI-control-3D")[v]

pm_cpu = PerfectModelEnsemble(ds3d)
pm_cpu = pm_cpu.add_control(control3d)
[7]:
type(ds3d.data)
[7]:
numpy.ndarray
[16]:
%timeit _ = pm_cpu.mean(['x','y'])
24.3 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
[18]:
%timeit skill_cpu = pm_cpu.verify(metric="pearson_r", comparison="m2m", dim=["init", "member"])[v]
667 ms ± 29.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

GPU

[10]:
import cupy_xarray
[11]:
v = "tos"
ds3d = load_dataset("MPI-PM-DP-3D")[v].as_cupy()
ds3d.lead.attrs["unit"] = "years"
control3d = load_dataset("MPI-control-3D")[v].as_cupy()

pm_gpu = PerfectModelEnsemble(ds3d)
pm_gpu = pm_gpu.add_control(control3d)
/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/utils.py:140: UserWarning: Assuming annual resolution due to numeric inits. Change init to a datetime if it is another resolution.
  warnings.warn(
[12]:
type(ds3d.data)
[12]:
cupy.core.core.ndarray
[15]:
%timeit _ = pm_gpu.mean(['x','y'])
423 µs ± 15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
[17]:
%timeit skill_gpu = pm_gpu.verify(metric="pearson_r", comparison="m2m", dim=["init", "member"])[v]
64.1 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

🚀 GPU has the potential to increase verify by factor 10! 🚀

GPU Limitations

[20]:
pm_gpu.bootstrap(metric="pearson_r", comparison="m2m", dim=["init", "member"], iterations=10)[v]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-20-45b2445bf7cd> in <module>
----> 1 pm_gpu.bootstrap(metric="pearson_r", comparison="m2m", dim=["init", "member"], iterations=10)[v]

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/classes.py in bootstrap(self, metric, comparison, dim, reference, iterations, sig, pers_sig, **metric_kwargs)
    860             "init": True,
    861         }
--> 862         return self._apply_climpred_function(
    863             bootstrap_perfect_model,
    864             input_dict=input_dict,

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/classes.py in _apply_climpred_function(self, func, input_dict, **kwargs)
    545         if control:
    546             control = control.drop_vars(ctrl_vars)
--> 547         return func(ensemble, control, **kwargs)
    548
    549     def _vars_to_drop(self, init=True):

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/bootstrap.py in bootstrap_perfect_model(init_pm, control, metric, comparison, dim, reference, resample_dim, sig, iterations, pers_sig, reference_compute, **metric_kwargs)
   1259     )
   1260     lead_units_equal_control_time_stride(init_pm, control)
-> 1261     return bootstrap_compute(
   1262         init_pm,
   1263         control,

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/bootstrap.py in bootstrap_compute(hind, verif, hist, alignment, metric, comparison, dim, reference, resample_dim, sig, iterations, pers_sig, compute, resample_uninit, reference_compute, **metric_kwargs)
    899
    900     # get confidence intervals CI
--> 901     init_ci = _distribution_to_ci(bootstrapped_init_skill, ci_low, ci_high)
    902     if "uninitialized" in reference:
    903         uninit_ci = _distribution_to_ci(bootstrapped_uninit_skill, ci_low, ci_high)

/mnt/lustre01/pf/zmaw/m300524/climpred/climpred/bootstrap.py in _distribution_to_ci(ds, ci_low, ci_high, dim)
    211         if np.issubdtype(ds.dtype, np.bool_):
    212             ds = ds.astype(np.float_)  # fails on py>36 if boolean dtype
--> 213     return ds.quantile(q=[ci_low, ci_high], dim=dim, skipna=False)
    214
    215

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/dataset.py in quantile(self, q, dim, interpolation, numeric_only, keep_attrs, skipna)
   5819                             # the former is often more efficient
   5820                             reduce_dims = None
-> 5821                         variables[name] = var.quantile(
   5822                             q,
   5823                             dim=reduce_dims,

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/variable.py in quantile(self, q, dim, interpolation, keep_attrs, skipna)
   1951
   1952         axis = np.arange(-1, -1 * len(dim) - 1, -1)
-> 1953         result = apply_ufunc(
   1954             _wrapper,
   1955             self,

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1136     # feed Variables directly through apply_variable_ufunc
   1137     elif any(isinstance(a, Variable) for a in args):
-> 1138         return variables_vfunc(*args)
   1139     else:
   1140         # feed anything else through apply_array_ufunc

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    722             )
    723
--> 724     result_data = func(*input_data)
    725
    726     if signature.num_outputs == 1:

/work/mh0727/m300524/conda-envs/ML_gpu/lib/python3.9/site-packages/xarray/core/variable.py in _wrapper(npa, **kwargs)
   1948         def _wrapper(npa, **kwargs):
   1949             # move quantile axis to end. required for apply_ufunc
-> 1950             return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1)
   1951
   1952         axis = np.arange(-1, -1 * len(dim) - 1, -1)

<__array_function__ internals> in quantile(*args, **kwargs)

TypeError: no implementation found for 'numpy.quantile' on types that implement __array_function__: [<class 'cupy.core.core.ndarray'>, <class 'numpy.ndarray'>]