Skip to content

Commit ee20edb

Browse files
authoredJul 28, 2024··
[GraphBolt][CUDA] Enable non_blocking copy_to in gb.DataLoader. (#7603)
1 parent a004a25 commit ee20edb

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed
 

‎python/dgl/graphbolt/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,6 @@ def apply_to(x, device, non_blocking=False):
319319
return x
320320
if not non_blocking:
321321
return x.to(device)
322-
# The copy is non blocking only if the objects are pinned.
323-
assert x.is_pinned(), f"{x} should be pinned."
324322
return x.to(device, non_blocking=True)
325323

326324

@@ -373,6 +371,9 @@ def __init__(self, datapipe, device, non_blocking=False):
373371

374372
def __iter__(self):
375373
for data in self.datapipe:
374+
if self.non_blocking:
375+
# The copy is non blocking only if contents of data are pinned.
376+
assert data.is_pinned(), f"{data} should be pinned."
376377
yield recursive_apply(
377378
data, apply_to, self.device, self.non_blocking
378379
)

‎python/dgl/graphbolt/dataloader.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,23 @@ def __init__(
224224
),
225225
)
226226

227-
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
228-
# data pipeline up to the CopyTo operation to run in a separate thread.
229-
datapipe_graph = _find_and_wrap_parent(
230-
datapipe_graph,
231-
CopyTo,
232-
dp.iter.Prefetcher,
233-
buffer_size=2,
234-
)
227+
# (4) Cut datapipe at CopyTo and wrap with pinning and prefetching
228+
# before it. This enables enables non_blocking copies to the device.
229+
# Prefetching enables the data pipeline up to the CopyTo to run in a
230+
# separate thread.
231+
if torch.cuda.is_available():
232+
copiers = dp_utils.find_dps(datapipe_graph, CopyTo)
233+
for copier in copiers:
234+
if copier.device.type == "cuda":
235+
datapipe_graph = dp_utils.replace_dp(
236+
datapipe_graph,
237+
copier,
238+
copier.datapipe.transform(
239+
lambda x: x.pin_memory()
240+
).prefetch(2)
241+
# After the data gets pinned, we can copy non_blocking.
242+
.copy_to(copier.device, non_blocking=True),
243+
)
235244

236245
# The stages after feature fetching is still done in the main process.
237246
# So we set num_workers to 0 here.

0 commit comments

Comments
 (0)
Please sign in to comment.