-
Notifications
You must be signed in to change notification settings - Fork 257
Add support for KleidiAI int4 kernels on aarch64 Linux #2169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2169
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 6 Pending, 2 Unrelated FailuresAs of commit 9575890 with merge base 94e2e05 ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
In my testing environment, I manually created a symbolic link to
Here it is : apt install libomp-dev -y
ln -s /usr/lib/llvm-*/lib/libomp.so [path-to-virtualenv]/lib/python*/site-packages/torch/lib/libomp.so setting up
demo.pyimport copy
import time
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BatchEncoding,
LlamaForCausalLM,
LlamaTokenizer,
)
import torchao
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
MappingType,
Target,
quantize_,
)
from torchao.quantization.granularity import PerGroup
model_id = "meta-llama/Llama-3.2-1B-Instruct"
def load_model_and_tokenizer() -> tuple[LlamaTokenizer, LlamaForCausalLM]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
return tokenizer, model
def main() -> None:
print(f"\ntorch v{torch.__version__}")
print(f"torchao v{torchao.__version__}")
print("Loading tokenizer and model ...")
tokenizer, model = load_model_and_tokenizer()
print(f"tokenizer and model loaded on {model.device}")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "Can you explain quantum computing in simple terms?",
},
]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
chat_input: BatchEncoding = tokenizer(formatted_prompt, return_tensors="pt").to(
model.device,
)
# No optim inference
print("\n--- Running standard inference ---")
start_time = time.time()
with torch.no_grad():
chat_outputs = model.generate(**chat_input, max_new_tokens=100)
end_time = time.time()
inference_time = end_time - start_time
# Decode the generated tokens, excluding the prompt
input_ids: torch.Tensor = chat_input["input_ids"]
prompt_length: int = input_ids.shape[1]
response = tokenizer.decode(
chat_outputs[0][prompt_length:],
skip_special_tokens=True,
)
print(
f"----------------------------------\n{response}\n----------------------------------",
)
print(f"Inference time: {inference_time:.2f} seconds")
# KleidiAI optim
print("\n--- Attempting KleidiAI optimization ---")
quantized_model = copy.deepcopy(model)
quantize_(
quantized_model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
weight_mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.ASYMMETRIC,
weight_scale_dtype=torch.bfloat16,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
target=Target.KLEIDIAI
),
),
)
start_time = time.time()
with torch.no_grad():
chat_outputs_quantized = quantized_model.generate(
**chat_input, max_new_tokens=100
)
end_time = time.time()
inference_time_quantized = end_time - start_time
response_quantized = tokenizer.decode(
chat_outputs_quantized[0][prompt_length:],
skip_special_tokens=True,
)
print(
f"----------------------------------\n{response_quantized}\n----------------------------------",
)
print(f"Quantized inference time: {inference_time_quantized:.2f} seconds")
print(f"Speedup: {inference_time / inference_time_quantized:.2f}x")
# Print speedups
if inference_time > 0:
print(
f"Speedup with KleidiAI optimization: {inference_time / inference_time_quantized:.2f}x",
)
if __name__ == "__main__":
main() |
On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT? There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though. |
Is libomp not bundled with torch of linux? Is there nothing in site-packages/torch/lib/libomp.so? If so, does setting TORCHAO_PARALLEL_BACKEND=OPENMP fix the issue without doing the manual link? |
Thanks for the PR @vctrmn! It mostly looks good, but let's guard on neon dot flag and resolve the compile issue in CI. |
@metascroy I confirm that
Looking into the directories :
Would you recommend adding logic in the |
I had the same compile error in my ubuntu instance. I solved it by installing the nightly version of torch : I will take a look at this |
Hi @metascroy ! I feel that I need to clarify torchao building options before diving into Correct me, for ARM64 platforms:
When
Also, an important note is that on Linux aarch64, we need |
Yes, I think that summarizes the required changes. |
No, I don't think we should create a symlink in CMakeLists.txt. When you set, did you see a message about "Building with TORCHAO_PARALLEL_BACKEND=OPENMP" during compilation: https://www.internalfb.com/code/fbsource/[c58edd4c506c386869789e9db5aa191ad34a3742]/fbcode/pytorch/ao/torchao/experimental/Utils.cmake?lines=35. When TORCHAO_PARALLEL_BACKEND is set to OPENMP, it shouldn't need to link against OMP in PyTorch, so that is what I'm curious about. It's an argument on the cmakelists, so we might need to define it with -DTORCHAO_PARALLEL_BACKEND=OPENMP on linux (https://www.internalfb.com/code/fbsource/[c58edd4c506c386869789e9db5aa191ad34a3742]/fbcode/pytorch/ao/torchao/experimental/CMakeLists.txt?lines=27-29) |
Thank you @metascroy for the feedback! I believe the PR is ready now with all the requested changes:
The new build command is now: BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install . In my testing, using OPENMP, the quantized inference time is around 7.5 seconds, which represents a 2x speedup !!! |
On discussion point 1: Looking at GCC doc (https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#index-march): Perhaps we should create a new argument in
|
What do you think @digantdesai ? |
Looking good! Let's wait for CI + to see if @digantdesai has any feedback on various arm versions. |
It looks like cpuinfo_uarch_zen5 is still failing in CI. Let me try checking out your PR on my mac and see if I can debug. |
This PR adds support for using KleidiAI int4 kernels on aarch64 Linux systems. Previously, these kernels were only enabled on macOS ARM platforms, but with these changes, they can be properly built and loaded on any ARM64 Linux system with the appropriate features (NEON, dot product, etc.).
Changes
setup.py
to allow explicit building of arm kernels via theBUILD_TORCHAO_CPU
environment variableop_lib.py
to search in multiple potential installation pathsCMakeLists.txt
How to build
Users can build torchao with KleidiAI support on aarch64 Linux using:
Testing
Scaleway Ubuntu VM (4 CPU x 16 GB RAM - COPARM1-4C-16G)
Discussion Points
-march=armv8.4-a+dotprod
inthe CMakeLists.txt
. While this works for my ubuntu vm, we may want to implement a more flexible solution that detects the specific ARM features available ?Related issue
#2143