Skip to content

Add support for KleidiAI int4 kernels on aarch64 Linux #2169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

vctrmn
Copy link

@vctrmn vctrmn commented May 4, 2025

This PR adds support for using KleidiAI int4 kernels on aarch64 Linux systems. Previously, these kernels were only enabled on macOS ARM platforms, but with these changes, they can be properly built and loaded on any ARM64 Linux system with the appropriate features (NEON, dot product, etc.).

Changes

  • Modified setup.py to allow explicit building of arm kernels via the BUILD_TORCHAO_CPU environment variable
  • Updated library detection in op_lib.py to search in multiple potential installation paths
  • Fixed compiler warnings
  • Added appropriate compiler flags for aarch64 in CMakeLists.txt

How to build

Users can build torchao with KleidiAI support on aarch64 Linux using:

BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 pip install .

Testing

Scaleway Ubuntu VM (4 CPU x 16 GB RAM - COPARM1-4C-16G)

~# lscpu
Architecture:             aarch64
  CPU op-mode(s):         32-bit, 64-bit
  Byte Order:             Little Endian
CPU(s):                   4
  On-line CPU(s) list:    0-3
Vendor ID:                ARM
  BIOS Vendor ID:         QEMU
  Model name:             Neoverse-N1
    BIOS Model name:      virt-6.2  CPU @ 2.0GHz
    BIOS CPU family:      1
    Model:                1
    Thread(s) per core:   1
    Core(s) per socket:   4
    Socket(s):            1
    Stepping:             r3p1
    BogoMIPS:             50.00
    Flags:                fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop asimddp ssbs

Discussion Points

  • I have a doubt about fixing the architecture flags to -march=armv8.4-a+dotprod in the CMakeLists.txt. While this works for my ubuntu vm, we may want to implement a more flexible solution that detects the specific ARM features available ?
  • This is a quick implementation to get things working, but you might want to discuss the right way to implement and use build env variables in setup.py for a more robust solution ?

Related issue

#2143

Copy link

pytorch-bot bot commented May 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2169

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 6 Pending, 2 Unrelated Failures

As of commit 9575890 with merge base 94e2e05 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 4, 2025
@vctrmn
Copy link
Author

vctrmn commented May 4, 2025

In my testing environment, I manually created a symbolic link to libomp.so to prevent error :

      [ 48%] Built target torchao_ops_linear_8bit_act_xbit_weight_aten
      gmake[2]: *** No rule to make target '/root/ao/venv/lib/python3.12/site-packages/torch/lib/libomp.so', needed by '/root/ao/build/lib.linux-aarch64-cpython-312/torchao/libtorchao_ops_aten.so'.  Stop.
      gmake[1]: *** [CMakeFiles/Makefile2:287: CMakeFiles/torchao_ops_aten.dir/all] Error 2
      gmake: *** [Makefile:136: all] Error 2
      Traceback (most recent call last):

Here it is :

apt install libomp-dev -y
ln -s /usr/lib/llvm-*/lib/libomp.so [path-to-virtualenv]/lib/python*/site-packages/torch/lib/libomp.so
setting up
apt update
apt install gcc g++ cmake ninja-build build-essential python3-pip python3-venv google-perftools -y

git clone https://github.com/vctrmn/ao.git
cd ao

python3 -m venv venv
source venv/bin/activate
pip install wheel setuptools
pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install numpy

apt install libomp-dev -y
ln -s /usr/lib/llvm-18/lib/libomp.so /root/ao/venv/lib/python3.12/site-packages/torch/lib/libomp.so

BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 pip install .

pip install transformers
huggingface-cli login
python demo.py
demo.py
import copy
import time

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BatchEncoding,
    LlamaForCausalLM,
    LlamaTokenizer,
)

import torchao
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
    PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.quantization.quant_api import (
    Int8DynamicActivationIntxWeightConfig,
    MappingType,
    Target,
    quantize_,
)
from torchao.quantization.granularity import PerGroup

model_id = "meta-llama/Llama-3.2-1B-Instruct"


