-
Notifications
You must be signed in to change notification settings - Fork 257
KleidiAI int4 kernels not loading properly on aarch64 Linux #2143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc @metascroy can you take a look? |
Thanks @vctrmn for opening an issue. The TLDR is that we have not enabled lowbit kernels on non-Mac platforms, so this is a feature request. If you want to help us add ARM linux support, I don't think it will be too bad:
From the ao directory, you should be able to do: "BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 pip install ."
If that libary is not present, nothing will be loaded and you'll see missing ops. If the library is not found, try to find out why: what inside this block is not running during install on linux (https://github.com/pytorch/ao/blob/main/setup.py#L503-L534).
|
Hi @metascroy I have successfully implemented the changes you suggested and got the kleidiai int4 kernels working on my ARM Ubuntu instance, here what I have done #2169 FYI : I have tested with Llama-3.2-1B-Instruct model and confirmed I am getting a 1.34x speedup compared to the standard pytorch version, that is great !!! (this is a not a proper benchmark but it gives an initial idea) setting up
demo.pyimport copy
import time
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BatchEncoding,
LlamaForCausalLM,
LlamaTokenizer,
)
import torchao
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
MappingType,
Target,
quantize_,
)
from torchao.quantization.granularity import PerGroup
model_id = "meta-llama/Llama-3.2-1B-Instruct"
def load_model_and_tokenizer() -> tuple[LlamaTokenizer, LlamaForCausalLM]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
return tokenizer, model
def main() -> None:
print(f"\ntorch v{torch.__version__}")
print(f"torchao v{torchao.__version__}")
print("Loading tokenizer and model ...")
tokenizer, model = load_model_and_tokenizer()
print(f"tokenizer and model loaded on {model.device}")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "Can you explain quantum computing in simple terms?",
},
]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
chat_input: BatchEncoding = tokenizer(formatted_prompt, return_tensors="pt").to(
model.device,
)
# No optim inference
print("\n--- Running standard inference ---")
start_time = time.time()
with torch.no_grad():
chat_outputs = model.generate(**chat_input, max_new_tokens=100)
end_time = time.time()
inference_time = end_time - start_time
# Decode the generated tokens, excluding the prompt
input_ids: torch.Tensor = chat_input["input_ids"]
prompt_length: int = input_ids.shape[1]
response = tokenizer.decode(
chat_outputs[0][prompt_length:],
skip_special_tokens=True,
)
print(
f"----------------------------------\n{response}\n----------------------------------",
)
print(f"Inference time: {inference_time:.2f} seconds")
# KleidiAI optim
print("\n--- Attempting KleidiAI optimization ---")
quantized_model = copy.deepcopy(model)
quantize_(
quantized_model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
weight_mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.ASYMMETRIC,
weight_scale_dtype=torch.bfloat16,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
target=Target.KLEIDIAI
),
),
)
start_time = time.time()
with torch.no_grad():
chat_outputs_quantized = quantized_model.generate(
**chat_input, max_new_tokens=100
)
end_time = time.time()
inference_time_quantized = end_time - start_time
response_quantized = tokenizer.decode(
chat_outputs_quantized[0][prompt_length:],
skip_special_tokens=True,
)
print(
f"----------------------------------\n{response_quantized}\n----------------------------------",
)
print(f"Quantized inference time: {inference_time_quantized:.2f} seconds")
print(f"Speedup: {inference_time / inference_time_quantized:.2f}x")
# Print speedups
if inference_time > 0:
print(
f"Speedup with KleidiAI optimization: {inference_time / inference_time_quantized:.2f}x",
)
if __name__ == "__main__":
main() |
That's great! I'll take a look at your PR. Btw, you should also be able to specify Target.AUTO. And you can see what kernel is selected if you define TORCH_CPP_LOG_LEVEL=INFO before running your script. |
Hello there !
I have been trying to follow your instructions to get KleidiAI int4 kernels working on a Scaleway ARM instance (4x16), but I'm still encountering issues.
I've done the following:
/usr/local/lib/libkleidiai.a
)USE_CPP=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 pip install .
However, when I try to run code that uses the KleidiAI kernels, I get this error:
My CPU definitely has the required ARM features (verified with
/proc/cpuinfo
):Including: asimd (NEON), asimddp (Dot Product), etc.
I'm particularly interested in the optimizations mentioned in the recently merged PR #2000 , which added the new KleidiAI kernels for ARM NEON dotprod.
When I run the KleidiAI benchmark, I can see:
So it seems like the KleidiAI kernels themselves are working, but for some reason the
_pack_8bit_act_4bit_weight
operator isn't being registered properly in torchao.Is there something specific I need to do to get the _pack_8bit_act_4bit_weight operator registered? Are there any diagnostic steps I can take to debug this further?
#1721 (comment)
The text was updated successfully, but these errors were encountered: