diff --git a/docker/Dockerfile b/docker/Dockerfile index 9803ff8..4e3703e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -35,12 +35,11 @@ RUN pip install --no-cache-dir \ ruff # Then install PyTorch-dependent packages with constraint to use existing torch -RUN pip install --no-cache-dir \ - --extra-index-url https://download.pytorch.org/whl/xpu \ - -C torch==2.6.0+xpu \ - transformers \ - accelerate \ - bitsandbytes +RUN pip install transformers accelerate bitsandbytes + +# Copy the bitsandbytes-intel repository into /workspace/src/bnb and install it. +COPY .. ${WORKSPACE}/src/bnb +RUN cd ${WORKSPACE}/src/bnb && pip install . COPY --chmod=755 docker/entrypoint.sh /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] diff --git a/src/bitsandbytes_intel/cpu_xpu_common.py b/src/bitsandbytes_intel/cpu_xpu_common.py index 13d20ee..d9af7fa 100644 --- a/src/bitsandbytes_intel/cpu_xpu_common.py +++ b/src/bitsandbytes_intel/cpu_xpu_common.py @@ -1,10 +1,11 @@ import subprocess -from typing import Optional +from typing import Optional, Tuple import warnings import torch import torch.nn.functional as F +from bitsandbytes.utils import QuantState from bitsandbytes.functional import ( QuantState, create_dynamic_map, @@ -57,6 +58,17 @@ def _ipex_xpu_version_prereq(major, minor): return False +str2optimizer8bit_blockwise = {} +if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7): + str2optimizer8bit_blockwise = { + "adam": ( + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16, + ), + } + + def _maybe_torch_compile(func): # torch.compile requires g++ and pytorch >= 2.0 if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu: @@ -77,8 +89,32 @@ def reverse_4bit_compress_format(weight): return out +def transform( + A: torch.Tensor, + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For CPU/XPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + @_maybe_torch_compile -def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): +def int8_double_quant_impl(A, threshold=0.0, col_stats=None, row_stats=None, out_col=None, out_row=None): """ Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in @@ -157,6 +193,26 @@ def quant_to_int8(A, stats): return out_row, out_col, row_stats.float(), col_stats.float(), outlier_cols +def int8_vectorwise_quant_impl(A: torch.Tensor, threshold=0.0): + # TODO: We can optimize this as we don't actually need column-wise quant. + out, _, stats, _, outlier_cols = int8_double_quant_impl(A, threshold=threshold) + return out, stats, outlier_cols + + +def int8_vectorwise_dequant_impl(A: torch.Tensor, stats: torch.Tensor): + """Dequantizes a tensor with dtype `torch.int8` to `torch.float32`. + + Args: + A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor. + stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics. + + Returns: + `torch.Tensor` with dtype `torch.float32`: The dequantized tensor. + """ + # To dequantize we divide by 127, or multiply by the reciprocal. + return A * stats.view(-1, 1) * 7.874015718698502e-3 + + def int8_linear_matmul_impl( A: torch.Tensor, B: torch.Tensor, @@ -227,10 +283,10 @@ def int8_mm_dequant_impl( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, - out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, compute_dtype=torch.float32, output_dtype=torch.float32, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Dequant and add bias @@ -303,11 +359,11 @@ def int8_mm_dequant_impl( def quantize_4bit_impl( A: Tensor, absmax: Tensor = None, - out: Tensor = None, blocksize=64, compress_statistics=False, quant_type="nf4", quant_storage=torch.uint8, + out: Tensor = None, ) -> Tensor: """ Quantize tensor A in blocks of 4-bit values. @@ -443,9 +499,9 @@ def dequantize_4bit_impl( A: Tensor, quant_state=None, absmax: Tensor = None, - out: Tensor = None, blocksize: int = 64, quant_type="nf4", + out: Tensor = None, ) -> Tensor: """ Dequantizes 4-bit blockwise quantized values. @@ -471,6 +527,11 @@ def dequantize_4bit_impl( torch.Tensor: Dequantized tensor. """ + # For NF4, ipex have dequant kernel. + if quant_type == "nf4" and getattr(quant_state, "ipex", False): + out = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t() + return out + transpose = True if A.shape[0] == 1 else False A = A.reshape(-1) device = A.device @@ -545,10 +606,8 @@ def dequantize_4bit_impl( def gemm_4bit_impl( A: torch.Tensor, B: torch.Tensor, - out: Optional[torch.Tensor] = None, - transposed_A=False, - transposed_B=False, state: QuantState = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Matrix-matrix multiplication with 4-bit quantization. @@ -598,3 +657,118 @@ def gemm_4bit_impl( else: out = output return out + + +def dequantize_blockwise( + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, +) -> torch.Tensor: + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + # void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) + if out.dtype == torch.float16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + elif out.dtype == torch.bfloat16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + elif out.dtype == torch.float32: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + + +def quantize_blockwise( + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + optim_func = None + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + optim_func( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + g.numel() + ) + + +def optimizer_update_32bit( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, +) -> None: + raise NotImplementedError diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index 8ffe13b..cdba9f8 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -1,51 +1,377 @@ from collections.abc import Sequence +from typing import Optional import math import torch -from .cpu_xpu_common import int8_linear_matmul_impl +from bitsandbytes.utils import QuantState +from .cpu_xpu_common import ( + int8_linear_matmul_impl, + int8_double_quant_impl, + int8_vectorwise_quant_impl, + int8_mm_dequant_impl, + quantize_4bit_impl, + dequantize_4bit_impl, + gemm_4bit_impl, + dequantize_blockwise, + optimizer_update_8bit_blockwise, + ipex_xpu, + ipex_cpu_only, +) print("Loading ops module") -def register_ops(): +def register_xpu_ops(): print("Registering XPU implementations") - # Check if the operator exists - if not hasattr(torch.ops.bitsandbytes, "int8_linear_matmul"): - raise RuntimeError("bitsandbytes::int8_linear_matmul not found! Make sure bitsandbytes is installed") - + # Register the int8_linear_matmul implementation @torch.library.impl("bitsandbytes::int8_linear_matmul", "XPU") def int8_linear_matmul_xpu(A: torch.Tensor, B: torch.Tensor): - print("int8_linear_matmul_xpu called with tensors of shape:", A.shape, B.shape) - return int8_linear_matmul_impl(A, B) - + return int8_linear_matmul_impl(A, B) @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "XPU") def int8_linear_matmul_xpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - print("int8_linear_matmul_xpu_out called with tensors of shape:", A.shape, B.shape) return int8_linear_matmul_impl(A, B, out) - - @torch.library.impl("bitsandbytes::dequantize_4bit.out", "XPU") + + # Register the int8_double_quant implementation + @torch.library.impl("bitsandbytes::int8_double_quant", "XPU") + def int8_double_quant_xpu( + A: torch.Tensor, + threshold: float = 0.0, + col_stats: torch.Tensor = None, + row_stats: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_double_quant_impl(A, threshold, col_stats, row_stats) + @torch.library.impl("bitsandbytes::int8_double_quant.out", "XPU") + def int8_double_quant_xpu_out( + A: torch.Tensor, + threshold: float = 0.0, + col_stats: torch.Tensor = None, + row_stats: torch.Tensor = None, + out_col: torch.Tensor = None, + out_row: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_double_quant_impl(A, threshold, col_stats, row_stats, out_col, out_row) + + # Register the int8_vectorwise_quant implementation + @torch.library.impl("bitsandbytes::int8_vectorwise_quant", "XPU") + def int8_vectorwise_quant_xpu( + A: torch.Tensor, + threshold: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_vectorwise_quant_impl(A, threshold) + + # Register the int8_mm_dequant implementation + @torch.library.impl("bitsandbytes::int8_mm_dequant", "XPU") + def int8_mm_dequant_xpu( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: torch.Tensor = None, + compute_dtype=torch.float32, + output_dtype=torch.float32, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype) + @torch.library.impl("bitsandbytes::int8_mm_dequant.out", "XPU") + def int8_mm_dequant_xpu_out( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: torch.Tensor = None, + compute_dtype = torch.float32, + output_dtype = torch.float32, + out: torch.Tensor = None, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype, out) + + # Register the quantize_4bit implementation + @torch.library.impl("bitsandbytes::quantize_4bit", "XPU") + def quantize_4bit_xpu( + A: torch.Tensor, + absmax: torch.Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + absmax, + blocksize, + compress_statistics, + quant_type, + quant_storage, + ) + @torch.library.impl("bitsandbytes::quantize_4bit.out", "XPU") + def quantize_4bit_xpu_out( + A: torch.Tensor, + absmax: torch.Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, + out: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + absmax, + blocksize, + compress_statistics, + quant_type, + quant_storage, + out, + ) + + # Register the dequantize_4bit implementation + @torch.library.impl("bitsandbytes::dequantize_4bit", "XPU") def dequantize_4bit_xpu( A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, + quant_state = None, + absmax: torch.Tensor = None, + blocksize: int = 64, + quant_type = "nf4", ) -> torch.Tensor: - # TODO - # if quant_type == "nf4" and getattr(quant_state, "ipex", False): - # output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t() - # else: - # output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) - - # return output - raise NotImplementedError + return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type) + @torch.library.impl("bitsandbytes::dequantize_4bit.out", "XPU") + def dequantize_4bit_xpu_out( + A: torch.Tensor, + quant_state = None, + absmax: torch.Tensor = None, + blocksize: int = 64, + quant_type = "nf4", + out: torch.Tensor = None, + ) -> torch.Tensor: + return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) + + # Register the gemv_4bit implementation + @torch.library.impl("bitsandbytes::gemv_4bit", "XPU") + def gemv_4bit_xpu( + A: torch.Tensor, + B: torch.Tensor, + state: QuantState = None, + ) -> torch.Tensor: + return gemm_4bit_impl(A, B, state=state) + @torch.library.impl("bitsandbytes::gemv_4bit.out", "XPU") + def gemv_4bit_xpu_out( + A: torch.Tensor, + B: torch.Tensor, + state: QuantState = None, + out: torch.Tensor = None, + ) -> torch.Tensor: + return gemm_4bit_impl(A, B, state=state, out=out) + + # Register the dequantize_blockwise implementation + @torch.library.impl("bitsandbytes::dequantize_blockwise", "XPU") + def dequantize_blockwise_xpu( + A: torch.Tensor, + absmax: torch.Tensor = None, + code: torch.Tensor = None, + out: torch.Tensor = None, + blocksize: int = 4096, + ) -> torch.Tensor: + return dequantize_blockwise(A, absmax, code, out, blocksize) + + # Register the optimizer_update_8bit_blockwise implementation + @torch.library.impl("bitsandbytes::optimizer_update_8bit_blockwise", "XPU") + def optimizer_update_8bit_blockwise_xpu( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + optimizer_update_8bit_blockwise( + optimizer_name, + g, + p, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + ) print("Successfully registered XPU implementation") + +def register_cpu_ops(): + print("Registering CPU implementations") + + # Register the int8_linear_matmul implementation + @torch.library.impl("bitsandbytes::int8_linear_matmul", "CPU") + def int8_linear_matmul_cpu(A: torch.Tensor, B: torch.Tensor): + return int8_linear_matmul_impl(A, B) + @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "CPU") + def int8_linear_matmul_cpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + return int8_linear_matmul_impl(A, B, out) + + # Register the int8_double_quant implementation + @torch.library.impl("bitsandbytes::int8_double_quant", "CPU") + def int8_double_quant_cpu( + A: torch.Tensor, + threshold: float = 0.0, + col_stats: torch.Tensor = None, + row_stats: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_double_quant_impl(A, threshold, col_stats, row_stats) + @torch.library.impl("bitsandbytes::int8_double_quant.out", "CPU") + def int8_double_quant_cpu_out( + A: torch.Tensor, + threshold: float = 0.0, + col_stats: torch.Tensor = None, + row_stats: torch.Tensor = None, + out_col: torch.Tensor = None, + out_row: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_double_quant_impl(A, threshold, col_stats, row_stats, out_col, out_row) + + # Register the int8_vectorwise_quant implementation + @torch.library.impl("bitsandbytes::int8_vectorwise_quant", "CPU") + def int8_vectorwise_quant_cpu( + A: torch.Tensor, + threshold: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_vectorwise_quant_impl(A, threshold) + + # Register the int8_mm_dequant implementation + @torch.library.impl("bitsandbytes::int8_mm_dequant", "CPU") + def int8_mm_dequant_cpu( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: torch.Tensor = None, + compute_dtype=torch.float32, + output_dtype=torch.float32, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype) + @torch.library.impl("bitsandbytes::int8_mm_dequant.out", "CPU") + def int8_mm_dequant_cpu_out( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: torch.Tensor = None, + compute_dtype = torch.float32, + output_dtype = torch.float32, + out: torch.Tensor = None, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype, out) + + # Register the quantize_4bit implementation + @torch.library.impl("bitsandbytes::quantize_4bit", "CPU") + def quantize_4bit_cpu( + A: torch.Tensor, + absmax: torch.Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + absmax, + blocksize, + compress_statistics, + quant_type, + quant_storage, + ) + @torch.library.impl("bitsandbytes::quantize_4bit.out", "CPU") + def quantize_4bit_cpu_out( + A: torch.Tensor, + absmax: torch.Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, + out: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + absmax, + blocksize, + compress_statistics, + quant_type, + quant_storage, + out, + ) + + # Register the dequantize_4bit implementation + @torch.library.impl("bitsandbytes::dequantize_4bit", "CPU") + def dequantize_4bit_cpu( + A: torch.Tensor, + quant_state = None, + absmax: torch.Tensor = None, + blocksize: int = 64, + quant_type = "nf4", + ) -> torch.Tensor: + return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type) + @torch.library.impl("bitsandbytes::dequantize_4bit.out", "CPU") + def dequantize_4bit_cpu_out( + A: torch.Tensor, + quant_state = None, + absmax: torch.Tensor = None, + blocksize: int = 64, + quant_type = "nf4", + out: torch.Tensor = None, + ) -> torch.Tensor: + return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) + + # Register the gemv_4bit implementation + @torch.library.impl("bitsandbytes::gemv_4bit", "CPU") + def gemv_4bit_cpu( + A: torch.Tensor, + B: torch.Tensor, + state: QuantState = None, + ) -> torch.Tensor: + return gemm_4bit_impl(A, B, state=state) + @torch.library.impl("bitsandbytes::gemv_4bit.out", "CPU") + def gemv_4bit_cpu_out( + A: torch.Tensor, + B: torch.Tensor, + state: QuantState = None, + out: torch.Tensor = None, + ) -> torch.Tensor: + return gemm_4bit_impl(A, B, state=state, out=out) + + # Register the dequantize_blockwise implementation + @torch.library.impl("bitsandbytes::dequantize_blockwise", "CPU") + def dequantize_blockwise_cpu( + A: torch.Tensor, + absmax: torch.Tensor = None, + code: torch.Tensor = None, + out: torch.Tensor = None, + blocksize: int = 4096, + ) -> torch.Tensor: + return dequantize_blockwise(A, absmax, code, out, blocksize) + + print("Successfully registered CPU implementation") + + +def register_hpu_ops(): print("Registering HPU implementations") @torch.library.impl("bitsandbytes::dequantize_4bit", "HPU") @@ -77,4 +403,18 @@ def quantize_4bit_hpu( print("Successfully registered HPU implementations") +def register_ops(): + # Check if the operator exists + if not hasattr(torch.ops.bitsandbytes, "int8_linear_matmul"): + raise RuntimeError("bitsandbytes::int8_linear_matmul not found! Make sure bitsandbytes is installed") + + if ipex_xpu: + register_xpu_ops() + elif ipex_cpu_only: + register_cpu_ops() + # TODO: Need to check HPU + else: + register_hpu_ops() + + print("ops module loaded")