Skip to content

Commit 047f2b2

Browse files
suyogguptakaiyux
andauthored
perf: [AutoDeploy] Enable AutoDeploy as a backend in trtllm-bench (#3041)
* Enable AutoDeploy as a backend in trtllm-bench Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * update how caches are resized Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * fix: files permission from 100755 to 100644 Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * some comments Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * lint Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * lint Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * lint Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * lint Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * Fix function name Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * refactor Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * Remove spurious change Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * Add cursor generated doc strings Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * re-enable ad test Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * some perf cleanup Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * debug ci Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * ensure that overlap scheduler is enabled Signed-off-by: Suyog Gupta <suyogg@nvidia.com> * Reorder the tests Signed-off-by: Suyog Gupta <suyogg@nvidia.com> --------- Signed-off-by: Suyog Gupta <suyogg@nvidia.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
1 parent 3e035f2 commit 047f2b2

File tree

14 files changed

+180
-51
lines changed

14 files changed

+180
-51
lines changed

examples/auto_deploy/build_and_run_ad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def build_llm_from_config(config: SimpleConfig) -> LLM:
4444
model_kwargs=config.model_kwargs,
4545
attn_backend=config.attn_backend,
4646
skip_loading_weights=config.skip_loading_weights,
47+
cuda_graph_max_batch_size=config.max_batch_size,
4748
)
4849
ad_logger.info(f"AutoDeploy Config: {ad_config}")
4950

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_opt.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515

1616
class CompiledGraph(nn.Module):
17-
def __init__(self, model: GraphModule, max_batch_size: int):
17+
def __init__(
18+
self, model: GraphModule, max_batch_size: int, cuda_graph_batch_sizes: List[int] = None
19+
):
1820
super().__init__()
1921
self._in_spec: TreeSpec = model._in_spec
2022
self._out_spec: TreeSpec = model._out_spec
@@ -24,6 +26,11 @@ def __init__(self, model: GraphModule, max_batch_size: int):
2426
self._input_buffer: torch.Tensor = torch.empty(0, 1)
2527
self._out_buffer_flat: List[torch.Tensor] = None
2628
self._args_hash: Optional[Tuple[int, ...]] = None
29+
self.cuda_graph_batch_sizes = (
30+
cuda_graph_batch_sizes
31+
if cuda_graph_batch_sizes is not None
32+
else self._get_graph_batch_sizes(self.max_batch_size)
33+
)
2734

2835
def _get_hash(self, flat_args: List[Any]) -> Tuple[int, ...]:
2936
return tuple(hash(a) for a in flat_args)
@@ -90,7 +97,7 @@ def _capture_cudagraph(self, input_t: torch.Tensor, flat_args: List[Any]):
9097
assert out_spec == self._out_spec, "Output spec mismatch."
9198

9299
# capture graph now for a range of batch sizes
93-
for bs in self._get_graph_batch_sizes(self.max_batch_size):
100+
for bs in self.cuda_graph_batch_sizes:
94101
ad_logger.info(f"Capturing graph for batch size: {bs}")
95102

96103
# setup args, kwargs
@@ -131,7 +138,12 @@ def forward(self, *args, **kwargs) -> Any:
131138
class TorchOptCompiler(BackendCompiler):
132139
@torch.inference_mode()
133140
def compile(self) -> CompiledGraph:
134-
compiled_gm = CompiledGraph(self.gm, max_batch_size=self.max_batch_size)
141+
cuda_graph_batch_sizes = self.compiler_kwargs.get("cuda_graph_batch_sizes", None)
142+
compiled_gm = CompiledGraph(
143+
self.gm,
144+
max_batch_size=self.max_batch_size,
145+
cuda_graph_batch_sizes=cuda_graph_batch_sizes,
146+
)
135147

136148
# try capturing cudagraph
137149
if self.args is not None or self.kwargs is not None:

tensorrt_llm/_torch/auto_deploy/compile/compiler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ def __init__(
5555
args: Tuple[Any, ...],
5656
kwargs: Optional[Dict[str, Any]] = None,
5757
dynamic_shapes=None,
58+
compiler_kwargs: Optional[Dict[str, Any]] = None,
5859
):
5960
self.gm = gm
6061
self.args = args
6162
self.kwargs = kwargs or {}
6263
self.dynamic_shapes = dynamic_shapes
63-
64+
self.compiler_kwargs = compiler_kwargs or {}
6465
# identify max_batch_size
6566
if self.dynamic_shapes is not None and 0 in self.dynamic_shapes[0]:
6667
self.max_batch_size = self.dynamic_shapes[0][0].max
@@ -79,13 +80,16 @@ def compile_and_capture(
7980
args: Tuple[Any, ...],
8081
kwargs: Optional[Dict[str, Any]] = None,
8182
dynamic_shapes=None,
83+
compiler_kwargs: Optional[Dict[str, Any]] = None,
8284
) -> nn.Module:
8385
"""Compile or capture graph for single-token generation."""
8486
elapsed_time = -time.time()
87+
ad_logger.info("Fusion before compiling...")
88+
8589
ad_logger.info(f"Compiling for {backend} backend...")
8690

