Using dask with climpred

This demo demonstrates climpred’s capabilities with dask. This enables enables out-of-memory and parallel computation for large datasets with climpred.

import warnings

%matplotlib inline
import numpy as np
import xarray as xr
import dask
import climpred

warnings.filterwarnings("ignore")
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_2881/146010171.py in <module>
      5 import xarray as xr
      6 import dask
----> 7 import climpred
      8 
      9 warnings.filterwarnings("ignore")

~/checkouts/readthedocs.org/user_builds/climpred/checkouts/v2.2.0/climpred/__init__.py in <module>
      2 from pkg_resources import DistributionNotFound, get_distribution
      3 
----> 4 from . import (
      5     bias_removal,
      6     bootstrap,

~/checkouts/readthedocs.org/user_builds/climpred/checkouts/v2.2.0/climpred/bias_removal.py in <module>
     21     pass
     22 try:
---> 23     from xclim import sdba
     24 except ImportError:
     25     pass

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/__init__.py in <module>
     69 Other restrictions : ``map_blocks`` will remove any "auxiliary" coordinates before calling the wrapped function and will add them back on exit.
     70 """
---> 71 from . import detrending, processing, utils
     72 from .adjustment import *
     73 from .base import Grouper

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/detrending.py in <module>
     11 from .base import Grouper, ParametrizableWithDataset, map_groups, parse_group
     12 from .loess import loess_smoothing
---> 13 from .utils import ADDITIVE, apply_correction, invert
     14 
     15 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/utils.py in <module>
     16 
     17 from .base import Grouper, parse_group
---> 18 from .nbutils import _extrapolate_on_quantiles
     19 
     20 MULTIPLICATIVE = "*"

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/nbutils.py in <module>
    135 
    136 @njit([float32(float32[:, :]), float64(float64[:, :])], fastmath=True)
--> 137 def _autocorrelation(X):
    138     """Mean of the NxN pairwise distances of points in X of shape KxN.
    139 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/decorators.py in wrapper(func)
    224             with typeinfer.register_dispatcher(disp):
    225                 for sig in sigs:
--> 226                     disp.compile(sig)
    227                 disp.disable_compile()
    228         return disp

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in compile(self, sig)
    977                 with ev.trigger_event("numba:compile", data=ev_details):
    978                     try:
--> 979                         cres = self._compiler.compile(args, return_type)
    980                     except errors.ForceLiteralArg as e:
    981                         def folded(args, kws):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in compile(self, args, return_type)
    139 
    140     def compile(self, args, return_type):
--> 141         status, retval = self._compile_cached(args, return_type)
    142         if status:
    143             return retval

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in _compile_cached(self, args, return_type)
    153 
    154         try:
--> 155             retval = self._compile_core(args, return_type)
    156         except errors.TypingError as e:
    157             self._failed_cache[key] = e

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in _compile_core(self, args, return_type)
    166 
    167         impl = self._get_implementation(args, {})
--> 168         cres = compiler.compile_extra(self.targetdescr.typing_context,
    169                                       self.targetdescr.target_context,
    170                                       impl,

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    684     pipeline = pipeline_class(typingctx, targetctx, library,
    685                               args, return_type, flags, locals)
--> 686     return pipeline.compile_extra(func)
    687 
    688 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in compile_extra(self, func)
    426         self.state.lifted = ()
    427         self.state.lifted_from = None
--> 428         return self._compile_bytecode()
    429 
    430     def compile_ir(self, func_ir, lifted=(), lifted_from=None):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in _compile_bytecode(self)
    490         """
    491         assert self.state.func_ir is None
--> 492         return self._compile_core()
    493 
    494     def _compile_ir(self):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in _compile_core(self)
    460                 res = None
    461                 try:
--> 462                     pm.run(self.state)
    463                     if self.state.cr is not None:
    464                         break

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in run(self, state)
    332                 pass_inst = _pass_registry.get(pss).pass_inst
    333                 if isinstance(pass_inst, CompilerPass):
--> 334                     self._runPass(idx, pass_inst, state)
    335                 else:
    336                     raise BaseException("Legacy pass in use")

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     33         def _acquire_compile_lock(*args, **kwargs):
     34             with self:
---> 35                 return func(*args, **kwargs)
     36         return _acquire_compile_lock
     37 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in _runPass(self, index, pss, internal_state)
    287             mutated |= check(pss.run_initialization, internal_state)
    288         with SimpleTimer() as pass_time:
--> 289             mutated |= check(pss.run_pass, internal_state)
    290         with SimpleTimer() as finalize_time:
    291             mutated |= check(pss.run_finalizer, internal_state)

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in check(func, compiler_state)
    260 
    261         def check(func, compiler_state):
--> 262             mangled = func(compiler_state)
    263             if mangled not in (True, False):
    264                 msg = ("CompilerPass implementations should return True/False. "

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/typed_passes.py in run_pass(self, state)
    423                 # Insert native function for use by other jitted-functions.
    424                 # We also register its library to allow for inlining.
--> 425                 cfunc = targetctx.get_executable(library, fndesc, env)
    426                 targetctx.insert_user_function(cfunc, fndesc, [library])
    427                 state['cr'] = _LowerResult(fndesc, call_helper,

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/cpu.py in get_executable(self, library, fndesc, env)
    228         """
    229         # Code generation
--> 230         baseptr = library.get_pointer_to_function(fndesc.llvm_func_name)
    231         fnptr = library.get_pointer_to_function(fndesc.llvm_cpython_wrapper_name)
    232 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in get_pointer_to_function(self, name)
    986             - non-zero if the symbol is defined.
    987         """
--> 988         self._ensure_finalized()
    989         ee = self._codegen._engine
    990         if not ee.is_symbol_defined(name):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in _ensure_finalized(self)
    566     def _ensure_finalized(self):
    567         if not self._finalized:
--> 568             self.finalize()
    569 
    570     def create_ir_module(self, name):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in finalize(self)
    761         # Optimize the module after all dependences are linked in above,
    762         # to allow for inlining.
--> 763         self._optimize_final_module()
    764 
    765         self._final_module.verify()

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in _optimize_final_module(self)
    681         with self._recorded_timings.record(full_name):
    682             # The full optimisation suite is then run on the refop pruned IR
--> 683             self._codegen._mpm_full.run(self._final_module)
    684 
    685     def _get_module_for_linking(self):

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/llvmlite/binding/passmanagers.py in run(self, module)
    205         Run optimization passes on the given module.
    206         """
--> 207         return ffi.lib.LLVMPY_RunPassManager(self, module)
    208 
    209 

~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/llvmlite/binding/ffi.py in __call__(self, *args, **kwargs)
    149     def __call__(self, *args, **kwargs):
    150         with self._lock:
--> 151             return self._cfn(*args, **kwargs)
    152 
    153 

KeyboardInterrupt: 
from dask.distributed import Client
import multiprocessing
ncpu = multiprocessing.cpu_count()
processes = False
nworker = 2
threads = ncpu // nworker
print(
    f"Number of CPUs: {ncpu}, number of threads: {threads}, number of workers: {nworker}, processes: {processes}",
)
client = Client(
    processes=processes,
    threads_per_worker=threads,
    n_workers=nworker,
    memory_limit="64GB",
)
client
Number of CPUs: 4, number of threads: 2, number of workers: 2, processes: False

Client

Cluster

  • Workers: 2
  • Cores: 4
  • Memory: 34.36 GB

Synthetic data

# generic
ny, nx = 256, 220
nl, ni, nm = 20, 12, 10
init = xr.DataArray(np.random.random((nl, ni, nm, ny, nx)), dims=('lead', 'init', 'member', 'y', 'x'))
init.name='var'
init['init'] = np.arange(3000, 3300, 300 // ni)
init['lead'] = np.arange(1,1+init.lead.size)
control = xr.DataArray(np.random.random((300, ny, nx)),dims=('time', 'y', 'x'))
control.name='var'
control['time'] = np.arange(3000, 3300)

pm = climpred.PerfectModelEnsemble(init).add_control(control)

verify()

PerfectModelEnsemble.verify()

kw = {'comparison':'m2e', 'metric':'rmse', 'dim':['init', 'member']}

without dask

%time s = pm.verify(**kw)
CPU times: user 12 s, sys: 7.08 s, total: 19.1 s
Wall time: 19.1 s
  • 2 core Mac Book Pro 2018: CPU times: user 11.5 s, sys: 6.88 s, total: 18.4 s Wall time: 19.6 s

  • 24 core mistral node: CPU times: user 9.22 s, sys: 10.3 s, total: 19.6 s Wall time: 19.5 s

with dask

In order to use dask efficient, we need to chunk the data appropriately. Processing chunks of data lazily with dask creates a tiny overhead per dask, therefore chunking mostly makes sense when applying it to large data.

It is important that the data is chunked along a different dimension than dim passed to verify()!

chunked_dim = 'y'
chunks = {chunked_dim:init[chunked_dim].size // nworker}
pm_chunked = pm.chunk(chunks)
# if memory allows
# pm_chunked = pm_chunked.persist()
pm_chunked.get_initialized()['var'].data
Array Chunk
Bytes 1.08 GB 540.67 MB
Shape (20, 12, 10, 256, 220) (20, 12, 10, 128, 220)
Count 2 Tasks 2 Chunks
Type float64 numpy.ndarray
12 20 220 256 10
%%time
s_chunked = pm_chunked.verify(**kw)
assert dask.is_dask_collection(s_chunked)
s_chunked = s_chunked.compute()
CPU times: user 35.8 s, sys: 18.5 s, total: 54.3 s
Wall time: 19.5 s
  • 2 core Mac Book Pro 2018: CPU times: user 2min 35s, sys: 1min 4s, total: 3min 40s Wall time: 2min 10s

  • 24 core mistral node: CPU times: user 26.2 s, sys: 1min 37s, total: 2min 3s Wall time: 5.38 s

try:
    xr.testing.assert_allclose(s,s_chunked,atol=1e-6)
except AssertionError:
    for v in s.data_vars:
        (s-s_chunked)[v].plot(robust=True, col='lead')
  • The results s and s_chunked are identical as requested.

  • Chunking reduces Wall time from 20s to 5s on supercomputer.


bootstrap()

This speedup translates into PerfectModelEnsemble.bootstrap(), where bootstrapped resamplings of intializialized, uninitialized and persistence skill are computed and then translated into p values and confidence intervals.

kwp = kw.copy()
kwp['iterations'] = 4

without dask

%time s_p = pm.bootstrap(**kwp)
  • 2 core Mac Book Pro 2018: CPU times: user 2min 3s, sys: 1min 22s, total: 3min 26s Wall time: 3min 43s

  • 24 core mistral node: CPU times: user 1min 51s, sys: 1min 54s, total: 3min 45s Wall time: 3min 25s

with dask

When ds is chunked, PerfectModelEnsemble.bootstrap() performs all skill calculations on resampled inputs in parallel.

%time s_p_chunked = pm_chunked.bootstrap(**kwp).compute()
  • 2 core Mac Book Pro 2018: CPU times: user 2min 35s, sys: 1min 4s, total: 3min 40s Wall time: 2min 10s

  • 24 core mistral node: CPU times: user 2min 55s, sys: 8min 8s, total: 11min 3s Wall time: 1min 53s