Skip to content

Commit d87ba61

Browse files
authored
Minimize duplication in map_blocks task graph (#8412)
* Adapt map_blocks to use new Coordinates API * cleanup * typing fixes * Minimize duplication in `map_blocks` task graph Closes #8409 * Some more optimization * Refactor inserting of in memory data * [WIP] De-duplicate in expected["indexes"] * Revert "[WIP] De-duplicate in expected["indexes"]" This reverts commit 7276cbf. * Revert "Refactor inserting of in memory data" This reverts commit f6557f7. * Be more clever about scalar broadcasting * Small speedup * Small improvement * Trim some more. * Restrict numpy code path only for scalars and indexes * Small cleanup * Add test * typing fixes * optimize * reorder * better test * cleanup + whats-new
1 parent 41d33f5 commit d87ba61

File tree

4 files changed

+132
-74
lines changed

4 files changed

+132
-74
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ Documentation
5252

5353
Internal Changes
5454
~~~~~~~~~~~~~~~~
55-
55+
- The implementation of :py:func:`map_blocks` has changed to minimize graph size and duplication of data.
56+
This should be a strict improvement even though the graphs are not always embarassingly parallel any more.
57+
Please open an issue if you spot a regression. (:pull:`8412`, :issue:`8409`).
58+
By `Deepak Cherian <https://github.com/dcherian>`_.
5659
- Remove null values before plotting. (:pull:`8535`).
5760
By `Jimmy Westling <https://github.com/illviljan>`_.
5861

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ module = [
9191
"cf_units.*",
9292
"cfgrib.*",
9393
"cftime.*",
94+
"cloudpickle.*",
9495
"cubed.*",
9596
"cupy.*",
9697
"dask.types.*",

xarray/core/parallel.py

Lines changed: 102 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from xarray.core.indexes import Index
1616
from xarray.core.merge import merge
1717
from xarray.core.pycompat import is_dask_collection
18+
from xarray.core.variable import Variable
1819

1920
if TYPE_CHECKING:
2021
from xarray.core.types import T_Xarray
@@ -156,6 +157,75 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping
156157
return slice(None)
157158

158159

160+
def subset_dataset_to_block(
161+
graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index
162+
):
163+
"""
164+
Creates a task that subsets an xarray dataset to a block determined by chunk_index.
165+
Block extents are determined by input_chunk_bounds.
166+
Also subtasks that subset the constituent variables of a dataset.
167+
"""
168+
import dask
169+
170+
# this will become [[name1, variable1],
171+
# [name2, variable2],
172+
# ...]
173+
# which is passed to dict and then to Dataset
174+
data_vars = []
175+
coords = []
176+
177+
chunk_tuple = tuple(chunk_index.values())
178+
chunk_dims_set = set(chunk_index)
179+
variable: Variable
180+
for name, variable in dataset.variables.items():
181+
# make a task that creates tuple of (dims, chunk)
182+
if dask.is_dask_collection(variable.data):
183+
# get task name for chunk
184+
chunk = (
185+
variable.data.name,
186+
*tuple(chunk_index[dim] for dim in variable.dims),
187+
)
188+
189+
chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple
190+
graph[chunk_variable_task] = (
191+
tuple,
192+
[variable.dims, chunk, variable.attrs],
193+
)
194+
else:
195+
assert name in dataset.dims or variable.ndim == 0
196+
197+
# non-dask array possibly with dimensions chunked on other variables
198+
# index into variable appropriately
199+
subsetter = {
200+
dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds)
201+
for dim in variable.dims
202+
}
203+
if set(variable.dims) < chunk_dims_set:
204+
this_var_chunk_tuple = tuple(chunk_index[dim] for dim in variable.dims)
205+
else:
206+
this_var_chunk_tuple = chunk_tuple
207+
208+
chunk_variable_task = (
209+
f"{name}-{gname}-{dask.base.tokenize(subsetter)}",
210+
) + this_var_chunk_tuple
211+
# We are including a dimension coordinate,
212+
# minimize duplication by not copying it in the graph for every chunk.
213+
if variable.ndim == 0 or chunk_variable_task not in graph:
214+
subset = variable.isel(subsetter)
215+
graph[chunk_variable_task] = (
216+
tuple,
217+
[subset.dims, subset._data, subset.attrs],
218+
)
219+
220+
# this task creates dict mapping variable name to above tuple
221+
if name in dataset._coord_names:
222+
coords.append([name, chunk_variable_task])
223+
else:
224+
data_vars.append([name, chunk_variable_task])
225+
226+
return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)
227+
228+
159229
def map_blocks(
160230
func: Callable[..., T_Xarray],
161231
obj: DataArray | Dataset,
@@ -280,6 +350,10 @@ def _wrapper(
280350

281351
result = func(*converted_args, **kwargs)
282352

353+
merged_coordinates = merge(
354+
[arg.coords for arg in args if isinstance(arg, (Dataset, DataArray))]
355+
).coords
356+
283357
# check all dims are present
284358
missing_dimensions = set(expected["shapes"]) - set(result.sizes)
285359
if missing_dimensions:
@@ -295,12 +369,16 @@ def _wrapper(
295369
f"Received dimension {name!r} of length {result.sizes[name]}. "
296370
f"Expected length {expected['shapes'][name]}."
297371
)
298-
if name in expected["indexes"]:
299-
expected_index = expected["indexes"][name]
300-
if not index.equals(expected_index):
301-
raise ValueError(
302-
f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead."
303-
)
372+
373+
# ChainMap wants MutableMapping, but xindexes is Mapping
374+
merged_indexes = collections.ChainMap(
375+
expected["indexes"], merged_coordinates.xindexes # type: ignore[arg-type]
376+
)
377+
expected_index = merged_indexes.get(name, None)
378+
if expected_index is not None and not index.equals(expected_index):
379+
raise ValueError(
380+
f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead."
381+
)
304382

305383
# check that all expected variables were returned
306384
check_result_variables(result, expected, "coords")
@@ -356,6 +434,8 @@ def _wrapper(
356434
dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg
357435
for arg in aligned
358436
)
437+
# rechunk any numpy variables appropriately
438+
xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs)
359439

360440
merged_coordinates = merge([arg.coords for arg in aligned]).coords
361441

@@ -378,7 +458,7 @@ def _wrapper(
378458
new_coord_vars = template_coords - set(merged_coordinates)
379459

380460
preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars]
381-
# preserved_coords contains all coordinates bariables that share a dimension
461+
# preserved_coords contains all coordinates variables that share a dimension
382462
# with any index variable in preserved_indexes
383463
# Drop any unneeded vars in a second pass, this is required for e.g.
384464
# if the mapped function were to drop a non-dimension coordinate variable.
@@ -403,6 +483,13 @@ def _wrapper(
403483
" Please construct a template with appropriately chunked dask arrays."
404484
)
405485

486+
new_indexes = set(template.xindexes) - set(merged_coordinates)
487+
modified_indexes = set(
488+
name
489+
for name, xindex in coordinates.xindexes.items()
490+
if not xindex.equals(merged_coordinates.xindexes.get(name, None))
491+
)
492+
406493
for dim in output_chunks:
407494
if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]):
408495
raise ValueError(
@@ -443,63 +530,7 @@ def _wrapper(
443530
dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items()
444531
}
445532

446-
def subset_dataset_to_block(
447-
graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index
448-
):
449-
"""
450-
Creates a task that subsets an xarray dataset to a block determined by chunk_index.
451-
Block extents are determined by input_chunk_bounds.
452-
Also subtasks that subset the constituent variables of a dataset.
453-
"""
454-
455-
# this will become [[name1, variable1],
456-
# [name2, variable2],
457-
# ...]
458-
# which is passed to dict and then to Dataset
459-
data_vars = []
460-
coords = []
461-
462-
chunk_tuple = tuple(chunk_index.values())
463-
for name, variable in dataset.variables.items():
464-
# make a task that creates tuple of (dims, chunk)
465-
if dask.is_dask_collection(variable.data):
466-
# recursively index into dask_keys nested list to get chunk
467-
chunk = variable.__dask_keys__()
468-
for dim in variable.dims:
469-
chunk = chunk[chunk_index[dim]]
470-
471-
chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple
472-
graph[chunk_variable_task] = (
473-
tuple,
474-
[variable.dims, chunk, variable.attrs],
475-
)
476-
else:
477-
# non-dask array possibly with dimensions chunked on other variables
478-
# index into variable appropriately
479-
subsetter = {
480-
dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds)
481-
for dim in variable.dims
482-
}
483-
subset = variable.isel(subsetter)
484-
chunk_variable_task = (
485-
f"{name}-{gname}-{dask.base.tokenize(subset)}",
486-
) + chunk_tuple
487-
graph[chunk_variable_task] = (
488-
tuple,
489-
[subset.dims, subset, subset.attrs],
490-
)
491-
492-
# this task creates dict mapping variable name to above tuple
493-
if name in dataset._coord_names:
494-
coords.append([name, chunk_variable_task])
495-
else:
496-
data_vars.append([name, chunk_variable_task])
497-
498-
return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)
499-
500-
# variable names that depend on the computation. Currently, indexes
501-
# cannot be modified in the mapped function, so we exclude thos
502-
computed_variables = set(template.variables) - set(coordinates.xindexes)
533+
computed_variables = set(template.variables) - set(coordinates.indexes)
503534
# iterate over all possible chunk combinations
504535
for chunk_tuple in itertools.product(*ichunk.values()):
505536
# mapping from dimension name to chunk index
@@ -523,11 +554,12 @@ def subset_dataset_to_block(
523554
},
524555
"data_vars": set(template.data_vars.keys()),
525556
"coords": set(template.coords.keys()),
557+
# only include new or modified indexes to minimize duplication of data, and graph size.
526558
"indexes": {
527559
dim: coordinates.xindexes[dim][
528560
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
529561
]
530-
for dim in coordinates.xindexes
562+
for dim in (new_indexes | modified_indexes)
531563
},
532564
}
533565

