This repository implements the sparse attention mechanism introduced in the paper Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention and provides an efficient training implementation based on Triton.
🎉 We now support both training and inference for Native Sparse Attention (variable-length version, including prefilling, decoding, and KV cache management). We have provided a toy model at model.ToyNSALlama
, which supports forward
function for training and generate
function for inference. Welcome to try it out!
Ensure the following dependencies are installed:
- PyTorch >= 2.1.0
- triton >= 3.0.0
- einops >= 0.7.0
- flash_attn >= 2.6.3
- PyTorch implementations (
ops.torch
) are intended for debugging only. - For production use, prefer Triton operators (
ops.triton
). - All implementations are based on the varlen approach similiar to flash_attn_func_varlen. Please concatenate the inputs of a batch before use.
- Only support attention head dimension less than 128 for now.
You can install native_sparse_attention
using pip:
pip install git+https://git.1-hub.cnXunhaoLai/native-sparse-attention-triton.git
The ops
module has implemented several functions required for native sparse attention. For detailed usage instructions, please see this link.
You can import those functions from the ops
module:
import torch
from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention
# input example
num_q_heads = 64
num_kv_heads = 4
head_dim = 128
kernel_size = 32
kernel_stride = 16
block_size = 64
topk = 16
cu_seqlens = torch.Tensor([0, 1024, 8192, 16384]).to(torch.int32).cuda()
query = torch.randn(16384, num_q_heads, head_dim).to(torch.bfloat16).cuda()
key = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
value = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
# weight example
w = (
torch.randn(num_kv_heads, kernel_size * head_dim, head_dim)
.to(torch.bfloat16)
.cuda()
)
pe = torch.randn(num_kv_heads, kernel_size, head_dim).to(torch.bfloat16).cuda()
# 1. key value compression
compressed_key, compressed_cu_seqlens = linear_compress(
key, w, cu_seqlens, kernel_size, kernel_stride, pe
)
compressed_value, _ = linear_compress(
value, w, cu_seqlens, kernel_size, kernel_stride, None
)
# 2. attention between query and compressed key value
compressed_attn_output, topk_idx = compressed_attention(
query,
compressed_key,
compressed_value,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
compressed_cu_seqlens,
init_blocks=1,
local_blocks=2,
)
# 3. topk sparse attention
sparse_attn_output = topk_sparse_attention(
query,
key,
value,
topk_idx,
block_size,
cu_seqlens,
)
The modules
directory also provides implementations based on torch.nn.module
for easy integration into models.
from native_sparse_attention.modules import NativeSparseAttention, RopeConfig
NSA_Layer = NativeSparseAttention(
compress_type="linear",
hidden_size=4096,
num_q_heads=64,
num_kv_heads=4,
head_dim=128,
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
rope_config=RopeConfig(
max_position_embeddings=32768,
head_dim=128,
rope_theta=500000,
rope_scaling={
"factor": 4.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
),
)
We offer two simplified LLaMA models in the model
directory, featuring self-attention and native sparse attention. For more details on how to use these models, please refer to this link.
from native_sparse_attention.model import ToyNSALlamaConfig, InferenceConfig, ToyNSALlama
config = ToyNSALlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
compress_type="weightedpool",
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
)
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
model = ToyNSALlama(config, inference_config).cuda().bfloat16()
Some test scripts are available in the test
folder and can be run directly for unit testing. For example:
python test/test_topk_sparse_attention.py
python test/test_nsa_module.py
python test/test_nsa_model.py
Here are the speed benchmarks conducted on a single NVIDIA A100 GPU or H100 GPU for the topk_sparse_attention
function:
A100 GPU speed benchmarks:
** forward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.414144 0.635648 0.633440 1.009184
1 4096.0 1.400304 2.267552 1.179808 1.916736
2 8192.0 5.223776 8.528160 2.266816 3.723168
3 16384.0 20.225697 32.745537 4.468128 7.359168
4 32768.0 79.587715 128.951065 8.517440 14.142848
5 65536.0 321.240479 511.652100 17.249599 30.991360
6 131072.0 1349.810425 2063.245605 36.400482 67.884544
** backward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 1.315440 2.348560 1.941568 2.691040
1 4096.0 4.271584 8.553184 3.647744 5.032160
2 8192.0 15.323984 32.665440 5.650144 9.066112
3 16384.0 58.753281 127.675964 11.160832 17.113279
4 32768.0 227.770462 504.572693 21.723392 34.715614
5 65536.0 899.181274 2059.718506 44.517181 76.309441
6 131072.0 3587.918701 8530.726562 105.344734 182.970169
H100 GPU benchmarks:
** forward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.259552 0.293888 0.584544 0.917664
1 4096.0 0.846848 1.029904 1.094976 1.745136
2 8192.0 3.043744 3.843392 2.128256 3.396880
3 16384.0 11.743568 14.791360 4.190528 6.704192
4 32768.0 45.968513 57.532478 7.614496 12.417440
5 65536.0 187.234375 228.093948 14.840048 24.511856
6 131072.0 810.890381 914.693970 29.470400 48.990192
** backward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.798976 1.096096 1.117312 1.380016
1 4096.0 2.545680 3.826336 1.669760 2.214880
2 8192.0 9.029760 14.411633 2.772096 3.947456
3 16384.0 34.144016 58.945698 5.201344 7.538912
4 32768.0 135.718369 233.369247 9.968864 15.154192
5 65536.0 541.053894 929.337646 21.089870 33.818878
6 131072.0 2139.974854 3785.540527 54.918144 93.750717
Here comes another speed benchmark result for testing compressed_attention
function on a single NVIDIA A100 GPU or H100 GPU:
A100 GPU speed benchmarks:
** forward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed Compressed-wo-Score
0 2048.0 0.413664 0.635488 0.655024 0.170816
1 4096.0 1.396416 2.247648 1.132304 0.377152
2 8192.0 5.234656 8.526400 2.879200 0.977952
3 16384.0 19.988865 32.755199 9.426448 2.943024
4 32768.0 79.419907 128.955170 30.284096 9.901120
5 65536.0 321.590210 511.615509 112.260544 36.001602
6 131072.0 1346.996338 2069.837891 423.099518 136.820038
** backward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed
0 2048.0 1.322560 2.352000 0.486784
1 4096.0 4.270832 8.552608 0.971392
2 8192.0 15.515680 32.671329 2.603744
3 16384.0 59.345055 128.377472 8.499456
4 32768.0 230.626144 506.581238 30.064833
5 65536.0 919.260498 2068.642578 113.466560
6 131072.0 3646.603760 8498.374023 439.623444
H100 GPU speed benchmarks:
** forward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed Compressed-wo-Score
0 2048.0 0.259488 0.297152 0.485920 0.103232
1 4096.0 0.847376 1.030400 0.710208 0.217760
2 8192.0 3.044016 3.875840 1.607360 0.516016
3 16384.0 11.823104 14.829360 4.970272 1.440288
4 32768.0 46.204750 57.527809 15.004992 4.584736
5 65536.0 187.324249 227.909958 53.009087 16.134224
6 131072.0 810.707214 910.106873 191.245728 60.154270
** backward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed
0 2048.0 0.797728 1.090640 0.283104
1 4096.0 2.547088 3.834592 0.550464
2 8192.0 9.021520 14.421088 1.249184
3 16384.0 34.159508 58.793377 3.743440
4 32768.0 136.844070 233.447708 12.640032
5 65536.0 537.559814 929.360229 46.054817
6 131072.0 2135.629883 3782.351562 175.587296
All the speed benchmarks above were tested with 64 query heads, 4 key/value heads, and a head dimension of 128.
Contributions are welcome! Please open an issue to discuss major changes.
For any questions or feedback, please feel free to contact laixunhao@pku.edu.cn.
@inproceedings{Yuan2025NativeSA,
title = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention},
author = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:276408911}
}