Skip to content

Commit b153177

Browse files
ekagra-ranjanlk-chen
authored andcommitted
[V1][Spec Decode] Make Eagle model arch config driven (vllm-project#17323)
1 parent 16c4427 commit b153177

File tree

3 files changed

+26
-13
lines changed

3 files changed

+26
-13
lines changed

vllm/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2401,7 +2401,8 @@ def __post_init__(self):
24012401
pass
24022402
else:
24032403
eagle_config = EAGLEConfig(
2404-
self.draft_model_config.hf_config)
2404+
self.draft_model_config.hf_config,
2405+
method=self.method)
24052406
self.draft_model_config.hf_config = eagle_config
24062407

24072408
if (self.num_speculative_tokens is not None

vllm/transformers_utils/configs/eagle.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class EAGLEConfig(PretrainedConfig):
1515
def __init__(self,
1616
model: Union[PretrainedConfig, dict, None] = None,
1717
truncated_vocab_size: Optional[int] = None,
18+
method: Optional[str] = 'eagle',
1819
**kwargs):
1920

2021
model_config: Union[PretrainedConfig, DeepseekV2Config, None]
@@ -45,7 +46,23 @@ def __init__(self,
4546
if not envs.VLLM_USE_V1:
4647
kwargs["architectures"] = ["EAGLEModel"]
4748
else:
48-
kwargs["architectures"] = ["EagleLlamaForCausalLM"]
49+
# Eagle model name should follow naming convention of
50+
# LlamaForCausalLM -> EagleLlamaForCausalLM
51+
if method == "eagle":
52+
assert self.model is not None, \
53+
"model should not be None when method is eagle"
54+
kwargs["architectures"] = [
55+
f"Eagle{arch}" for arch in self.model.architectures
56+
]
57+
elif method == "eagle3":
58+
assert self.model is not None, \
59+
"model should not be None when method is eagle3"
60+
kwargs["architectures"] = [
61+
f"Eagle3{arch}" for arch in self.model.architectures
62+
]
63+
else:
64+
raise ValueError(f"Invalid method {method}. \
65+
Supported methods are eagle and eagle3.")
4966

5067
super().__init__(**kwargs)
5168

vllm/v1/spec_decode/eagle.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from vllm.logger import init_logger
1010
from vllm.model_executor.model_loader.loader import get_model_loader
1111
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
12-
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
13-
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
12+
from vllm.model_executor.models import ModelRegistry
1413
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
1514
from vllm.v1.sample.metadata import SamplingMetadata
1615

@@ -225,15 +224,11 @@ def load_model(self, target_model: nn.Module) -> None:
225224
with set_default_torch_dtype(
226225
draft_model_config.dtype), set_current_vllm_config(
227226
self.vllm_config):
228-
if self.vllm_config.speculative_config.method == "eagle":
229-
self.model = EagleLlamaForCausalLM(
230-
model_config=draft_model_config,
231-
start_layer_id=target_layer_num).to(target_device)
232-
else:
233-
assert self.vllm_config.speculative_config.method == "eagle3"
234-
self.model = Eagle3LlamaForCausalLM(
235-
model_config=draft_model_config,
236-
start_layer_id=target_layer_num).to(target_device)
227+
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
228+
draft_model_config.architectures)
229+
self.model = draft_model_cls(
230+
model_config=draft_model_config,
231+
start_layer_id=target_layer_num).to(target_device)
237232

238233
loaded_weights = self.model.load_weights(
239234
loader.get_all_weights(

0 commit comments

Comments
 (0)