-
Notifications
You must be signed in to change notification settings - Fork 365
feat: TensorRT AOT Plugin #3504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/aot_plugin.py 2025-05-05 05:52:23.878918+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/aot_plugin.py 2025-05-05 05:52:44.176344+00:00
@@ -23,13 +23,11 @@
output = x + 1
tl.store(y_ptr + offsets, output, mask=mask)
@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc]
-def add_one(
- X: torch.Tensor
-) -> torch.Tensor:
+def add_one(X: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda
# Create output tensor
Y = torch.empty_like(X)
@@ -53,19 +51,22 @@
# torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
# "my::add_one"
# )
+
@trtp.register("my::add_one")
def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
return X.like()
+
@trtp.aot_impl("my::add_one")
def add_plugin_aot_impl(
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
-) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
-
+) -> Tuple[
+ Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
+]:
type_str = "fp32" if X.dtype == trt.float32 else "fp16"
block_size = 256
src = triton.compiler.ASTSource(
@@ -101,10 +102,11 @@
compiled_kernel.asm["ptx"],
launch_params,
extra_args,
)
+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"my::add_one",
supports_dynamic_shapes=False,
requires_output_allocator=False,
aot=True,
@@ -127,18 +129,15 @@
parser.add_argument(
"--aot", action="store_true", help="Try to use AOT compilation", default=False
)
args = parser.parse_args()
-
-
my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
# This works!
assert my_model(X=m)[0][0] == 3.0
-
with torch_tensorrt.logging.debug():
trt_inputs = [m]
model_trt = torch_tensorrt.compile(
my_model,
@@ -151,6 +150,6 @@
for i in range(10):
res = model_trt(m)
assert torch.allclose(res, my_model(m)), "Results do not match!"
print("Inference successful!")
- print(res)
\ No newline at end of file
+ print(res)
# "my::add_one" | ||
# ) | ||
|
||
@trtp.register("my::add_one") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not use torch_tensorrt.dynamo.conversion.custom_op
here?
"my::add_one", | ||
supports_dynamic_shapes=False, | ||
requires_output_allocator=False, | ||
aot=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think that we need 2 things. 1. there should be a flag something like use_aot_if_available
and then in generate_plugin_converter
a function that checks on the aot_impl
registration
@@ -80,7 +81,7 @@ def custom_kernel_converter( | |||
if isinstance(v, torch.fx.immutable_collections.immutable_list): | |||
kwargs[k] = np.array(v) | |||
|
|||
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs)) | |||
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=aot) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there should be a utility function that checks on aot_impl
registrations
Description
This PR demonstrates how to use AOT plugin in Torch-TensorRT
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: