diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a6a311e97633..928c309252016 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/fp_eXmY/fp_eXmY_linear.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") diff --git a/csrc/ops.h b/csrc/ops.h index c50eb39a3dacc..97c647b873565 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -142,6 +142,12 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional<torch::Tensor> const& azp, c10::optional<torch::Tensor> const& bias); + +torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA, + torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int64_t splitK = 1); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/fp_eXmY/configs.h b/csrc/quantization/fp_eXmY/configs.h new file mode 100644 index 0000000000000..856955c95ba63 --- /dev/null +++ b/csrc/quantization/fp_eXmY/configs.h @@ -0,0 +1,73 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from +// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/configs.h + +#ifndef CONFIGS_H +#define CONFIGS_H + +// #define DEBUG_MODE +#define PIPELINE_LEVEL_GMEM 2 +#define PIPELINE_LEVEL_SMEM 2 // only support 2 + +/************************ Hardware Parameters ************************/ +#define WARP_SIZE 32 +#define REG_BIT_WIDTH 32 +// mma: M=16 K=16 N=8 +#define MMA_8 8 +#define MMA_16 16 +// for memory access +#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... +#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 + +/******************** Register Allocation For GEMM ********************/ +#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation +/********************** Memory Padding Parameters **********************/ +// Eliminating bank-conflict +#define PADDING_BYTES_16 16 // Padding 16 bytes each column +#define PADDING_SHARED_MEM_FOR_B_8 \ + 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B +#define PADDING_SHARED_MEM_FOR_C_4 \ + 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() + // for C +/************************* WARP Tiling part-1 *************************/ +#define WARP_ROW_MMA_TENSORS 4 +#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 +#define WARP_K_MMA_TENSORS 4 +#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 +template <int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_> +struct TilingConfig { + // Depending on "n" dimension of the GEMM + static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; + static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; + static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; + /************************* WARP Tiling part-2 *************************/ + static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; + /*************************Thread Block Tiling *************************/ + static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; + static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; + static constexpr int TILE_K = WARP_K; + /********************** #Thread per Thread Block **********************/ + static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; + static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; + /******************************* Others *******************************/ + static constexpr int SMEM_SIZE_B_TILE = + TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * + PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 + static constexpr int SMEM_SIZE_C_TILE = + TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 +}; + +#endif // CONFIGS_H \ No newline at end of file diff --git a/csrc/quantization/fp_eXmY/cp_async.cuh b/csrc/quantization/fp_eXmY/cp_async.cuh new file mode 100644 index 0000000000000..69c2fea0940e6 --- /dev/null +++ b/csrc/quantization/fp_eXmY/cp_async.cuh @@ -0,0 +1,82 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from +// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_cp.async.cuh + +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +// Extended from CUTLASS's source code + +#ifndef CP_ASYNC_CUH +#define CP_ASYNC_CUH + +#include <cuda.h> +#include <cuda_fp16.h> +#include <cuda_runtime.h> + +template <int SizeInBytes> +__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, + bool pred_guard = true) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static_assert(SizeInBytes == 16, "Size is not supported"); + unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); + asm volatile( + "{ \n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes)); +#endif +} + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does +/// not block. +__device__ __forceinline__ void cp_async_group_commit() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +/// Blocks until all but <N> previous cp.async.commit_group operations have +/// committed. +template <int N> +__device__ __forceinline__ void cp_async_wait_group() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +/// Blocks until all previous cp.async.commit_group operations have committed. +// cp.async.wait_all is equivalent to : +// cp.async.commit_group; +// cp.async.wait_group 0; +__device__ __forceinline__ void cp_async_wait_all() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_all;\n" ::); +#endif +} + +#endif \ No newline at end of file diff --git a/csrc/quantization/fp_eXmY/fp_eXmY_linear.cu b/csrc/quantization/fp_eXmY/fp_eXmY_linear.cu new file mode 100644 index 0000000000000..9652fdaed59cc --- /dev/null +++ b/csrc/quantization/fp_eXmY/fp_eXmY_linear.cu @@ -0,0 +1,305 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is adapted from +// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu + +#include "quant_matmul.cuh" +#include "quant_reduction.cuh" + +#include <stdio.h> +#include <assert.h> + +#include <torch/all.h> +#include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> +#include <torch/library.h> + +namespace vllm { + +template <typename TilingConfig, typename OutputDataType, int EXPONENT, + int MANTISSA> +static void Kernel_Ex(cudaStream_t stream, const uint4* Weight, + const half* Scales, const half* B, OutputDataType* C, + const size_t M_Global, const size_t N_Global, + const size_t K_Global, int Split_K) { +#ifdef DEBUG_MODE + printf("\n"); + printf("Launcher.cu->Kernel_Ex():\n"); + printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, + Split_K); + printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, + TilingConfig::TILE_K, TilingConfig::TILE_N); +#endif + static size_t SHMEM_SZ = + max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_PER_TB_A_TILE, + TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute( + QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>, + cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); + size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; + size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; + dim3 GridDim(dimN, dimM, 1); + dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); +// +#ifdef DEBUG_MODE + printf( + "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, " + "BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n", + GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, + SHMEM_SZ); + printf("\n"); +#endif + QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA> + <<<GridDim, BlockDim, SHMEM_SZ, stream>>>(Weight, Scales, B, C, M_Global, + N_Global, K_Global, Split_K); +} + +template <int EXPONENT, int MANTISSA> +cudaError_t fpx_linear_kernel( + cudaStream_t stream, const uint4* Weight, const half* Scales, const half* B, + half* C, const size_t M_Global, const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K * + // M_Global * N_Global * sizeof(fp32) + int Split_K) { + TORCH_CHECK(M_Global % 256 == 0, "M_Global must be a multiple of 256."); + TORCH_CHECK(K_Global % 64 == 0, "K_Global must be a multiple of 64."); + TORCH_CHECK(N_Global > 0, "N_Global must be greater than zero."); + + // Work around to support more N shapes: + size_t N_PowerOf2; + if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8; + if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16; + if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32; + if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64; + if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128; + if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128; + + if (Split_K == 1) { + switch (N_PowerOf2) { + case 8: + Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, + Split_K); + break; + case 16: + Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, + Split_K); + break; + case 32: + Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, + Split_K); + break; + case 64: + Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, + Split_K); + break; + case 128: + Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, + Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("FP6LLM_API Error: Unsupported N dimension %zu!\n", + N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, + Split_K); + break; + } + } else { + switch (N_PowerOf2) { + case 8: + Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, + K_Global, Split_K); + break; + case 16: + Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, + K_Global, Split_K); + break; + case 32: + Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, + K_Global, Split_K); + break; + case 64: + Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, + K_Global, Split_K); + break; + case 128: + Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, + K_Global, Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("FP6LLM_API Error: Unsupported N dimension %zu!\n", + N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>( + stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, + K_Global, Split_K); + break; + } + // Reduction for SplitK + dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, + 1); + dim3 BlockDim(WARP_SIZE, 1, 1); + SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>( + C, Reduction_Workspace, M_Global, N_Global, Split_K); + } + return cudaGetLastError(); +} +} // namespace vllm + +// MODIFICATION NOTE: dtype of _weights is changed to uint8 +/* +Computes FPx-FP16 GEMM (PyTorch interface). +[Mathematical Formula] +Standard definition of linear layer: Out = In * trans(W), where In, Out, and +W are stored in row-major. After Equivalent transformation : trans(Out) = +W * trans(In). Note that we do not perform "transpose" during runtime, we +instead interpret the In/Out as column-major matrices when calling our CUDA +kernel. [Inputs] _in_feats: tensor of shape [B, IC]; // half + _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words +contains 8 FPx weights. _scales: tensor of shape [OC]; // +half splitK: splitting the MatMul problem along K dimension for higher GPU +utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half +*/ +torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA, + torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int64_t splitK = 1) { + const int64_t NBITS = 1 + EXPONENT + MANTISSA; + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_channels = _weights.size(0); + TORCH_CHECK(num_in_channels % 64 == 0, + "Expected in_features to be a multiple of 64, but received ", + num_in_channels); + TORCH_CHECK((num_in_channels / 8 * NBITS) == + _weights.size(1)); // Making sure the K dimension is matched. + // + int M = num_out_channels; + int K = num_in_channels; + int N = num_in_feats; + // Input Tensors + auto weight = reinterpret_cast<const uint4*>( + _weights.data_ptr<uint8_t>()); // weights is [OC, IC] but in FP6. + auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>()); + auto scales = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>()); + // Output Tensors + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + at::Tensor _out_feats = + torch::empty({num_in_feats, num_out_channels}, options); + auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>()); + + options = + torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); + at::Tensor _workspace = + torch::empty({splitK, num_in_feats, num_out_channels}, options); + auto Reduction_Workspace = reinterpret_cast<float*>( + _workspace.data_ptr<float>()); // Reduction_Workspace_Size = Split_K * + // M_Global * N_Global * sizeof(fp32) + + // NOTE(alpin): use at::cuda::getCurrentCUDAStream() instead of default + // stream (0) this fixes problem with CUDA graphs when used with + // torch.compile() + auto dev = _in_feats.device().index(); + auto stream = at::cuda::getCurrentCUDAStream(dev); + + /* + The heuristic is weight_bit - exponent_bit - 1 = mantissa_bit + + NOTE(alpin): Using switch-statement here probably doesn't matter, + the compiler will likely optimize it to a jump table. + */ + + // FP4 + if (EXPONENT == 1 && MANTISSA == 2) + vllm::fpx_linear_kernel<1, 2>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 3 && MANTISSA == 0) + vllm::fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 1) + vllm::fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + // FP5 + else if (EXPONENT == 1 && MANTISSA == 3) + vllm::fpx_linear_kernel<1, 3>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 2) + vllm::fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 3 && MANTISSA == 1) + vllm::fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 4 && MANTISSA == 0) + vllm::fpx_linear_kernel<4, 0>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + + // FP6 + else if (EXPONENT == 1 && MANTISSA == 4) + vllm::fpx_linear_kernel<1, 4>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 3) + vllm::fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 3 && MANTISSA == 2) + vllm::fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 4 && MANTISSA == 1) + vllm::fpx_linear_kernel<4, 1>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 5 && MANTISSA == 0) + vllm::fpx_linear_kernel<5, 0>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + // FP7 + else if (EXPONENT == 1 && MANTISSA == 5) + vllm::fpx_linear_kernel<1, 5>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 4) + vllm::fpx_linear_kernel<2, 4>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 3 && MANTISSA == 3) + vllm::fpx_linear_kernel<3, 3>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 4 && MANTISSA == 2) + vllm::fpx_linear_kernel<4, 2>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 5 && MANTISSA == 1) + vllm::fpx_linear_kernel<5, 1>(stream, weight, scales, in_feats, out_feats, + M, N, K, Reduction_Workspace, splitK); + + else + TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, + " is not supported."); + + return _out_feats; +} \ No newline at end of file diff --git a/csrc/quantization/fp_eXmY/mma.cuh b/csrc/quantization/fp_eXmY/mma.cuh new file mode 100644 index 0000000000000..7f66207a9399a --- /dev/null +++ b/csrc/quantization/fp_eXmY/mma.cuh @@ -0,0 +1,108 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is modified from +// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh + +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +#ifndef MMA_CUH +#define MMA_CUH + +#include <cuda.h> +#include <cuda_fp16.h> +#include <cuda_runtime.h> + +#include <assert.h> +#include "configs.h" + +// MODIFICATION NOTE: to support MSVC +// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ +// Reg)[4] +// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ +// read_SPTR) +template <typename TilingConfig> +__device__ __forceinline__ void B_FromSharedToReg( + uint32_t (*__restrict__ Reg)[4], + half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + int slice_id) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #ifdef DEBUG_MODE + static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0)); + #endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = + TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * + WARP_j; // each warp may start from reading warp_start_col'th column of + // the B tile in shared memory + #ifdef DEBUG_MODE + assert(warp_start_col == 0); + #endif + + int col = (lane_id % 8) + (lane_id / 16) * 8; + int row = (lane_id % 16) / 8 * 8; + uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared( + &read_SPTR[warp_start_col + col][slice_id * MMA_16 + row])); + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } else { + #pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) { + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += + 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +#endif +} + +// MODIFICATION NOTE: to support MSVC, the function signature is changed from +// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, +// uint32_t __restrict__ *b). +__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t* __restrict__ c, + uint32_t* __restrict__ a, + uint32_t* __restrict__ b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +#endif +} + +#endif // MMA_CUH \ No newline at end of file diff --git a/csrc/quantization/fp_eXmY/quant_matmul.cuh b/csrc/quantization/fp_eXmY/quant_matmul.cuh new file mode 100644 index 0000000000000..8516e0dd8592d --- /dev/null +++ b/csrc/quantization/fp_eXmY/quant_matmul.cuh @@ -0,0 +1,354 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is modified from +// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh + +#include "configs.h" +#include "utils/gmem.cuh" +#include "utils/core.cuh" + +/************************** Bitwidth of Weight Segments + * ************************/ +#define BIT_WIDTH_1 1 +#define BIT_WIDTH_2 2 +#define BIT_WIDTH_4 4 +/*************************** 64*64 Weghts of Weight Matrix + * *********************/ +#define WEIGHT_PER_WARP (WARP_M * WARP_K) // 64*64 = 4096 +#define SMEM_SIZE_PER_WARP_1BIT \ + (WEIGHT_PER_WARP * BIT_WIDTH_1 / \ + 8) // 512 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_WARP_2BIT \ + (WEIGHT_PER_WARP * BIT_WIDTH_2 / \ + 8) // 1024 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_WARP_4BIT \ + (WEIGHT_PER_WARP * BIT_WIDTH_4 / \ + 8) // 2048 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_TB_1BIT \ + (SMEM_SIZE_PER_WARP_1BIT * TilingConfig::BLOCK_WARPS * \ + PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A + // = 6 KB; double buffer for 2-level pipeline A= 4 + // KB. +#define SMEM_SIZE_PER_TB_2BIT \ + (SMEM_SIZE_PER_WARP_2BIT * TilingConfig::BLOCK_WARPS * \ + PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A + // = 12 KB; double buffer for 2-level pipeline A= 8 + // KB. +#define SMEM_SIZE_PER_TB_4BIT \ + (SMEM_SIZE_PER_WARP_4BIT * TilingConfig::BLOCK_WARPS * \ + PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A + // = 24 KB; double buffer for 2-level pipeline A= 16 + // KB. +#define SMEM_SIZE_PER_TB_A_TILE \ + (SMEM_SIZE_PER_TB_1BIT + SMEM_SIZE_PER_TB_2BIT + \ + SMEM_SIZE_PER_TB_4BIT) // used in fp6_linear.cu, Kernel_Ex(). +/******************** Global Memory Layout For QUANTIZED DATA + * *******************/ +#define NUM_INT4_PER_WARP_1BIT (WEIGHT_PER_WARP * BIT_WIDTH_1 / 128) // 32 +#define NUM_INT4_PER_WARP_2BIT (WEIGHT_PER_WARP * BIT_WIDTH_2 / 128) // 64 +#define NUM_INT4_PER_WARP_4BIT (WEIGHT_PER_WARP * BIT_WIDTH_4 / 128) // 128 + +/* + * C = A*B + * A: row major with ahead-of-time layout transformation, FP6 + * B: col major, FP16 + * C: col major, FP16 + */ +template <typename TilingConfig, typename OutputDataType, int EXPONENT, + int MANTISSA> +__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, + const half* B, OutputDataType* C, + const size_t M_Global, const size_t N_Global, + const size_t K_Global, int Split_K) { +#ifdef DEBUG_MODE + assert(K_Global % TilingConfig::TILE_K == 0); + assert(M_Global % TilingConfig::TILE_M == 0); + assert(gridDim.y == Split_K * (M_Global / TilingConfig::TILE_M)); +#endif + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + const uint4* Weight_1bit = Weight; + const uint4* Weight_2bit = + Weight_1bit + + (USE_SEG_1BIT ? M_Global * K_Global * BIT_WIDTH_1 / 128 : 0); + const uint4* Weight_4bit = + Weight_2bit + + (USE_SEG_2BIT ? M_Global * K_Global * BIT_WIDTH_2 / 128 : 0); + // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned + extern __shared__ __align__(128) half smem[]; + half(*smem_array)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + reinterpret_cast<half(*)[WARP_K + PADDING_SHARED_MEM_FOR_B_8]>( + smem + SMEM_SIZE_PER_TB_A_TILE / + 2); // Dynamic shared memory for FP16 B tiles + __shared__ half + QuantScales[64 * + TilingConfig::BLOCK_WARPS]; // static shared memory for + // quantization scales, 64 row + // per warp * 4 warps = 512 Bytes + // Thread Block Mapping, considering SplitK + const size_t BatchID = blockIdx.y / (M_Global / TilingConfig::TILE_M); + const size_t x = + blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t y = + blockIdx.y % + (M_Global / TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; + // BlockID_Col = x ) + const size_t Tile_Start_M = y * TilingConfig::TILE_M; + const size_t Tile_Start_N = x * TilingConfig::TILE_N; + const size_t NumColumnToCopy = + (N_Global - Tile_Start_N) < TilingConfig::TILE_N + ? (N_Global - Tile_Start_N) + : TilingConfig::TILE_N; + const size_t NumBlock_K = K_Global / TilingConfig::TILE_K; + const size_t AverageNumBlock_K = NumBlock_K / Split_K; + const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; + size_t NumIter = AverageNumBlock_K; + size_t StartBlockID_K = AverageNumBlock_K * BatchID; + if (BatchID < ExtraNumBlock_K) { + NumIter++; + StartBlockID_K += BatchID; + } else + StartBlockID_K += ExtraNumBlock_K; + // Warp ID. + const int warpId = threadIdx.x / WARP_SIZE; + int WARP_i = warpId / TilingConfig::BLOCK_COL_WARPS; // WARP_i: row number; + // WARP_j: column number + // int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + // Global Memory Address for Matrix A (Weight) + // ///////////////////////////////////////////////////////////////////////// + // StartPTR for each ThreadBlock(TB) + const uint4* TB_StartGPTR_A_1BIT = + Weight_1bit + + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_1BIT; + const uint4* TB_StartGPTR_A_2BIT = + Weight_2bit + + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_2BIT; + const uint4* TB_StartGPTR_A_4BIT = + Weight_4bit + + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_4BIT; + // StartPTR for each WARP. + const uint4* WARP_StartGPTR_A_1BIT = + TB_StartGPTR_A_1BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_1BIT; + const uint4* WARP_StartGPTR_A_2BIT = + TB_StartGPTR_A_2BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_2BIT; + const uint4* WARP_StartGPTR_A_4BIT = + TB_StartGPTR_A_4BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_4BIT; + // StartPTR for each WARP, considering SplitK + const size_t WARP_Start_UnitID_K = StartBlockID_K; + WARP_StartGPTR_A_1BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_1BIT; + WARP_StartGPTR_A_2BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_2BIT; + WARP_StartGPTR_A_4BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_4BIT; + // Copying A tile from Global to Shared, using double-buffer + // ////////////////////////////////////////////////////////// StartSPTR for + // each ThreadBlock + uint32_t* AFrag_1BIT_SPTR = reinterpret_cast<uint32_t*>(smem); + uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT / 4; + uint32_t* AFrag_4BIT_SPTR = + AFrag_2BIT_SPTR + + SMEM_SIZE_PER_TB_2BIT / + 4; // 8 buffers including double buffers, 12 for trible buffers + // StartSPTR for each WARP + AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT / 4; + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT / 4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT / 4; + // Pre-fetch of A tile + for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) { + if (USE_SEG_1BIT) + CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>( + AFrag_1BIT_SPTR + i * SMEM_SIZE_PER_WARP_1BIT / 4 * 4, + WARP_StartGPTR_A_1BIT); + if (USE_SEG_2BIT) + CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>( + AFrag_2BIT_SPTR + i * SMEM_SIZE_PER_WARP_2BIT / 4 * 4, + WARP_StartGPTR_A_2BIT); + if (USE_SEG_4BIT) + CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>( + AFrag_4BIT_SPTR + i * SMEM_SIZE_PER_WARP_4BIT / 4 * 4, + WARP_StartGPTR_A_4BIT); + WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT / 16; + WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT / 16; + WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT / 16; + } + // Global Memory Address for Matrix A (QuantScale) + // ///////////////////////////////////////////////////////////////////// + const half* TB_StartGPTR_A_Scale = + Scales + (y * TilingConfig::BLOCK_ROW_WARPS) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales + WARP_i * 64, + WARP_StartGPTR_A_Scales); + // Copying B tile from Global to Shared, considering SplitK + // ///////////////////////////////////////////////////////////// + const half* BTile_GPTR = + B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) { + CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS>( + smem_array + i * TilingConfig::TILE_N, BTile_GPTR, K_Global, + NumColumnToCopy); + BTile_GPTR += TilingConfig::TILE_K; + } + // Register Allocation for A,B, and C, Initilazed to Zeros + // ///////////////////////////////////////////////////////////////////// + constexpr int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA + // block + constexpr int NumRegSets_b = + (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block + uint32_t a[NumRegSets_a * PIPELINE_LEVEL_SMEM] + [4]; // double/Trible buffer is used // Registers to store + // decompressed FP6 + uint32_t b[NumRegSets_b * PIPELINE_LEVEL_SMEM] + [4]; // double/Triple buffer is used // Register to store FP16 B + // matrix (a slice) + float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; + for (int i = 0; i < NumRegSets_a * NumRegSets_b; i++) + for (int j = 0; j < REG_PER_THREAD_C_TENSOR_16_16; j++) c[i][j] = 0.0f; + // + cp_async_wait_all(); + __syncthreads(); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales + ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i * 64); + // Initializing the Software Pipeline: writing registers. + // //////////////////////////////////////////////////////////////////////////////////////////////// + initialize_mma_slice<TilingConfig, EXPONENT, MANTISSA>( + a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, + Scales_RPTR); +// The outer loop. +// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma unroll(1) + for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) { + // Trible-Buffer for A Tile + uint32_t* __restrict__ read_SPTR_Frag_1bit = + AFrag_1BIT_SPTR + + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT / 4 * + 4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag_2bit = + AFrag_2BIT_SPTR + + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT / 4 * + 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag_4bit = + AFrag_4BIT_SPTR + + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT / 4 * + 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read2_SPTR_Frag_1bit = + AFrag_1BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_PER_WARP_1BIT / 4 * 4; + uint32_t* __restrict__ read2_SPTR_Frag_2bit = + AFrag_2BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_PER_WARP_2BIT / 4 * 4; + uint32_t* __restrict__ read2_SPTR_Frag_4bit = + AFrag_4BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_PER_WARP_4BIT / 4 * 4; + uint32_t* __restrict__ write_SPTR_Frag_1bit = + AFrag_1BIT_SPTR + + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_PER_WARP_1BIT / 4 * + 4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag_2bit = + AFrag_2BIT_SPTR + + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_PER_WARP_2BIT / 4 * + 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag_4bit = + AFrag_4BIT_SPTR + + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_PER_WARP_4BIT / 4 * + 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + // Trible-Buffer for B Tile + // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is + // changed to below. similarly for read2_SPTR and write_SPTR. + half(*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + half(*__restrict__ read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + half(*__restrict__ write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + TilingConfig::TILE_N; + // + bool GlobalCopy = (tile_id_k + PIPELINE_LEVEL_GMEM - 1) < NumIter; + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + if (USE_SEG_1BIT) + CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>( + write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); + if (USE_SEG_2BIT) + CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>( + write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); + if (USE_SEG_4BIT) + CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>( + write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); + // copying B tile from GlobalMemory to SharedMemory + CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS>( + write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + cp_async_group_commit(); + core_mma_slice<TilingConfig, EXPONENT, MANTISSA>( + c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, + read_SPTR, Scales_RPTR, + 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each + // WARP; read_SPTR is shared among WARPs + core_mma_slice<TilingConfig, EXPONENT, MANTISSA>( + c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, + read_SPTR, Scales_RPTR, 2); + core_mma_slice<TilingConfig, EXPONENT, MANTISSA>( + c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, + read_SPTR, Scales_RPTR, 3); + // Barriers and Synchronizations + cp_async_wait_group<PIPELINE_LEVEL_GMEM - 2>(); + __syncthreads(); + core_mma_slice<TilingConfig, EXPONENT, MANTISSA>( + c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, + read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); + // Updating global PTRs + WARP_StartGPTR_A_1BIT += + SMEM_SIZE_PER_WARP_1BIT / 16; // 2KB/16=128 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A_2BIT += + SMEM_SIZE_PER_WARP_2BIT / 16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A_4BIT += + SMEM_SIZE_PER_WARP_4BIT / 16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store the C fragments to shared memory. + float(*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4] = + reinterpret_cast< + float(*)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4]>(smem); + StoreToSharedMemoryFromRegister<TilingConfig>(smem_CFrag, c); + __syncthreads(); + // Now that shared memory contains all the D tiles, stream them to global + // memory. + OutputDataType* BlockGlobalPTR = C + BatchID * (M_Global * N_Global) + + Tile_Start_M + Tile_Start_N * M_Global; + for (size_t i = warpId; i < NumColumnToCopy; + i += TilingConfig::BLOCK_WARPS) // i-th column +#pragma unroll + for (size_t j = threadIdx.x % WARP_SIZE; j < TilingConfig::TILE_M; + j += WARP_SIZE) // j-th row + { + if constexpr (std::is_same<OutputDataType, half>::value) + BlockGlobalPTR[j + i * M_Global] = __float2half_rn(smem_CFrag[i][j]); + else + BlockGlobalPTR[j + i * M_Global] = smem_CFrag[i][j]; + } +} \ No newline at end of file diff --git a/csrc/quantization/fp_eXmY/quant_reduction.cuh b/csrc/quantization/fp_eXmY/quant_reduction.cuh new file mode 100644 index 0000000000000..8059a177dd222 --- /dev/null +++ b/csrc/quantization/fp_eXmY/quant_reduction.cuh @@ -0,0 +1,70 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is copied from +// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh + +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +// Used for the reduction of result matrix if Split-K is used +// Reduction_Workspace: (Split_K, M_Global, N_Global), column major +// C: (M_Global, N_Global), column major +// Each thread deals with 8 output elements, each elements is the sum of Split_K +// elements +// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 +// float_per_thread (256bit) -> 256 float per warp Write Global: Each +// Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> +// 256 half per warp +// GridSize = (M_Global*N_Global) / 256 + +#include <cuda.h> +#include <cuda_fp16.h> +#include <cuda_runtime.h> + +#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 +#define HALF_PER_128BIT 8 + +__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, + size_t M_Global, size_t N_Global, + int Split_K) { + half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = + Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + // Initializing Thread-Local Results + float Results[HALF_PER_128BIT]; +#pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; + // Reduction + for (int i = 0; i < Split_K; i++) { +#pragma unroll + for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; + THREAD_GPTR_R += M_Global * N_Global; + } +// Writing to global memory +#pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) + THREAD_GPTR_C[i] = __float2half_rn(Results[i]); +} \ No newline at end of file diff --git a/csrc/quantization/fp_eXmY/utils/core.cuh b/csrc/quantization/fp_eXmY/utils/core.cuh new file mode 100644 index 0000000000000..7e8338ae34482 --- /dev/null +++ b/csrc/quantization/fp_eXmY/utils/core.cuh @@ -0,0 +1,188 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is modified from +// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh + +#ifndef UTILS_CORE_CUH +#define UTILS_CORE_CUH + +#include <assert.h> + +#include "../configs.h" +#include "../mma.cuh" +#include "parallel_dequant.cuh" + +template <int NUM_INT_PER_THREAD> +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], + uint32_t* SPTR, + int slice_id) { + SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE); + int lane_id = threadIdx.x % WARP_SIZE; +#pragma unroll + for (int i = 0; i < NUM_INT_PER_THREAD; i++) { + Reg[i] = SPTR[lane_id + i * WARP_SIZE]; + } +} + +// MODIFICATION NOTE: to support MSVC, half __restrict__ +// (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. +template <typename TilingConfig, int EXPONENT, int MANTISSA> +__device__ __forceinline__ void initialize_mma_slice( + uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read, + uint32_t* __restrict__ A_2BIT_SPTR_read, + uint32_t* __restrict__ A_4BIT_SPTR_read, + half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales) { + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 + // per thread => 6 register per thread; + uint32_t a_1bit[1]; // NO double buffer + uint32_t a_2bit[2]; // NO double buffer + uint32_t a_4bit[4]; // NO double buffer + if (USE_SEG_1BIT) + CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1BIT_SPTR_read, 0); + if (USE_SEG_2BIT) + CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2BIT_SPTR_read, 0); + if (USE_SEG_4BIT) + CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4BIT_SPTR_read, 0); + Dequant_32FP6_4Way<EXPONENT, MANTISSA>( + a, a_1bit, a_2bit, a_4bit, + RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register + // level, dequantizing a slice each time + B_FromSharedToReg<TilingConfig>(b, B_SPTR_read, + 0); // Loading B from shared to registers +} + +// MODIFICATION NOTE: to support MSVC, half __restrict__ +// (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. +template <typename TilingConfig, int EXPONENT, int MANTISSA> +__device__ __forceinline__ void core_mma_slice( + float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], + uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read, + uint32_t* __restrict__ A_2bit_SPTR_read, + uint32_t* __restrict__ A_4bit_SPTR_read, + half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales, + int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 + // for prefetching +{ + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + +#ifdef DEBUG_MODE + assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == + 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded + // to a 16*16 MMA block +#endif + const int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA + // block + const int NumRegSets_b = + (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block + uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = + reinterpret_cast<uint32_t(*)[REG_PER_THREAD_C_TENSOR_16_16]>( + c); // GlobalRegisters for accumulated FP32 results + + // Setting RPTRs for double buffers + uint32_t(*a_read)[4] = a; + uint32_t(*a_write)[4] = a; + uint32_t(*b_read)[4] = b; + uint32_t(*b_write)[4] = b; + if (slice_id % 2 == 1) { + b_write += NumRegSets_b; + a_write += NumRegSets_a; + } else { + b_read += NumRegSets_b; + a_read += NumRegSets_a; + } + +// Reading registers and issuing core tensor core computations (a slice of A and +// B tile in shared memory) +#pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]); + } else { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) { + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], + b_read[j]); + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, + a_read[i], b_read[j] + 2); // c+4; b+2 + } + } + } + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 + // per thread => 6 register per thread; + uint32_t a_1bit[1]; // NO double buffer + uint32_t a_2bit[2]; // NO double buffer + uint32_t a_4bit[4]; // NO double buffer + if (USE_SEG_1BIT) + CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1bit_SPTR_read, slice_id); + if (USE_SEG_2BIT) + CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2bit_SPTR_read, slice_id); + if (USE_SEG_4BIT) + CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4bit_SPTR_read, slice_id); + Dequant_32FP6_4Way<EXPONENT, MANTISSA>( + a_write, a_1bit, a_2bit, a_4bit, + RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register + // level, dequantizing a slice each time + B_FromSharedToReg<TilingConfig>( + b_write, B_SPTR_read, slice_id); // Loading B from shared to registers +} + +template <typename TilingConfig> +__device__ __forceinline__ void StoreToSharedMemoryFromRegister( + float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], + float c[][REG_PER_THREAD_C_TENSOR_16_16]) { + const int lane_id = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); +#pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; + j++) { // Dealing with one 16*8 Tensor + int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS; + int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2); + int Tensor_row_offset = warp_row_offset + i * MMA_16; + int Tensor_col_offset = j * MMA_8; +#pragma unroll + for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) { + int row_offset = lane_id / 4; + if (r >= 2) row_offset += 8; + int col_offset = (lane_id % 4) * 2; + if (r % 2 == 1) col_offset += 1; + smem_CFrag[Tensor_col_offset + col_offset] + [Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset]; + } + } + } +} + +#endif \ No newline at end of file diff --git a/csrc/quantization/fp_eXmY/utils/gmem.cuh b/csrc/quantization/fp_eXmY/utils/gmem.cuh new file mode 100644 index 0000000000000..6eb9c61b1c7ef --- /dev/null +++ b/csrc/quantization/fp_eXmY/utils/gmem.cuh @@ -0,0 +1,94 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is modified from +// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh + +#ifndef UTILS_GMEM_CUH +#define UTILS_GMEM_CUH + +#include <assert.h> +#include "../configs.h" +#include "../cp_async.cuh" + +/* + * Copying A1/A2 from global memory to shared memory. + * Usually 1024 or 2048 Bytes + */ +template <int SMEM_SIZE_IN_BYTES_PER_WARP> +__device__ __forceinline__ void CopyFromGlobalToShared_A( + uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) { +#ifdef DEBUG_MODE + static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0); +#endif + int lane_id = threadIdx.x % WARP_SIZE; + half* SPTR_HALF = reinterpret_cast<half*>(SPTR); + const half* GPTR_HALF = reinterpret_cast<const half*>(GPTR); + SPTR_HALF += lane_id * 8; + GPTR_HALF += lane_id * 8; +#pragma unroll + for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) { + cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard); + SPTR_HALF += 256; // Forward 512 Bytes + GPTR_HALF += 256; // Forward 512 Bytes + } +} + +/* + * Copying 64 Quant Scales (FP16) from global memory to shared memory. + */ +__device__ __forceinline__ void CopyFromGlobalToShared_Scales( + half* SPTR_QuantScales, const half* GPTR_A_Scales) { + int lane_id = threadIdx.x % WARP_SIZE; + int Offset_Shared = lane_id * 2; + int Offset_Global = lane_id / 4 + (lane_id % 4) * 16; + for (int i = 0; i < 2; i++) + SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8]; +} + +// MODIFICATION NOTE: to support MSVC, half __restrict__ +// (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. +/* + * (1) Copying X rows * 64 columns of FP16 values, originally in row major + * (2) Copying 64 rows * X columns of FP16 values, originally in column major + * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 + * Threads + */ +template <int MaxNumOfLinesToCopy, int BLOCK_WARPS> +__device__ __forceinline__ void CopyFromGlobalToShared( + half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, const int GlobalStride, + const int NumOfLinesLeft, // To support arbitrary N dimensions. + bool Pred = true) { + // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time + const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; + const int NumOfGroups = NumOfThreads / 8; + const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1; + // runtime variables + const int line_id = threadIdx.x / 8; + const int line_offset = (threadIdx.x % 8) * 8; + // PTR for source global memory and target shared memory + GlobalPTR += line_id * GlobalStride + line_offset; + SharedPTR += line_id; +#pragma unroll + for (int i = 0; i < MaxIteration; i++) { + bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred; + cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + // + GlobalPTR += NumOfGroups * GlobalStride; + SharedPTR += NumOfGroups; + } +} + +#endif \ No newline at end of file diff --git a/csrc/quantization/fp_eXmY/utils/parallel_dequant.cuh b/csrc/quantization/fp_eXmY/utils/parallel_dequant.cuh new file mode 100644 index 0000000000000..3405ddaafbfcc --- /dev/null +++ b/csrc/quantization/fp_eXmY/utils/parallel_dequant.cuh @@ -0,0 +1,148 @@ +// Copyright 2024 FP6-LLM authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file is modified from +// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh +// To support MSVC, all instances of u_int32_t are changed to uint32_t. + +#ifndef UTILS_PARALLELDEQUANT_CUH +#define UTILS_PARALLELDEQUANT_CUH + +#include <cuda.h> +#include <cuda_fp16.h> +#include <cuda_runtime.h> + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is applied. + */ +template <int EXPONENT, int MANTISSA> +__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t* In, uint32_t* Out1, + uint32_t* Out2) { + // + constexpr int RIGHT_SHIFT = 5 - EXPONENT; + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | MASK3 >> 16; + // + *Out1 = *In & 0x80008000; + *Out1 |= ((*In) & MASK) >> RIGHT_SHIFT; + // + *In = (*In) << 8; + *Out2 = *In & 0x80008000; + *Out2 |= ((*In) & MASK) >> RIGHT_SHIFT; +} + +template <int EXPONENT, int MANTISSA> +__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, + half Scale) { + constexpr int BIAS_OFFSET = (int(1) << (5 - 1)) - (int(1) << (EXPONENT - 1)); + constexpr int BIAS = int(1) << BIAS_OFFSET; + // + half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair); + half* FP16_2 = FP16_1 + 1; + uint32_t output; + half* output_half_ptr = reinterpret_cast<half*>(&output); + output_half_ptr[0] = + __hmul(__hmul(*FP16_1, __float2half(1.0f * BIAS)), Scale); + output_half_ptr[1] = + __hmul(__hmul(*FP16_2, __float2half(1.0f * BIAS)), Scale); + return output; +} + +// MODIFICATION NOTE: to support MSVC +// - u_int32_t __restrict__ Reg[][4] is changed to below. +// - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for +// read_RPTR_2bit and read_RPTR_4bit +template <int EXPONENT, int MANTISSA> +__device__ __forceinline__ void Dequant_32FP6_4Way( + uint32_t (*__restrict__ Reg)[4], uint32_t* __restrict__ read_RPTR_1bit, + uint32_t* __restrict__ read_RPTR_2bit, + uint32_t* __restrict__ read_RPTR_4bit, uint32_t* Scales) { + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + // + uint32_t* OutputRegs = reinterpret_cast<uint32_t*>(Reg); + uint32_t* Frag_PTR_1bit = read_RPTR_1bit; + uint32_t* Frag_PTR_2bit = read_RPTR_2bit; + uint32_t* Frag_PTR_4bit = read_RPTR_4bit; + half* Scale_RPTR = reinterpret_cast<half*>(Scales); +// Dequantizing 32 FP6, each Loop dequantizing 4 FP6 +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + uint32_t Packed_FP6 = 0; + uint32_t tmp = 0; + // 1bit Frag + if (USE_SEG_1BIT) { + tmp = (*Frag_PTR_1bit) & 0x80808080; + Packed_FP6 |= tmp >> (BIT_WIDTH & 0); + if (i % 8 == 7) + Frag_PTR_1bit++; + else + (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1; + } + // 2bit Frag + if (USE_SEG_2BIT) { + tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0; + Packed_FP6 |= tmp >> (BIT_WIDTH & 1); + if (i % 4 == 3) + Frag_PTR_2bit++; + else + (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2; + } + // 4bit Frag2 + if (USE_SEG_4BIT) { + tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0; + Packed_FP6 |= tmp >> (BIT_WIDTH & 3); + if (i % 2 == 1) + Frag_PTR_4bit++; + else + (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; + } + // + uint32_t out1, out2; + FPx_FP16_Cast_4Way<EXPONENT, MANTISSA>(&Packed_FP6, &out1, &out2); + // + *OutputRegs = MultScale<EXPONENT, MANTISSA>( + out1, Scale_RPTR[0]); // Multiply FP16 scales + OutputRegs += 1; + *OutputRegs = MultScale<EXPONENT, MANTISSA>( + out2, Scale_RPTR[1]); // Multiply FP16 scales + OutputRegs += 1; + // Updating offset for FP16 scales for every two iterations + if (i % 2 == 1) Scale_RPTR += 2; + } +} + +/* + * + */ +__device__ __forceinline__ void ExtractFromSharedToReg_Scales( + uint32_t* Scales, half* WARP_SPTR_Scales) { + int lane_id = threadIdx.x % WARP_SIZE; + uint32_t* SPTR_uint = reinterpret_cast<uint32_t*>(WARP_SPTR_Scales); + uint32_t tmpReg = SPTR_uint[lane_id]; +#pragma unroll + for (int i = 0; i < 4; i++) { + // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); + } +} + +#endif \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b8185c24d5628..70d023633230e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -249,6 +249,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "SymInt size_k) -> Tensor"); // conditionally compiled so impl registration is in source file + // FP_eXmY kernel for quantization to custom + // irregular bit-widths. + ops.def( + "fp_eXmY_linear_forward_cuda(int EXPONENT, int MANTISSA," + " Tensor _in_feats, Tensor _weights," + " Tensor _scales, int splitK=1) -> Tensor"); + ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA, + &fp_eXmY_linear_forward_cuda); + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( diff --git a/format.sh b/format.sh index be6ee0ce46dcb..83b6ed4c0ce22 100755 --- a/format.sh +++ b/format.sh @@ -249,6 +249,14 @@ CLANG_FORMAT_EXCLUDES=( 'csrc/quantization/gguf/vecdotq.cuh' 'csrc/quantization/gguf/mmq.cuh' 'csrc/quantization/gguf/mmvq.cuh' + 'csrc/quantization/fp_eXmY/configs.h' + 'csrc/quantization/fp_eXmY/cp_async.cuh' + 'csrc/quantization/fp_eXmY/mma.cuh' + 'csrc/quantization/fp_eXmY/quant_matmul.cuh' + 'csrc/quantization/fp_eXmY/quant_reduction.cuh' + 'csrc/quantization/fp_eXmY/utils/core.cuh' + 'csrc/quantization/fp_eXmY/gmem.cuh' + 'csrc/quantization/fp_eXmY/parallel_dequant.cuh' ) # Format specified files with clang-format diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f57414bd5197e..8ea2c52c36380 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -317,6 +317,20 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) +# fp_eXmY +def fp_eXmY_linear_forward_cuda( + EXPONENT: int, + MANTISSA: int, + _in_feats: torch.Tensor, + _weights: torch.Tensor, + _scales: torch.Tensor, + splitK: int = 1, +) -> torch.Tensor: + return torch.ops._C.fp_eXmY_linear_forward_cuda(EXPONENT, MANTISSA, + _in_feats, _weights, + _scales, splitK) + + if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @register_fake("_C::gptq_marlin_24_gemm") @@ -443,6 +457,19 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + @register_fake("_C::fp_eXmY_linear_forward_cuda") + def _fp_eXmY_linear_forward_cuda_fake( + EXPONENT: int, + MANTISSA: int, + _in_feats: torch.Tensor, + _weights: torch.Tensor, + _scales: torch.Tensor, + splitK: int = 1, + ) -> torch.Tensor: + return torch.empty((1, 1), + dtype=_in_feats.dtype, + device=_in_feats.device) + @register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index a1fba98233b80..ee63fe3f9438b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -313,7 +313,8 @@ def _verify_quantization(self) -> None: optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", - "compressed-tensors", "experts_int8" + "compressed-tensors", "experts_int8", "fp4_weights", "fp5_weights", + "fp6_weights", "fp7_weights" ] tpu_supported_quantization = ["tpu_int8"] neuron_supported_quantization = ["neuron_quant"] @@ -345,6 +346,30 @@ def _verify_quantization(self) -> None: f"method specified in the `quantization` argument " f"({self.quantization}).") + from vllm.model_executor.layers.quantization.fp_eXmY import ( + DEFAULT_FP_EXMY_EXP_BITS, VALID_FP_EXMY_METHODS) + if self.quantization is not None and self.quantization in \ + VALID_FP_EXMY_METHODS: + fp_bits = int(self.quantization[2]) + exp_bits = DEFAULT_FP_EXMY_EXP_BITS[fp_bits] + self.hf_config.quantization_config = { + "bits": fp_bits, + "exp_bits": exp_bits, + "quant_method": self.quantization + } + # TODO(alpin): Investigate supporting bfloat16 dtype + if self.dtype != torch.float16: + logger.info( + "%s data type is not supported for " + "fp%s quantization. Using float16 instead.", self.dtype, + fp_bits) + self.dtype = torch.float16 + # In some cases, CUDA graph execution breaks this quant method + logger.warning( + "CUDA Graph execution may not work with fp%s " + "quantization. You can try disabling it " + "with `enforce_eager=True` if you run into issues.", fp_bits) + if self.quantization is not None: if self.quantization not in supported_quantization: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index da841d052d728..7940869144f3e 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -15,6 +15,7 @@ ExpertsInt8Config) from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.fp_eXmY import FP_eXmYConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -50,6 +51,11 @@ "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, + # Aliases for the different fpX default configs + "fp4_weights": FP_eXmYConfig, + "fp5_weights": FP_eXmYConfig, + "fp6_weights": FP_eXmYConfig, + "fp7_weights": FP_eXmYConfig, "ipex": IPEXConfig, } diff --git a/vllm/model_executor/layers/quantization/fp_eXmY.py b/vllm/model_executor/layers/quantization/fp_eXmY.py new file mode 100644 index 0000000000000..bbb796a0907d9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp_eXmY.py @@ -0,0 +1,208 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn + +from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.fp_eXmY_utils import ( + _SPLIT_K_MAP, from_scaled_tc_fpx, to_scaled_tc_fpx) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + +# Used in vllm/config.py::ModelConfig::_verify_quantization +VALID_FP_EXMY_METHODS = [ + "fp4_weights", "fp5_weights", "fp6_weights", "fp7_weights" +] +DEFAULT_FP_EXMY_EXP_BITS = { + 4: 2, + 5: 2, + 6: 2, + 7: 3, +} + + +class FP_eXmYConfig(QuantizationConfig): + """Config for FP_eXmY quantizer. It supports fp4, + fp5, fp6, fp7. + + Reference: https://arxiv.org/abs/2401.14112 + + Args: + weight_bits: the target quantization bits, should be one of + 4, 5, 6, 7. + """ + + def __init__( + self, + weight_bits: int = 6, + exp_bits: int = 2, + ) -> None: + self.weight_bits = weight_bits + self.exponent_bits = exp_bits + + self.mantissa_bits = weight_bits - self.exponent_bits - 1 + + self.valid_types = [torch.float16] + + if self.weight_bits not in DEFAULT_FP_EXMY_EXP_BITS: + raise ValueError( + "Currently, only 4-bit, 5-bit, 6-bit, and 7-bit " + "weight-only quantization are supported for fp_eXmY " + f"quantization, but got {self.weight_bits} bits.") + + if self.exponent_bits not in range(7): + raise ValueError( + "Exponent bits should be between 0 and 6, but got " + f"{self.exponent_bits}.") + + if get_tensor_model_parallel_rank() == 0: + logger.info("Loading model in FP%s_E%sM%s format.", + self.weight_bits, self.exponent_bits, + self.mantissa_bits) + + def __repr__(self) -> str: + return (f"FP_eXmYConfig(weight_bits={self.weight_bits}), " + f"exponent_bits={self.exponent_bits}") + + @classmethod + def get_name(cls) -> str: + return "FP_eXmY" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "FP_eXmYConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + exp_bits = cls.get_from_keys(config, ["exp_bits"]) + return cls(weight_bits=weight_bits, exp_bits=exp_bits) + + def get_linear_method(self) -> "FP_eXmYLinearMethod": + return FP_eXmYLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["FP_eXmYLinearMethod"]: + if isinstance(layer, LinearBase): + return FP_eXmYLinearMethod(self) + return None + + +class FP_eXmYLinearMethod(LinearMethodBase): + """Linear method for FP_eXmY quantizer. + Args: + quant_config: the FP_eXmY quantization config. + """ + + def __init__(self, quant_config: FP_eXmYConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = FP_eXmYParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.quant_llmdequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.quant_llmquantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + weights = weight.data + scales = weight.scales + out_dim, in_dim = weights.shape + bsize = x.reshape(-1, x.shape[-1]).shape[0] + splitK = _SPLIT_K_MAP[(bsize - 1) // + 64].get(out_dim, 1) if bsize <= 768 else 1 + if bias is None: + return ops.fp_eXmY_linear_forward_cuda( + self.quant_config.exponent_bits, + self.quant_config.mantissa_bits, x, weights, scales, splitK) + else: + return ops.fp_eXmY_linear_forward_cuda( + self.quant_config.exponent_bits, + self.quant_config.mantissa_bits, x, weights, scales, + splitK) + bias + + +class FP_eXmYParameter(nn.Parameter): + """ + FP_eXmY quantized parameter class that implements fp5/fp6/fp7 + quantization. Weights are stored in quantized form on + GPUs, and can be directly applied to float16 activations. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: FP_eXmYConfig): + + data = torch.empty(torch.Size( + (orig_shape[0], orig_shape[1] * quant_config.weight_bits // 8)), + dtype=torch.uint8) + + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.scales = torch.empty(orig_shape[0], dtype=torch.float16) + self.quant_config = quant_config + self.orig_shape = orig_shape + return self + + def quant_llmquantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + data, scales = to_scaled_tc_fpx(tensor.data, + self.quant_config.exponent_bits, + self.quant_config.mantissa_bits) + self.data.copy_(data) + self.scales.copy_(scales) + + def quant_llmdequantize(self, output_dtype=None): + output_dtype = output_dtype or torch.get_default_dtype() + return from_scaled_tc_fpx(self.data, self.quant_config.exponent_bits, + self.quant_config.mantissa_bits, + self.scales).to(output_dtype) diff --git a/vllm/model_executor/layers/quantization/utils/fp_eXmY_utils.py b/vllm/model_executor/layers/quantization/utils/fp_eXmY_utils.py new file mode 100644 index 0000000000000..4c27f451277f9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/fp_eXmY_utils.py @@ -0,0 +1,592 @@ +# ruff: noqa +# Copyright (c) the vLLM team. +# Copyright (c) the PygmalionAI team. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3). +# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain: +# 1. No encodings are reserved for special values (+/-inf, NaN). +# 2. When downcasting from FP32 to FPx, +# - Rounding mode is round to nearest, ties to even. +# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx +# magnitude (sign is preserved). +from functools import reduce +from typing import Tuple + +import torch +from torch import Tensor + + +def _n_ones(n: int) -> int: + return (1 << n) - 1 + + +EBITS_F32, MBITS_F32 = 8, 23 +F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) + +# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +_SPLIT_K_MAP = [ + { # tokens: [1, 64] + 3072: 18, + 4096: 13, + 5120: 10, + 6144: 9, + 8192: 6, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 7 + }, + { # tokens: [65:128] + 3072: 9, + 4096: 6, + 5120: 5, + 6144: 9, + 8192: 3, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 6 + }, + { # tokens: [129:192] + 3072: 6, + 4096: 4, + 5120: 7, + 6144: 3, + 8192: 2, + 10240: 5, + 14336: 5, + 28672: 5, + 57344: 4 + }, + { # tokens: [193:256] + 3072: 9, + 4096: 3, + 5120: 5, + 6144: 2, + 8192: 5, + 10240: 4, + 14336: 8, + 28672: 6, + 57344: 4 + }, + { # tokens: [257:320] + 3072: 7, + 4096: 5, + 5120: 2, + 6144: 5, + 8192: 4, + 10240: 1, + 14336: 3, + 28672: 3, + 57344: 4 + }, + { # tokens: [321:384] + 3072: 3, + 4096: 2, + 5120: 5, + 6144: 3, + 8192: 1, + 10240: 8, + 14336: 3, + 28672: 4, + 57344: 3 + }, + { # tokens: [385:448] + 3072: 5, + 4096: 7, + 5120: 3, + 6144: 5, + 8192: 7, + 10240: 3, + 14336: 1, + 28672: 1, + 57344: 3 + }, + { # tokens: [449:512] + 3072: 2, + 4096: 5, + 5120: 4, + 6144: 1, + 8192: 5, + 10240: 2, + 14336: 6, + 28672: 4, + 57344: 1 + }, + { # tokens: [513:576] + 3072: 2, + 4096: 3, + 5120: 1, + 6144: 1, + 8192: 3, + 10240: 3, + 14336: 3, + 28672: 1, + 57344: 1 + }, + { # tokens: [577:640] + 3072: 5, + 4096: 4, + 5120: 1, + 6144: 4, + 8192: 2, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [641:704] + 3072: 3, + 4096: 1, + 5120: 2, + 6144: 2, + 8192: 1, + 10240: 2, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [705:768] + 3072: 3, + 4096: 1, + 5120: 3, + 6144: 2, + 8192: 1, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + } +] + + +def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: + """Convert FP32 numbers to sub-byte floating point numbers with the given + number of exponent and mantissa bits. + Input: torch.Tensor of dtype torch.float + Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored + in the least significant bits. e.g. + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding + Note: there are no special values (NaN, inf) support in this code. Values + outside the representable range of FPx after rounding are clamped to the + maximum FPx magnitude (sign is preserved). + Code below is an adaptation of https://fburl.com/code/ciwofcg4 + Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 + Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 + """ + assert x.dtype == torch.float + assert 1 + ebits + mbits <= 8 + + # calculate constants + exp_bias = _n_ones(ebits - 1) + max_int = _n_ones(ebits + mbits) + sign_mask = 1 << (ebits + mbits) + + # TODO document this better + magic_adder = _n_ones(MBITS_F32 - mbits - 1) + + # all E bits and M bits are 1s + max_normal = (2**(_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / + (2**mbits))) + + # E bits = 1, M bits = 0 + min_normal = 2**(1 - exp_bias) + + denorm_exp = ( + # exp bias conversion between formats + (F32_EXP_BIAS - exp_bias) + # mantissa length difference between formats + + (MBITS_F32 - mbits) + # add one to encoded exponent for denormalized numbers + + 1) + denorm_mask_int = denorm_exp << MBITS_F32 + + # reinterpret int32 as float32 + denorm_mask_float = torch.tensor(denorm_mask_int, + dtype=torch.int32).view(torch.float32) + + # save the sign + # Note that we have torch.uint32, but some ops like cpu bit shifts + # do not work on it. So, we stay in int32. + x = x.view(torch.int32) + sign = x & 0x80000000 + + # set everything to positive, will add sign back at the end + x = x ^ sign + + # TODO: can the branch floating point comparisons below be done without + # converting to float? probably but need to verify + x = x.view(torch.float) + + # rewrite saturate/denorm/norm branches without explicit data dependent + # control flow, to be more compiler friendly + saturate_mask = x >= max_normal + denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), + x < min_normal) + normal_mask = torch.logical_not( + torch.logical_or(saturate_mask, denormal_mask)) + + # + # branch 1: saturate to max val - handled later in the code which combines + # the branches + # + + # + # branch 2: to conversion to denormal as well as rounding up to normal + # + denormal_x = x + denorm_mask_float + denormal_x = denormal_x.view(torch.int32) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(torch.uint8) + + # + # branch 3: stay in normal range, adjust the exponent and round + # + normal_x = x.view(torch.int32) + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 + # update exponent, rounding bias part 1 + val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + # take the bits! + normal_x = normal_x >> (MBITS_F32 - mbits) + normal_x = normal_x.to(torch.uint8) + + # + # combine the branches + # + x = torch.full_like(x, max_int, dtype=torch.uint8) + x = torch.where(denormal_mask, denormal_x, x) + x = torch.where(normal_mask, normal_x, x) + + # add sign back + sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) + sign_lp = sign_lp.to(torch.uint8) + # Right shift of a negative signed integer can fill the least significant + # bits with either 1s or 0s, depending on the implementation. Since PyTorch + # doesn't have an uint32 dtype, we mask out these bits to get just the + # f4 sign bit + sign_lp = sign_lp & sign_mask + x = x | sign_lp + + return x.to(torch.uint8) + + +# TODO(alpin): check if LUT for everything is faster than bit shifting, +# especially for fp4 (only 2^4=16 unique values). +def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: + """Convert sub-byte floating point numbers with the given number of + exponent and mantissa bits to FP32. + + Input: torch.Tensor of dtype uint8, where the bit encoding is stored + in the least significant bits. e.g. + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding + Output: torch.Tensor of dtype fp32 with the dequantized value + """ + assert x.dtype == torch.uint8 + assert 1 + ebits + mbits <= 8 + + sign_mask = 1 << (ebits + mbits) + exp_bias = _n_ones(ebits - 1) + mantissa_mask = _n_ones(mbits) + + # save the sign + sign_lp = x & sign_mask + + # set everything to positive, will add sign back at the end + x_pos = x ^ sign_lp + + # + # 1. Calculate zero mask + # + zero_mask = x_pos == 0 + + # + # 2. Calculate the denormal path mask + # + denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) + + # + # 3. Calculate the normal path + # + + # calculate the new exponent and shift it to bits 2:9 of the result + exp_biased_lp = x_pos >> mbits + exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS + exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 + + # shift the mantissa to bits 10:32 of the result + mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) + mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) + result = exp_biased_f32 | mantissa_f32 + + # + # 4. Add the zero and denormal casts to the already casted normal path + # + result[zero_mask] = 0 + + denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS + + # fast path. + # without this, performance for FP4_E2M1 is slower by 2x + if mbits == 1: + result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32 + + else: + # iterate over all possible values of mantissa + # i=0, j=1 + # i=1, j=10,11 + # i=2, j=100,101,110,111 + # and so on + for i in range(mbits): + for mantissa_cmp in range(1 << i, 1 << (i + 1)): + # left shift mantissa until it overflows (create an implicit 1) + # subtract exponent by the same amount + left_shift = mbits - i + mantissa_f32 = (mantissa_cmp - + (1 << i)) << (left_shift + MBITS_F32 - mbits) + exp_biased_f32 = (denormal_exp_biased - + left_shift) << MBITS_F32 + + # we can update this in-place since the values won't overlap + # torch.compile() may complain unsupported operand type(s) + # for |: 'SymInt' and 'int', thus we use + instead of | here + mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = ( + exp_biased_f32 + mantissa_f32) + + result = torch.where(denormal_mask, mantissa_lp_int32, result) + + # add sign back + sign_f32 = sign_lp.to( + torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) + result = result | sign_f32 + + return result.view(torch.float) + + +_ONES_TABLE = [_n_ones(i) for i in range(8)] + + +def _pack(x: Tensor, n_bits: int) -> Tensor: + return reduce(torch.bitwise_or, [ + x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) + for i in range(8 // n_bits) + ]) + + +def _unpack(x: Tensor, n_bits: int) -> Tensor: + return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) + for i in range(8 // n_bits)], + dim=-1).flatten(-2) + + +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 +def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: + # the original code unpacks/packs the values from/to uint32 while we + # unpack/pack the values from/to uint8 + # thus, we need to reverse byte order within a uint32 word. + x = x.reshape(-1, 4).flip(1) + + x = _unpack(x, n_bits) + x = x.view(-1, 4 * (8 // n_bits)) + + if not undo: + bit_order = { + 1: [ + 1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, 0, + 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30 + ], + 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], + 4: [1, 5, 3, 7, 0, 4, 2, 6], + }[n_bits] + + else: + # this is inverse of the above, obtained by running + # [v.index(i) for i in range(len(v))] + bit_order = { + 1: [ + 16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, 20, + 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15 + ], + 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], + 4: [4, 0, 6, 2, 5, 1, 7, 3], + }[n_bits] + + x = x[:, bit_order] + x = _pack(x, n_bits) + + # reverse byte order within a uint32 word again. + x = x.reshape(-1, 4).flip(1) + return x.flatten() + + +# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h +def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + # Pass 1 from original code + tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6) + tensor = tensor.reshape(-1, 32, 2) + tensor = tensor.permute(1, 0, 2) + tensor = tensor.flatten() + + used_bits = 0 + fragments = [] + + for y in [1, 2, 4]: + if nbits & y: + mask = (1 << y) - 1 + tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask + tensor_ybit = _pack(tensor_ybit, y) + + tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) + tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) + fragments.append(tensor_ybit) + used_bits += y + + return torch.cat(fragments, dim=0).view(M, -1) + + +# more optimized version of _pack_tc_fpx() for FP6 by merging ops +def _pack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) + tensor = tensor.flip(3) + + tensor_2bit = (tensor >> 4) & 0b11 + tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) + tensor_2bit = _pack(tensor_2bit.flatten(), 2) + + tensor_4bit = tensor & 0b1111 + tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) + tensor_4bit = _pack(tensor_4bit.flatten(), 4) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) + + +# currently only optimize for TC-FP6 packing +def pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _pack_tc_fp6(tensor) + return _pack_tc_fpx(tensor, nbits) + + +def to_scaled_tc_fpx(tensor: Tensor, ebits: int, + mbits: int) -> Tuple[Tensor, Tensor]: + # _n_ones() is not compatible with torch.compile() due to << operator + # https://github.com/pytorch/pytorch/issues/119152 + # exp_bias = _n_ones(ebits - 1) + # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) + + # workaround: global lookup table + exp_bias = _ONES_TABLE[ebits - 1] + max_normal = (2**(_ONES_TABLE[ebits] - exp_bias) * + (_ONES_TABLE[mbits + 1] / (2**mbits))) + + tensor = tensor.float() + scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal + tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) + tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits) + return tensor_tc_fpx, scale.half() + + +# inverse of _pack_tc_fpx() +def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 + M = tensor.shape[0] + size = tensor.numel() + tensor = tensor.flatten() + offset = 0 + used_bits = 0 + + tensor_fpx = None + + for y in [1, 2, 4]: + if nbits & y: + size_ybit = size // nbits * y + tensor_ybit = tensor[offset:offset + size_ybit] + offset += size_ybit + + # undo Pass 3 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) + # undo Pass 2 + tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) + + tensor_ybit = _unpack(tensor_ybit.flatten(), y) + tensor_ybit = tensor_ybit << (nbits - used_bits - y) + used_bits += y + + if tensor_fpx is None: + tensor_fpx = tensor_ybit + else: + tensor_fpx |= tensor_ybit + + # undo Pass 1 + # NOTE(alpin): the view op here fails for FP8 + if tensor_fpx is None: + tensor_fpx = torch.zeros_like(tensor_ybit) + tensor_fpx = tensor_fpx.view(32, -1, 2).permute(1, 0, 2) + tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8) + tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6) + tensor_fpx = tensor_fpx.reshape(M, -1) + return tensor_fpx + + +# more optimized version of _unpack_tc_fpx() for FP6 by merging ops +# inverse of _unpack_tc_fp6() +def _unpack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 + M = tensor.shape[0] + N = tensor.shape[1] // 3 * 4 + assert (M % 64 == 0) and (N % 64 == 0) + size_2bit = M * N // 4 + size_4bit = M * N // 2 + tensor = tensor.view(-1) + assert tensor.numel() == size_2bit + size_4bit + + tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) + + tensor_2bit = _unpack(tensor_2bit, 2) + tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) + tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) + + tensor_4bit = _unpack(tensor_4bit, 4) + tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) + tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) + + tensor_fp6 = (tensor_2bit << 4) | tensor_4bit + tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) + return tensor_fp6 + + +def unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _unpack_tc_fp6(tensor) + return _unpack_tc_fpx(tensor, nbits) + + +def from_scaled_tc_fpx(tensor: Tensor, + ebits: int, + mbits: int, + scale=None) -> Tensor: + fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits) + tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits) + if scale is not None: + tensor = tensor * scale.float().view(-1, 1) + return tensor