Using dask with climpred

This demo demonstrates climpred’s capabilities with dask. This enables enables out-of-memory and parallel computation for large datasets with climpred.

import warnings

%matplotlib inline
import numpy as np
import xarray as xr
import dask
import climpred

warnings.filterwarnings("ignore")
from dask.distributed import Client
import multiprocessing
ncpu = multiprocessing.cpu_count()
processes = False
nworker = 2
threads = ncpu // nworker
print(
    f"Number of CPUs: {ncpu}, number of threads: {threads}, number of workers: {nworker}, processes: {processes}",
)
client = Client(
    processes=processes,
    threads_per_worker=threads,
    n_workers=nworker,
    memory_limit="64GB",
)
client
Number of CPUs: 2, number of threads: 1, number of workers: 2, processes: False

Client

Client-e0044c35-61de-11ec-8b3a-0242ac110002

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://172.17.0.2:8787/status

Cluster Info

Synthetic data

# generic
ny, nx = 256, 220
nl, ni, nm = 20, 12, 10
init = xr.DataArray(np.random.random((nl, ni, nm, ny, nx)), dims=('lead', 'init', 'member', 'y', 'x'))
init.name='var'
init['init'] = np.arange(3000, 3300, 300 // ni)
init['lead'] = np.arange(1,1+init.lead.size)
control = xr.DataArray(np.random.random((300, ny, nx)),dims=('time', 'y', 'x'))
control.name='var'
control['time'] = np.arange(3000, 3300)

pm = climpred.PerfectModelEnsemble(init).add_control(control)

verify()

PerfectModelEnsemble.verify()

kw = {'comparison':'m2e', 'metric':'rmse', 'dim':['init', 'member']}

without dask

%time s = pm.verify(**kw)
CPU times: user 10.8 s, sys: 15.8 s, total: 26.6 s
Wall time: 25.6 s
  • 2 core Mac Book Pro 2018: CPU times: user 11.5 s, sys: 6.88 s, total: 18.4 s Wall time: 19.6 s

  • 24 core mistral node: CPU times: user 9.22 s, sys: 10.3 s, total: 19.6 s Wall time: 19.5 s

with dask

In order to use dask efficient, we need to chunk the data appropriately. Processing chunks of data lazily with dask creates a tiny overhead per dask, therefore chunking mostly makes sense when applying it to large data.

It is important that the data is chunked along a different dimension than dim passed to verify()!

chunked_dim = 'y'
chunks = {chunked_dim:init[chunked_dim].size // nworker}
pm_chunked = pm.chunk(chunks)
# if memory allows
# pm_chunked = pm_chunked.persist()
pm_chunked.get_initialized()['var'].data
Array Chunk
Bytes 1.01 GiB 515.62 MiB
Shape (20, 12, 10, 256, 220) (20, 12, 10, 128, 220)
Count 2 Tasks 2 Chunks
Type float64 numpy.ndarray
12 20 220 256 10
%%time
s_chunked = pm_chunked.verify(**kw)
assert dask.is_dask_collection(s_chunked)
s_chunked = s_chunked.compute()
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<timed exec> in <module>

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/site-packages/xarray/core/dataset.py in compute(self, **kwargs)
   1026         """
   1027         new = self.copy(deep=False)
-> 1028         return new.load(**kwargs)
   1029 
   1030     def _persist_inplace(self, **kwargs) -> "Dataset":

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/site-packages/xarray/core/dataset.py in load(self, **kwargs)
    860 
    861             # evaluate all the dask arrays simultaneously
--> 862             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    863 
    864             for k, data in zip(lazy_data, evaluated_data):

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/site-packages/dask/base.py in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    569         postcomputes.append(x.__dask_postcompute__())
    570 
--> 571     results = schedule(dsk, keys, **kwargs)
    572     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    573 

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/site-packages/distributed/client.py in get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2689                     should_rejoin = False
   2690             try:
-> 2691                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2692             finally:
   2693                 for f in futures.values():

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1944             else:
   1945                 local_worker = None
-> 1946             return self.sync(
   1947                 self._gather,
   1948                 futures,

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/site-packages/distributed/utils.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    308             return future
    309         else:
--> 310             return sync(
    311                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    312             )

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    359     else:
    360         while not e.is_set():
--> 361             e.wait(10)
    362     if error[0]:
    363         typ, exc, tb = error[0]

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/threading.py in wait(self, timeout)
    572             signaled = self._flag
    573             if not signaled:
--> 574                 signaled = self._cond.wait(timeout)
    575             return signaled
    576 

~/checkouts/readthedocs.org/user_builds/climpred/conda/stable/lib/python3.9/threading.py in wait(self, timeout)
    314             else:
    315                 if timeout > 0:
--> 316                     gotit = waiter.acquire(True, timeout)
    317                 else:
    318                     gotit = waiter.acquire(False)

KeyboardInterrupt: 
  • 2 core Mac Book Pro 2018: CPU times: user 2min 35s, sys: 1min 4s, total: 3min 40s Wall time: 2min 10s

  • 24 core mistral node: CPU times: user 26.2 s, sys: 1min 37s, total: 2min 3s Wall time: 5.38 s

try:
    xr.testing.assert_allclose(s,s_chunked,atol=1e-6)
except AssertionError:
    for v in s.data_vars:
        (s-s_chunked)[v].plot(robust=True, col='lead')
  • The results s and s_chunked are identical as requested.

  • Chunking reduces Wall time from 20s to 5s on supercomputer.


bootstrap()

This speedup translates into PerfectModelEnsemble.bootstrap(), where bootstrapped resamplings of intializialized, uninitialized and persistence skill are computed and then translated into p values and confidence intervals.

kwp = kw.copy()
kwp['iterations'] = 4

without dask

%time s_p = pm.bootstrap(**kwp)
  • 2 core Mac Book Pro 2018: CPU times: user 2min 3s, sys: 1min 22s, total: 3min 26s Wall time: 3min 43s

  • 24 core mistral node: CPU times: user 1min 51s, sys: 1min 54s, total: 3min 45s Wall time: 3min 25s

with dask

When ds is chunked, PerfectModelEnsemble.bootstrap() performs all skill calculations on resampled inputs in parallel.

%time s_p_chunked = pm_chunked.bootstrap(**kwp).compute()
  • 2 core Mac Book Pro 2018: CPU times: user 2min 35s, sys: 1min 4s, total: 3min 40s Wall time: 2min 10s

  • 24 core mistral node: CPU times: user 2min 55s, sys: 8min 8s, total: 11min 3s Wall time: 1min 53s