8791
compiler_cls = BackendRegistry.get(backend)
88-
compiled_module = compiler_cls(gm, args, kwargs, dynamic_shapes).compile()
92+
compiled_module = compiler_cls(gm, args, kwargs, dynamic_shapes, compiler_kwargs).compile()
8993

9094
elapsed_time += time.time()
9195
ad_logger.info(f"Compile time with backend {backend}: {elapsed_time:.6f} seconds")

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
is also responsible for functionalizing information about the sequence and pass it on the the
77
various attention interface. The AttentionDescriptor is the main interface to the attention operator
88
and operates on a purely functional paradigm that is compatible with the torch custom op system.
9+
910
"""
1011

1112
from abc import ABC, abstractmethod
@@ -121,7 +122,9 @@ def __post_init__(self):
121122
self.page_size = self.max_seq_len
122123
if self.max_num_tokens < 1:
123124
self.max_num_tokens = self.max_batch_size * self.max_seq_len
124-
total_tokens = self.max_batch_size * self.page_size
125+
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
126+
# we use the provided max_num_tokens to calculate the number of pages
127+
total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len)
125128
self._num_pages = (total_tokens) // self.page_size + (total_tokens % self.page_size > 0)
126129
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
127130
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
@@ -191,6 +194,12 @@ def is_generate(self) -> bool:
191194
def num_pages(self) -> int:
192195
return self._num_pages
193196

197+
@num_pages.setter
198+
def num_pages(self, value):
199+
self._num_pages = value
200+
# update the cache_loc tensor
201+
self.cache_loc.resize_(value)
202+
194203
@property
195204
def is_paged(self) -> bool:
196205
return self.page_size < self.max_seq_len
@@ -306,6 +315,19 @@ def _set_example_sequence(self) -> None:
306315
self.nest_sequences(input_ids)
307316
self.input_ids = input_ids
308317

318+
def _set_max_num_tokens_sample(self) -> None:
319+
"""Set an example sequence with max_num_tokens."""
320+
self.reset()
321+
seq_len = self.max_num_tokens // self.max_batch_size
322+
input_ids = torch.ones(
323+
self.max_batch_size,
324+
seq_len,
325+
dtype=torch.int,
326+
device=self.device,
327+
)
328+
self.pages_per_seq.fill_(seq_len // self.page_size)
329+
self.nest_sequences(input_ids)
330+
309331
def _set_generate_only_batch(self) -> None:
310332
"""Set an example sequence for generate-only batch."""
311333
self.reset()
@@ -319,16 +341,14 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
319341
# set new sequence lengths
320342
seq_lens = [len(ids) for ids in input_ids]
321343
self.seq_len.zero_()
322-
self.seq_len[: len(seq_lens)] = torch.tensor(seq_lens, device=self.device)
344+
self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
323345

324346
# set new input_ids as new tensor from flattened input_ids
325347
ids_tnsr_list = [
326-
lst.detach().to(self.device)
327-
if isinstance(lst, torch.Tensor)
328-
else torch.tensor(lst, dtype=torch.int, device=self.device)
348+
lst.detach() if isinstance(lst, torch.Tensor) else torch.tensor(lst, dtype=torch.int)
329349
for lst in input_ids
330350
]
331-
self.input_ids = torch.cat(ids_tnsr_list, dim=0)
351+
self.input_ids = torch.cat(ids_tnsr_list, dim=0).to(self.device)
332352

333353
# set derivative properties
334354
self._sequence_lengths = seq_lens
@@ -362,10 +382,10 @@ def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
362382
cache_loc_flat = torch.tensor(
363383
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
364384
)
365-
self.cache_loc[: len(cache_loc_flat)] = cache_loc_flat.to(self.device)
385+
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
366386

367387
pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int)
368-
self.pages_per_seq[: len(pages_per_seq)] = pages_per_seq.to(self.device)
388+
self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)
369389

370390

371391
Constant = Union[int, float, str, None]

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

100755100644
Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def calculate_max_num_blocks(
6262
# TODO (lliebenwein): this is VERY hacky... Ideally, we want to compute the number of blocks
6363
# just like in the original implementation. However, let's wait for the layer-wise attention
6464
# implementation before over-optimizing the function here
65+
ad_logger.info("Using fake cache manager with head_dim=0 and num pages:", self.num_blocks)
6566
return self.num_blocks, 0
6667

6768

@@ -86,6 +87,7 @@ def build_from_config(
8687
device: DeviceLikeType,
8788
):
8889
"""Build the ADEngine using the AutoDeployConfig that gets passed through from the LLM."""
90+
8991
# construct model factory
9092
model_kwargs = {"max_position_embeddings": seq_info.max_seq_len, **ad_config.model_kwargs}
9193
factory = ModelFactoryRegistry.get("hf")(
@@ -95,15 +97,7 @@ def build_from_config(
9597
)
9698

9799
# construct inference optimizer
98-
# TODO (lliebenwein): let's split up the compile backend to separately handle cuda graph
99-
# and torch compile so we can follow the PyTorchConfig here and enable it separately.
100-
if ad_config.use_cuda_graph or ad_config.torch_compile_enabled:
101-
compile_backend = "torch-opt"
102-
else:
103-
compile_backend = "torch-simple"
104-
build_and_optimize = InferenceOptimizer(
105-
factory=factory, attn_backend=ad_config.attn_backend, compile_backend=compile_backend
106-
)
100+
build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config)
107101

108102
# construct engine
109103
engine = cls(build_and_optimize, seq_info, device)

tensorrt_llm/_torch/auto_deploy/shim/interface.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,26 @@ def initialize_caches(self) -> None:
4545
name: get_cache(self.info) for name, get_cache in self._cache_initializers.items()
4646
}
4747

48+
def current_cache_size_bytes(self) -> int:
49+
"""Calculate and return the total size of all caches in bytes."""
50+
total_size = 0
51+
for name, cache in self._caches.items():
52+
# this hack is needed since _caches also contains global buffers such as freqs_cis.
53+
if "cache" in name:
54+
total_size += cache.element_size() * cache.numel()
55+
return total_size
56+
57+
def resize_cache(self, new_num_pages: int):
58+
"""Resize the cache to the new number of pages."""
59+
# TODO: We should do some sanity check on the new number of pages.
60+
self.info.num_pages = new_num_pages
61+
for name, cache in self._caches.items():
62+
# We assume cache is a tensor of shape (max_batch_size, page_size, n_heads, head_dim)
63+
if "cache" in name:
64+
current_shape = cache.shape
65+
new_shape = (new_num_pages, *current_shape[1:])
66+
cache.resize_(new_shape)
67+
4868

4969
GetInferenceModel = Callable[[CachedSequenceInterface], nn.Module]
5070

tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,43 @@ def insert_mha_with_kv_cache(
171171
egm = canonicalize_graph(egm, shape_prop=False)
172172
ad_logger.debug("After inserting MHA with KV cache: " + str(egm))
173173
return egm
174+
175+
176+
def resize_kv_cache(
177+
egm: GraphModule, cm: CachedSequenceInterface, free_mem_ratio: float = 0.8
178+
) -> None:
179+
"""Inflate the kv cache to occupy the available GPU memory.
180+
181+
free_mem_ratio specifies the fraction of available memory to occupy.
182+
"""
183+
free_mem, total_mem = torch.cuda.mem_get_info()
184+
ad_logger.info(f"Free memory: {free_mem}, Total memory: {total_mem}")
185+
current_cache_size = cm.current_cache_size_bytes()
186+
current_num_pages = cm.info.num_pages
187+
ad_logger.info(
188+
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
189+
)
190+
191+
try:
192+
# Let's run a forward pass to get the memory usage
193+
cm.info._set_max_num_tokens_sample()
194+
free_mem_pre, _ = torch.cuda.mem_get_info()
195+
ad_logger.info(f"Free memory before forward pass: {free_mem_pre}")
196+
egm(*cm.args)
197+
free_mem_post, _ = torch.cuda.mem_get_info()
198+
ad_logger.info(f"Free memory after forward pass: {free_mem_post}")
199+
200+
memory_for_forward_pass = free_mem_pre - free_mem_post
201+
ad_logger.info(f"Memory for forward pass: {memory_for_forward_pass}")
202+
203+
new_cache_size = free_mem_post * free_mem_ratio + current_cache_size
204+
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
205+
ad_logger.info(f"New cache size: {new_cache_size}, New num pages: {new_num_pages}")
206+
cm.resize_cache(new_num_pages)
207+
except Exception as e:
208+
ad_logger.warning(
209+
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
210+
)
211+
212+
# Free memory
213+
torch.cuda.empty_cache()

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""High-level entrypoint to transform a model into an efficient inference model."""
22