def load_model_and_tokenizer() -> tuple[LlamaTokenizer, LlamaForCausalLM]:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    return tokenizer, model


def main() -> None:
    print(f"\ntorch v{torch.__version__}")
    print(f"torchao v{torchao.__version__}")

    print("Loading tokenizer and model ...")
    tokenizer, model = load_model_and_tokenizer()
    print(f"tokenizer and model loaded on {model.device}")

    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {
            "role": "user",
            "content": "Can you explain quantum computing in simple terms?",
        },
    ]

    formatted_prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    chat_input: BatchEncoding = tokenizer(formatted_prompt, return_tensors="pt").to(
        model.device,
    )

    # No optim inference
    print("\n--- Running standard inference ---")
    start_time = time.time()

    with torch.no_grad():
        chat_outputs = model.generate(**chat_input, max_new_tokens=100)

    end_time = time.time()
    inference_time = end_time - start_time

    # Decode the generated tokens, excluding the prompt
    input_ids: torch.Tensor = chat_input["input_ids"]
    prompt_length: int = input_ids.shape[1]
    response = tokenizer.decode(
        chat_outputs[0][prompt_length:],
        skip_special_tokens=True,
    )

    print(
        f"----------------------------------\n{response}\n----------------------------------",
    )
    print(f"Inference time: {inference_time:.2f} seconds")

    # KleidiAI optim
    print("\n--- Attempting KleidiAI optimization ---")
    quantized_model = copy.deepcopy(model)

    quantize_(
        quantized_model,
        Int8DynamicActivationIntxWeightConfig(
            weight_dtype=torch.int4,
            weight_granularity=PerGroup(32),
            weight_mapping_type=MappingType.SYMMETRIC,
            act_mapping_type=MappingType.ASYMMETRIC,
            weight_scale_dtype=torch.bfloat16,
            layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
                target=Target.KLEIDIAI
            ),
        ),
    )

    start_time = time.time()

    with torch.no_grad():
        chat_outputs_quantized = quantized_model.generate(
            **chat_input, max_new_tokens=100
        )

    end_time = time.time()
    inference_time_quantized = end_time - start_time

    response_quantized = tokenizer.decode(
        chat_outputs_quantized[0][prompt_length:],
        skip_special_tokens=True,
    )

    print(
        f"----------------------------------\n{response_quantized}\n----------------------------------",
    )
    print(f"Quantized inference time: {inference_time_quantized:.2f} seconds")
    print(f"Speedup: {inference_time / inference_time_quantized:.2f}x")

    # Print speedups
    if inference_time > 0:
        print(
            f"Speedup with KleidiAI optimization: {inference_time / inference_time_quantized:.2f}x",
        )


if __name__ == "__main__":
    main()

@metascroy
Copy link
Contributor

On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT?

There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though.

@metascroy
Copy link
Contributor

metascroy commented May 6, 2025

In my testing environment, I manually created a symbolic link to libomp.so to prevent error :

Is libomp not bundled with torch of linux? Is there nothing in site-packages/torch/lib/libomp.so?

If so, does setting TORCHAO_PARALLEL_BACKEND=OPENMP fix the issue without doing the manual link?

@metascroy
Copy link
Contributor

Thanks for the PR @vctrmn! It mostly looks good, but let's guard on neon dot flag and resolve the compile issue in CI.

@vctrmn
Copy link
Author

vctrmn commented May 6, 2025

In my testing environment, I manually created a symbolic link to libomp.so to prevent error :

Is libomp not bundled with torch of linux? Is there nothing in site-packages/torch/lib/libomp.so?

If so, does setting TORCHAO_PARALLEL_BACKEND=OPENMP fix the issue without doing the manual link?

@metascroy I confirm that libomp.so is indeed not bundled with torch in my aarch64 ubuntu instance, and setting TORCHAO_PARALLEL_BACKEND=OPENMP doesn't solve the issue.

