Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Hardware][AMD] Add support for GGUF quantization on ROCm #10254

Merged
merged 10 commits into from
Nov 23, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
initial port
kliuae committed Aug 30, 2024
commit 71cedee57c9b4f79c3ae3d28b78e9a6cf2e6fe0e
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -185,6 +185,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
@@ -213,7 +214,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
@@ -125,6 +125,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,

torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
#endif

torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
@@ -135,6 +136,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);

#ifndef USE_ROCM
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
15 changes: 15 additions & 0 deletions csrc/quantization/gguf/ggml-common.h
Original file line number Diff line number Diff line change
@@ -966,4 +966,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}

static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uint32_t b) {
uint32_t neq = a^b;
return !(neq & 0xff000000) * 0xff000000 |
!(neq & 0x00ff0000) * 0x00ff0000 |
!(neq & 0x0000ff00) * 0x0000ff00 |
!(neq & 0x000000ff) * 0x000000ff;
}

static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) {
return (static_cast<uint8_t>(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) +
(static_cast<uint8_t>(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) +
(static_cast<uint8_t>(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) +
(static_cast<uint8_t>(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0);
}
#endif // defined(USE_ROCM)
6 changes: 4 additions & 2 deletions csrc/quantization/gguf/gguf_kernel.cu
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@
#include "mmvq.cuh"
#include "mmq.cuh"

#include "cuda_compat.h"

// Q8 gemv
static __global__ void quantize_q8_1(const half* __restrict__ x,
void* __restrict__ vy, const int kx,
@@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,

#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
amax = fmaxf(amax, VLLM_SHFL_XOR_SYNC_WIDTH(amax, mask, 32));
sum += VLLM_SHFL_XOR_SYNC_WIDTH(sum, mask, 32);
}

const float d = amax / 127;
4 changes: 4 additions & 0 deletions csrc/quantization/gguf/mmvq.cuh
Original file line number Diff line number Diff line change
@@ -29,7 +29,11 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
#ifndef USE_ROCM
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
#else
tmp += __shfl_xor(tmp, mask, 32);
#endif
}

if (threadIdx.x == 0) {
44 changes: 22 additions & 22 deletions csrc/quantization/gguf/vecdotq.cuh
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *

template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
const int * v, const int * u, const float & d4, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;

#pragma unroll
@@ -55,7 +55,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp

template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;

#pragma unroll
@@ -82,7 +82,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp

template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;

#pragma unroll
@@ -115,7 +115,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp

template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;

#pragma unroll
@@ -149,7 +149,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp

template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
const int * v, const int * u, const float & d8_0, const float & d8_1) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;

#pragma unroll
@@ -163,7 +163,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp

template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM

int sumi = 0;

@@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
const half2 & dm2, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;

@@ -217,7 +217,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
const half2 & dm2, const float & d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi_d = 0;
int sumi_m = 0;

@@ -254,7 +254,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
const int & scale_offset, const float & d3, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM

float sumf = 0.0f;

@@ -288,7 +288,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
const float & d3, const float & d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;

#pragma unroll
@@ -313,7 +313,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM

float sumf_d = 0.0f;
float sumf_m = 0.0f;
@@ -338,7 +338,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;

@@ -369,7 +369,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM

float sumf_d = 0.0f;
float sumf_m = 0.0f;
@@ -400,7 +400,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;

@@ -432,7 +432,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
const float & d, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf = 0.0f;

#pragma unroll
@@ -452,7 +452,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
const float & d6, const float * __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;

#pragma unroll
@@ -1569,7 +1569,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(

static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq2_s * bq2 = (const block_iq2_s *) vbq;

const int ib32 = iqs;
@@ -1606,7 +1606,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(

static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;

const int ib32 = iqs;
@@ -1633,7 +1633,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(

static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;

const int ib32 = iqs;
@@ -1658,7 +1658,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(

static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;

const int ib32 = iqs;
@@ -1698,7 +1698,7 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4

static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM

const block_iq4_nl * bq = (const block_iq4_nl *) vbq;

@@ -1723,7 +1723,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(

static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;

2 changes: 2 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
@@ -159,6 +159,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// awq_marlin repack from AWQ.
ops.def("awq_marlin_repack", &awq_marlin_repack);
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
#endif

// Dequantization for GGML.
ops.def("ggml_dequantize", &ggml_dequantize);
@@ -172,6 +173,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);

#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
@@ -267,7 +267,7 @@ def _parse_quant_hf_config(self):

def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"]
rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8", "gguf"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",