3+
import gc
4+
5+
import torch
36
from torch.fx import GraphModule
47

58
from ..compile import compile_and_capture
69
from ..custom_ops.attention_interface import AttentionRegistry
710
from ..distributed import common as dist_ad
811
from ..models.factory import ModelFactory
9-
from ..shim.interface import CachedSequenceInterface
12+
from ..shim.interface import AutoDeployConfig, CachedSequenceInterface
1013
from ..utils.logger import ad_logger
1114
from ._graph import move_to_device
1215
from .export import torch_export_to_gm
@@ -21,6 +24,7 @@
2124
insert_mha_with_kv_cache,
2225
match_moe_pattern,
2326
quantize,
27+
resize_kv_cache,
2428
)
2529

2630

@@ -29,12 +33,18 @@ def __init__(
2933
self,
3034
factory: ModelFactory,
3135
*, # TODO (lliebenwein): temporary until we have a better config system
32-
attn_backend: str,
33-
compile_backend: str,
36+
ad_config: AutoDeployConfig,
3437
visualize: bool = False,
3538
):
3639
self.factory = factory
37-
self.attn_backend = attn_backend
40+
self.attn_backend = ad_config.attn_backend
41+
# TODO (lliebenwein): let's split up the compile backend to separately handle cuda graph
42+
# and torch compile so we can follow the PyTorchConfig here and enable it separately.
43+
self.ad_config = ad_config
44+
if ad_config.use_cuda_graph or ad_config.torch_compile_enabled:
45+
compile_backend = "torch-opt"
46+
else:
47+
compile_backend = "torch-simple"
3848
self.compile_backend = compile_backend
3949
self.visualize = visualize
4050