~/ao# BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .
...
      [ 48%] Built target torchao_ops_linear_8bit_act_xbit_weight_aten
      gmake[2]: *** No rule to make target '/root/ao/venv/lib/python3.12/site-packages/torch/lib/libomp.so', needed by '/root/ao/build/lib.linux-aarch64-cpython-312/torchao/libtorchao_ops_aten.so'.  Stop.
      gmake[1]: *** [CMakeFiles/Makefile2:287: CMakeFiles/torchao_ops_aten.dir/all] Error 2
      gmake: *** [Makefile:136: all] Error 2

Looking into the directories :

~/ao# ls /root/ao/venv/lib/python3.12/site-packages/torch/lib/
libc10.so  libshm  libshm.so  libshm_windows  libtorch.so  libtorch_cpu.so  libtorch_global_deps.so  libtorch_python.so

~/ao# ls /root/ao/venv/lib/python3.12/site-packages/torch.libs/
libarm_compute-d924ca35.so  libarm_compute_graph-17c2200a.so  libgfortran-8a9a71bc.so.5.0.0  libgomp-947d5fa1.so.1.0.0  libopenblasp-r0-0d78ce56.3.29.so

Would you recommend adding logic in the CMakeLists.txt to automatically handle this (finding libomp.so + creating the symlink) ? Or should this be addressed at the pytorch repo level ?

@vctrmn
Copy link
Author

vctrmn commented May 6, 2025

On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT?

There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though.

I had the same compile error in my ubuntu instance. I solved it by installing the nightly version of torch : pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu

I will take a look at this

@vctrmn
Copy link
Author

vctrmn commented May 6, 2025

Hi @metascroy ! I feel that I need to clarify torchao building options before diving into setup.py and CMakeLists.txt

Correct me, for ARM64 platforms:

  • On macOS: TORCHAO_BUILD_CPU_AARCH64 is enabled by default
  • On Linux: TORCHAO_BUILD_CPU_AARCH64 must be explicitly enabled

When TORCHAO_BUILD_CPU_AARCH64 is enabled, the following additional options become available:

  • TORCHAO_BUILD_KLEIDIAI (should be explicit in args)
  • TORCHAO_ENABLE_ARM_NEON_DOT (this option is guard to enable -march=armv8.4-a+dotprod on Linux, but should be enabled by default on macOS)

Also, an important note is that on Linux aarch64, we need BUILD_TORCHAO_CPU=1

@metascroy
Copy link
Contributor

Hi @metascroy ! I feel that I need to clarify torchao building options before diving into setup.py and CMakeLists.txt

Correct me, for ARM64 platforms:

  • On macOS: TORCHAO_BUILD_CPU_AARCH64 is enabled by default
  • On Linux: TORCHAO_BUILD_CPU_AARCH64 must be explicitly enabled

When TORCHAO_BUILD_CPU_AARCH64 is enabled, the following additional options become available:

  • TORCHAO_BUILD_KLEIDIAI (should be explicit in args)
  • TORCHAO_ENABLE_ARM_NEON_DOT (this option is guard to enable -march=armv8.4-a+dotprod on Linux, but should be enabled by default on macOS)

Also, an important note is that on Linux aarch64, we need BUILD_TORCHAO_CPU=1

Yes, I think that summarizes the required changes.

@metascroy
Copy link
Contributor

In my testing environment, I manually created a symbolic link to libomp.so to prevent error :

Is libomp not bundled with torch of linux? Is there nothing in site-packages/torch/lib/libomp.so?
If so, does setting TORCHAO_PARALLEL_BACKEND=OPENMP fix the issue without doing the manual link?

@metascroy I confirm that libomp.so is indeed not bundled with torch in my aarch64 ubuntu instance, and setting TORCHAO_PARALLEL_BACKEND=OPENMP doesn't solve the issue.

~/ao# BUILD_TORCHAO_CPU=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .
...
      [ 48%] Built target torchao_ops_linear_8bit_act_xbit_weight_aten
      gmake[2]: *** No rule to make target '/root/ao/venv/lib/python3.12/site-packages/torch/lib/libomp.so', needed by '/root/ao/build/lib.linux-aarch64-cpython-312/torchao/libtorchao_ops_aten.so'.  Stop.
      gmake[1]: *** [CMakeFiles/Makefile2:287: CMakeFiles/torchao_ops_aten.dir/all] Error 2
      gmake: *** [Makefile:136: all] Error 2

