@@ -396,7 +396,43 @@ def read_async(self, ids: torch.Tensor = None):
396
396
The read feature future.
397
397
"""
398
398
assert torch .ops .graphbolt .detect_io_uring ()
399
- yield self ._ondisk_npy_array .index_select (ids )
399
+ if ids .is_cuda :
400
+ ids_device = ids .device
401
+ current_stream = torch .cuda .current_stream ()
402
+ device_to_host_stream = get_device_to_host_uva_stream ()
403
+ device_to_host_stream .wait_stream (current_stream )
404
+ with torch .cuda .stream (device_to_host_stream ):
405
+ ids .record_stream (torch .cuda .current_stream ())
406
+ ids = ids .to (self ._tensor .device , non_blocking = True )
407
+ ids_copy_event = torch .cuda .Event ()
408
+ ids_copy_event .record ()
409
+
410
+ yield # first stage is done.
411
+
412
+ ids_copy_event .synchronize ()
413
+ values = self ._ondisk_npy_array .index_select (ids )
414
+ yield
415
+
416
+ host_to_device_stream = get_device_to_host_uva_stream ()
417
+ with torch .cuda .stream (host_to_device_stream ):
418
+ values_cuda = values .wait ().to (ids_device , non_blocking = True )
419
+ values_cuda .record_stream (current_stream )
420
+ values_copy_event = torch .cuda .Event ()
421
+ values_copy_event .record ()
422
+
423
+ class _Waiter :
424
+ @staticmethod
425
+ def wait ():
426
+ values_copy_event .wait ()
427
+ return values_cuda
428
+
429
+ yield _Waiter ()
430
+ else :
431
+ yield self ._ondisk_npy_array .index_select (ids )
432
+
433
+ def read_async_num_stages (self , ids_device : torch .device ):
434
+ """The number of stages of the read_async operation"""
435
+ return 3 if ids_device .type == "cuda" else 1
400
436
401
437
def size (self ):
402
438
"""Get the size of the feature.
0 commit comments