Skip to content

Support query optimization with Dask expression arrays#11382

Open
mrocklin wants to merge 5 commits into
pydata:mainfrom
mrocklin:codex/composite-expr-protocol
Open

Support query optimization with Dask expression arrays#11382
mrocklin wants to merge 5 commits into
pydata:mainfrom
mrocklin:codex/composite-expr-protocol

Conversation

@mrocklin

Copy link
Copy Markdown
Contributor

A while ago I finished Dask expression arrays which support query optimization. This PR supports them in Xarray. This required a few things:

  • Creating a new __dask_exprs__ protocol in Dask (see Support composite expressions with __dask_exprs__ protocol dask/dask#12457)
  • Implementing that protocol in Xarray (this does most of the lifting)
  • Building a array chunk manager (this was done in the dask-array project)
  • Some silliness around xarray's map_blocks function
  • Changing a few explicit uses of dask.array to instead use the chunk manager (these should probably be changed regardless)

Example

import dask
import dask_array
import xarray as xr
from xarray.namedarray.parallelcompat import get_chunked_array_type


ds = xr.tutorial.scatter_example_dataset(seed=42).chunk({"x": 1, "y": 1, "z": 2, "w": 2})

# The slice and rechunk start above the elementwise operation.  dask-array's
# optimizer can push them down so it only builds the small requested window.
window = (ds.A + ds.B).chunk({"y": 3}).isel(x=slice(0, 1), y=slice(0, 3))

tasks_before = len(window.__dask_graph__())
(optimized_window,) = dask.optimize(window)
optimized_data = window.data.optimize()
tasks_after = len(optimized_data.__dask_graph__())

manager = get_chunked_array_type(ds.A.data)

print(f"xarray chunk manager: {type(manager).__name__}")
print(f"dask.optimize result: {type(optimized_window).__name__}")
print(f"array type: {type(window.data).__module__}.{type(window.data).__name__}")
print(f"graph tasks before optimize: {tasks_before}")
print(f"graph tasks after optimize:  {tasks_after}")
print()
print("Before optimize:")
window.data.pprint()
print()
print("After optimize:")
optimized_data.pprint()

Output

xarray chunk manager: DaskArrayExprManager
dask.optimize result: DataArray
array type: dask_array._collection.Array
graph tasks before optimize: 448
graph tasks after optimize:  12

Before optimize:
  Operation                Shape    Bytes   Chunks
  Getitem           (1, 3, 4, 4)    384 B  1×3×2×2
  └ Rechunk        (3, 11, 4, 4)  4.1 kiB  1×3×2×2
    └ Add          (3, 11, 4, 4)  4.1 kiB  1×1×2×2
      ├ FromArray  (3, 11, 4, 4)  4.1 kiB  1×1×2×2
      └ FromArray  (3, 11, 4, 4)  4.1 kiB  1×1×2×2

After optimize:
  Operation           Shape  Bytes   Chunks
  Add          (1, 3, 4, 4)  384 B  1×3×2×2
  ├ FromArray  (1, 3, 4, 4)  384 B  1×3×2×2
  └ FromArray  (1, 3, 4, 4)  384 B  1×3×2×2

@mrocklin mrocklin force-pushed the codex/composite-expr-protocol branch from 147a748 to 3501992 Compare June 12, 2026 02:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant