Skip to content

Commit 631cfc2

Browse files
committed
Tighten compilation cache invariants around eagle
I'm recording down my understanding of how eagle and the compilation cache works after discussing vllm-project#17211 with @luyuzhe111 and @WoosukKwon. In the future we likely will have a situation where we want to torch.compile multiple pieces of code (e.g. decoder and encoder separately) and then we'll need to refactor the system to support it (each compiled region needs its own cache directory with its own hash) But until then the current design seems fine. Signed-off-by: rzou <zou3519@gmail.com>
1 parent f62cad6 commit 631cfc2

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

vllm/compilation/backends.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
415415
self.compilation_config.cache_dir = cache_dir
416416

417417
if compilation_counter.num_graphs_seen > 0:
418+
# NOTE: Eagle3 compilation
419+
# The eagle3 head is a separate model that gets run, so it needs
420+
# its own cache dir (each cache dir is 1:1 with a model.forward).
421+
# The eagle3 head does not need its own hash; the hash of the
422+
# original model entirely determines the config of the eagle3 head.
423+
#
424+
# If you are here because you are using multiple torch.compile
425+
# calls in a single model, please open an issue and let's discuss.
426+
speculative_config = self.vllm_config.speculative_config
427+
assert speculative_config is not None
428+
assert speculative_config.method == "eagle3"
429+
418430
cache_dir = self.compilation_config.cache_dir + \
419431
f'-{compilation_counter.num_graphs_seen}'
420432
else:

0 commit comments

Comments
 (0)