From eaa4aab2834de302bcf05d0471efafab294fd405 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Wed, 2 Apr 2025 14:53:43 +0000 Subject: [PATCH 1/3] fix: Destory cuda graphs before setting weight streaming --- core/runtime/TRTEngine.cpp | 4 ++++ core/runtime/TRTEngine.h | 1 + core/runtime/register_jit_hooks.cpp | 1 + .../runtime/_CudaGraphsTorchTensorRTModule.py | 13 +++++++------ .../dynamo/runtime/_PythonTorchTensorRTModule.py | 11 +++++++---- .../dynamo/runtime/_TorchTensorRTModule.py | 3 +++ .../dynamo/runtime/meta_ops/register_meta_ops.py | 3 +++ py/torch_tensorrt/runtime/_cudagraphs.py | 9 +++++++-- py/torch_tensorrt/runtime/_weight_streaming.py | 7 +++++-- 9 files changed, 38 insertions(+), 14 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 9f93fe4b4e..c57a66ac43 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -453,6 +453,10 @@ std::vector TRTEngine::serialize() { return serialized_info; } +void TRTEngine::reset_cudagraph() { + cudagraph.reset(); +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index e9b1905610..d6e4fc62d5 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -185,6 +185,7 @@ struct TRTEngine : torch::CustomClassHolder { // c10::List Run(c10::List inputs); void set_profiling_paths(); + void reset_cudagraph(); #ifndef NDEBUG bool profile_execution = true; #else diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index c05be4e8aa..6d6824dca6 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info) .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) + .def("reset_cudagraph", &TRTEngine::reset_cudagraph) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 5af9b11a4b..c47a1b3bd2 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -103,9 +103,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: return False - def __del__(self) -> None: + def reset_cudagraph(self) -> None: if self.cudagraph: self.cudagraph.reset() + self.cudagraph = None + + def __del__(self) -> None: + self.reset_cudagraph() def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable @@ -119,8 +123,7 @@ def forward( shape_changed = self.validate_input_shapes(inputs) need_cudagraphs_record = shape_changed or self.is_weight_streaming_set if need_cudagraphs_record: - if self.cudagraph: - self.cudagraph.reset() + self.reset_cudagraph() self._input_buffers = [None] * len(inputs) self.is_weight_streaming_set = False @@ -196,7 +199,5 @@ def forward( return outputs[0] return outputs else: - if self.cudagraph: - self.cudagraph.reset() - self.cudagraph = None + self.reset_cudagraph() return self.compiled_module(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 891d063ed3..acb83460fb 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -333,9 +333,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: result.__setstate__(self.__getstate__()) return result - def __del__(self) -> None: + def reset_cudagraph(self) -> None: if self.cudagraph: self.cudagraph.reset() + self.cudagraph = None + + def __del__(self) -> None: + self.reset_cudagraph() def setup_input_tensors( self, @@ -426,9 +430,8 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed ) - if need_cudagraphs_reset and self.cudagraph: - self.cudagraph.reset() - self.cudagraph = None + if need_cudagraphs_reset: + self.reset_cudagraph() if need_cudagraphs_record: self._input_buffers = [None] * len(self.input_names) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index e6b6a21421..e4adca58e3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -209,6 +209,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: return budget_bytes + def reset_cudagraph(self) -> None: + self.engine.reset_cudagraph() + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index f481c5b2b8..2cb5f4166d 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -142,6 +142,9 @@ def automatic_device_memory_budget_getter(self) -> Any: def infer_outputs(self, input_shapes: List[Any]) -> Any: pass + def reset_cudagraph(self) -> Any: + pass + def __setstate__(self, serialized_state: List[str]) -> Any: pass diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index c771564826..66c185096c 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Union +from typing import Any, Optional, Union import torch import torch_tensorrt @@ -68,6 +68,7 @@ def __init__(self, compiled_module: torch.nn.Module) -> None: global _PY_RT_CUDAGRAPHS self.old_mode = _PY_RT_CUDAGRAPHS self.compiled_module = compiled_module + self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None def __enter__(self) -> torch.nn.Module: global _PY_RT_CUDAGRAPHS @@ -98,7 +99,8 @@ def __enter__(self) -> torch.nn.Module: logger.debug( "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" ) - return CudaGraphsTorchTensorRTModule(self.compiled_module) + self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module) + return self.cudagraphs_module else: if num_trt_module > 0: logger.debug("No graph breaks detected, using runtime cudagraphs mode") @@ -113,6 +115,9 @@ def __enter__(self) -> torch.nn.Module: def __exit__(self, *args: Any) -> None: # Set cudagraphs back to old mode set_cudagraphs_mode(self.old_mode) + # __del__ is not entirely predictable, so we reset cudagraph here + if self.cudagraphs_module: + self.cudagraphs_module.reset_cudagraph() def enable_cudagraphs( diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 3b11087fcb..093a347760 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -76,12 +76,15 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int: int(streamable_bytes / total_bytes * requested_budget) for streamable_bytes in self.streamable_budget ] + if self.cuda_graphs_module: + self.cuda_graphs_module.is_weight_streaming_set = True + self.cuda_graphs_module.reset_cudagraph() + for i, (name, rt_mod) in enumerate(self.rt_mods): + rt_mod.reset_cudagraph() ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i]) logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}") - if self.cuda_graphs_module: - self.cuda_graphs_module.is_weight_streaming_set = True return ws_budget_bytes def __setattr__(self, name: str, value: Any) -> None: From 48a7d7e69643d445b9c80a53699e7fbb9da42e4e Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Thu, 3 Apr 2025 02:14:41 +0000 Subject: [PATCH 2/3] chore: Rename to reset_captured_graph --- core/runtime/TRTEngine.cpp | 2 +- core/runtime/TRTEngine.h | 2 +- core/runtime/register_jit_hooks.cpp | 2 +- .../dynamo/runtime/_CudaGraphsTorchTensorRTModule.py | 8 ++++---- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 6 +++--- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 4 ++-- .../dynamo/runtime/meta_ops/register_meta_ops.py | 2 +- py/torch_tensorrt/runtime/_cudagraphs.py | 2 +- py/torch_tensorrt/runtime/_weight_streaming.py | 4 ++-- 9 files changed, 16 insertions(+), 16 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index c57a66ac43..9a04aba6de 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -453,7 +453,7 @@ std::vector TRTEngine::serialize() { return serialized_info; } -void TRTEngine::reset_cudagraph() { +void TRTEngine::reset_captured_graph() { cudagraph.reset(); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index d6e4fc62d5..2db640b6b1 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -185,7 +185,7 @@ struct TRTEngine : torch::CustomClassHolder { // c10::List Run(c10::List inputs); void set_profiling_paths(); - void reset_cudagraph(); + void reset_captured_graph(); #ifndef NDEBUG bool profile_execution = true; #else diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 6d6824dca6..cbe19b0af6 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -88,7 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info) .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) - .def("reset_cudagraph", &TRTEngine::reset_cudagraph) + .def("reset_captured_graph", &TRTEngine::reset_captured_graph) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index c47a1b3bd2..167bdb14ef 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -103,13 +103,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: return False - def reset_cudagraph(self) -> None: + def reset_captured_graph(self) -> None: if self.cudagraph: self.cudagraph.reset() self.cudagraph = None def __del__(self) -> None: - self.reset_cudagraph() + self.reset_captured_graph() def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable @@ -123,7 +123,7 @@ def forward( shape_changed = self.validate_input_shapes(inputs) need_cudagraphs_record = shape_changed or self.is_weight_streaming_set if need_cudagraphs_record: - self.reset_cudagraph() + self.reset_captured_graph() self._input_buffers = [None] * len(inputs) self.is_weight_streaming_set = False @@ -199,5 +199,5 @@ def forward( return outputs[0] return outputs else: - self.reset_cudagraph() + self.reset_captured_graph() return self.compiled_module(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index acb83460fb..43598a2289 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -333,13 +333,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: result.__setstate__(self.__getstate__()) return result - def reset_cudagraph(self) -> None: + def reset_captured_graph(self) -> None: if self.cudagraph: self.cudagraph.reset() self.cudagraph = None def __del__(self) -> None: - self.reset_cudagraph() + self.reset_captured_graph() def setup_input_tensors( self, @@ -431,7 +431,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ) if need_cudagraphs_reset: - self.reset_cudagraph() + self.reset_captured_graph() if need_cudagraphs_record: self._input_buffers = [None] * len(self.input_names) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index e4adca58e3..5a364e7a39 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -209,8 +209,8 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: return budget_bytes - def reset_cudagraph(self) -> None: - self.engine.reset_cudagraph() + def reset_captured_graph(self) -> None: + self.engine.reset_captured_graph() def setup_engine(self) -> None: """ diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index 2cb5f4166d..1b6963fa50 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -142,7 +142,7 @@ def automatic_device_memory_budget_getter(self) -> Any: def infer_outputs(self, input_shapes: List[Any]) -> Any: pass - def reset_cudagraph(self) -> Any: + def reset_captured_graph(self) -> Any: pass def __setstate__(self, serialized_state: List[str]) -> Any: diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 66c185096c..8cc77a9c0b 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -117,7 +117,7 @@ def __exit__(self, *args: Any) -> None: set_cudagraphs_mode(self.old_mode) # __del__ is not entirely predictable, so we reset cudagraph here if self.cudagraphs_module: - self.cudagraphs_module.reset_cudagraph() + self.cudagraphs_module.reset_captured_graph() def enable_cudagraphs( diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 093a347760..914ebd8cb5 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -78,10 +78,10 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int: ] if self.cuda_graphs_module: self.cuda_graphs_module.is_weight_streaming_set = True - self.cuda_graphs_module.reset_cudagraph() + self.cuda_graphs_module.reset_captured_graph() for i, (name, rt_mod) in enumerate(self.rt_mods): - rt_mod.reset_cudagraph() + rt_mod.reset_captured_graph() ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i]) logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}") From 82eea13d0c83458ab9c8673bccb95654de6a1dd3 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Fri, 4 Apr 2025 00:58:00 +0000 Subject: [PATCH 3/3] chore: Rename to _reset_captured_graph() --- .../dynamo/runtime/_CudaGraphsTorchTensorRTModule.py | 8 ++++---- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 6 +++--- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 2 +- py/torch_tensorrt/runtime/_cudagraphs.py | 2 +- py/torch_tensorrt/runtime/_weight_streaming.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 167bdb14ef..9e54fbac3d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -103,13 +103,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: return False - def reset_captured_graph(self) -> None: + def _reset_captured_graph(self) -> None: if self.cudagraph: self.cudagraph.reset() self.cudagraph = None def __del__(self) -> None: - self.reset_captured_graph() + self._reset_captured_graph() def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable @@ -123,7 +123,7 @@ def forward( shape_changed = self.validate_input_shapes(inputs) need_cudagraphs_record = shape_changed or self.is_weight_streaming_set if need_cudagraphs_record: - self.reset_captured_graph() + self._reset_captured_graph() self._input_buffers = [None] * len(inputs) self.is_weight_streaming_set = False @@ -199,5 +199,5 @@ def forward( return outputs[0] return outputs else: - self.reset_captured_graph() + self._reset_captured_graph() return self.compiled_module(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 43598a2289..6415ce11c3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -333,13 +333,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: result.__setstate__(self.__getstate__()) return result - def reset_captured_graph(self) -> None: + def _reset_captured_graph(self) -> None: if self.cudagraph: self.cudagraph.reset() self.cudagraph = None def __del__(self) -> None: - self.reset_captured_graph() + self._reset_captured_graph() def setup_input_tensors( self, @@ -431,7 +431,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ) if need_cudagraphs_reset: - self.reset_captured_graph() + self._reset_captured_graph() if need_cudagraphs_record: self._input_buffers = [None] * len(self.input_names) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 5a364e7a39..c3fe925eee 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -209,7 +209,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: return budget_bytes - def reset_captured_graph(self) -> None: + def _reset_captured_graph(self) -> None: self.engine.reset_captured_graph() def setup_engine(self) -> None: diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 8cc77a9c0b..346132145e 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -117,7 +117,7 @@ def __exit__(self, *args: Any) -> None: set_cudagraphs_mode(self.old_mode) # __del__ is not entirely predictable, so we reset cudagraph here if self.cudagraphs_module: - self.cudagraphs_module.reset_captured_graph() + self.cudagraphs_module._reset_captured_graph() def enable_cudagraphs( diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 914ebd8cb5..0874d31d11 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -78,10 +78,10 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int: ] if self.cuda_graphs_module: self.cuda_graphs_module.is_weight_streaming_set = True - self.cuda_graphs_module.reset_captured_graph() + self.cuda_graphs_module._reset_captured_graph() for i, (name, rt_mod) in enumerate(self.rt_mods): - rt_mod.reset_captured_graph() + rt_mod._reset_captured_graph() ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i]) logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")