We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 12fa63c commit 74a7670Copy full SHA for 74a7670
scripts/train.py
@@ -64,7 +64,7 @@ def _init_streams(
64
state._unshard_stream = state._device_handle.Stream(priority=high_priority)
65
# Stream for overlapping gradient reduction with the backward pass gradient
66
# computation
67
- state._post_backward_stream = state._device_handle.Stream(priority=high_priority)
+ state._post_backward_stream = state._unshard_stream
68
# Stream for pre-unshard logic, namely allocations and writes for CPU
69
# offloading (H2D copy) and mixed precision (low precision cast)
70
state._pre_unshard_stream = state._post_backward_stream
0 commit comments