-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Implement DualChunkAttention for Qwen2 Models #6139
base: main
Are you sure you want to change the base?
Conversation
Can you discuss why is this the case? If possible i would really appreciate that we get the first iteration working with cuda graph. @WoosukKwon: Also, IMO the best place to put this actually in FlashAttentionBackend and accept extra arguments for chunked config, but would like to to hear your thought |
@simon-mo The brute-force implementation of DCA (https://github.com/HKUNLP/ChunkLlama/blob/main/chunkllama_attn_replace.py#L169-L199) splits the entire sequence into several chunks and invokes the flashattention kernel 3*
@WoosukKwon The primary issue here lies in the argument list of |
Is If not, I would strongly prefer that the logic for this feature be implemented inside of |
Expect to the implementation of CUDA kernel😊 |
@robertgshaw2-neuralmagic No, it can be applied to all most models using RoPE. |
@nanmi I'm actually seeking helps from the community to implement the CUDA kernel. I'm not an expert. 😅 |
For instance, we have the following code: if dual_chunk_attention_config is not None:
self.attn = DualChunkAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
dual_chunk_attention_config=dual_chunk_attention_config)
else:
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config) And: if self.dual_chunk_attention_config is None:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
else:
q, q_succ, q_inter, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, q_succ, q_inter, k, v, kv_cache, attn_metadata) If we could push this branching into |
@robertgshaw2-neuralmagic DCA requires three queries ( |
/
Can you elaborate on your ideas? I implemented DCA based on the python code logic and CUDA C++, but I think it can be more elegant and efficient logically. I am worried about how to optimize the segmentation processing and the final merging. |
@nanmi The functions Taking |
I'll follow your idea and think about how to use CUDA to implement |
@nanmi Really appreciate your help! If you encounter any questions, feel free to ask here |
@@ -682,6 +699,8 @@ def _prepare_model_input_tensors( | |||
slot_mapping_tensor = torch.tensor(slot_mapping, | |||
dtype=torch.long, | |||
device=self.device) | |||
prefill_original_seq_lens_tensor = torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to push the creation of this tensor into attn_backend.make_metadata
?
self.kv_cache_dtype, | ||
self.block_size, | ||
) if num_attn_heads else None | ||
if getattr(self.model_config.hf_config, "dual_chunk_attention_config", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add dual_chunk_attention
to ModelConfig
to accessing the hf_config
from here. We ran into problems with accessing the hf_config
with Sliding Window in the past
) if num_attn_heads else None | ||
if getattr(self.model_config.hf_config, "dual_chunk_attention_config", | ||
None): | ||
self.attn_backend = get_dual_chunk_attn_backend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this has to be a different function vs passing a self.model_config.get_dual_chunk_attention
to get_attn_backend
Apologies for the delay in reviewing here @simon-mo @WoosukKwon and I discussed, and while we do not love implementing features in specific models, I think that we do not have the right abstractions in place yet to enable simultaneously modifying I left a few comments on I think that this PR also needs some tests:
|
seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len | ||
max_seq_len_inter = seq_lens_inter.max().item() | ||
if max_seq_len_inter: | ||
inter_output, succ_softmax_lse = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be inter_softmax_lse
instead of succ_softmax_lse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out the typo. Will revise in the next commit.
chunk_len = chunk_size - local_size | ||
if chunk_len % block_size != 0: | ||
raise ValueError("chunk_len must be divisible by block_size.") | ||
chunk_num_curr = (cache_seqlens - 1) // chunk_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line makes the actual length of intra-chunk
smaller than chunk_len
in the variable length situation. It differs from the algorithm in the prefill stage with a fixed length of chunk_len
. Could you explain the reason for this difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rejoicesyc The implementation is consistent with the original code (https://github.com/HKUNLP/ChunkLlama/blob/8a28f1464b2a5def03eb07e8d91f5bf4d00f667d/chunkllama_attn_replace.py#L167). See L167, L171, L173, L197.
Since the sequence length may not be divisible by chunk_len
, there has to be a chunk that has a smaller length. In DCA algorithm, the last chunk (i.e., the intra chunk) may have length <= chunk_len
while the previous chunk has a fixed length of chunk_len
.
1.0).clip(min=1) | ||
query = (query * mscale.view(-1, 1, 1, 1)).to( | ||
query.dtype | ||
) # possible for numerical issue, need to fused in the kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @hzhwcmhf, could you explain this numerical issue and why kernel fusion can solve it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rejoicesyc We want to compute (query @ key) * softmax_scale * mscale
as the attention weight, where query
is represented in float16
or bfloat16
, and mscale
is represented in float32
. Better precision can be achieved if the scaling is computed in float32
, rather than multiplying query
by mscale
before the attention operation. The flash attention kernel provides an argument softmax_scale
, which is close to our requirement. However, softmax_scale
is a constant for the entire query, whereas mscale
is a vector that specifies different softmax_scale
values for each query vector at different positions.
This pull request has merge conflicts that must be resolved before it can be |
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
Overview
Dual Chunk Attention is a training-free method to extend model context length. It splits the entire sequence into chunks and uses three distinct query vectors to capture relative information within the same chunk, between successive chunks, and between distant chunks. The original implementation (using Hugging Face's Transformers) can be found here.
Qwen2 models now integrate this method, supporting 1M context length processing on 8x H100-80G GPUs. Qwen2-72B-Instruct achieves ~75% accuracy in the Needle in A Haystack test with 1M tokens input.
Features:
3*chunk_num
times).Limitations:
enforce_eager
must be set to True due to the dynamic graph created by the brute-force implementation.Changes
DualChunkRotaryEmbedding
invllm/model_executor/layers/rotary_embedding.py
. The functionDualChunkRotaryEmbedding.forward
returnsquery, query_succ, query_inter
instead of a singlequery
. These three query vectors are used for computing intra-/succ-/inter-attention in Dual Chunk Attention.DualChunkAttentionBackend
invllm/attention/backends/abstract.py
and implementDualChunkFlashAttentionBackend
invllm/attention/backends/dual_chunk_flash_attn.py
. Note that we add an extra variableprefill_original_seq_lens_tensor
inDualChunkFlashAttentionMetadata
, which stores the whole prefill sequences' lengths. To obtain the value, we insert a few lines invllm/model_runner.py
.DualChunkAttention
invllm/vllm/attention/layer.py
, which simply callsDualChunkAttentionBackend
.dual_chunk_attention_config
inQwen2Model
andQwen2MoeModel
.How to use
config.json
file following:You will see outputs like: