Skip to content

Commit a004a25

Browse files
authoredJul 28, 2024··
[GraphBolt][CUDA] Add non_blocking option to CopyTo. (#7602)
1 parent 781cc50 commit a004a25

File tree

4 files changed

+105
-19
lines changed

4 files changed

+105
-19
lines changed
 

‎python/dgl/graphbolt/base.py

+55-9
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from torch.utils.data import functional_datapipe
2121
from torchdata.datapipes.iter import IterDataPipe
2222

23-
from .internal_utils import recursive_apply
23+
from .internal_utils import (
24+
get_nonproperty_attributes,
25+
recursive_apply,
26+
recursive_apply_reduce_all,
27+
)
2428

2529
__all__ = [
2630
"CANONICAL_ETYPE_DELIMITER",
@@ -306,10 +310,32 @@ def seed_type_str_to_ntypes(seed_type, seed_size):
306310
return ntypes
307311

308312

309-
def apply_to(x, device):
313+
def apply_to(x, device, non_blocking=False):
310314
"""Apply `to` function to object x only if it has `to`."""
311315

312-
return x.to(device) if hasattr(x, "to") else x
316+
if device == "pinned" and hasattr(x, "pin_memory"):
317+
return x.pin_memory()
318+
if not hasattr(x, "to"):
319+
return x
320+
if not non_blocking:
321+
return x.to(device)
322+
# The copy is non blocking only if the objects are pinned.
323+
assert x.is_pinned(), f"{x} should be pinned."
324+
return x.to(device, non_blocking=True)
325+
326+
327+
def is_object_pinned(obj):
328+
"""Recursively check all members of the object and return True if only if
329+
all are pinned."""
330+
331+
for attr in get_nonproperty_attributes(obj):
332+
member_result = recursive_apply_reduce_all(
333+
getattr(obj, attr),
334+
lambda x: x is None or x.is_pinned(),
335+
)
336+
if not member_result:
337+
return False
338+
return True
313339

314340

315341
@functional_datapipe("copy_to")
@@ -334,17 +360,22 @@ class CopyTo(IterDataPipe):
334360
The DataPipe.
335361
device : torch.device
336362
The PyTorch CUDA device.
363+
non_blocking : bool
364+
Whether the copy should be performed without blocking. All elements have
365+
to be already in pinned system memory if enabled. Default is False.
337366
"""
338367

339-
def __init__(self, datapipe, device):
368+
def __init__(self, datapipe, device, non_blocking=False):
340369
super().__init__()
341370
self.datapipe = datapipe
342-
self.device = device
371+
self.device = torch.device(device)
372+
self.non_blocking = non_blocking
343373

344374
def __iter__(self):
345375
for data in self.datapipe:
346-
data = recursive_apply(data, apply_to, self.device)
347-
yield data
376+
yield recursive_apply(
377+
data, apply_to, self.device, self.non_blocking
378+
)
348379

349380

350381
@functional_datapipe("mark_end")
@@ -460,7 +491,9 @@ def __init__(self, indptr: torch.Tensor, indices: torch.Tensor):
460491
def __repr__(self) -> str:
461492
return _csc_format_base_str(self)
462493

463-
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
494+
def to( # pylint: disable=invalid-name
495+
self, device: torch.device, non_blocking=False
496+
) -> None:
464497
"""Copy `CSCFormatBase` to the specified device using reflection."""
465498

466499
for attr in dir(self):
@@ -470,12 +503,25 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
470503
self,
471504
attr,
472505
recursive_apply(
473-
getattr(self, attr), lambda x: apply_to(x, device)
506+
getattr(self, attr),
507+
apply_to,
508+
device,
509+
non_blocking=non_blocking,
474510
),
475511
)
476512

477513
return self
478514

515+
def pin_memory(self):
516+
"""Copy `SampledSubgraph` to the pinned memory using reflection."""
517+
518+
return self.to("pinned")
519+
520+
def is_pinned(self) -> bool:
521+
"""Check whether `SampledSubgraph` is pinned using reflection."""
522+
523+
return is_object_pinned(self)
524+
479525

480526
def _csc_format_base_str(csc_format_base: CSCFormatBase) -> str:
481527
final_str = "CSCFormatBase("

‎python/dgl/graphbolt/minibatch.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
import torch
77

8-
from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
8+
from .base import (
9+
apply_to,
10+
CSCFormatBase,
11+
etype_str_to_tuple,
12+
expand_indptr,
13+
is_object_pinned,
14+
)
915
from .internal_utils import (
1016
get_attributes,
1117
get_nonproperty_attributes,
@@ -350,20 +356,31 @@ def to_pyg_data(self):
350356
)
351357
return pyg_data
352358

353-
def to(self, device: torch.device): # pylint: disable=invalid-name
359+
def to(
360+
self, device: torch.device, non_blocking=False
361+
): # pylint: disable=invalid-name
354362
"""Copy `MiniBatch` to the specified device using reflection."""
355363

356-
def _to(x):
357-
return x.to(device) if hasattr(x, "to") else x
364+
copy_fn = lambda x: apply_to(x, device, non_blocking=non_blocking)
358365

359366
transfer_attrs = get_nonproperty_attributes(self)
360367

361368
for attr in transfer_attrs:
362369
# Only copy member variables.
363-
setattr(self, attr, recursive_apply(getattr(self, attr), _to))
370+
setattr(self, attr, recursive_apply(getattr(self, attr), copy_fn))
364371

365372
return self
366373

374+
def pin_memory(self):
375+
"""Copy `MiniBatch` to the pinned memory using reflection."""
376+
377+
return self.to("pinned")
378+
379+
def is_pinned(self) -> bool:
380+
"""Check whether `SampledSubgraph` is pinned using reflection."""
381+
382+
return is_object_pinned(self)
383+
367384

368385
def _minibatch_str(minibatch: MiniBatch) -> str:
369386
final_str = ""

‎python/dgl/graphbolt/sampled_subgraph.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
CSCFormatBase,
1111
etype_str_to_tuple,
1212
expand_indptr,
13+
is_object_pinned,
1314
isin,
1415
)
1516

@@ -232,7 +233,9 @@ def exclude_edges(
232233
)
233234
return calling_class(*_slice_subgraph(self, index))
234235

235-
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
236+
def to(
237+
self, device: torch.device, non_blocking=False
238+
) -> None: # pylint: disable=invalid-name
236239
"""Copy `SampledSubgraph` to the specified device using reflection."""
237240

238241
for attr in dir(self):
@@ -242,12 +245,25 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
242245
self,
243246
attr,
244247
recursive_apply(
245-
getattr(self, attr), lambda x: apply_to(x, device)
248+
getattr(self, attr),
249+
apply_to,
250+
device,
251+
non_blocking=non_blocking,
246252
),
247253
)
248254

249255
return self
250256

257+
def pin_memory(self):
258+
"""Copy `SampledSubgraph` to the pinned memory using reflection."""
259+
260+
return self.to("pinned")
261+
262+
def is_pinned(self) -> bool:
263+
"""Check whether `SampledSubgraph` is pinned using reflection."""
264+
265+
return is_object_pinned(self)
266+
251267

252268
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
253269
indptr = node_pair.indptr

‎tests/python/pytorch/graphbolt/test_base.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,26 @@
1212
from . import gb_test_utils
1313

1414

15-
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
16-
def test_CopyTo():
15+
@unittest.skipIf(F._default_context_str != "gpu", "CopyTo needs GPU to test")
16+
@pytest.mark.parametrize("non_blocking", [False, True])
17+
def test_CopyTo(non_blocking):
1718
item_sampler = gb.ItemSampler(
1819
gb.ItemSet(torch.arange(20), names="seeds"), 4
1920
)
21+
if non_blocking:
22+
item_sampler = item_sampler.transform(lambda x: x.pin_memory())
2023

2124
# Invoke CopyTo via class constructor.
2225
dp = gb.CopyTo(item_sampler, "cuda")
2326
for data in dp:
2427
assert data.seeds.device.type == "cuda"
2528

29+
dp = gb.CopyTo(item_sampler, "cuda", non_blocking)
30+
for data in dp:
31+
assert data.seeds.device.type == "cuda"
32+
2633
# Invoke CopyTo via functional form.
27-
dp = item_sampler.copy_to("cuda")
34+
dp = item_sampler.copy_to("cuda", non_blocking)
2835
for data in dp:
2936
assert data.seeds.device.type == "cuda"
3037

0 commit comments

Comments
 (0)
Please sign in to comment.