Using dask
with climpred
¶
This demo demonstrates climpred
’s capabilities with dask
. This enables enables out-of-memory and parallel computation for large datasets with climpred
.
import warnings
%matplotlib inline
import numpy as np
import xarray as xr
import dask
import climpred
warnings.filterwarnings("ignore")
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
/tmp/ipykernel_2881/146010171.py in <module>
5 import xarray as xr
6 import dask
----> 7 import climpred
8
9 warnings.filterwarnings("ignore")
~/checkouts/readthedocs.org/user_builds/climpred/checkouts/v2.2.0/climpred/__init__.py in <module>
2 from pkg_resources import DistributionNotFound, get_distribution
3
----> 4 from . import (
5 bias_removal,
6 bootstrap,
~/checkouts/readthedocs.org/user_builds/climpred/checkouts/v2.2.0/climpred/bias_removal.py in <module>
21 pass
22 try:
---> 23 from xclim import sdba
24 except ImportError:
25 pass
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/__init__.py in <module>
69 Other restrictions : ``map_blocks`` will remove any "auxiliary" coordinates before calling the wrapped function and will add them back on exit.
70 """
---> 71 from . import detrending, processing, utils
72 from .adjustment import *
73 from .base import Grouper
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/detrending.py in <module>
11 from .base import Grouper, ParametrizableWithDataset, map_groups, parse_group
12 from .loess import loess_smoothing
---> 13 from .utils import ADDITIVE, apply_correction, invert
14
15
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/utils.py in <module>
16
17 from .base import Grouper, parse_group
---> 18 from .nbutils import _extrapolate_on_quantiles
19
20 MULTIPLICATIVE = "*"
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/xclim/sdba/nbutils.py in <module>
135
136 @njit([float32(float32[:, :]), float64(float64[:, :])], fastmath=True)
--> 137 def _autocorrelation(X):
138 """Mean of the NxN pairwise distances of points in X of shape KxN.
139
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/decorators.py in wrapper(func)
224 with typeinfer.register_dispatcher(disp):
225 for sig in sigs:
--> 226 disp.compile(sig)
227 disp.disable_compile()
228 return disp
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in compile(self, sig)
977 with ev.trigger_event("numba:compile", data=ev_details):
978 try:
--> 979 cres = self._compiler.compile(args, return_type)
980 except errors.ForceLiteralArg as e:
981 def folded(args, kws):
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in compile(self, args, return_type)
139
140 def compile(self, args, return_type):
--> 141 status, retval = self._compile_cached(args, return_type)
142 if status:
143 return retval
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in _compile_cached(self, args, return_type)
153
154 try:
--> 155 retval = self._compile_core(args, return_type)
156 except errors.TypingError as e:
157 self._failed_cache[key] = e
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/dispatcher.py in _compile_core(self, args, return_type)
166
167 impl = self._get_implementation(args, {})
--> 168 cres = compiler.compile_extra(self.targetdescr.typing_context,
169 self.targetdescr.target_context,
170 impl,
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
684 pipeline = pipeline_class(typingctx, targetctx, library,
685 args, return_type, flags, locals)
--> 686 return pipeline.compile_extra(func)
687
688
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in compile_extra(self, func)
426 self.state.lifted = ()
427 self.state.lifted_from = None
--> 428 return self._compile_bytecode()
429
430 def compile_ir(self, func_ir, lifted=(), lifted_from=None):
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in _compile_bytecode(self)
490 """
491 assert self.state.func_ir is None
--> 492 return self._compile_core()
493
494 def _compile_ir(self):
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler.py in _compile_core(self)
460 res = None
461 try:
--> 462 pm.run(self.state)
463 if self.state.cr is not None:
464 break
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in run(self, state)
332 pass_inst = _pass_registry.get(pss).pass_inst
333 if isinstance(pass_inst, CompilerPass):
--> 334 self._runPass(idx, pass_inst, state)
335 else:
336 raise BaseException("Legacy pass in use")
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
33 def _acquire_compile_lock(*args, **kwargs):
34 with self:
---> 35 return func(*args, **kwargs)
36 return _acquire_compile_lock
37
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in _runPass(self, index, pss, internal_state)
287 mutated |= check(pss.run_initialization, internal_state)
288 with SimpleTimer() as pass_time:
--> 289 mutated |= check(pss.run_pass, internal_state)
290 with SimpleTimer() as finalize_time:
291 mutated |= check(pss.run_finalizer, internal_state)
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/compiler_machinery.py in check(func, compiler_state)
260
261 def check(func, compiler_state):
--> 262 mangled = func(compiler_state)
263 if mangled not in (True, False):
264 msg = ("CompilerPass implementations should return True/False. "
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/typed_passes.py in run_pass(self, state)
423 # Insert native function for use by other jitted-functions.
424 # We also register its library to allow for inlining.
--> 425 cfunc = targetctx.get_executable(library, fndesc, env)
426 targetctx.insert_user_function(cfunc, fndesc, [library])
427 state['cr'] = _LowerResult(fndesc, call_helper,
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/cpu.py in get_executable(self, library, fndesc, env)
228 """
229 # Code generation
--> 230 baseptr = library.get_pointer_to_function(fndesc.llvm_func_name)
231 fnptr = library.get_pointer_to_function(fndesc.llvm_cpython_wrapper_name)
232
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in get_pointer_to_function(self, name)
986 - non-zero if the symbol is defined.
987 """
--> 988 self._ensure_finalized()
989 ee = self._codegen._engine
990 if not ee.is_symbol_defined(name):
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in _ensure_finalized(self)
566 def _ensure_finalized(self):
567 if not self._finalized:
--> 568 self.finalize()
569
570 def create_ir_module(self, name):
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in finalize(self)
761 # Optimize the module after all dependences are linked in above,
762 # to allow for inlining.
--> 763 self._optimize_final_module()
764
765 self._final_module.verify()
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/numba/core/codegen.py in _optimize_final_module(self)
681 with self._recorded_timings.record(full_name):
682 # The full optimisation suite is then run on the refop pruned IR
--> 683 self._codegen._mpm_full.run(self._final_module)
684
685 def _get_module_for_linking(self):
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/llvmlite/binding/passmanagers.py in run(self, module)
205 Run optimization passes on the given module.
206 """
--> 207 return ffi.lib.LLVMPY_RunPassManager(self, module)
208
209
~/checkouts/readthedocs.org/user_builds/climpred/conda/v2.2.0/lib/python3.9/site-packages/llvmlite/binding/ffi.py in __call__(self, *args, **kwargs)
149 def __call__(self, *args, **kwargs):
150 with self._lock:
--> 151 return self._cfn(*args, **kwargs)
152
153
KeyboardInterrupt:
Load a
Client
to usedask.distributed
: stackoverflow(Optionally) Use the
dask
dashboard to visualize performance
from dask.distributed import Client
import multiprocessing
ncpu = multiprocessing.cpu_count()
processes = False
nworker = 2
threads = ncpu // nworker
print(
f"Number of CPUs: {ncpu}, number of threads: {threads}, number of workers: {nworker}, processes: {processes}",
)
client = Client(
processes=processes,
threads_per_worker=threads,
n_workers=nworker,
memory_limit="64GB",
)
client
Number of CPUs: 4, number of threads: 2, number of workers: 2, processes: False
Client
|
Cluster
|
Synthetic data¶
# generic
ny, nx = 256, 220
nl, ni, nm = 20, 12, 10
init = xr.DataArray(np.random.random((nl, ni, nm, ny, nx)), dims=('lead', 'init', 'member', 'y', 'x'))
init.name='var'
init['init'] = np.arange(3000, 3300, 300 // ni)
init['lead'] = np.arange(1,1+init.lead.size)
control = xr.DataArray(np.random.random((300, ny, nx)),dims=('time', 'y', 'x'))
control.name='var'
control['time'] = np.arange(3000, 3300)
pm = climpred.PerfectModelEnsemble(init).add_control(control)
verify()
¶
kw = {'comparison':'m2e', 'metric':'rmse', 'dim':['init', 'member']}
without dask
¶
%time s = pm.verify(**kw)
CPU times: user 12 s, sys: 7.08 s, total: 19.1 s
Wall time: 19.1 s
2 core Mac Book Pro 2018: CPU times: user 11.5 s, sys: 6.88 s, total: 18.4 s Wall time: 19.6 s
24 core mistral node: CPU times: user 9.22 s, sys: 10.3 s, total: 19.6 s Wall time: 19.5 s
with dask
¶
In order to use dask
efficient, we need to chunk the data appropriately. Processing chunks of data lazily with dask
creates a tiny overhead per dask, therefore chunking mostly makes sense when applying it to large data.
It is important that the data is chunked along a different dimension than
dim
passed toverify()
!
chunked_dim = 'y'
chunks = {chunked_dim:init[chunked_dim].size // nworker}
pm_chunked = pm.chunk(chunks)
# if memory allows
# pm_chunked = pm_chunked.persist()
pm_chunked.get_initialized()['var'].data
|
%%time
s_chunked = pm_chunked.verify(**kw)
assert dask.is_dask_collection(s_chunked)
s_chunked = s_chunked.compute()
CPU times: user 35.8 s, sys: 18.5 s, total: 54.3 s
Wall time: 19.5 s
2 core Mac Book Pro 2018: CPU times: user 2min 35s, sys: 1min 4s, total: 3min 40s Wall time: 2min 10s
24 core mistral node: CPU times: user 26.2 s, sys: 1min 37s, total: 2min 3s Wall time: 5.38 s
try:
xr.testing.assert_allclose(s,s_chunked,atol=1e-6)
except AssertionError:
for v in s.data_vars:
(s-s_chunked)[v].plot(robust=True, col='lead')
The results
s
ands_chunked
are identical as requested.Chunking reduces Wall time from 20s to 5s on supercomputer.
bootstrap()
¶
This speedup translates into PerfectModelEnsemble.bootstrap()
, where bootstrapped resamplings of intializialized, uninitialized and persistence skill are computed and then translated into p values and confidence intervals.
kwp = kw.copy()
kwp['iterations'] = 4
without dask
¶
%time s_p = pm.bootstrap(**kwp)
2 core Mac Book Pro 2018: CPU times: user 2min 3s, sys: 1min 22s, total: 3min 26s Wall time: 3min 43s
24 core mistral node: CPU times: user 1min 51s, sys: 1min 54s, total: 3min 45s Wall time: 3min 25s
with dask
¶
When ds
is chunked, PerfectModelEnsemble.bootstrap()
performs all skill calculations on resampled inputs in parallel.
%time s_p_chunked = pm_chunked.bootstrap(**kwp).compute()
2 core Mac Book Pro 2018: CPU times: user 2min 35s, sys: 1min 4s, total: 3min 40s Wall time: 2min 10s
24 core mistral node: CPU times: user 2min 55s, sys: 8min 8s, total: 11min 3s Wall time: 1min 53s