@@ -103,6 +113,7 @@ def __call__(self, cm: CachedSequenceInterface) -> GraphModule:
103113
# initialize caches, load weights, and map to correct device
104114
cm.initialize_caches()
105115

116+
# load weights
106117
self.factory.load_or_random_init(egm, mmap=True, map_location=cm.device)
107118
move_to_device(egm, cm.device)
108119

@@ -135,14 +146,27 @@ def __call__(self, cm: CachedSequenceInterface) -> GraphModule:
135146
except ImportError:
136147
pass
137148

149+
############################################################################################
150+
# RESIZE CACHE
151+
############################################################################################
152+
# Free memory ratio is hardcoded to 0.8 for now to ensure we have enough memory for graph capture.
153+
resize_kv_cache(egm, cm, free_mem_ratio=0.8)
154+
138155
############################################################################################
139156
# COMPILE MODEL
140157
############################################################################################
141158

142159
cm.info._set_generate_only_batch()
160+
compiler_kwargs = {"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes}
143161
egm_compiled = compile_and_capture(
144-
egm, self.compile_backend, args=cm.args, dynamic_shapes=cm.dynamic_shapes
162+
egm,
163+
self.compile_backend,
164+
args=cm.args,
165+
dynamic_shapes=cm.dynamic_shapes,
166+
compiler_kwargs=compiler_kwargs,
145167
)
146168
cm.info.reset()
147169

170+
torch.cuda.empty_cache()
171+
gc.collect()
148172
return egm_compiled

tensorrt_llm/bench/benchmark/throughput.py

100644100755
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
help="Path to a serialized TRT-LLM engine.",
4242
)
4343
@optgroup.option("--backend",
44-
type=click.Choice(["pytorch"]),
44+
type=click.Choice(["pytorch", "autodeploy"]),
4545
default=None,
4646
help="Set to 'pytorch' for pytorch path. Default is cpp path.")
4747
@optgroup.option(
@@ -209,7 +209,7 @@ def throughput_command(
209209
logger.info(metadata.get_summary_for_print())
210210

211211
# Engine configuration parsing
212-
if backend and backend.lower() == "pytorch":
212+
if backend and backend.lower() in ["pytorch", "autodeploy"]:
213213
exec_settings = get_settings(params, metadata, bench_env.model,
214214
bench_env.checkpoint_path)
215215
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
@@ -262,6 +262,8 @@ def throughput_command(
262262
try:
263263
logger.info("Setting up throughput benchmark.")
264264
kwargs = kwargs | runtime_config.get_llm_args()
265+
kwargs['backend'] = backend
266+
265267
if runtime_config.backend == 'pytorch':
266268
llm = PyTorchLLM(**kwargs)
267269
else:

0 commit comments

Comments
 (0)