Looking into the directories :

~/ao# ls /root/ao/venv/lib/python3.12/site-packages/torch/lib/
libc10.so  libshm  libshm.so  libshm_windows  libtorch.so  libtorch_cpu.so  libtorch_global_deps.so  libtorch_python.so

~/ao# ls /root/ao/venv/lib/python3.12/site-packages/torch.libs/
libarm_compute-d924ca35.so  libarm_compute_graph-17c2200a.so  libgfortran-8a9a71bc.so.5.0.0  libgomp-947d5fa1.so.1.0.0  libopenblasp-r0-0d78ce56.3.29.so

Would you recommend adding logic in the CMakeLists.txt to automatically handle this (finding libomp.so + creating the symlink) ? Or should this be addressed at the pytorch repo level ?

No, I don't think we should create a symlink in CMakeLists.txt. When you set, did you see a message about "Building with TORCHAO_PARALLEL_BACKEND=OPENMP" during compilation: https://www.internalfb.com/code/fbsource/[c58edd4c506c386869789e9db5aa191ad34a3742]/fbcode/pytorch/ao/torchao/experimental/Utils.cmake?lines=35.

When TORCHAO_PARALLEL_BACKEND is set to OPENMP, it shouldn't need to link against OMP in PyTorch, so that is what I'm curious about.

It's an argument on the cmakelists, so we might need to define it with -DTORCHAO_PARALLEL_BACKEND=OPENMP on linux (https://www.internalfb.com/code/fbsource/[c58edd4c506c386869789e9db5aa191ad34a3742]/fbcode/pytorch/ao/torchao/experimental/CMakeLists.txt?lines=27-29)

@vctrmn
Copy link
Author

vctrmn commented May 7, 2025

Thank you @metascroy for the feedback! I believe the PR is ready now with all the requested changes:

  • Renamed build_torchao_experimental to build_macos_arm_auto for better clarity
  • Added BUILD_TORCHAO_EXPERIMENTAL to enable the experimental features on non-macOS platforms

The new build command is now:

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .

In my testing, using OPENMP, the quantized inference time is around 7.5 seconds, which represents a 2x speedup !!!

@vctrmn
Copy link
Author

vctrmn commented May 7, 2025

On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT?

There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though.

