Skip to content

Enable ipex op #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ RUN pip install --no-cache-dir \
ruff

# Then install PyTorch-dependent packages with constraint to use existing torch
RUN pip install --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/xpu \
-C torch==2.6.0+xpu \
transformers \
accelerate \
bitsandbytes
RUN pip install transformers accelerate bitsandbytes

# Copy the bitsandbytes-intel repository into /workspace/src/bnb and install it.
COPY .. ${WORKSPACE}/src/bnb
RUN cd ${WORKSPACE}/src/bnb && pip install .

COPY --chmod=755 docker/entrypoint.sh /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"]
Expand Down
190 changes: 182 additions & 8 deletions src/bitsandbytes_intel/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import subprocess
from typing import Optional
from typing import Optional, Tuple
import warnings

import torch
import torch.nn.functional as F

from bitsandbytes.utils import QuantState
from bitsandbytes.functional import (
QuantState,
create_dynamic_map,
Expand Down Expand Up @@ -57,6 +58,17 @@ def _ipex_xpu_version_prereq(major, minor):
return False


str2optimizer8bit_blockwise = {}
if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7):
str2optimizer8bit_blockwise = {
"adam": (
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32,
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16,
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16,
),
}


def _maybe_torch_compile(func):
# torch.compile requires g++ and pytorch >= 2.0
if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu:
Expand All @@ -77,8 +89,32 @@ def reverse_4bit_compress_format(weight):
return out


def transform(
A: torch.Tensor,
out: Optional[torch.Tensor] = None,
transpose=False,
state: Optional[Tuple[torch.Size, str]] = None,
):
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For CPU/XPU, it returns the original tensor if transpose=False.
Otherwise, it returns the transpose of A
"""
if transpose:
if out is not None:
out.copy_(A.T)
else:
out = A.T
else:
if out is not None:
out.copy_(A)
else:
out = A
return out, state


@_maybe_torch_compile
def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
def int8_double_quant_impl(A, threshold=0.0, col_stats=None, row_stats=None, out_col=None, out_row=None):
"""
Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8.
If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in
Expand Down Expand Up @@ -157,6 +193,26 @@ def quant_to_int8(A, stats):
return out_row, out_col, row_stats.float(), col_stats.float(), outlier_cols


def int8_vectorwise_quant_impl(A: torch.Tensor, threshold=0.0):
# TODO: We can optimize this as we don't actually need column-wise quant.
out, _, stats, _, outlier_cols = int8_double_quant_impl(A, threshold=threshold)
return out, stats, outlier_cols


def int8_vectorwise_dequant_impl(A: torch.Tensor, stats: torch.Tensor):
"""Dequantizes a tensor with dtype `torch.int8` to `torch.float32`.

Args:
A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor.
stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics.

Returns:
`torch.Tensor` with dtype `torch.float32`: The dequantized tensor.
"""
# To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3


def int8_linear_matmul_impl(
A: torch.Tensor,
B: torch.Tensor,
Expand Down Expand Up @@ -227,10 +283,10 @@ def int8_mm_dequant_impl(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
compute_dtype=torch.float32,
output_dtype=torch.float32,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Dequant and add bias
Expand Down Expand Up @@ -303,11 +359,11 @@ def int8_mm_dequant_impl(
def quantize_4bit_impl(
A: Tensor,
absmax: Tensor = None,
out: Tensor = None,
blocksize=64,
compress_statistics=False,
quant_type="nf4",
quant_storage=torch.uint8,
out: Tensor = None,
) -> Tensor:
"""
Quantize tensor A in blocks of 4-bit values.
Expand Down Expand Up @@ -443,9 +499,9 @@ def dequantize_4bit_impl(
A: Tensor,
quant_state=None,
absmax: Tensor = None,
out: Tensor = None,
blocksize: int = 64,
quant_type="nf4",
out: Tensor = None,
) -> Tensor:
"""
Dequantizes 4-bit blockwise quantized values.
Expand All @@ -471,6 +527,11 @@ def dequantize_4bit_impl(
torch.Tensor:
Dequantized tensor.
"""
# For NF4, ipex have dequant kernel.
if quant_type == "nf4" and getattr(quant_state, "ipex", False):
out = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
return out

transpose = True if A.shape[0] == 1 else False
A = A.reshape(-1)
device = A.device
Expand Down Expand Up @@ -545,10 +606,8 @@ def dequantize_4bit_impl(
def gemm_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
state: QuantState = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Matrix-matrix multiplication with 4-bit quantization.
Expand Down Expand Up @@ -598,3 +657,118 @@ def gemm_4bit_impl(
else:
out = output
return out


def dequantize_blockwise(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 4096,
) -> torch.Tensor:
if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7):
raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.")

# void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
if out.dtype == torch.float16:
ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
elif out.dtype == torch.bfloat16:
ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
elif out.dtype == torch.float32:
ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")


def quantize_blockwise(
A: torch.Tensor,
code: Optional[torch.Tensor] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=4096,
nested=False,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError

def optimizer_update_8bit_blockwise(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
optim_func = None
if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7):
raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.")

if g.dtype == torch.float32 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
elif (
g.dtype == torch.bfloat16
and state1.dtype == torch.uint8
and len(str2optimizer8bit_blockwise[optimizer_name]) == 3
):
optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
optim_func(
p,
g,
state1,
state2,
beta1,
beta2,
beta3,
alpha,
eps,
step,
lr,
qmap1,
qmap2,
absmax1,
absmax2,
weight_decay,
gnorm_scale,
skip_zeros,
g.numel()
)


def optimizer_update_32bit(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
beta1: float,
eps: float,
step: int,
lr: float,
state2: Optional[torch.Tensor] = None,
beta2: float = 0.0,
beta3: float = 0.0,
alpha: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0,
skip_zeros=False,
) -> None:
raise NotImplementedError
Loading