@@ -541,14 +573,11 @@ def subset_dataset_to_block(
541573
gname_l = f"{name}-{gname}"
542574
var_key_map[name] = gname_l
543575

544-
key: tuple[Any, ...] = (gname_l,)
545-
for dim in variable.dims:
546-
if dim in chunk_index:
547-
key += (chunk_index[dim],)
548-
else:
549-
# unchunked dimensions in the input have one chunk in the result
550-
# output can have new dimensions with exactly one chunk
551-
key += (0,)
576+
# unchunked dimensions in the input have one chunk in the result
577+
# output can have new dimensions with exactly one chunk
578+
key: tuple[Any, ...] = (gname_l,) + tuple(
579+
chunk_index[dim] if dim in chunk_index else 0 for dim in variable.dims
580+
)
552581

553582
# We're adding multiple new layers to the graph:
554583
# The first new layer is the result of the computation on

xarray/tests/test_dask.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,3 +1746,28 @@ def test_new_index_var_computes_once():
17461746
data = dask.array.from_array(np.array([100, 200]))
17471747
with raise_if_dask_computes(max_computes=1):
17481748
Dataset(coords={"z": ("z", data)})
1749+
1750+
1751+
def test_minimize_graph_size():
1752+
# regression test for https://github.com/pydata/xarray/issues/8409
1753+
ds = Dataset(
1754+
{
1755+
"foo": (
1756+
("x", "y", "z"),
1757+
dask.array.ones((120, 120, 120), chunks=(20, 20, 1)),
1758+
)
1759+
},
1760+
coords={"x": np.arange(120), "y": np.arange(120), "z": np.arange(120)},
1761+
)
1762+
1763+
mapped = ds.map_blocks(lambda x: x)
1764+
graph = dict(mapped.__dask_graph__())
1765+
1766+
numchunks = {k: len(v) for k, v in ds.chunksizes.items()}
1767+
for var in "xyz":
1768+
actual = len([key for key in graph if var in key[0]])
1769+
# assert that we only include each chunk of an index variable
1770+
# is only included once, not the product of number of chunks of
1771+
# all the other dimenions.
1772+
# e.g. previously for 'x', actual == numchunks['y'] * numchunks['z']
1773+
assert actual == numchunks[var], (actual, numchunks[var])

0 commit comments

Comments
 (0)