Skip to content

Xarray indices and other non-Dask variables are round-tripped to the Dask scheduler #9200

@crusaderky

Description

@crusaderky

Upon calling one of

  • dask.persist()
  • dask.compute()
  • distributed.Client.persist()
  • distributed.Client.compute()

on a xarray.DataArray or xarray.Dataset, the indices and all other non-Dask variables are serialized and sent to the Dask scheduler and then back to the client when the computation is done. This is unnecessary and can be extremely costly, as the whole data is serialized, sent over the network, and stored on the scheduler for the whole duration of the computation. In case of object string variables, there is also a substantial impact on the GIL.

This issue also impacts any other third-party library that defines Dask collections with non-trivial contents outside of their Dask graph.

Reproducer

import dask
import dask.array as da
import numpy as np
import distributed
import xarray
import pickle
from dask.utils import format_bytes

if __name__ == "__main__":
    a = xarray.DataArray(
        da.random.random(10_000_000),
        dims=["x"],
        coords={"x": np.arange(0, 100_000_000, 10)},
    )

    print("__dask_graph__()", format_bytes(len(pickle.dumps(a.__dask_graph__()))))
    print("__dask_keys__()", format_bytes(len(pickle.dumps(a.__dask_keys__()))))
    print("__dask_postpersist__()", format_bytes(len(pickle.dumps(a.__dask_postpersist__()))))
    print("__dask_postcompute__()", format_bytes(len(pickle.dumps(a.__dask_postcompute__()))))

    with distributed.Client() as client:
        b = client.persist(a)  # sends 152 MiB
        # b, = dask.persist(a)  # sends 152 MiB
        # b = a.persist()  # sends 3 kiB; see below
        c = b.sum()
        del b
        c.compute()

Output:

__dask_graph__() 3.10 kiB
__dask_keys__() 71 B
__dask_postpersist__() 152.59 MiB
__dask_postcompute__() 152.59 MiB
/home/crusaderky/github/enkiext/.pixi/envs/default/lib/python3.14/site-packages/distributed/client.py:3387: UserWarning: Sending large graph of size 152.59 MiB.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.

Analysis

Upon calling persist(), the Dask Client should dismember the xarray collection by calling its hooks __dask_graph__(), __dask_keys__(), and __dask_postpersist__(). Upon calling compute(), Dask should instead use __dask_graph__(), __dask_keys__(), and __dask_postcompute__().

Crucially, in xarray the return value of __dask_postpersist__() and __dask_postcompute__() embeds all non-Dask variables. The output of these functions is supposed to be stored on the Client and never serialized.

Instead, distributed.Client.persist calls collections_to_expr:

expr = collections_to_expr(collections, optimize_graph)

which in turn calls HLGExpr.from_collection:
https://github.com/dask/dask/blob/714eaf4e3515ec5c65cbc61af22a932c84d156c6/dask/base.py#L445

which in turn embeds the whole contents of __dask_postpersist__ in its return value, which is then sent to the scheduler:
https://github.com/dask/dask/blob/714eaf4e3515ec5c65cbc61af22a932c84d156c6/dask/_expr.py#L1020

Workarounds

  1. use xarray.DataArray.persist and other xarray methods, which are unaffected due to the apparently convoluted way they are implemented: https://github.com/pydata/xarray/blob/37f2d49b5cfbf5ca7e24dbad6347f99f5a24a368/xarray/core/dataset.py#L800-L815
    This however can result in inefficiency when persisting, and even more so computing, multiple collections at once.

tmp = a.drop_vars(a.coords)
tmp, = dask.persist(tmp)
b = xarray.merge([a.coords.to_dataset(), tmp])

Metadata

Metadata

Assignees

No one assigned

    Labels

    memoryp2Affects more than a few users but doesn't prevent core functionsperformance

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions