Skip to content

Commit 74a7670

Browse files
committed
3 streams
1 parent 12fa63c commit 74a7670

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

scripts/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _init_streams(
6464
state._unshard_stream = state._device_handle.Stream(priority=high_priority)
6565
# Stream for overlapping gradient reduction with the backward pass gradient
6666
# computation
67-
state._post_backward_stream = state._device_handle.Stream(priority=high_priority)
67+
state._post_backward_stream = state._unshard_stream
6868
# Stream for pre-unshard logic, namely allocations and writes for CPU
6969
# offloading (H2D copy) and mixed precision (low precision cast)
7070
state._pre_unshard_stream = state._post_backward_stream

0 commit comments

Comments
 (0)