Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Implement DualChunkAttention for Qwen2 Models #6139

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

hzhwcmhf
Copy link
Contributor

@hzhwcmhf hzhwcmhf commented Jul 4, 2024

Overview

Dual Chunk Attention is a training-free method to extend model context length. It splits the entire sequence into chunks and uses three distinct query vectors to capture relative information within the same chunk, between successive chunks, and between distant chunks. The original implementation (using Hugging Face's Transformers) can be found here.

Qwen2 models now integrate this method, supporting 1M context length processing on 8x H100-80G GPUs. Qwen2-72B-Instruct achieves ~75% accuracy in the Needle in A Haystack test with 1M tokens input.

Features:

  • Brute-force implementation of Dual Chunk Attention (calling flash-attention 3*chunk_num times).
  • Chunked prefill support. (Recommended, otherwise OOM may happen.)

Limitations:

  • Only Flash-Attention backend is available. Quantized KV Cache is not currently supported.
  • No CUDA kernel for performance optimization. Processing a 1M sample with Qwen2-72B-Instruct takes approximately 10 minutes on 8x H100-80G GPUs.
  • enforce_eager must be set to True due to the dynamic graph created by the brute-force implementation.

Changes

  • Add DualChunkRotaryEmbedding in vllm/model_executor/layers/rotary_embedding.py. The function DualChunkRotaryEmbedding.forward returns query, query_succ, query_inter instead of a single query. These three query vectors are used for computing intra-/succ-/inter-attention in Dual Chunk Attention.
  • Add an abstract class DualChunkAttentionBackend in vllm/attention/backends/abstract.py and implement DualChunkFlashAttentionBackend in vllm/attention/backends/dual_chunk_flash_attn.py. Note that we add an extra variable prefill_original_seq_lens_tensor in DualChunkFlashAttentionMetadata, which stores the whole prefill sequences' lengths. To obtain the value, we insert a few lines in vllm/model_runner.py.
  • Add DualChunkAttention in vllm/vllm/attention/layer.py, which simply calls DualChunkAttentionBackend.
  • Introduce dual_chunk_attention_config in Qwen2Model and Qwen2MoeModel.

How to use

  1. After downloading the model weights, modify the config.json file following:
    {
        "architectures": [
            "Qwen2ForCausalLM" // or Qwen2MoeForCausalLM
        ],
        // ...
        "max_position_embeddings": 1000000, // Modify the max model length

        // adding the following snippets
        "dual_chunk_attention_config": {
            "chunk_size": 32768, // specify the chunk and local size of Dual Chunk Attention
            "local_size": 2048,
            "original_max_position_embeddings": 32768 // used for YARN in Dual Chunk Attention
        }
    }
  1. Run the following codes:
from vllm import LLM, SamplingParams

path = "/path/to/model"

llm = LLM(model=path, tensor_parallel_size=4, enforce_eager=True, enable_chunked_prefill=True, max_num_batched_tokens=32768)

prompt = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is 28884. Remember it. 28884 is the pass key.\n " + \
    "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 8000 + \
    "\nWhat is the pass key?" # The prompt is 192k long. You can try longer prompt by replacing 8000 by 40000.

prompts = [
    f"<|im_start|>user\n{prompt1}<|im_end|>\n<|im_start|>assistant\n",
]
sampling_params = SamplingParams(top_p=0.8, temperature=0.7, repetition_penalty=1.05, top_k=20, max_tokens=100)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Generated text: {generated_text!r}")

You will see outputs like:

Generated text: 'The pass key is 28884.'

@simon-mo
Copy link
Collaborator

simon-mo commented Jul 5, 2024

enforce_eager must be set to True due to the dynamic graph created by the brute-force implementation.

Can you discuss why is this the case? If possible i would really appreciate that we get the first iteration working with cuda graph.

@WoosukKwon: Also, IMO the best place to put this actually in FlashAttentionBackend and accept extra arguments for chunked config, but would like to to hear your thought

@hzhwcmhf
Copy link
Contributor Author

hzhwcmhf commented Jul 6, 2024

Can you discuss why is this the case? If possible i would really appreciate that we get the first iteration working with cuda graph.

@simon-mo The brute-force implementation of DCA (https://github.com/HKUNLP/ChunkLlama/blob/main/chunkllama_attn_replace.py#L169-L199) splits the entire sequence into several chunks and invokes the flashattention kernel 3*chunk_num times. Since chunk_num varies dynamically based on the sequence length, utilizing CUDA graphs may not be feasible in this context. I think it might be necessary to implement new CUDA kernels for this purpose.

Also, IMO the best place to put this actually in FlashAttentionBackend and accept extra arguments for chunked config, but would like to to hear your thought

@WoosukKwon The primary issue here lies in the argument list of FlashAttentionBackend.__forward__. DCA requires three queries (query, query_succ, query_inter), which are not compatible with the current function signature.

@robertgshaw2-redhat
Copy link
Collaborator

Is DualChunkAttention specific to Qwen?

If not, I would strongly prefer that the logic for this feature be implemented inside of Attention rather than inside the Qwen model files so that it can be shared across models. I have recently been spending a lot of time unwinding a lot of work associated with MoE utilities (including fp8) being too tightly integrated into the Mixtral implementation and adding these types of features only on specific models can become very hard to maintain

@nanmi
Copy link

nanmi commented Jul 8, 2024

Can you discuss why is this the case? If possible i would really appreciate that we get the first iteration working with cuda graph.

@simon-mo The brute-force implementation of DCA (https://github.com/HKUNLP/ChunkLlama/blob/main/chunkllama_attn_replace.py#L169-L199) splits the entire sequence into several chunks and invokes the flashattention kernel 3*chunk_num times. Since chunk_num varies dynamically based on the sequence length, utilizing CUDA graphs may not be feasible in this context. I think it might be necessary to implement new CUDA kernels for this purpose.

Also, IMO the best place to put this actually in FlashAttentionBackend and accept extra arguments for chunked config, but would like to to hear your thought

@WoosukKwon The primary issue here lies in the argument list of FlashAttentionBackend.__forward__. DCA requires three queries (query, query_succ, query_inter), which are not compatible with the current function signature.

Expect to the implementation of CUDA kernel😊

@hzhwcmhf
Copy link
Contributor Author

hzhwcmhf commented Jul 8, 2024

Is DualChunkAttention specific to Qwen?
If not, I would strongly prefer that the logic for this feature be implemented inside of Attention rather than inside the Qwen model files so that it can be shared across models. I have recently been spending a lot of time unwinding a lot of work associated with MoE utilities (including fp8) being too tightly integrated into the Mixtral implementation and adding these types of features only on specific models can become very hard to maintain

@robertgshaw2-neuralmagic No, it can be applied to all most models using RoPE.
Actually, the core parts of DCA are implemented in vllm/attention/* and vllm/model_executor/layers/rotary_embedding.py. However, I think the model files have to be modified for passing the DCA configs (just like rope_scaling to enable YARN) and passing extra arguments to the dual chunk attention layer. You can review the Qwen2Model and Qwen2MoeModel for these changes.

@hzhwcmhf
Copy link
Contributor Author

hzhwcmhf commented Jul 8, 2024

Expect to the implementation of CUDA kernel😊

@nanmi I'm actually seeking helps from the community to implement the CUDA kernel. I'm not an expert. 😅

@robertgshaw2-redhat
Copy link
Collaborator

Is DualChunkAttention specific to Qwen?
If not, I would strongly prefer that the logic for this feature be implemented inside of Attention rather than inside the Qwen model files so that it can be shared across models. I have recently been spending a lot of time unwinding a lot of work associated with MoE utilities (including fp8) being too tightly integrated into the Mixtral implementation and adding these types of features only on specific models can become very hard to maintain

@robertgshaw2-neuralmagic No, it can be applied to all most models using RoPE. Actually, the core parts of DCA are implemented in vllm/attention/* and vllm/model_executor/layers/rotary_embedding.py. However, I think the model files have to be modified for passing the DCA configs (just like rope_scaling to enable YARN) and passing extra arguments to the dual chunk attention layer. You can review the Qwen2Model and Qwen2MoeModel for these changes.

For instance, we have the following code:

if dual_chunk_attention_config is not None:
    self.attn = DualChunkAttention(
        self.num_heads,
        self.head_dim,
        self.scaling,
        num_kv_heads=self.num_kv_heads,
        cache_config=cache_config,
        quant_config=quant_config,
        dual_chunk_attention_config=dual_chunk_attention_config)
else:
    self.attn = Attention(self.num_heads,
        self.head_dim,
        self.scaling,
        num_kv_heads=self.num_kv_heads,
        cache_config=cache_config,
        quant_config=quant_config)

And:

if self.dual_chunk_attention_config is None:
    q, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
else:
    q, q_succ, q_inter, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, q_succ, q_inter, k, v, kv_cache, attn_metadata)

If we could push this branching into Attention rather than having the branching in Qwen, it would be much easier to support this with other models

@hzhwcmhf
Copy link
Contributor Author

hzhwcmhf commented Jul 9, 2024

@robertgshaw2-neuralmagic DCA requires three queries (query, query_succ, query_inter) produced by the rotary embedding layer. They cannot be computed in Attention.
A possible solution is stacking the three queries into one, and this will change the shape of query. Is that OK? @WoosukKwon Do you have any idea?

@nanmi
Copy link

nanmi commented Jul 19, 2024

@robertgshaw2-neuralmagicDCA 需要旋转嵌入层生成的三个查询(queryquery_succquery_inter)。它们无法在 中计算Attention。 一种可能的解决方案是将三个查询堆叠为一个,这将改变 的形状query。这样可以吗?@WoosukKwon你有什么主意吗?

/

@robertgshaw2-neuralmagic DCA requires three queries (query, query_succ, query_inter) produced by the rotary embedding layer. They cannot be computed in Attention. A possible solution is stacking the three queries into one, and this will change the shape of query. Is that OK? @WoosukKwon Do you have any idea?

Can you elaborate on your ideas? I implemented DCA based on the python code logic and CUDA C++, but I think it can be more elegant and efficient logically. I am worried about how to optimize the segmentation processing and the final merging.

@hzhwcmhf
Copy link
Contributor Author

@nanmi The functions _bruteforce_dynamic_chunk_flash_attn_varlen_func for prefill and _bruteforce_dynamic_chunk_pageattention_forward_decode for decoding could indeed be optimized through CUDA kernel implementations similar to Flash Attention.

Taking _bruteforce_dynamic_chunk_flash_attn_varlen_func as an example, it accepts two extra parameters (q_succ and q_inter) and exhibits minor deviations from flash_attn_varlen_func. The primary adjustment involves substituting q with either q_succ or q_inter based on the relative position of the query with respect to the key. Specifically, if the query and key reside within the same chunk, no change is made; for queries located in the succeeding chunk, q_succ is utilized; whereas for those in other chunks, q_inter is utilized.

@nanmi
Copy link

nanmi commented Jul 22, 2024

@nanmi The functions _bruteforce_dynamic_chunk_flash_attn_varlen_func for prefill and _bruteforce_dynamic_chunk_pageattention_forward_decode for decoding could indeed be optimized through CUDA kernel implementations similar to Flash Attention.

Taking _bruteforce_dynamic_chunk_flash_attn_varlen_func as an example, it accepts two extra parameters (q_succ and q_inter) and exhibits minor deviations from flash_attn_varlen_func. The primary adjustment involves substituting q with either q_succ or q_inter based on the relative position of the query with respect to the key. Specifically, if the query and key reside within the same chunk, no change is made; for queries located in the succeeding chunk, q_succ is utilized; whereas for those in other chunks, q_inter is utilized.

I'll follow your idea and think about how to use CUDA to implement

@hzhwcmhf
Copy link
Contributor Author

@nanmi Really appreciate your help! If you encounter any questions, feel free to ask here

@@ -682,6 +699,8 @@ def _prepare_model_input_tensors(
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
prefill_original_seq_lens_tensor = torch.tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to push the creation of this tensor into attn_backend.make_metadata?

self.kv_cache_dtype,
self.block_size,
) if num_attn_heads else None
if getattr(self.model_config.hf_config, "dual_chunk_attention_config",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add dual_chunk_attention to ModelConfig to accessing the hf_config from here. We ran into problems with accessing the hf_config with Sliding Window in the past

) if num_attn_heads else None
if getattr(self.model_config.hf_config, "dual_chunk_attention_config",
None):
self.attn_backend = get_dual_chunk_attn_backend(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this has to be a different function vs passing a self.model_config.get_dual_chunk_attention to get_attn_backend

@robertgshaw2-redhat
Copy link
Collaborator

Apologies for the delay in reviewing here

@simon-mo @WoosukKwon and I discussed, and while we do not love implementing features in specific models, I think that we do not have the right abstractions in place yet to enable simultaneously modifying rotary_embedding and attention. However, given the timeline of this PR, we are okay to land this with the changes made directly in the Qwen models. We should consider a new design that will enable this feature in other models, however, such as using a shared function like rotary_and_attention which handles the branching logic for DCA. This is similar to how we have used the fused_moe shared function in the past.

I left a few comments on model_runner.py to try to reduce changes to that layer. Ideally, we can avoid needing to break the abstractions and have little modification

I think that this PR also needs some tests:

  • unit tests for the rotary_embedding and attention implementations
  • end-to-end integration test with a real model

seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len
max_seq_len_inter = seq_lens_inter.max().item()
if max_seq_len_inter:
inter_output, succ_softmax_lse = (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be inter_softmax_lse instead of succ_softmax_lse

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out the typo. Will revise in the next commit.

chunk_len = chunk_size - local_size
if chunk_len % block_size != 0:
raise ValueError("chunk_len must be divisible by block_size.")
chunk_num_curr = (cache_seqlens - 1) // chunk_len

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line makes the actual length of intra-chunk smaller than chunk_len in the variable length situation. It differs from the algorithm in the prefill stage with a fixed length of chunk_len. Could you explain the reason for this difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rejoicesyc The implementation is consistent with the original code (https://github.com/HKUNLP/ChunkLlama/blob/8a28f1464b2a5def03eb07e8d91f5bf4d00f667d/chunkllama_attn_replace.py#L167). See L167, L171, L173, L197.

Since the sequence length may not be divisible by chunk_len, there has to be a chunk that has a smaller length. In DCA algorithm, the last chunk (i.e., the intra chunk) may have length <= chunk_len while the previous chunk has a fixed length of chunk_len.

1.0).clip(min=1)
query = (query * mscale.view(-1, 1, 1, 1)).to(
query.dtype
) # possible for numerical issue, need to fused in the kernel

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @hzhwcmhf, could you explain this numerical issue and why kernel fusion can solve it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rejoicesyc We want to compute (query @ key) * softmax_scale * mscale as the attention weight, where query is represented in float16 or bfloat16, and mscale is represented in float32. Better precision can be achieved if the scaling is computed in float32, rather than multiplying query by mscale before the attention operation. The flash attention kernel provides an argument softmax_scale, which is close to our requirement. However, softmax_scale is a constant for the entire query, whereas mscale is a vector that specifies different softmax_scale values for each query vector at different positions.

Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hzhwcmhf.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Feb 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase stale Over 90 days of inactivity
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants