From 32789ad7be116dd2dd49f8c394a68073e4f6eaca Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Mon, 7 Apr 2025 09:02:00 -0400
Subject: [PATCH 1/6] Additional 4bit CPU ops

---
 bitsandbytes/backends/cpu/ops.py | 41 ++++++++++++++++++++++++++++++++
 bitsandbytes/nn/modules.py       |  2 +-
 tests/test_ops.py                |  6 ++++-
 3 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py
index ac906b7ec..9d101e443 100644
--- a/bitsandbytes/backends/cpu/ops.py
+++ b/bitsandbytes/backends/cpu/ops.py
@@ -1,3 +1,4 @@
+from collections.abc import Sequence
 import ctypes as ct
 from typing import Optional
 
@@ -119,6 +120,10 @@ def _(
 ) -> tuple[torch.Tensor, torch.Tensor]:
     torch._check_is_size(blocksize)
     torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
+    torch._check(
+        A.dtype in [torch.bfloat16, torch.float16, torch.float32],
+        lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
+    )
 
     n = A.numel()
 
@@ -140,3 +145,39 @@ def _(
         packed = packed.squeeze().view(quant_storage).unsqueeze(1)
 
     return packed, absmax.float()
+
+
+@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
+def _(
+    A: torch.Tensor,
+    absmax: torch.Tensor,
+    blocksize: int,
+    quant_type: str,
+    shape: Sequence[int],
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    torch._check_is_size(blocksize)
+    torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
+    torch._check(
+        dtype in [torch.bfloat16, torch.float16, torch.float32],
+        lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
+    )
+    torch._check(
+        A.dtype == torch.uint8,
+        lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
+    )
+
+    # Grab upper and lower nibbles. Using int64 for indexing in the LUT.
+    upper = (A >> 4).to(torch.int64)
+    lower = (A & 0x0F).to(torch.int64)
+
+    # Expand to blocks
+    blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
+
+    # Dequantize
+    blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
+
+    # Reshape to original shape
+    blocks = blocks.reshape(-1, *shape[1:])
+
+    return blocks.to(dtype)
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index dfa688abb..ba0e174c4 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -480,7 +480,7 @@ def forward(self, x: torch.Tensor):
 
         bias = None if self.bias is None else self.bias.to(self.compute_dtype)
 
-        return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
+        return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
 
 
 class LinearFP4(Linear4bit):
diff --git a/tests/test_ops.py b/tests/test_ops.py
index 8c9c6a646..53e229185 100644
--- a/tests/test_ops.py
+++ b/tests/test_ops.py
@@ -171,7 +171,11 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
     @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
     def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
         if device == "cpu":
-            pytest.skip("CPU implementation is not available")
+            if quant_type != "nf4":
+                pytest.skip("CPU implementation is only available for nf4")
+
+            if storage_dtype != torch.uint8:
+                pytest.skip("CPU implementation only supports uint8 storage")
 
         shape = (128, 128)
 

From 958fecb74039ce5dcd3bb0a83591ee81d66c31c4 Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Mon, 7 Apr 2025 09:02:00 -0400
Subject: [PATCH 2/6] Additional 4bit CPU ops

---
 bitsandbytes/backends/cpu/ops.py | 41 ++++++++++++++++++++++++++++++++
 bitsandbytes/nn/modules.py       |  2 +-
 tests/test_ops.py                |  6 ++++-
 3 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py
index b7513c4d3..b615d37a4 100644
--- a/bitsandbytes/backends/cpu/ops.py
+++ b/bitsandbytes/backends/cpu/ops.py
@@ -1,3 +1,4 @@
+from collections.abc import Sequence
 import ctypes as ct
 from typing import Optional
 
@@ -119,6 +120,10 @@ def _(
 ) -> tuple[torch.Tensor, torch.Tensor]:
     torch._check_is_size(blocksize)
     torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
+    torch._check(
+        A.dtype in [torch.bfloat16, torch.float16, torch.float32],
+        lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
+    )
 
     n = A.numel()
 
@@ -140,3 +145,39 @@ def _(
         packed = packed.squeeze().view(quant_storage).unsqueeze(1)
 
     return packed, absmax.float()
+
+
+@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
+def _(
+    A: torch.Tensor,
+    absmax: torch.Tensor,
+    blocksize: int,
+    quant_type: str,
+    shape: Sequence[int],
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    torch._check_is_size(blocksize)
+    torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
+    torch._check(
+        dtype in [torch.bfloat16, torch.float16, torch.float32],
+        lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
+    )
+    torch._check(
+        A.dtype == torch.uint8,
+        lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
+    )
+
+    # Grab upper and lower nibbles. Using int64 for indexing in the LUT.
+    upper = (A >> 4).to(torch.int64)
+    lower = (A & 0x0F).to(torch.int64)
+
+    # Expand to blocks
+    blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
+
+    # Dequantize
+    blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
+
+    # Reshape to original shape
+    blocks = blocks.reshape(-1, *shape[1:])
+
+    return blocks.to(dtype)
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index ea5451502..e0f866be7 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -486,7 +486,7 @@ def forward(self, x: torch.Tensor):
 
         bias = None if self.bias is None else self.bias.to(self.compute_dtype)
 
-        return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
+        return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
 
 
 class LinearFP4(Linear4bit):
diff --git a/tests/test_ops.py b/tests/test_ops.py
index 9869f51ef..70e368fea 100644
--- a/tests/test_ops.py
+++ b/tests/test_ops.py
@@ -171,7 +171,11 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
     @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
     def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
         if device == "cpu":
-            pytest.skip("CPU implementation is not available")
+            if quant_type != "nf4":
+                pytest.skip("CPU implementation is only available for nf4")
+
+            if storage_dtype != torch.uint8:
+                pytest.skip("CPU implementation only supports uint8 storage")
 
         shape = (128, 128)
 

From 0410ec181abb84510ddc25503ba60ba6e6659546 Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Thu, 24 Apr 2025 16:17:53 -0400
Subject: [PATCH 3/6] Implement additional device-agnostic ops and test updates

---
 bitsandbytes/backends/cpu/ops.py     | 34 ++++++++++++++
 bitsandbytes/backends/cuda/ops.py    | 39 ----------------
 bitsandbytes/backends/default/ops.py | 70 ++++++++++++++++++++++++++++
 bitsandbytes/nn/modules.py           | 31 +++++++-----
 tests/test_autograd.py               | 14 ++++--
 tests/test_functional.py             | 25 ++++++++--
 tests/test_linear4bit.py             |  4 +-
 tests/test_linear8bitlt.py           | 13 ++----
 tests/test_modules.py                | 28 +++++------
 tests/test_ops.py                    | 11 ++---
 10 files changed, 178 insertions(+), 91 deletions(-)

diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py
index b615d37a4..0da9eac94 100644
--- a/bitsandbytes/backends/cpu/ops.py
+++ b/bitsandbytes/backends/cpu/ops.py
@@ -167,6 +167,8 @@ def _(
         lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
     )
 
+    A = A.view(-1, 1)
+
     # Grab upper and lower nibbles. Using int64 for indexing in the LUT.
     upper = (A >> 4).to(torch.int64)
     lower = (A & 0x0F).to(torch.int64)
@@ -181,3 +183,35 @@ def _(
     blocks = blocks.reshape(-1, *shape[1:])
 
     return blocks.to(dtype)
+
+
+@register_kernel("bitsandbytes::gemv_4bit", "cpu")
+def _(
+    A: torch.Tensor,
+    B: torch.Tensor,
+    shapeB: Sequence[int],
+    absmax: torch.Tensor,
+    code: torch.Tensor,
+    blocksize: int,
+) -> torch.Tensor:
+    # TODO: We need to determine whether `code` is NF4, FP4, or other.
+    # Right now we assume NF4, as this is the only one supported on CPU.
+
+    B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
+        B,
+        absmax,
+        blocksize,
+        "nf4",
+        shape=shapeB,
+        dtype=A.dtype,
+    )
+
+    # User called gemv with B.t(), so we need to transpose it back.
+    # if B.shape[0] == 1:
+    #    B_dq = B_dq.t()
+
+    return torch.nn.functional.linear(
+        A,
+        B_dq,
+        bias=None,
+    )
diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py
index 5ffcdb767..efdef2871 100644
--- a/bitsandbytes/backends/cuda/ops.py
+++ b/bitsandbytes/backends/cuda/ops.py
@@ -22,45 +22,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
     _int8_linear_matmul_impl(A, B, out)
 
 
-@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
-def _(
-    A: torch.Tensor,
-    CA: torch.Tensor,
-    CB: torch.Tensor,
-    SCA: torch.Tensor,
-    SCB: torch.Tensor,
-    outlier_cols: Optional[torch.Tensor] = None,
-    bias: Optional[torch.Tensor] = None,
-) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
-    subB = None
-
-    if outlier_cols is not None and outlier_cols.numel():
-        # Extract the inputs with outliers in original precision
-        subA = A[:, outlier_cols].contiguous()
-
-        # Dequantize the corresponding weight columns
-        subB = (
-            torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
-            .to(A.dtype)
-            .t()
-        )
-
-        # TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
-
-    else:
-        # Needed for torch.compile when there are no outliers.
-        subA = torch.empty(0, device=A.device, dtype=A.dtype)
-
-    # Int8 Matmul + Dequant + Bias
-    output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
-
-    if subB is not None:
-        # Add the outlier columns back to the output
-        output = output.addmm(subA, subB)
-
-    return output, subA
-
-
 def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
     A, B = B, A
 
diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py
index 6e581038d..653f87659 100644
--- a/bitsandbytes/backends/default/ops.py
+++ b/bitsandbytes/backends/default/ops.py
@@ -1,3 +1,4 @@
+from math import prod
 from typing import Optional
 
 import torch
@@ -5,6 +6,45 @@
 from ..._ops import register_kernel
 
 
+@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
+def _(
+    A: torch.Tensor,
+    CA: torch.Tensor,
+    CB: torch.Tensor,
+    SCA: torch.Tensor,
+    SCB: torch.Tensor,
+    outlier_cols: Optional[torch.Tensor] = None,
+    bias: Optional[torch.Tensor] = None,
+) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+    subB = None
+
+    if outlier_cols is not None and outlier_cols.numel():
+        # Extract the inputs with outliers in original precision
+        subA = A[:, outlier_cols].contiguous()
+
+        # Dequantize the corresponding weight columns
+        subB = (
+            torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
+            .to(A.dtype)
+            .t()
+        )
+
+        # TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
+
+    else:
+        # Needed for torch.compile when there are no outliers.
+        subA = torch.empty(0, device=A.device, dtype=A.dtype)
+
+    # Int8 Matmul + Dequant + Bias
+    output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
+
+    if subB is not None:
+        # Add the outlier columns back to the output
+        output = output.addmm(subA, subB)
+
+    return output, subA
+
+
 @register_kernel("bitsandbytes::int8_scaled_mm", "default")
 def _(
     A: torch.Tensor,
@@ -41,3 +81,33 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[tor
     if out is not None:
         result = out.copy_(result)
     return result
+
+
+@register_kernel("bitsandbytes::int8_vectorwise_quant", "default")
+def _(A: torch.Tensor, threshold=0.0):
+    rows = prod(A.shape[:-1])
+    outlier_cols = None
+
+    if threshold > 0.0:
+        outliers = A.abs() >= threshold
+
+        if outliers.any():
+            # Determine which columns contain outliers, and zero out the
+            # outliers ahead of quantization.
+            outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
+            A[outliers] = 0
+        else:
+            # Needed for torch.compile support.
+            outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)
+
+    # Get absmax for each row.
+    row_stats = torch.max(A.abs(), dim=1).values.float()
+
+    # Quantize row-wise to int8.
+    out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8)
+
+    # Zero out values from outlier columns across all rows.
+    if rows > 1 and outlier_cols is not None:
+        out_row[:, outlier_cols] = 0
+
+    return out_row, row_stats, outlier_cols
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index e0f866be7..74277f65e 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -585,19 +585,28 @@ def __new__(
         obj.has_fp16_weights = has_fp16_weights
         return obj
 
-    def cuda(self, device):
+    def _quantize(self, device):
         if self.has_fp16_weights:
-            return super().cuda(device)
-        else:
-            # We quantize the weight and store in 8bit row-major
-            B = self.data.contiguous().half().cuda(device)
-            CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
-            self.data = CB
-            self.CB = CB
-            self.SCB = SCB
+            return super().to(device)
+
+        # We quantize the weight and store in 8bit row-major
+        B = self.data.contiguous().to(device=device, dtype=torch.float16)
+        CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
+        self.data = CB
+        self.CB = CB
+        self.SCB = SCB
 
         return self
 
+    def cpu(self):
+        return self.to(device="cpu")
+
+    def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
+        return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
+
+    def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
+        return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
+
     def __deepcopy__(self, memo):
         # adjust this if new arguments are added to the constructor
         new_instance = type(self).__new__(
@@ -627,8 +636,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
     def to(self, *args, **kwargs):
         device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
 
-        if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
-            return self.cuda(device)
+        if device is not None and device.type != "meta" and self.data.device.type == "cpu":
+            return self._quantize(device)
         else:
             new_param = Int8Params(
                 super().to(device=device, dtype=dtype, non_blocking=non_blocking),
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index 7c43cab80..b6ba284c9 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -32,9 +32,15 @@
 def test_matmullt(
     device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
 ):
-    if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
-        # TODO: Deprecate/remove?
-        pytest.skip("switchback_bnb only works on CUDA.")
+    if device != "cuda":
+        if funcs[1] == bnb.research.switchback_bnb:
+            # TODO: Deprecate/remove?
+            pytest.skip("switchback_bnb only works on CUDA.")
+
+        if req_grad[1]:
+            # This will be deprecated for CUDA in the future. We don't expect
+            # this to work on any other device.
+            pytest.skip("Deprecated feature with CUDA support only.")
 
     dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
     dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
@@ -171,7 +177,7 @@ def test_matmul_4bit(
     quant_type,
 ):
     if device == "cpu" and quant_type == "fp4":
-        pytest.skip("Only nf4 is supported on CPU")
+        pytest.xfail("Only nf4 is supported on CPU")
 
     dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
     dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 5b9038288..ee2b52429 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -186,7 +186,7 @@ def test_few_bit_quant(self, device, bits, method):
             code = F.create_dynamic_map(True, bits - 0, bits).to(device)
         elif method == "quantile":
             if device != "cuda":
-                pytest.xfail("Quantile map only works on CUDA")
+                pytest.skip("Quantile map only works on CUDA")
             values = torch.randn(2048, 2048, device="cuda")
             code = F.create_quantile_map(values, bits).cuda()
         # for some data types we have no zero
@@ -593,7 +593,7 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims):
 
             A = A.view(-1, A.shape[-1])
 
-            CA, _, statsA, _, _ = F.int8_double_quant(A)
+            CA, statsA, _ = F.int8_vectorwise_quant(A)
             CB, statsB, _ = F.int8_vectorwise_quant(B)
             output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
 
@@ -1102,6 +1102,9 @@ class TestQuantize4BitFunctional:
     @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
     @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
     def test_4bit_quant(self, device, dtype, quant_type, blocksize):
+        if device == "cpu" and quant_type != "nf4":
+            pytest.xfail("fp4 quantization is not supported on CPU")
+
         A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
         qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
         A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
@@ -1134,6 +1137,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
     @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
     @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
     def test_4bit_compressed_stats(self, device, quant_type, blocksize):
+        if device == "cpu" and quant_type != "nf4":
+            pytest.xfail("fp4 quantization is not supported on CPU")
+
         errs1 = []
         errs2 = []
         for i in range(10):
@@ -1206,6 +1212,12 @@ def test_bench_4bit_dequant(self, quant_type):
     )
     @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
     def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
+        if device == "cpu":
+            if storage_type != "nf4":
+                pytest.xfail("fp4 quantization is not supported on CPU")
+            if quant_storage != torch.uint8:
+                pytest.xfail("Only uint8 storage is supported on CPU")
+
         errs1 = []
         errs2 = []
         errs3 = []
@@ -1216,7 +1228,11 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
         max_errs2 = []
         max_errs3 = []
 
-        for i in range(100):
+        # Large number of iterations is excessive and slow on CPU.
+        # Keep for CUDA for now.
+        iters = 100 if device == "cuda" else 10
+
+        for i in range(iters):
             if kind == "fc1":
                 A = torch.randn(1, dim, dtype=dtype, device=device)
                 B = torch.randn(dim * 4, dim, dtype=dtype, device=device) / math.sqrt(dim)
@@ -1337,6 +1353,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
     @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
     @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
     def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
+        if device == "cpu" and storage_type != "nf4":
+            pytest.xfail("fp4 quantization is not supported on CPU")
+
         dims = 10
         torch.random.manual_seed(np.random.randint(0, 412424242))
         dims = get_test_dims(0, 8192, n=dims)
diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py
index 669319298..d4123acf6 100644
--- a/tests/test_linear4bit.py
+++ b/tests/test_linear4bit.py
@@ -24,8 +24,8 @@
 @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
 @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
 def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
-    if device == "cpu":
-        pytest.xfail("Dequantization is not yet implemented for CPU")
+    if device == "cpu" and quant_type == "fp4":
+        pytest.xfail("FP4 is not supported for CPU")
 
     original_dtype = torch.float16
     compute_dtype = None
diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py
index 53a566cb9..8c08cfa2c 100644
--- a/tests/test_linear8bitlt.py
+++ b/tests/test_linear8bitlt.py
@@ -22,9 +22,6 @@
 # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
 @pytest.mark.parametrize("device", get_available_devices())
 def test_linear_no_igemmlt(device):
-    if device == "cpu":
-        pytest.xfail("Not yet implemented on CPU")
-
     linear = torch.nn.Linear(1024, 3072)
     x = torch.randn(3, 1024, dtype=torch.half)
     linear_custom = Linear8bitLt(
@@ -81,8 +78,8 @@ def test_linear_serialization(
     save_before_forward,
     load_before_cuda,
 ):
-    if device == "cpu":
-        pytest.xfail("Not yet implemented on CPU")
+    if device != "cuda" and has_fp16_weights:
+        pytest.skip("has_fp16_weights is only supported on CUDA and is deprecated")
 
     linear = torch.nn.Linear(32, 96)
     # TODO: Fallback for bad shapes
@@ -111,7 +108,7 @@ def test_linear_serialization(
     if save_before_forward:
         bytes_8bit = torch_save_to_buffer(linear_custom)
 
-    x_first = x.clone().cuda().requires_grad_(True)
+    x_first = x.clone().to(device).requires_grad_(True)
     fx_first = linear_custom(x_first).float()
     grad_proj = torch.randn_like(fx_first)
     (fx_first * grad_proj).mean().backward()
@@ -157,11 +154,11 @@ def test_linear_serialization(
     if not load_before_cuda:
         new_linear_custom2 = torch_load_from_buffer(bytes_8bit)
 
-    x_second = x.clone().cuda().requires_grad_(True)
+    x_second = x.clone().to(device).requires_grad_(True)
     fx_second = new_linear_custom(x_second).float()
     (fx_second * grad_proj).mean().backward()
 
-    x_third = x.clone().cuda().requires_grad_(True)
+    x_third = x.clone().to(device).requires_grad_(True)
     fx_third = new_linear_custom2(x_third).float()
     (fx_third * grad_proj).mean().backward()
 
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 8ef0890ec..dc1d60e6c 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -55,9 +55,6 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
 @pytest.mark.parametrize("device", get_available_devices())
 @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold"))
 def test_linear8bitlt_inference(device, threshold):
-    if device == "cpu":
-        pytest.xfail("Not yet implemented on CPU")
-
     l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
     assert l1.weight.device.type == device
     assert l1.weight.dtype == torch.int8
@@ -120,9 +117,6 @@ def test_linear8bitlt_accumulated_gradient(device):
 @pytest.mark.parametrize("device", get_available_devices())
 @pytest.mark.parametrize("threshold", [0.0, 2.0])
 def test_linear8bitlt_no_fp16_weights(device, threshold):
-    if device == "cpu":
-        pytest.xfail("Not yet supported on CPU")
-
     l1 = (
         bnb.nn.Linear8bitLt(
             32,
@@ -211,7 +205,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
         has_fp16_weights=False,
     )
     w1, w2 = mlp.fc1.weight.clone().to(device), mlp.fc2.weight.clone().to(device)  # grab weights before quantization,
-    mlp = mlp.cuda().half()  # and this line triggers quantization
+    mlp = mlp.to(device).half()  # and this line triggers quantization
 
     for i in range(100):
         b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
@@ -253,9 +247,6 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
     ids=["Int8Lt", "NF4"],
 )
 def test_linear_kbit_fp32_bias(device, module):
-    if device == "cpu":
-        pytest.xfail("Not yet implemented on CPU")
-
     # casts model to fp16 -> int8 automatically
     l1 = module(32, 64).to(device)
     assert l1.weight.dtype in [torch.int8, torch.uint8]
@@ -295,7 +286,7 @@ def test_linear_kbit_fp32_bias(device, module):
 @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
 def test_kbit_backprop(device, module):
     if device == "cpu":
-        pytest.xfail("Not yet implemented on CPU")
+        pytest.xfail("Test is not yet supported on CPU")
 
     b = 16
     dim1 = 36
@@ -401,7 +392,10 @@ def test_fp8linear():
 )
 def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
     if device == "cpu":
-        pytest.xfail("Not yet supported on CPU")
+        if embedding_class is bnb.nn.EmbeddingFP4:
+            pytest.xfail("FP4 is not supported for CPU")
+        if quant_storage is not None and quant_storage != torch.uint8:
+            pytest.xfail("CPU only supports uint8 storage for 4bit")
 
     num_embeddings = 128
 
@@ -449,7 +443,10 @@ def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim,
 )
 def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
     if device == "cpu":
-        pytest.xfail("Not yet supported on CPU")
+        if embedding_class is bnb.nn.EmbeddingFP4:
+            pytest.xfail("FP4 is not supported for CPU")
+        if quant_storage is not None and quant_storage != torch.uint8:
+            pytest.xfail("CPU only supports uint8 storage for 4bit")
 
     is_8bit = embedding_class is bnb.nn.Embedding8bit
 
@@ -486,7 +483,7 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu
 @pytest.mark.parametrize("device", get_available_devices())
 def test_4bit_linear_warnings(device):
     if device == "cpu":
-        pytest.xfail("Not yet implemented on CPU")
+        pytest.xfail("gemv_4bit op is not yet implemented on CPU")
 
     dim1 = 64
 
@@ -525,9 +522,6 @@ def test_4bit_linear_warnings(device):
 
 @pytest.mark.parametrize("device", get_available_devices())
 def test_4bit_embedding_warnings(device):
-    if device == "cpu":
-        pytest.xfail("Not yet implemented on CPU")
-
     num_embeddings = 128
     default_block_size = 64
 
diff --git a/tests/test_ops.py b/tests/test_ops.py
index 70e368fea..ea448f99b 100644
--- a/tests/test_ops.py
+++ b/tests/test_ops.py
@@ -37,9 +37,6 @@ def test_int8_linear_matmul_out(self, device):
     @pytest.mark.parametrize("threshold", [0.0, 6.0])
     @pytest.mark.parametrize("device", get_available_devices())
     def test_int8_vectorwise_quant(self, threshold, device):
-        if device == "cpu":
-            pytest.skip("CPU implementation is not available")
-
         A = torch.randn(10, 20, dtype=torch.float16, device=device)
         A[1][0] = 1000.0
 
@@ -147,7 +144,7 @@ class Test4bitBlockwiseQuantOps:
     @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
     def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
         if device == "cpu" and quant_type != "nf4":
-            pytest.skip("CPU implementation is only available for nf4")
+            pytest.xfail("CPU implementation is only available for nf4")
 
         if storage_dtype != torch.uint8:
             pytest.xfail("Known issue with storage_dtype != uint8")
@@ -172,10 +169,10 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
     def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
         if device == "cpu":
             if quant_type != "nf4":
-                pytest.skip("CPU implementation is only available for nf4")
+                pytest.xfail("CPU implementation is only available for nf4")
 
             if storage_dtype != torch.uint8:
-                pytest.skip("CPU implementation only supports uint8 storage")
+                pytest.xfail("CPU implementation only supports uint8 storage")
 
         shape = (128, 128)
 
@@ -208,7 +205,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
     @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
     def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
         if device == "cpu":
-            pytest.skip("CPU implementation is not available")
+            pytest.xfail("CPU implementation is not available")
 
         out_features = 1024
         in_features = 256

From d02b536b197652dac925216854b0cadc3ed54e70 Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Thu, 24 Apr 2025 18:22:14 -0400
Subject: [PATCH 4/6] More test fixes

---
 bitsandbytes/functional.py |  2 +-
 tests/test_linear4bit.py   | 12 ++++++++----
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index c9341230f..d17ff2e88 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -779,7 +779,7 @@ def quantize_blockwise(
             state2=state2,
         )
     else:
-        quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype)
+        quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)
 
     # TODO(matthewdouglas): Deprecate out kwarg
     out = out.copy_(_out) if out is not None else _out
diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py
index d4123acf6..67b61cb05 100644
--- a/tests/test_linear4bit.py
+++ b/tests/test_linear4bit.py
@@ -24,8 +24,11 @@
 @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
 @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
 def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
-    if device == "cpu" and quant_type == "fp4":
-        pytest.xfail("FP4 is not supported for CPU")
+    if device == "cpu":
+        if quant_type == "fp4":
+            pytest.xfail("FP4 is not supported for CPU")
+        if quant_storage != "uint8":
+            pytest.xfail("Only uint8 storage is supported for CPU")
 
     original_dtype = torch.float16
     compute_dtype = None
@@ -144,8 +147,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua
     linear_q3 = torch_load_from_buffer(bytes_4bit)
 
     # Test moving to CPU and back to GPU
-    linear_q2.to("cpu")
-    linear_q2.to(device)
+    if device != "cpu":
+        linear_q2.to("cpu")
+        linear_q2.to(device)
     d = linear_qs(x)
     assert c.dtype == d.dtype
     assert c.device == d.device

From f40e8aeaaaedc156824d026cd987194cb76b32b3 Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Fri, 25 Apr 2025 10:07:13 -0400
Subject: [PATCH 5/6] int8 tests passing

---
 bitsandbytes/backends/default/ops.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py
index 653f87659..20e596f25 100644
--- a/bitsandbytes/backends/default/ops.py
+++ b/bitsandbytes/backends/default/ops.py
@@ -88,13 +88,17 @@ def _(A: torch.Tensor, threshold=0.0):
     rows = prod(A.shape[:-1])
     outlier_cols = None
 
+    outlier_restore = None
+
     if threshold > 0.0:
         outliers = A.abs() >= threshold
 
         if outliers.any():
             # Determine which columns contain outliers, and zero out the
-            # outliers ahead of quantization.
+            # outliers ahead of quantization. We need to keep a backup of these
+            # outliers to restore them after quantization.
             outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
+            outlier_restore = A[outliers].clone()
             A[outliers] = 0
         else:
             # Needed for torch.compile support.
@@ -110,4 +114,8 @@ def _(A: torch.Tensor, threshold=0.0):
     if rows > 1 and outlier_cols is not None:
         out_row[:, outlier_cols] = 0
 
+    # Restore outliers.
+    if outlier_restore is not None:
+        A[outliers] = outlier_restore
+
     return out_row, row_stats, outlier_cols

From db5591956d2b2ef74fa7534dc749c6b5c87b1f28 Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Fri, 25 Apr 2025 15:40:40 -0400
Subject: [PATCH 6/6] Fix feature flag for multi_backend

---
 bitsandbytes/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index b8dc5a5e1..917cd0b6a 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -21,7 +21,7 @@
 
 # This is a signal for integrations with transformers/diffusers.
 # Eventually we may remove this but it is currently required for compatibility.
-features = {"multi-backend"}
+features = {"multi_backend"}
 supported_torch_devices = {
     "cpu",
     "cuda",  # NVIDIA/AMD GPU