Skip to content

Commit 41f215f

Browse files
committed
Tighten compilation cache invariants around eagle
I'm recording down my understanding of how eagle and the compilation cache works after discussing vllm-project#17211 with @luyuzhe111 and @WoosukKwon. In the future we likely will have a situation where we want to torch.compile multiple pieces of code (e.g. decoder and encoder separately) and then we'll need to refactor the system to support it (each compiled region needs its own cache directory with its own hash) But until then the current design seems fine. Signed-off-by: rzou <zou3519@gmail.com>
1 parent f62cad6 commit 41f215f

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

vllm/compilation/backends.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,22 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
415415
self.compilation_config.cache_dir = cache_dir
416416

417417
if compilation_counter.num_graphs_seen > 0:
418+
# NOTE: Eagle compilation
419+
# The eagle head is a separate model that gets run, so it needs
420+
# its own cache dir (each cache dir is 1:1 with a model.forward).
421+
#
422+
# We currently assume that the eagle head does not need its own
423+
# hash: in the vLLM repo, the hash of the original model currently
424+
# entirely determines the config of the eagle head.
425+
# It's very possible that this assumption will change in the
426+
# future and we'll need to update this code.
427+
#
428+
# If you are here because you are using multiple torch.compile
429+
# calls in a single model, please open an issue and let's discuss.
430+
speculative_config = self.vllm_config.speculative_config
431+
assert speculative_config is not None
432+
assert speculative_config.method.use_eagle()
433+
418434
cache_dir = self.compilation_config.cache_dir + \
419435
f'-{compilation_counter.num_graphs_seen}'
420436
else:

0 commit comments

Comments
 (0)