15
15
from xarray .core .indexes import Index
16
16
from xarray .core .merge import merge
17
17
from xarray .core .pycompat import is_dask_collection
18
+ from xarray .core .variable import Variable
18
19
19
20
if TYPE_CHECKING :
20
21
from xarray .core .types import T_Xarray
@@ -156,6 +157,75 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping
156
157
return slice (None )
157
158
158
159
160
+ def subset_dataset_to_block (
161
+ graph : dict , gname : str , dataset : Dataset , input_chunk_bounds , chunk_index
162
+ ):
163
+ """
164
+ Creates a task that subsets an xarray dataset to a block determined by chunk_index.
165
+ Block extents are determined by input_chunk_bounds.
166
+ Also subtasks that subset the constituent variables of a dataset.
167
+ """
168
+ import dask
169
+
170
+ # this will become [[name1, variable1],
171
+ # [name2, variable2],
172
+ # ...]
173
+ # which is passed to dict and then to Dataset
174
+ data_vars = []
175
+ coords = []
176
+
177
+ chunk_tuple = tuple (chunk_index .values ())
178
+ chunk_dims_set = set (chunk_index )
179
+ variable : Variable
180
+ for name , variable in dataset .variables .items ():
181
+ # make a task that creates tuple of (dims, chunk)
182
+ if dask .is_dask_collection (variable .data ):
183
+ # get task name for chunk
184
+ chunk = (
185
+ variable .data .name ,
186
+ * tuple (chunk_index [dim ] for dim in variable .dims ),
187
+ )
188
+
189
+ chunk_variable_task = (f"{ name } -{ gname } -{ chunk [0 ]!r} " ,) + chunk_tuple
190
+ graph [chunk_variable_task ] = (
191
+ tuple ,
192
+ [variable .dims , chunk , variable .attrs ],
193
+ )
194
+ else :
195
+ assert name in dataset .dims or variable .ndim == 0
196
+
197
+ # non-dask array possibly with dimensions chunked on other variables
198
+ # index into variable appropriately
199
+ subsetter = {
200
+ dim : _get_chunk_slicer (dim , chunk_index , input_chunk_bounds )
201
+ for dim in variable .dims
202
+ }
203
+ if set (variable .dims ) < chunk_dims_set :
204
+ this_var_chunk_tuple = tuple (chunk_index [dim ] for dim in variable .dims )
205
+ else :
206
+ this_var_chunk_tuple = chunk_tuple
207
+
208
+ chunk_variable_task = (
209
+ f"{ name } -{ gname } -{ dask .base .tokenize (subsetter )} " ,
210
+ ) + this_var_chunk_tuple
211
+ # We are including a dimension coordinate,
212
+ # minimize duplication by not copying it in the graph for every chunk.
213
+ if variable .ndim == 0 or chunk_variable_task not in graph :
214
+ subset = variable .isel (subsetter )
215
+ graph [chunk_variable_task ] = (
216
+ tuple ,
217
+ [subset .dims , subset ._data , subset .attrs ],
218
+ )
219
+
220
+ # this task creates dict mapping variable name to above tuple
221
+ if name in dataset ._coord_names :
222
+ coords .append ([name , chunk_variable_task ])
223
+ else :
224
+ data_vars .append ([name , chunk_variable_task ])
225
+
226
+ return (Dataset , (dict , data_vars ), (dict , coords ), dataset .attrs )
227
+
228
+
159
229
def map_blocks (
160
230
func : Callable [..., T_Xarray ],
161
231
obj : DataArray | Dataset ,
@@ -280,6 +350,10 @@ def _wrapper(
280
350
281
351
result = func (* converted_args , ** kwargs )
282
352
353
+ merged_coordinates = merge (
354
+ [arg .coords for arg in args if isinstance (arg , (Dataset , DataArray ))]
355
+ ).coords
356
+
283
357
# check all dims are present
284
358
missing_dimensions = set (expected ["shapes" ]) - set (result .sizes )
285
359
if missing_dimensions :
@@ -295,12 +369,16 @@ def _wrapper(
295
369
f"Received dimension { name !r} of length { result .sizes [name ]} . "
296
370
f"Expected length { expected ['shapes' ][name ]} ."
297
371
)
298
- if name in expected ["indexes" ]:
299
- expected_index = expected ["indexes" ][name ]
300
- if not index .equals (expected_index ):
301
- raise ValueError (
302
- f"Expected index { name !r} to be { expected_index !r} . Received { index !r} instead."
303
- )
372
+
373
+ # ChainMap wants MutableMapping, but xindexes is Mapping
374
+ merged_indexes = collections .ChainMap (
375
+ expected ["indexes" ], merged_coordinates .xindexes # type: ignore[arg-type]
376
+ )
377
+ expected_index = merged_indexes .get (name , None )
378
+ if expected_index is not None and not index .equals (expected_index ):
379
+ raise ValueError (
380
+ f"Expected index { name !r} to be { expected_index !r} . Received { index !r} instead."
381
+ )
304
382
305
383
# check that all expected variables were returned
306
384
check_result_variables (result , expected , "coords" )
@@ -356,6 +434,8 @@ def _wrapper(
356
434
dataarray_to_dataset (arg ) if isinstance (arg , DataArray ) else arg
357
435
for arg in aligned
358
436
)
437
+ # rechunk any numpy variables appropriately
438
+ xarray_objs = tuple (arg .chunk (arg .chunksizes ) for arg in xarray_objs )
359
439
360
440
merged_coordinates = merge ([arg .coords for arg in aligned ]).coords
361
441
@@ -378,7 +458,7 @@ def _wrapper(
378
458
new_coord_vars = template_coords - set (merged_coordinates )
379
459
380
460
preserved_coords = merged_coordinates .to_dataset ()[preserved_coord_vars ]
381
- # preserved_coords contains all coordinates bariables that share a dimension
461
+ # preserved_coords contains all coordinates variables that share a dimension
382
462
# with any index variable in preserved_indexes
383
463
# Drop any unneeded vars in a second pass, this is required for e.g.
384
464
# if the mapped function were to drop a non-dimension coordinate variable.
@@ -403,6 +483,13 @@ def _wrapper(
403
483
" Please construct a template with appropriately chunked dask arrays."
404
484
)
405
485
486
+ new_indexes = set (template .xindexes ) - set (merged_coordinates )
487
+ modified_indexes = set (
488
+ name
489
+ for name , xindex in coordinates .xindexes .items ()
490
+ if not xindex .equals (merged_coordinates .xindexes .get (name , None ))
491
+ )
492
+
406
493
for dim in output_chunks :
407
494
if dim in input_chunks and len (input_chunks [dim ]) != len (output_chunks [dim ]):
408
495
raise ValueError (
@@ -443,63 +530,7 @@ def _wrapper(
443
530
dim : np .cumsum ((0 ,) + chunks_v ) for dim , chunks_v in output_chunks .items ()
444
531
}
445
532
446
- def subset_dataset_to_block (
447
- graph : dict , gname : str , dataset : Dataset , input_chunk_bounds , chunk_index
448
- ):
449
- """
450
- Creates a task that subsets an xarray dataset to a block determined by chunk_index.
451
- Block extents are determined by input_chunk_bounds.
452
- Also subtasks that subset the constituent variables of a dataset.
453
- """
454
-
455
- # this will become [[name1, variable1],
456
- # [name2, variable2],
457
- # ...]
458
- # which is passed to dict and then to Dataset
459
- data_vars = []
460
- coords = []
461
-
462
- chunk_tuple = tuple (chunk_index .values ())
463
- for name , variable in dataset .variables .items ():
464
- # make a task that creates tuple of (dims, chunk)
465
- if dask .is_dask_collection (variable .data ):
466
- # recursively index into dask_keys nested list to get chunk
467
- chunk = variable .__dask_keys__ ()
468
- for dim in variable .dims :
469
- chunk = chunk [chunk_index [dim ]]
470
-
471
- chunk_variable_task = (f"{ name } -{ gname } -{ chunk [0 ]!r} " ,) + chunk_tuple
472
- graph [chunk_variable_task ] = (
473
- tuple ,
474
- [variable .dims , chunk , variable .attrs ],
475
- )
476
- else :
477
- # non-dask array possibly with dimensions chunked on other variables
478
- # index into variable appropriately
479
- subsetter = {
480
- dim : _get_chunk_slicer (dim , chunk_index , input_chunk_bounds )
481
- for dim in variable .dims
482
- }
483
- subset = variable .isel (subsetter )
484
- chunk_variable_task = (
485
- f"{ name } -{ gname } -{ dask .base .tokenize (subset )} " ,
486
- ) + chunk_tuple
487
- graph [chunk_variable_task ] = (
488
- tuple ,
489
- [subset .dims , subset , subset .attrs ],
490
- )
491
-
492
- # this task creates dict mapping variable name to above tuple
493
- if name in dataset ._coord_names :
494
- coords .append ([name , chunk_variable_task ])
495
- else :
496
- data_vars .append ([name , chunk_variable_task ])
497
-
498
- return (Dataset , (dict , data_vars ), (dict , coords ), dataset .attrs )
499
-
500
- # variable names that depend on the computation. Currently, indexes
501
- # cannot be modified in the mapped function, so we exclude thos
502
- computed_variables = set (template .variables ) - set (coordinates .xindexes )
533
+ computed_variables = set (template .variables ) - set (coordinates .indexes )
503
534
# iterate over all possible chunk combinations
504
535
for chunk_tuple in itertools .product (* ichunk .values ()):
505
536
# mapping from dimension name to chunk index
@@ -523,11 +554,12 @@ def subset_dataset_to_block(
523
554
},
524
555
"data_vars" : set (template .data_vars .keys ()),
525
556
"coords" : set (template .coords .keys ()),
557
+ # only include new or modified indexes to minimize duplication of data, and graph size.
526
558
"indexes" : {
527
559
dim : coordinates .xindexes [dim ][
528
560
_get_chunk_slicer (dim , chunk_index , output_chunk_bounds )
529
561
]
530
- for dim in coordinates . xindexes
562
+ for dim in ( new_indexes | modified_indexes )
531
563
},
532
564
}
533
565
@@ -541,14 +573,11 @@ def subset_dataset_to_block(
541
573
gname_l = f"{ name } -{ gname } "
542
574
var_key_map [name ] = gname_l
543
575
544
- key : tuple [Any , ...] = (gname_l ,)
545
- for dim in variable .dims :
546
- if dim in chunk_index :
547
- key += (chunk_index [dim ],)
548
- else :
549
- # unchunked dimensions in the input have one chunk in the result
550
- # output can have new dimensions with exactly one chunk
551
- key += (0 ,)
576
+ # unchunked dimensions in the input have one chunk in the result
577
+ # output can have new dimensions with exactly one chunk
578
+ key : tuple [Any , ...] = (gname_l ,) + tuple (
579
+ chunk_index [dim ] if dim in chunk_index else 0 for dim in variable .dims
580
+ )
552
581
553
582
# We're adding multiple new layers to the graph:
554
583
# The first new layer is the result of the computation on
0 commit comments