On discussion point 1: Looking at GCC doc (https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#index-march): Perhaps we should create a new argument in setup.py to handle the architecture specification more flexibly ?

arch value Architecture Includes by default
‘armv8-a’ Armv8-A ‘+fp’, ‘+simd’
‘armv8.1-a’ Armv8.1-A ‘armv8-a’, ‘+crc’, ‘+lse’, ‘+rdma’
‘armv8.2-a’ Armv8.2-A ‘armv8.1-a’
‘armv8.3-a’ Armv8.3-A ‘armv8.2-a’, ‘+pauth’, ‘+fcma’, ‘+jscvt’
‘armv8.4-a’ Armv8.4-A ‘armv8.3-a’, ‘+flagm’, ‘+fp16fml’, ‘+dotprod’, ‘+rcpc2’
‘armv8.5-a’ Armv8.5-A ‘armv8.4-a’, ‘+sb’, ‘+ssbs’, ‘+predres’, ‘+frintts’, ‘+flagm2’
‘armv8.6-a’ Armv8.6-A ‘armv8.5-a’, ‘+bf16’, ‘+i8mm’
‘armv8.7-a’ Armv8.7-A ‘armv8.6-a’, ‘+wfxt’, ‘+xs’
‘armv8.8-a’ Armv8.8-a ‘armv8.7-a’, ‘+mops’
‘armv8.9-a’ Armv8.9-a ‘armv8.8-a’
‘armv9-a’ Armv9-A ‘armv8.5-a’, ‘+sve’, ‘+sve2’
‘armv9.1-a’ Armv9.1-A ‘armv9-a’, ‘+bf16’, ‘+i8mm’
‘armv9.2-a’ Armv9.2-A ‘armv9.1-a’, ‘+wfxt’, ‘+xs’
‘armv9.3-a’ Armv9.3-A ‘armv9.2-a’, ‘+mops’
‘armv9.4-a’ Armv9.4-A ‘armv9.3-a’, ‘+sve2p1’
‘armv9.5-a’ Armv9.4-A ‘armv9.4-a’, ‘cpa’, ‘+faminmax’, ‘+lut’
‘armv8-r’ Armv8-R ‘armv8-r’

@metascroy
Copy link
Contributor

On discussion point 1, can we guard -march=armv8.4-a+dotprod behind TORCHAO_ENABLE_ARM_NEON_DOT?
There is a compile error in the unit test ("/Users/runner/work/ao/ao/build/temp.macosx-11.1-arm64-cpython-310/_deps/cpuinfo-src/tools/cpu-info.c:135:8: error: use of undeclared identifier 'cpuinfo_uarch_zen5'"). It's not immediately clear to me what in your changes would cause that, though.

On discussion point 1: Looking at GCC doc (https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#index-march): Perhaps we should create a new argument in setup.py to handle the architecture specification more flexibly ?

arch value Architecture Includes by default
‘armv8-a’ Armv8-A ‘+fp’, ‘+simd’
‘armv8.1-a’ Armv8.1-A ‘armv8-a’, ‘+crc’, ‘+lse’, ‘+rdma’
‘armv8.2-a’ Armv8.2-A ‘armv8.1-a’
‘armv8.3-a’ Armv8.3-A ‘armv8.2-a’, ‘+pauth’, ‘+fcma’, ‘+jscvt’
‘armv8.4-a’ Armv8.4-A ‘armv8.3-a’, ‘+flagm’, ‘+fp16fml’, ‘+dotprod’, ‘+rcpc2’
‘armv8.5-a’ Armv8.5-A ‘armv8.4-a’, ‘+sb’, ‘+ssbs’, ‘+predres’, ‘+frintts’, ‘+flagm2’
‘armv8.6-a’ Armv8.6-A ‘armv8.5-a’, ‘+bf16’, ‘+i8mm’
‘armv8.7-a’ Armv8.7-A ‘armv8.6-a’, ‘+wfxt’, ‘+xs’
‘armv8.8-a’ Armv8.8-a ‘armv8.7-a’, ‘+mops’
‘armv8.9-a’ Armv8.9-a ‘armv8.8-a’
‘armv9-a’ Armv9-A ‘armv8.5-a’, ‘+sve’, ‘+sve2’
‘armv9.1-a’ Armv9.1-A ‘armv9-a’, ‘+bf16’, ‘+i8mm’
‘armv9.2-a’ Armv9.2-A ‘armv9.1-a’, ‘+wfxt’, ‘+xs’
‘armv9.3-a’ Armv9.3-A ‘armv9.2-a’, ‘+mops’
‘armv9.4-a’ Armv9.4-A ‘armv9.3-a’, ‘+sve2p1’
‘armv9.5-a’ Armv9.4-A ‘armv9.4-a’, ‘cpa’, ‘+faminmax’, ‘+lut’
‘armv8-r’ Armv8-R ‘armv8-r’

What do you think @digantdesai ?

@metascroy
Copy link
Contributor

Thank you @metascroy for the feedback! I believe the PR is ready now with all the requested changes:

  • Renamed build_torchao_experimental to build_macos_arm_auto for better clarity
  • Added BUILD_TORCHAO_EXPERIMENTAL to enable the experimental features on non-macOS platforms

The new build command is now:

BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .

In my testing, using OPENMP, the quantized inference time is around 7.5 seconds, which represents a 2x speedup !!!

Looking good! Let's wait for CI + to see if @digantdesai has any feedback on various arm versions.

@metascroy
Copy link
Contributor

It looks like cpuinfo_uarch_zen5 is still failing in CI. Let me try checking out your PR on my mac and see if I can debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants