Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ba4ee3b

Browse files
committedFeb 3, 2024
[GraphBolt][CUDA] Fix link prediction early-stop.
1 parent 15695ed commit ba4ee3b

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed
 

‎examples/sampling/graphbolt/link_prediction.py

+2
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ def train(args, model, graph, features, train_set):
340340

341341
total_loss += loss.item()
342342
if step + 1 == args.early_stop:
343+
# Early stopping requires a new dataloader to reset its state.
344+
dataloader = create_dataloader(args, graph, features, train_set)
343345
break
344346

345347
end_epoch_time = time.time()

‎python/dgl/graphbolt/feature_fetcher.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,5 @@ def _read(self, data):
174174
with torch.cuda.stream(self.stream):
175175
data = self._read_data(data, current_stream)
176176
if self.stream is not None:
177-
event = torch.cuda.current_stream().record_event()
178-
179-
def _wait():
180-
event.wait()
181-
182-
data.wait = _wait
177+
data.wait = torch.cuda.current_stream().record_event().wait
183178
return data

0 commit comments

Comments
 (0)
Please sign in to comment.