-
-
Notifications
You must be signed in to change notification settings - Fork 756
Description
Upon calling one of
dask.persist()dask.compute()distributed.Client.persist()distributed.Client.compute()
on a xarray.DataArray or xarray.Dataset, the indices and all other non-Dask variables are serialized and sent to the Dask scheduler and then back to the client when the computation is done. This is unnecessary and can be extremely costly, as the whole data is serialized, sent over the network, and stored on the scheduler for the whole duration of the computation. In case of object string variables, there is also a substantial impact on the GIL.
This issue also impacts any other third-party library that defines Dask collections with non-trivial contents outside of their Dask graph.
Reproducer
import dask
import dask.array as da
import numpy as np
import distributed
import xarray
import pickle
from dask.utils import format_bytes
if __name__ == "__main__":
a = xarray.DataArray(
da.random.random(10_000_000),
dims=["x"],
coords={"x": np.arange(0, 100_000_000, 10)},
)
print("__dask_graph__()", format_bytes(len(pickle.dumps(a.__dask_graph__()))))
print("__dask_keys__()", format_bytes(len(pickle.dumps(a.__dask_keys__()))))
print("__dask_postpersist__()", format_bytes(len(pickle.dumps(a.__dask_postpersist__()))))
print("__dask_postcompute__()", format_bytes(len(pickle.dumps(a.__dask_postcompute__()))))
with distributed.Client() as client:
b = client.persist(a) # sends 152 MiB
# b, = dask.persist(a) # sends 152 MiB
# b = a.persist() # sends 3 kiB; see below
c = b.sum()
del b
c.compute()Output:
__dask_graph__() 3.10 kiB
__dask_keys__() 71 B
__dask_postpersist__() 152.59 MiB
__dask_postcompute__() 152.59 MiB
/home/crusaderky/github/enkiext/.pixi/envs/default/lib/python3.14/site-packages/distributed/client.py:3387: UserWarning: Sending large graph of size 152.59 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly
or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
Analysis
Upon calling persist(), the Dask Client should dismember the xarray collection by calling its hooks __dask_graph__(), __dask_keys__(), and __dask_postpersist__(). Upon calling compute(), Dask should instead use __dask_graph__(), __dask_keys__(), and __dask_postcompute__().
Crucially, in xarray the return value of __dask_postpersist__() and __dask_postcompute__() embeds all non-Dask variables. The output of these functions is supposed to be stored on the Client and never serialized.
Instead, distributed.Client.persist calls collections_to_expr:
distributed/distributed/client.py
Line 3811 in 991fda7
| expr = collections_to_expr(collections, optimize_graph) |
which in turn calls HLGExpr.from_collection:
https://github.com/dask/dask/blob/714eaf4e3515ec5c65cbc61af22a932c84d156c6/dask/base.py#L445
which in turn embeds the whole contents of __dask_postpersist__ in its return value, which is then sent to the scheduler:
https://github.com/dask/dask/blob/714eaf4e3515ec5c65cbc61af22a932c84d156c6/dask/_expr.py#L1020
Workarounds
-
use
xarray.DataArray.persistand other xarray methods, which are unaffected due to the apparently convoluted way they are implemented: https://github.com/pydata/xarray/blob/37f2d49b5cfbf5ca7e24dbad6347f99f5a24a368/xarray/core/dataset.py#L800-L815
This however can result in inefficiency when persisting, and even more so computing, multiple collections at once.
tmp = a.drop_vars(a.coords)
tmp, = dask.persist(tmp)
b = xarray.merge([a.coords.to_dataset(), tmp])