diff --git a/setup.py b/setup.py index 31dd6a1094..7e60acbfa8 100644 --- a/setup.py +++ b/setup.py @@ -55,10 +55,6 @@ def read_version(file_path="version.txt"): and platform.system() == "Darwin" ) -use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and platform.system() == "Linux" - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 - version_prefix = read_version() # Version is version.dev year month date if using nightlies and version if not version = ( @@ -83,6 +79,8 @@ def use_debug_mode(): _get_cuda_arch_flags, ) +IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) + class BuildOptions: def __init__(self): @@ -257,35 +255,28 @@ def get_extensions(): print( "PyTorch GPU support is not available. Skipping compilation of CUDA extensions" ) - if CUDA_HOME is None and torch.version.cuda: - print("CUDA toolkit is not available. Skipping compilation of CUDA extensions") + if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available(): + print( + "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions" + ) print( "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" ) - if ROCM_HOME is None and torch.version.hip: - print("ROCm is not available. Skipping compilation of ROCm extensions") - print("If you'd like to compile ROCm extensions locally please install ROCm") - - use_cuda = torch.version.cuda and CUDA_HOME is not None - use_hip = torch.version.hip and ROCM_HOME is not None - extension = CUDAExtension if (use_cuda or use_hip) else CppExtension - - nvcc_args = [ - "-DNDEBUG" if not debug_mode else "-DDEBUG", - "-O3" if not debug_mode else "-O0", - "-t=0", - "-std=c++17", - ] - hip_args = [ - "-DNDEBUG" if not debug_mode else "-DDEBUG", - "-O3" if not debug_mode else "-O0", - "-std=c++17", - ] + + use_cuda = torch.cuda.is_available() and ( + CUDA_HOME is not None or ROCM_HOME is not None + ) + extension = CUDAExtension if use_cuda else CppExtension extra_link_args = [] extra_compile_args = { "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], - "nvcc": nvcc_args if use_cuda else hip_args, + "nvcc": [ + "-DNDEBUG" if not debug_mode else "-DDEBUG", + "-O3" if not debug_mode else "-O0", + "-t=0", + "-std=c++17", + ], } if not IS_WINDOWS: @@ -293,17 +284,6 @@ def get_extensions(): ["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"] ) - if use_cpp_avx512 and TORCH_VERSION_AT_LEAST_2_7: - if torch._C._cpu._is_avx512_supported(): - extra_compile_args["cxx"].extend( - [ - "-DCPU_CAPABILITY_AVX512", - "-march=native", - "-mfma", - "-fopenmp", - ] - ) - if debug_mode: extra_compile_args["cxx"].append("-g") if "nvcc" in extra_compile_args: @@ -319,95 +299,48 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") - hip_sparse_marlin_supported = True - if use_hip: - # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT - found_col16 = False - found_vec_ext = False - print("ROCM_HOME", ROCM_HOME) - hipblaslt_headers = list( - glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h")) - ) - print("hipblaslt_headers", hipblaslt_headers) - for header in hipblaslt_headers: - with open(header) as f: - text = f.read() - if "HIPBLASLT_ORDER_COL16" in text: - found_col16 = True - if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text: - found_vec_ext = True - if found_col16: - extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16") - print("hipblaslt found extended col order enums") - else: - print("hipblaslt does not have extended col order enums") - if found_vec_ext: - extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT") - print("hipblaslt found vec ext") - else: - print("hipblaslt does not have vec ext") - - # sparse_marlin depends on features in ROCm 6.4, __builtin_amdgcn_global_load_lds - ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split(".")[:2]) - hip_sparse_marlin_supported = ROCM_VERSION >= (6, 4) - # Get base directory and source paths curdir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(curdir, "torchao", "csrc") # Collect C++ source files sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) - if IS_WINDOWS: - # Remove csrc/cpu/*.cpp on Windows due to the link issue: unresolved external symbol PyInit__C - excluded_sources = list( - glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True) - ) - sources = [s for s in sources if s not in excluded_sources] - # Collect CUDA source files extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list( glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) ) - # Collect HIP source files extensions_hip_dir = os.path.join( extensions_dir, "cuda", "tensor_core_tiled_layout" ) hip_sources = list( glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) ) - if hip_sparse_marlin_supported: - extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") - hip_sources += list( - glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) - ) - extensions_hip_dir = os.path.join(extensions_dir, "rocm") - hip_sources += list( - glob.glob(os.path.join(extensions_hip_dir, "**/*.hip"), recursive=True) - ) + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") hip_sources += list( - glob.glob(os.path.join(extensions_hip_dir, "**/*.cpp"), recursive=True) + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) ) - # Add CUDA source files if needed - if use_cuda: + # Collect CUDA source files if needed + if not IS_ROCM and use_cuda: sources += cuda_sources - # TODO: Remove this and use what CUDA has once we fix all the builds. - # Add HIP source files if needed - if use_hip: + # TOOD: Remove this and use what CUDA has once we fix all the builds. + if IS_ROCM and use_cuda: # Add ROCm GPU architecture check gpu_arch = torch.cuda.get_device_properties(0).name if gpu_arch != "gfx942": print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") - print("Currently only gfx942 is supported. Compiling only for gfx942.") - extra_compile_args["nvcc"].append("--offload-arch=gfx942") - sources += hip_sources + print( + "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + ) + else: + sources += hip_sources use_cutlass = False cutlass_90a_sources = None - if use_cuda and not IS_WINDOWS: + if use_cuda and not IS_ROCM and not IS_WINDOWS: use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") cutlass_include_dir = os.path.join(cutlass_dir, "include") diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py deleted file mode 100644 index 9596e71a7a..0000000000 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ /dev/null @@ -1,217 +0,0 @@ -import itertools - -import pytest -import torch -import torch.utils.checkpoint -from torch._dynamo.utils import counters -from torch._inductor import config -from torch._inductor.test_case import TestCase, run_tests -from torch._inductor.utils import run_and_get_code -from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CPU -from torch.utils.cpp_extension import IS_WINDOWS - -import torchao -from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 - - -class SelfAttnLikeModule(torch.nn.Module): - def __init__( - self, - input_dim, - has_mask, - num_attention_heads=None, - attention_head_size=None, - ) -> None: - super().__init__() - self.input_dim = input_dim - self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) - self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) - self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) - self.softmax = torch.nn.Softmax(dim=-1) - assert num_attention_heads is not None - assert attention_head_size is not None - self.num_attention_heads = num_attention_heads - self.attention_head_size = attention_head_size - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) - self.dropout = torch.nn.Dropout(0) - self.has_mask = has_mask - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute([0, 2, 1, 3]) - - def forward(self, x, mask): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) - q = self.transpose_for_scores(q) - k = self.transpose_for_scores(k) - v = self.transpose_for_scores(v) - scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) - if self.has_mask and mask.dtype != scores.dtype: - scores = scores + mask - attention = self.softmax(scores) - attention = self.dropout(attention) - context_layer = torch.matmul(attention, v) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - context_layer = context_layer.view( - context_layer.size()[:-2] + (self.all_head_size,) - ) - return self.dense(context_layer) - - -class TestSDPAPatternRewriterTemplate(TestCase): - def _clone_inputs(self, inputs): - def clone(x): - if not isinstance(x, torch.Tensor): - return x - return x.clone() - - return [clone(x) for x in inputs] - - def _check_common( - self, - dot_prod_attention, - args1=None, - contains=True, - atol=1e-5, - has_fuse_pattern=True, - has_dropout=False, - check_train=True, - override_check_equal=False, - dtype=torch.float, - rtol=1.3e-6, - ): - if args1 is None: - tensor_shape = (4, 2, 16, 32) - args1 = [ - torch.randn(tensor_shape, device=self.device, dtype=dtype), - torch.randn(tensor_shape, device=self.device, dtype=dtype), - torch.randn(tensor_shape, device=self.device, dtype=dtype), - ] - else: - args1 = list(args1) - args2 = self._clone_inputs(args1) - - for training in [False, True] if check_train else [False]: - for x in itertools.chain(args1[:], args2[:]): - if isinstance(x, torch.Tensor) and x.is_floating_point(): - x.requires_grad = training - - dropout_arg = [training] if has_dropout else [] - torch.manual_seed(1234) - result1 = dot_prod_attention(*(args1 + dropout_arg)) - - counters.clear() - torch.manual_seed(1234) - compiled_model = torch.compile(dot_prod_attention, fullgraph=True) - result2, source_code = run_and_get_code( - compiled_model, - *(args2 + dropout_arg), - ) - source_code = "\n".join(source_code) - if has_fuse_pattern: - self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1) - if contains: - # many of the patterns get re-expanded in dispatcher - self.assertIn( - "torchao.scaled_dot_product_int8", - source_code, - ) - - # some tests configured with very low dropout where we still want to check equality - if not has_dropout or override_check_equal: - self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) - - if training: - result1.sum().backward() - result2.sum().backward() - for arg1, arg2 in zip(args1, args2): - if ( - isinstance(arg1, torch.Tensor) - and arg1.is_floating_point() - and (not has_dropout or override_check_equal) - ): - self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) - - @skipIfRocm - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" - ) - @pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet") - @config.patch({"freezing": True}) - def _test_sdpa_int8_rewriter(self): - from torch.export import export_for_training - - import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq - from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( - X86InductorQuantizer, - ) - - # pattern is different for bs=1 - torch.manual_seed(1234) - for dtype, has_mask, bs in itertools.product( - [torch.float32, torch.bfloat16], [True, False], [56, 1] - ): - seqlen, numhead, headsize = 197, 16, 64 - mod = SelfAttnLikeModule( - input_dim=headsize * numhead, - has_mask=has_mask, - num_attention_heads=numhead, - attention_head_size=headsize, - ).eval() - inputs = ( - torch.randn( - (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype - ), - torch.randn((bs, 1, 1, seqlen), device=self.device) - if has_mask - else None, - ) - enable_autocast = dtype == torch.bfloat16 - with ( - torch.no_grad(), - torch.amp.autocast( - self.device, enabled=enable_autocast, dtype=torch.bfloat16 - ), - ): - _int8_sdpa_init() - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) - quantizer.set_function_type_qconfig( - torch.matmul, quantizer.get_global_quantization_config() - ) - export_model = export_for_training( - mod, - inputs, - strict=True, - ).module() - prepare_model = prepare_pt2e(export_model, quantizer) - prepare_model(*inputs) - convert_model = convert_pt2e(prepare_model) - torchao.quantization.pt2e.move_exported_model_to_eval(convert_model) - self._check_common( - convert_model, args1=inputs, check_train=False, atol=1.0 - ) - - -if HAS_CPU: - - class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): - device = "cpu" - test_sdpa_int8_rewriter_cpu = ( - TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter - ) - - -if __name__ == "__main__": - if IS_LINUX: - run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 5025b8a19b..1cdce2cd81 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import itertools -import math import sys import pytest @@ -15,7 +14,6 @@ parametrize, ) from torch.testing._internal.optests import opcheck -from torch.utils.cpp_extension import IS_WINDOWS import torchao from torchao.dtypes.floatx import from_scaled_tc_floatx @@ -25,14 +23,10 @@ ) from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_7, - compute_max_diff, -) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff -IS_CUDA = torch.cuda.is_available() and torch.version.cuda -IS_ROCM = torch.cuda.is_available() and torch.version.hip +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) try: import torchao.ops @@ -58,7 +52,7 @@ def _create_floatx_inputs( fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 return floatx_weight.to(device), scale.to(device), fp16_act.to(device) - @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) def test_quant_llm_linear(self, ebits, mbits, dtype): @@ -88,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype): test_utils=test_utils, ) - @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) @@ -115,135 +109,6 @@ def test_quant_llm_linear_correctness( rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 assert relative_error < rtol - def _scaled_dot_product_int8_op_ref( - self, - q, - k, - v, - attn_mask=None, - dropout_p=0, - is_causal=False, - q_scale=1.0, - q_zp=0, - k_scale=1.0, - k_zp=0, - v_scale=1.0, - v_zp=0, - a_scale=1.0, - a_zp=0, - o_scale=1.0, - o_zp=0, - ): - q = (q.to(torch.float) - q_zp) * q_scale - k = (k.to(torch.float) - k_zp) * k_scale - v = (v.to(torch.float) - v_zp) * v_scale - scale_factor = 1 / math.sqrt(q.size(-1)) - attn = q @ k.transpose(-2, -1) - attn = attn * scale_factor - if attn_mask is not None: - attn = attn + attn_mask.to(torch.float) - attn_max = attn.max(dim=-1, keepdim=True).values - attn = attn - attn_max - attn = torch.exp(attn) - attn_sum = torch.sum(attn, dim=-1, keepdim=True) - attn = attn / attn_sum - attn = torch.clamp(torch.round(attn / a_scale) + a_zp, min=0, max=255) - attn = (attn - a_zp) * a_scale - out = attn @ v - out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255) - return out.to(torch.uint8) - - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" - ) - @pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet") - @parametrize("batch_size", [56, 120]) - @parametrize("n_head", [2, 16]) - @parametrize("q_seq_len", [18, 89]) - @parametrize("kv_seq_len", [100, 253]) - @parametrize("head_dim", [32, 64]) - @parametrize("mask_dtype", [None, torch.float32, torch.bfloat16]) - def test_scaled_dot_product_int8_op( - self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype - ): - torch.manual_seed(1234) - device = "cpu" - q_scale = float(1.7907238006591797) - q_zp = int(127) - k_scale = float(1.8039721250534058) - k_zp = int(125) - v_scale = float(1.839004635810852) - v_zp = int(127) - a_scale = float(0.003919653594493866) - a_zp = int(120) - o_scale = float(1.8191684484481812) - o_zp = int(128) - q_shape = [batch_size, q_seq_len, n_head, head_dim] - kv_shape = [batch_size, kv_seq_len, n_head, head_dim] - mask_shape = [batch_size, 1, 1, kv_seq_len] - q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 - k = ( - torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) - * 100 - ) - v = ( - torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) - * 100 - ) - q = q.to(torch.uint8) - k = k.to(torch.uint8) - v = v.to(torch.uint8) - attn_mask = ( - torch.randn(mask_shape, dtype=mask_dtype, device=device) - if mask_dtype is not None - else None - ) - q2, k2, v2, attn_mask_2 = ( - q.clone(), - k.clone(), - v.clone(), - attn_mask.clone() if mask_dtype is not None else None, - ) - - math_ref = self._scaled_dot_product_int8_op_ref( - q2, - k2, - v2, - attn_mask=attn_mask, - dropout_p=0.0, - is_causal=False, - q_scale=q_scale, - q_zp=q_zp, - k_scale=k_scale, - k_zp=k_zp, - v_scale=v_scale, - v_zp=v_zp, - a_scale=a_scale, - a_zp=a_zp, - o_scale=o_scale, - o_zp=o_zp, - ) - actual = torch.ops.torchao.scaled_dot_product_int8( - q, - k, - v, - attn_mask=attn_mask_2, - dropout_p=0.0, - is_causal=False, - q_scale=q_scale, - q_zp=q_zp, - k_scale=k_scale, - k_zp=k_zp, - v_scale=v_scale, - v_zp=v_zp, - a_scale=a_scale, - a_zp=a_zp, - o_scale=o_scale, - o_zp=o_zp, - ) - - self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) - instantiate_parametrized_tests(TestOps) @@ -274,7 +139,7 @@ def make_test_id(param): return f"tiles_{param}" -@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): @@ -292,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): # TODO: Fix "test_aot_dispatch_dynamic" test failure -@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): @@ -338,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): return dq.reshape(n, k) -@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -406,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant( # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize -@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -472,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( assert diff_op_ao < 1e-1 -@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -583,7 +448,7 @@ def reshape_w(w): ) -@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize( "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, @@ -673,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto ) -@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize( "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, @@ -752,27 +617,5 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) -@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available") -def test_swizzle_mm(): - test_utils = [ - "test_schema", - "test_autograd_registration", - "test_faketensor", - ] - - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") - - mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda") - mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda") - - opcheck( - torch.ops.torchao.swizzle_mm, - (mat1, mat2, False, False), - test_utils=test_utils, - ) - - if __name__ == "__main__": pytest.main(sys.argv) diff --git a/torchao/__init__.py b/torchao/__init__.py index fb96282d77..752aa94a4f 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -43,14 +43,13 @@ quantize_, ) -from . import dtypes, optim, swizzle, testing +from . import dtypes, optim, testing __all__ = [ "dtypes", "autoquant", "optim", "quantize_", - "swizzle", "testing", "ops", ] diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp deleted file mode 100644 index 36cd24ab5e..0000000000 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ /dev/null @@ -1,1907 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#endif - -#include -#include -#include -#include - -namespace torchao { - -namespace { - -inline double calculate_scale( - const at::Tensor& query, - double scale) { - return scale == 0.0 ? 1.0 / std::sqrt(query.size(-1)) : scale; -} - -#ifdef CPU_CAPABILITY_AVX512 - -template -inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto data_vec = at::vec::Vectorized(val); - int64_t d = 0; - for (; d < size - (size % vec_size); d += vec_size) { - data_vec.store(data + d); - } - if (d < size) { - data_vec.store(data + d, size - d); - } -} - -void reshape_attn_mask_to_4d( - at::Tensor& attn_mask, - int64_t batchSize, - int64_t num_head, - int64_t qSize, - int64_t kvSize) { - // Support mask shapes: - // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) - // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) - // Guaranteed in check_attn_mask_shape - int64_t attn_mask_size_0 = 1; - int64_t attn_mask_size_1 = 1; - if (attn_mask.dim() == 4) { - if (attn_mask.size(0) == batchSize) { - attn_mask_size_0 = batchSize; - } - if (attn_mask.size(1) == num_head) { - attn_mask_size_1 = num_head; - } - } - attn_mask = attn_mask - .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)}) - .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); -} - -// TODO: Use at::native::_store instead when it supports Half. -template -inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { - src.store(dst, size); -} - -template -inline typename std::enable_if_t || std::is_same_v, void> -_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { - auto res = at::vec::convert(src); - res.store(dst, size); -} - -/* -1. dequant -2. add mask -3. max reduce for softmax -*/ -template -inline void _dequant_mask_max_fusion_kernel( - const int32_t* in, - const mask_t* mask_ptr, - const int32_t* sum_a_ptr, - const int32_t* sum_b_ptr, - const int& M, - const int& N, - const int& ldi, - const int& ldm, // leading dimension mask - const int& ldo, - const int32_t& beta, // zp_a*zp_b*k - const float& alpha, // scale_a*scale_b*scale_sdpa - float* out, - float* sfm_max_ptr) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto vec_beta = at::vec::Vectorized(beta); - auto vec_alpha = at::vec::Vectorized(alpha); - for (long row = 0; row < M; row += 1) { - auto sum_a = sum_a_ptr[row]; - auto vec_sum_a = at::vec::Vectorized(sum_a); - const int32_t* tmp_in = in + row * ldi; - float* tmp_out = out + row * ldo; - const mask_t* mask_data_ptr = mask_ptr + row * ldm; - float tmp_max = -std::numeric_limits::infinity(); - auto vec_tmp_max = at::vec::Vectorized(tmp_max); - long col = 0; - for (; col < vec_size * (N / vec_size); col += vec_size) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col); - auto tmp7 = at::vec::convert(tmp6); - auto tmp8 = tmp5 + tmp7; - vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp8); - _store(tmp_out + col, tmp8); - } - if (col < N) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col, N - col); - auto tmp7 = at::vec::convert(tmp6); - auto tmp8 = tmp5 + tmp7; - _store(tmp_out + col, tmp8, N - col); - vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp8), N - col); - } - sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); - } -} - -/* -1. dequant -2. max reduce for softmax -*/ -inline void _dequant_max_fusion_kernel( - const int32_t* in, - const int32_t* sum_a_ptr, - const int32_t* sum_b_ptr, - const int& M, - const int& N, - const int& ldi, - const int& ldo, - const int32_t& beta, // zp_a*zp_b*k - const float& alpha, // scale_a*scale_b*scale_sdpa - float* out, - float* sfm_max_ptr) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto vec_beta = at::vec::Vectorized(beta); - auto vec_alpha = at::vec::Vectorized(alpha); - for (long row = 0; row < M; row += 1) { - auto sum_a = sum_a_ptr[row]; - auto vec_sum_a = at::vec::Vectorized(sum_a); - const int32_t* tmp_in = in + row * ldi; - float* tmp_out = out + row * ldo; - float tmp_max = -std::numeric_limits::infinity(); - auto vec_tmp_max = at::vec::Vectorized(tmp_max); - long col = 0; - for (; col < vec_size * (N / vec_size); col += vec_size) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp5); - _store(tmp_out + col, tmp5); - } - if (col < N) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - _store(tmp_out + col, tmp5, N - col); - vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp5), N - col); - } - sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); - } -} - -/* -1. Softmax: sub max, exp, sum reduce, div sum -2. quant -3. sum for attention -*/ -template -inline void _sub_exp_sum_div_quant_sum_fusion_kernel( - const float* in, - const int64_t& M, - const int64_t& N_step, - const int64_t& NSlice, - const int& ldi, - const int& ldo, - const int& kvSize, - const int& rndkvSplitSize, - const int& av_gemm_K, - const int32_t& beta1, // zp_a - const int32_t& beta2, // zp_b - const float& alpha, // scale_a - float* local, - scalar_t* out, - float* sfm_max_ptr, - float* sfm_sum_ptr, - int32_t* sum_a_ptr) { - const int32_t vec_size = at::vec::Vectorized::size(); - float min_val = 0; - float max_val = 255; - auto vec_min_val = at::vec::Vectorized(min_val); - auto vec_max_val = at::vec::Vectorized(max_val); - scalar_t zero = 0; - auto vec_zero = at::vec::Vectorized(zero); - float beta1_float = (float) beta1; - auto vec_beta1 = at::vec::Vectorized(beta1_float); - for (int64_t row = 0; row < M; ++row) { - auto sfm_max = sfm_max_ptr[row]; - auto vec_max = at::vec::Vectorized(sfm_max); - // sub max, exp, sum reduce - const float* qk_block_data = in + row * rndkvSplitSize; - for (int64_t l = 0; l < NSlice; l ++) { - int64_t n = l * N_step; - int64_t kvBlockSize = std::min(N_step, kvSize - n); - const float* tmp_in = qk_block_data + l * ldi; - float tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - float* tmp_out = local + n; - long col = 0; - for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - vec_tmp_sum += tmp2; - _store(tmp_out + col, tmp2); - } - if (col < kvBlockSize) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - _store(tmp_out + col, tmp2, kvBlockSize - col); - vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); - } - sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); - } - // div sum, sum for attention - auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; - auto vec_sum_scale = at::vec::Vectorized(sum_scale); - scalar_t* qk_reduced_block_data = out + row * av_gemm_K; - for (int64_t l = 0; l < NSlice; l ++) { - int64_t n = l * N_step; - int64_t kvBlockSize = std::min(N_step, kvSize - n); - int32_t tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - float* tmp_in = local + n; - scalar_t* tmp_out = qk_reduced_block_data + l * ldo; - long col = 0; - for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - _store(tmp_out + col, tmp4); - auto tmp6 = at::vec::convert(tmp4); - vec_tmp_sum += tmp6; - } - if (col < kvBlockSize) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - _store(tmp_out + col, tmp4, kvBlockSize - col); - auto tmp6 = at::vec::convert(tmp4); - vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col); - } - sum_a_ptr[row] += vec_tmp_sum.reduce_add() * beta2; - // set zero - col = kvBlockSize; - for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { - _store(tmp_out + col, vec_zero); - } - if (col < av_gemm_K) { - _store(tmp_out + col, vec_zero, av_gemm_K - col); - } - } - } -} - -/* -1. Softmax: sub max, exp, sum reduce, div sum -2. quant -*/ -template -inline void _sub_exp_sum_div_quant_fusion_kernel( - const float* in, - const int64_t& M, - const int64_t& N_step, - const int64_t& NSlice, - const int& ldi, - const int& ldo, - const int& kvSize, - const int& rndkvSplitSize, - const int& av_gemm_K, - const int32_t& beta1, // zp_a - const float& alpha, // scale_a - float* local, - scalar_t* out, - float* sfm_max_ptr, - float* sfm_sum_ptr) { - const int32_t vec_size = at::vec::Vectorized::size(); - float min_val = 0; - float max_val = 255; - auto vec_min_val = at::vec::Vectorized(min_val); - auto vec_max_val = at::vec::Vectorized(max_val); - scalar_t zero = 0; - auto vec_zero = at::vec::Vectorized(zero); - float beta1_float = (float) beta1; - auto vec_beta1 = at::vec::Vectorized(beta1_float); - for (int64_t row = 0; row < M; ++row) { - auto sfm_max = sfm_max_ptr[row]; - auto vec_max = at::vec::Vectorized(sfm_max); - // sub max, exp, sum reduce - const float* qk_block_data = in + row * rndkvSplitSize; - for (int64_t l = 0; l < NSlice; l ++) { - int64_t n = l * N_step; - int64_t kvBlockSize = std::min(N_step, kvSize - n); - const float* tmp_in = qk_block_data + l * ldi; - float tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - float* tmp_out = local + n; - long col = 0; - for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - vec_tmp_sum += tmp2; - _store(tmp_out + col, tmp2); - } - if (col < kvBlockSize) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); - _store(tmp_out + col, tmp2, kvBlockSize - col); - } - sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); - } - // div sum, sum for attention - auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; - auto vec_sum_scale = at::vec::Vectorized(sum_scale); - scalar_t* qk_reduced_block_data = out + row * av_gemm_K; - for (int64_t l = 0; l < NSlice; l ++) { - int64_t n = l * N_step; - int64_t kvBlockSize = std::min(N_step, kvSize - n); - float* tmp_in = local + n; - scalar_t* tmp_out = qk_reduced_block_data + l * ldo; - long col = 0; - for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - _store(tmp_out + col, tmp4); - } - if (col < kvBlockSize) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - _store(tmp_out + col, tmp4, kvBlockSize - col); - } - // set zero - col = kvBlockSize; - for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { - _store(tmp_out + col, vec_zero); - } - if (col < av_gemm_K) { - _store(tmp_out + col, vec_zero, av_gemm_K - col); - } - } - } -} - -/* -1. dequant -2. quant -*/ -template -inline void _dequant_quant_fusion_kernel( - const int32_t* in, - const int32_t* sum_a_ptr, - const int32_t* sum_b_ptr, - const int& M, - const int& N, - const int& ldi, - const int& ldo, - const int32_t& beta1, // zp_a*zp_b*k - const int32_t& beta2, // zp_c - const float& alpha, // scale_a*scale_b/scale_c - scalar_t* out) { - const int32_t vec_size = at::vec::Vectorized::size(); - float min_val = 0; - float max_val = 255; - auto vec_min_val = at::vec::Vectorized(min_val); - auto vec_max_val = at::vec::Vectorized(max_val); - auto vec_beta1 = at::vec::Vectorized(beta1); - auto vec_alpha = at::vec::Vectorized(alpha); - float beta2_float = (float) beta2; - auto vec_beta2 = at::vec::Vectorized(beta2_float); - for (long row = 0; row < M; row += 1) { - auto sum_a = sum_a_ptr[row]; - auto vec_sum_a = at::vec::Vectorized(sum_a); - const int32_t* tmp_in = in + row * ldi; - scalar_t* tmp_out = out + row * ldo; - long col = 0; - for (; col < vec_size * (N / vec_size); col += vec_size) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; - auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - _store(tmp_out + col, tmp8); - } - if (col < N) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; - auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - _store(tmp_out + col, tmp8, N - col); - } - } -} - -/* -1. dequant -2. quant -*/ -template -inline void _dequant_quant_fusion_kernel( - const int32_t* in, - const int32_t* sum_a_ptr, - const int& M, - const int& N, - const int& ldi, - const int& ldo, - const int32_t& beta2, // zp_c - const float& alpha, // scale_a*scale_b/scale_c - scalar_t* out) { - const int32_t vec_size = at::vec::Vectorized::size(); - float min_val = 0; - float max_val = 255; - auto vec_min_val = at::vec::Vectorized(min_val); - auto vec_max_val = at::vec::Vectorized(max_val); - // auto vec_beta1 = at::vec::Vectorized(beta1); - auto vec_alpha = at::vec::Vectorized(alpha); - float beta2_float = (float) beta2; - auto vec_beta2 = at::vec::Vectorized(beta2_float); - for (long row = 0; row < M; row += 1) { - auto sum_a = sum_a_ptr[row]; - auto vec_sum_a = at::vec::Vectorized(sum_a); - const int32_t* tmp_in = in + row * ldi; - scalar_t* tmp_out = out + row * ldo; - long col = 0; - for (; col < vec_size * (N / vec_size); col += vec_size) { - auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp3 = tmp1 - vec_sum_a; - // auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; - auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - _store(tmp_out + col, tmp8); - } - if (col < N) { - auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col, N - col); - auto tmp3 = tmp1 - vec_sum_a; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; - auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - _store(tmp_out + col, tmp8, N - col); - } - } -} - -template -inline void _int_sum_b_contiguous_kernel_helper( - const scalar_t* in, - int32_t* out, - const int& N, - const int32_t& scale) { - const int32_t vec_size = at::vec::Vectorized::size(); - int32_t tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - long i = 0; - for (; i < vec_size * (N / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(in + i); - auto tmp1 = at::vec::convert(tmp0); - vec_tmp_sum = vec_tmp_sum + tmp1; - } - if (i < N) { - auto tmp0 = at::vec::Vectorized::loadu(in + i, N - i); - auto tmp1 = at::vec::convert(tmp0); - vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp1, N - i); - } - out[0] = vec_tmp_sum.reduce_add() * scale; -} - -// reduce along dim b for shape [a, b], with sum shape [a] -template -inline void _int_sum_b_contiguous_kernel( - const scalar_t* in, - int32_t* out, - const int& M, - const int& N, - const int& ld, - const int32_t& scale) { - for (long r = 0; r < M; r += 1) { - _int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); - } -} - -// reduce along dim a for shape [a, b], with sum shape [b] -template -inline void _int_sum_a_contiguous_kernel( - const scalar_t* in, - int32_t* out, - const int& M, - const int& N, - const int& ld, - const int32_t& scale) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(scale); - // initialization with 0 - int32_t zero = 0; - auto vec_zero = at::vec::Vectorized(zero); - long i = 0; - for (; i < vec_size * (M / vec_size); i += vec_size) { - _store(out + i, vec_zero); - } - if (i < M) { - _store(out + i, vec_zero, M - i); - } - // sum - for (long j = 0; j < N; j++) { - const scalar_t* tmp_in = in + j * ld; - long k = 0; - for (; k < vec_size * (M / vec_size); k += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + k); - auto tmp1 = at::vec::Vectorized::loadu(out + k); - auto tmp2 = at::vec::convert(tmp0); - auto tmp3 = tmp1 + tmp2; - _store(out + k, tmp3); - } - if (k < M) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + k, M - k); - auto tmp1 = at::vec::Vectorized::loadu(out + k, M - k); - auto tmp2 = at::vec::convert(tmp0); - auto tmp3 = tmp1 + tmp2; - _store(out + k, tmp3, M - k); - } - } - // scale - i = 0; - for (; i < vec_size * (M / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(out + i); - auto tmp1 = tmp0 * vec_scale; - _store(out + i, tmp1); - } - if (i < M) { - auto tmp0 = at::vec::Vectorized::loadu(out + i, M - i); - auto tmp1 = tmp0 * vec_scale; - _store(out + i, tmp1, M - i); - } -} - -// do the transpose: [in_rows, in_cols] -> [in_cols, in_rows] -template -inline void do_transpose( - scalar_t* src, - scalar_t* dst, - int64_t in_rows, - int64_t in_cols, - int64_t ldi, - int64_t ldo) { - for (int64_t r=0; r [prows, pcols] -template -inline void pad_remain_row_col( - scalar_t* value_ptr, - int rows, - int cols, - int prows, - int pcols, - int ldi, - scalar_t pad_val=0) { - auto psize = pcols - cols; - if (psize == 0 && prows == rows) { - return; - } - const int32_t vec_size = at::vec::Vectorized::size(); - auto pad = at::vec::Vectorized(pad_val); - if (psize > 0) { - for (int i = 0; i < rows; i++) { - int j = 0; - for (; j < psize - (psize % vec_size); j += vec_size) { - pad.store(value_ptr + i * ldi + cols + j); - } - if (j < psize) { - pad.store(value_ptr + i * ldi + cols + j, psize - j); - } - } - } - - for (int i = rows; i < prows; i++) { - int j = 0; - for (; j < pcols - (pcols % vec_size); j += vec_size) { - pad.store(value_ptr + i * ldi + j); - } - if (j < pcols) { - pad.store(value_ptr + i * ldi + j, pcols - j); - } - } -} - -// copy value_ptr to dst_ptr with padding: [rows, cols] -> [prows, pcols] -template -inline void copy_value_with_pad( - scalar_t* value_ptr, - scalar_t* dst_ptr, - int rows, - int cols, - int prows, - int pcols, - int ldi, - scalar_t pad_val=0) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto pad = at::vec::Vectorized(pad_val); - int i = 0; - for (; i < rows; i++) { - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(dst_ptr + i * pcols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - j); - vec_v.store(dst_ptr + i * pcols + j, cols - j); - } - - // col padding - auto psize = pcols - cols; - if (psize > 0) { - int pj = 0; - for (; pj < psize - (psize % vec_size); pj += vec_size) { - pad.store(dst_ptr + i * pcols + cols + pj); - } - if (pj < psize) { - pad.store(dst_ptr + i * pcols + cols + pj, psize - pj); - } - } - } - - // row padding - for (; i < prows; i++) { - int j = 0; - for (; j < pcols - (pcols % vec_size); j += vec_size) { - pad.store(dst_ptr + i * pcols + j); - } - if (j < pcols) { - pad.store(dst_ptr + i * pcols + j, pcols - j); - } - - } - -} - -// UINT8 - one parallel loop with u8u8s32 GEMM -template = 0> -inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( - const at::Tensor& output, - const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, - double dropout_p, - bool is_causal, - std::optional attention_mask, - double scale, - float q_scale, - int32_t q_zp, - float k_scale, - int32_t k_zp, - float v_scale, - int32_t v_zp, - float a_scale, - int32_t a_zp, - float o_scale, - int32_t o_zp) { - // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) - // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - at::Tensor query = q.transpose(1, 2); - at::Tensor key = k.transpose(1, 2); - at::Tensor value = v.transpose(1, 2); - - using accum_t = float; - accum_t scaling_factor = calculate_scale(query, scale); - int block_64 = 64; - auto u8_dt = at::ScalarType::Byte; - - // Sizes - TORCH_CHECK( - (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), - "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); - TORCH_CHECK( - kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); - - int64_t batchSize = query.size(0); - int64_t qSize = query.size(1); - int64_t kvSize = value.size(1); - int64_t num_head = query.size(2); - int64_t headSize = query.size(3); - - bool has_attn_mask = attention_mask.has_value() && attention_mask.value().numel(); - if (has_attn_mask) { - reshape_attn_mask_to_4d(attention_mask.value(), batchSize, num_head, qSize, kvSize); - } - - // Strides - int64_t qStrideB = query.stride(0); - int64_t qStrideM = query.stride(1); - int64_t qStrideH = query.stride(2); - int64_t kStrideB = key.stride(0); - int64_t kStrideN = key.stride(1); - int64_t kStrideH = key.stride(2); - int64_t vStrideB = value.stride(0); - int64_t vStrideN = value.stride(1); - int64_t vStrideH = value.stride(2); - int64_t oStrideB = output.stride(0); - int64_t oStrideM = output.stride(1); - int64_t oStrideH = output.stride(2); - int64_t mStrideB = - (has_attn_mask && attention_mask.value().size(0) > 1) - ? attention_mask.value().stride(0) - : 0; - int64_t mStrideH = - (has_attn_mask && attention_mask.value().size(1) > 1) - ? attention_mask.value().stride(1) - : 0; - int64_t mStrideM = - (has_attn_mask && attention_mask.value().size(2) > 1) - ? attention_mask.value().stride(2) - : 0; - int64_t mStrideN = - (has_attn_mask && attention_mask.value().size(3) > 1) - ? attention_mask.value().stride(3) - : 0; - - int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; - int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; - int64_t qSlice = (qSize - 1) / qSplitSize + 1; - int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; - int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; - int64_t num_thread = at::get_num_threads(); - - int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; - int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; - int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; - int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; - - bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; - int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; - int av_gemm_K = kvSplitSize + av_gemm_K_padding; - - // Data ptrs - scalar_t* q_data = query.data_ptr(); - scalar_t* k_data = key.data_ptr(); - scalar_t* v_data = value.data_ptr(); - mask_t* mask_data = attention_mask.has_value() - ? attention_mask.value().data_ptr() - : nullptr; - scalar_t* out_data = output.data_ptr(); - - bool headSize_mul64 = headSize % 64 == 0; - int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64; - int qk_gemm_K = headSize + qk_gemm_K_padding; - - int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; - int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; - - int64_t total_size_uint8_per_thread = - /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + - /* qk_local */ kvSlice * av_gemm_K * 4 + - /* qk_reduce */ kvSlice * qk_reduce_strideL + - /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + - /* dst_s32 */ qSplitSize * rndHeadSize * 4 + - /* softmax_sum */ qSplitSize * 4 + - /* query_sum */ qSplitSize * 4 + - /* attention_sum */ qSplitSize * 4 + - /* softmax max */ qSplitSize * 4 + - /* query_padding_data */ qSplitSize * qk_gemm_K + - /* key_sum */ kvSize * 4 + - /* value_sum */ headSize * 4 + - /* key_t_reorder */ qk_gemm_K * rndkvSize + - /* value_t_reorder */ kvSlice * v_reorder_strideL; - - at::Tensor total_buf = at::empty( - {num_thread, total_size_uint8_per_thread}, - query.options()); - scalar_t* total_buf_data = total_buf.data_ptr(); - - at::parallel_for( - 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head); - int ompIdx = at::get_thread_num(); - scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; - int32_t offset = 0; - accum_t* qk_data = reinterpret_cast(total_buf_ptr); - offset += kvSlice * qSplitSize * rndkvSplitSize * 4; - accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * av_gemm_K * 4; - scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * qk_reduce_strideL; - int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndkvSplitSize * 4; - int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndHeadSize * 4; - accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * qk_gemm_K; - - int32_t* k_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += kvSize * 4; - int32_t* v_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += headSize * 4; - scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qk_gemm_K * rndkvSize; - scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - - uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * block_64]; - - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - - // sum k and v - if (q_zp == 0) { - fill_stub(k_sum_ptr, static_cast(0), kvSize); - } else { - _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, - k_sum_ptr, - kvSize, headSize, kStrideN, q_zp); - } - if (a_zp == 0) { - fill_stub(v_sum_ptr, static_cast(0), headSize); - } else { - _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, - v_sum_ptr, - headSize, kvSize, vStrideN, a_zp); - } - - // transpose and packing - for (int64_t n = 0; n < kvSize; n += kvSplitSize) { - int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - for (int64_t b = 0; b < kvBlockSize; b += block_64) { - bool istail = kvBlockSize - b < block_64; - int64_t trans_rows = istail ? kvBlockSize - b : block_64; - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - trans_rows, - headSize, - kStrideN, - block_64); - if (!headSize_mul64 || istail) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - trans_rows, - qk_gemm_K, - block_64, - block_64 - ); - } - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); - } - } - - // sdpa core - for (int64_t k = 0; k < qSlice; k++) { - int64_t m = k * qSplitSize; - int64_t qBlockSize = std::min(qSplitSize, qSize - m); - // Initialize sum and max - fill_stub( - sfm_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - a_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); - int64_t num_keys = - is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; - copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - qBlockSize, - headSize, - qBlockSize, - qk_gemm_K, - qStrideM); - // sum q - if (k_zp != 0) { - _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, - q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); - } else { - fill_stub( - q_sum_ptr, static_cast(0), qSplitSize); - } - const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; - for (int64_t l = 0; l < rkvSlice; l++) { - int64_t n = l * kvSplitSize; - int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - // Calculate q @ k.T - for (int64_t b = 0; b < kvBlockSize; b += block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvSplitSize, //ldc, - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } - - // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 - accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; - if (has_attn_mask) { - mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - _dequant_mask_max_fusion_kernel( - qk_s32_data, //in - mask_data_offset, //mask_ptr - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvSplitSize, //ldi - mStrideM, //ldm - rndkvSplitSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } else { - _dequant_max_fusion_kernel( - qk_s32_data, //in - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvSplitSize, //ldi - rndkvSplitSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } - } - // sub max, exp, sum reduce, div sum for softmax - // and quant - // and sum for attention - if (v_zp == 0) { - _sub_exp_sum_div_quant_fusion_kernel( - qk_data, //in - qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlices - qSplitSize * rndkvSplitSize, //ldi - qk_reduce_strideL, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr //sfm_sum_ptr - ); - } else { - _sub_exp_sum_div_quant_sum_fusion_kernel( - qk_data, //in - qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlice - qSplitSize * rndkvSplitSize, //ldi - qk_reduce_strideL, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - v_zp, // zp_b=beta2 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr, //sfm_sum_ptr - a_sum_ptr //a_sum_ptr - ); - } - // Calculate Softmax(q @ k.T) @ v - for (int64_t b = 0; b < headSize; b += block_64) { - auto value_reorder_b = value_reorder_ptr + b * av_gemm_K; - auto dst_s32_b = dst_s32_data + b; - for (int64_t s = 0; s < kvSlice; s++) { - at::native::cpublas::brgemm( - qSplitSize, block_64, av_gemm_K, - av_gemm_K, // lda - rndHeadSize, //block_64, //ldb - rndHeadSize, //ldc - s != 0, - qk_reduced_data + s * qk_reduce_strideL, - value_reorder_b + s * v_reorder_strideL, - dst_s32_b); - } - } - - // After the last gemm, - // do dequant compensation, quant and convert from s32 to int8 - if (a_zp == 0) { - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); - } else { - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - v_sum_ptr, //sum_b_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); - } - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head); - } - }); - // Once all computations are done, need to release HW context. - at::native::cpublas::brgemm_release(); -} - -// UINT8 - several parallel loops with u8u8s32 GEMM -template = 0> -inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( - const at::Tensor& output, - const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, - double dropout_p, - bool is_causal, - std::optional attention_mask, - double scale, - float q_scale, - int32_t q_zp, - float k_scale, - int32_t k_zp, - float v_scale, - int32_t v_zp, - float a_scale, - int32_t a_zp, - float o_scale, - int32_t o_zp) { - // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) - // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - at::Tensor query = q.transpose(1, 2); - at::Tensor key = k.transpose(1, 2); - at::Tensor value = v.transpose(1, 2); - - using accum_t = float; - accum_t scaling_factor = calculate_scale(query, scale); - int block_64 = 64; - auto u8_dt = at::ScalarType::Byte; - - // Sizes - TORCH_CHECK( - (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), - "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); - TORCH_CHECK( - kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); - - int64_t batchSize = query.size(0); - int64_t qSize = query.size(1); - int64_t kvSize = value.size(1); - int64_t num_head = query.size(2); - int64_t headSize = query.size(3); - - bool has_attn_mask = attention_mask.has_value() && attention_mask.value().numel(); - if (has_attn_mask) { - reshape_attn_mask_to_4d(attention_mask.value(), batchSize, num_head, qSize, kvSize); - } - - // Strides - int64_t qStrideB = query.stride(0); - int64_t qStrideM = query.stride(1); - int64_t qStrideH = query.stride(2); - int64_t kStrideB = key.stride(0); - int64_t kStrideN = key.stride(1); - int64_t kStrideH = key.stride(2); - int64_t vStrideB = value.stride(0); - int64_t vStrideN = value.stride(1); - int64_t vStrideH = value.stride(2); - int64_t oStrideB = output.stride(0); - int64_t oStrideM = output.stride(1); - int64_t oStrideH = output.stride(2); - int64_t mStrideB = - (has_attn_mask && attention_mask.value().size(0) > 1) - ? attention_mask.value().stride(0) - : 0; - int64_t mStrideH = - (has_attn_mask && attention_mask.value().size(1) > 1) - ? attention_mask.value().stride(1) - : 0; - int64_t mStrideM = - (has_attn_mask && attention_mask.value().size(2) > 1) - ? attention_mask.value().stride(2) - : 0; - int64_t mStrideN = - (has_attn_mask && attention_mask.value().size(3) > 1) - ? attention_mask.value().stride(3) - : 0; - - int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; - int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; - int64_t qSlice = (qSize - 1) / qSplitSize + 1; - int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; - int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; - int64_t num_thread = at::get_num_threads(); - - int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; - int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; - int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; - int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; - - bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; - int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; - int av_gemm_K = kvSplitSize + av_gemm_K_padding; - - // Data ptrs - scalar_t* q_data = query.data_ptr(); - scalar_t* k_data = key.data_ptr(); - scalar_t* v_data = value.data_ptr(); - mask_t* mask_data = attention_mask.has_value() - ? attention_mask.value().data_ptr() - : nullptr; - scalar_t* out_data = output.data_ptr(); - - bool headSize_mul64 = headSize % 64 == 0; - int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64; - int qk_gemm_K = headSize + qk_gemm_K_padding; - - int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; - int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; - - int64_t total_size_uint8_per_thread = - /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + - /* qk_local */ kvSlice * av_gemm_K * 4 + - /* qk_reduce */ kvSlice * qk_reduce_strideL + - /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + - /* dst_s32 */ qSplitSize * rndHeadSize * 4 + - /* softmax_sum */ qSplitSize * 4 + - /* query_sum */ qSplitSize * 4 + - /* attention_sum */ qSplitSize * 4 + - /* softmax max */ qSplitSize * 4 + - /* query_padding_data */ qSplitSize * qk_gemm_K; - - at::Tensor total_buf = at::empty( - {num_thread, total_size_uint8_per_thread}, - query.options()); - scalar_t* total_buf_data = total_buf.data_ptr(); - - int64_t kv_sum_size_per_BH = - /* key_sum */ kvSize + - /* value_sum */ headSize; - - at::Tensor kv_sum_buf = at::empty( - {batchSize, num_head, kv_sum_size_per_BH}, - query.options().dtype(at::kInt)); - int32_t* kv_sum_buf_data = kv_sum_buf.data_ptr(); - - int64_t kv_reorder_size_per_BH = - /* key_t_reorder */ qk_gemm_K * rndkvSize + - /* value_t_reorder */ kvSlice * v_reorder_strideL; - - at::Tensor kv_reorder_buf = at::empty( - {batchSize, num_head, kv_reorder_size_per_BH}, - query.options()); - scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); - scalar_t* key_reorder_ptr = kv_reorder_buf_data; - scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; - - // sum k and v - at::parallel_for( - 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - int32_t* kv_sum_ptr = kv_sum_buf_data - + i * num_head * kv_sum_size_per_BH - + j * kv_sum_size_per_BH; - int32_t* k_sum_ptr = kv_sum_ptr; - int32_t* v_sum_ptr = kv_sum_ptr + kvSize; - if (q_zp == 0) { - fill_stub(k_sum_ptr, static_cast(0), kvSize); - } else { - _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, - k_sum_ptr, - kvSize, headSize, kStrideN, q_zp); - } - if (a_zp == 0) { - fill_stub(v_sum_ptr, static_cast(0), headSize); - } else { - _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, - v_sum_ptr, - headSize, kvSize, vStrideN, a_zp); - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head); - } - }); - - // transpose and packing - at::parallel_for( - 0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0, l = 0, n = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head, l, kvSlice); - uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * block_64]; - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - n = l * kvSplitSize; - auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K; - auto v_reorder = value_reorder_ptr + - i * num_head * kvSlice * v_reorder_strideL + - j * kvSlice * v_reorder_strideL + n * rndHeadSize; - int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - for (int64_t b = 0; b < kvBlockSize; b += block_64) { - bool istail = kvBlockSize - b < block_64; - int64_t trans_rows = istail ? kvBlockSize - b : block_64; - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - trans_rows, - headSize, - kStrideN, - block_64); - if (!headSize_mul64 || istail) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - trans_rows, - qk_gemm_K, - block_64, - block_64 - ); - } - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - v_reorder + av_gemm_K * b); - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); - } - }); - - at::parallel_for( - 0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0, k = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head, k, qSlice); - int ompIdx = at::get_thread_num(); - scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; - int32_t offset = 0; - accum_t* qk_data = reinterpret_cast(total_buf_ptr); - offset += kvSlice * qSplitSize * rndkvSplitSize * 4; - accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * av_gemm_K * 4; - scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * qk_reduce_strideL; - int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndkvSplitSize * 4; - int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndHeadSize * 4; - accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); - - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - - int32_t* kv_sum_ptr = kv_sum_buf_data - + i * num_head * kv_sum_size_per_BH - + j * kv_sum_size_per_BH; - int32_t* k_sum_ptr = kv_sum_ptr; - int32_t* v_sum_ptr = kv_sum_ptr + kvSize; - - // sdpa core - int64_t m = k * qSplitSize; - int64_t qBlockSize = std::min(qSplitSize, qSize - m); - // Initialize sum and max - fill_stub( - sfm_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - a_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); - copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - qBlockSize, - headSize, - qSplitSize, - qk_gemm_K, - qStrideM); - // sum q - if (k_zp != 0) { - _int_sum_b_contiguous_kernel(query_t_padding_ptr, - q_sum_ptr, qBlockSize, headSize, qk_gemm_K, k_zp); - } else { - fill_stub( - q_sum_ptr, static_cast(0), qSplitSize); - } - const int64_t rkvSlice = (kvSize - 1) / kvSplitSize + 1; - for (int64_t l = 0; l < rkvSlice; l++) { - int64_t n = l * kvSplitSize; - int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K; - // Calculate q @ k.T - for (int64_t b = 0; b < kvBlockSize; b += block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvSplitSize, //ldc, - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } - - // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 - accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; - if (has_attn_mask) { - mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - _dequant_mask_max_fusion_kernel( - qk_s32_data, //in - mask_data_offset, //mask_ptr - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvSplitSize, //ldi - mStrideM, //ldm - rndkvSplitSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } else { - _dequant_max_fusion_kernel( - qk_s32_data, //in - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvSplitSize, //ldi - rndkvSplitSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } - } - // sub max, exp, sum reduce, div sum for softmax - // and quant - // and sum for attention - if (v_zp == 0) { - _sub_exp_sum_div_quant_fusion_kernel( - qk_data, //in - qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlices - qSplitSize * rndkvSplitSize, //ldi - qk_reduce_strideL, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr //sfm_sum_ptr - ); - } else { - _sub_exp_sum_div_quant_sum_fusion_kernel( - qk_data, //in - qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlice - qSplitSize * rndkvSplitSize, //ldi - qk_reduce_strideL, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - v_zp, // zp_b=beta2 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr, //sfm_sum_ptr - a_sum_ptr //a_sum_ptr - ); - } - // Calculate Softmax(q @ k.T) @ v - auto v_reorder = value_reorder_ptr + - i * num_head * kvSlice * v_reorder_strideL + - j * kvSlice * v_reorder_strideL; - for (int64_t b = 0; b < headSize; b += block_64) { - auto value_reorder_b = v_reorder + b * av_gemm_K; - auto dst_s32_b = dst_s32_data + b; - for (int64_t s = 0; s < kvSlice; s++) { - at::native::cpublas::brgemm( - qSplitSize, block_64, av_gemm_K, - av_gemm_K, // lda - rndHeadSize, //ldb - rndHeadSize, //ldc - s != 0, - qk_reduced_data + s * qk_reduce_strideL, - value_reorder_b + s * v_reorder_strideL, - dst_s32_b); - } - } - - // After the last gemm, - // do dequant compensation, quant and convert from s32 to int8 - if (a_zp == 0) { - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); - } else { - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - v_sum_ptr, //sum_b_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); - } - }); - // Once all computations are done, need to release HW context. - at::native::cpublas::brgemm_release(); -} - - -template -inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( - bool use_one_parallel_loop, - const at::Tensor& output, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - double dropout_p, - bool is_causal, - std::optional attn_mask, - double scale, - float q_scale, - int32_t q_zp, - float k_scale, - int32_t k_zp, - float v_scale, - int32_t v_zp, - float a_scale, - int32_t a_zp, - float o_scale, - int32_t o_zp) { - if (use_one_parallel_loop) { - sdpa_int8_fused_kernel_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - } else { - sdpa_int8_fused_kernel_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - } -} - - -#define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, \ - NAME, \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Bool, mask_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Float, mask_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Double, mask_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::BFloat16, mask_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Half, mask_t, __VA_ARGS__)) - -void sdpa_int8_fused_kernel( - const at::Tensor& output, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - double dropout_p, - bool is_causal, - std::optional attn_mask, - double scale, - float q_scale, - int32_t q_zp, - float k_scale, - int32_t k_zp, - float v_scale, - int32_t v_zp, - float a_scale, - int32_t a_zp, - float o_scale, - int32_t o_zp) { - TORCH_CHECK(query.scalar_type() == c10::kByte); - int64_t batchSize = query.size(0); - int64_t num_head = query.size(1); - int64_t q_seq_len = query.size(2); - int64_t kv_seq_len = key.size(2); - int64_t q_split_size = 32; - if (q_seq_len >= 768) { - q_split_size = 256; - } else if (q_seq_len >= 192) { - q_split_size = 64; - } - // Heuristic to decide whether to use one parallel loop or not - // true: one parallel loop for sum+packing+core - // false: three parallel loops for sum, packing, core - uint32_t l2_cache_size = at::cpu::L2_cache_size(); - int64_t num_thread = at::get_num_threads(); - int64_t attn_size = q_split_size * kv_seq_len * sizeof(int32_t) * num_thread; - bool use_one_parallel_loop = (batchSize * num_head > num_thread) && - (attn_size > 1.5 * l2_cache_size); - if (!attn_mask.has_value()) { - if (q_split_size == 256) { - sdpa_int8_fused_kernel_impl( - use_one_parallel_loop, - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - } else if (q_split_size == 64) { - sdpa_int8_fused_kernel_impl( - use_one_parallel_loop, - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - } else { - sdpa_int8_fused_kernel_impl( - use_one_parallel_loop, - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - } - } else { - AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { - if (q_split_size == 256) { - sdpa_int8_fused_kernel_impl( - use_one_parallel_loop, - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - } else if (q_split_size == 64) { - sdpa_int8_fused_kernel_impl( - use_one_parallel_loop, - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - } else { - sdpa_int8_fused_kernel_impl( - use_one_parallel_loop, - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - } - }); - } -} -#endif // CPU_CAPABILITY_AVX512 - -at::Tensor sdpa_int8_math_kernel( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - double dropout_p, - bool is_causal, - std::optional attn_mask, - double scale, - float q_scale, - int32_t q_zp, - float k_scale, - int32_t k_zp, - float v_scale, - int32_t v_zp, - float a_scale, - int32_t a_zp, - float o_scale, - int32_t o_zp) { - // dequant q/k/v - auto q = (query.to(at::kFloat) - q_zp) * q_scale; - auto k = (key.to(at::kFloat) - k_zp) * k_scale; - auto v = (value.to(at::kFloat) - v_zp) * v_scale; - const auto scaling_factor = calculate_scale(q, scale); - auto attn = at::matmul(q, k.transpose(-2, -1)) * scaling_factor; - if (attn_mask.has_value() && attn_mask.value().numel()) { - attn = attn.add(attn_mask.value().to(at::kFloat)); - } - attn = at::softmax(attn, -1); - // quant attn - attn = at::clamp_max( - at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 - ); - // dequant attn - attn = (attn - a_zp) * a_scale; - auto output = at::matmul(attn, v); - // quant output - output = at::clamp_max( - at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 - ).to(at::kByte); - return output; -} - - -at::Tensor _scaled_dot_product_int8_cpu( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - std::optional attn_mask, - double dropout_p, - bool is_causal, - double scale, - double q_scale, - int64_t q_zp, - double k_scale, - int64_t k_zp, - double v_scale, - int64_t v_zp, - double a_scale, - int64_t a_zp, - double o_scale, - int64_t o_zp) { - const auto dtype = query.scalar_type(); - TORCH_CHECK(!query.is_nested() && !key.is_nested() && !value.is_nested(), - "_scaled_dot_product_int8_cpu: Only accept plain inputs"); - TORCH_CHECK(!is_causal, - "_scaled_dot_product_int8_cpu: is_causal not supported."); - TORCH_CHECK(dtype == at::ScalarType::Byte, - "_scaled_dot_product_int8_cpu: Expected data type be U8, but got ", dtype, " instead."); - TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, - "_scaled_dot_product_int8_cpu: Accept only 4 dims inputs shape of {B, H, T, K}"); - TORCH_CHECK(dropout_p == 0.0, - "_scaled_dot_product_int8_cpu: Currently do not support dropout > 0"); - TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), - "_scaled_dot_product_int8_cpu: Q/K/V should have the same head size"); - TORCH_CHECK(!attn_mask.has_value() || - attn_mask.value().scalar_type() == at::kFloat || - attn_mask.value().scalar_type() == at::kBFloat16, - "_scaled_dot_product_int8_cpu: Expected attention mask be float or bf16"); - TORCH_CHECK(!attn_mask.has_value() || - (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4), - "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); - - #ifdef CPU_CAPABILITY_AVX512 - if (at::native::cpublas::could_pack(dtype)) { - at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); - sdpa_int8_fused_kernel(output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - return output.transpose(1, 2); - } else { - #endif // CPU_CAPABILITY_AVX512 - return sdpa_int8_math_kernel(query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp).transpose(1, 2).contiguous().transpose(1, 2); - #ifdef CPU_CAPABILITY_AVX512 - } - #endif // CPU_CAPABILITY_AVX512 -} - - -} // anonymous namespace - -TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::scaled_dot_product_int8", &_scaled_dot_product_int8_cpu); -} - -// } // at::native -} // namespace torchao diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp deleted file mode 100644 index bfaf6bf466..0000000000 --- a/torchao/csrc/rocm/swizzle/swizzle.cpp +++ /dev/null @@ -1,911 +0,0 @@ -// setup.py glob includes all *.cpp files -// but only build this for ROCm -#ifdef USE_ROCM -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using at::Scalar; -using at::Tensor; -using at::TensorArg; -using c10::kFloat; -using c10::ScalarType; -using c10::IntArrayRef; -using at::cuda::ScalarTypeToCudaDataType; - -// -// copied from aten/src/ATen/cuda/CUDABlas.cpp -// -namespace { - -static hipblasOperation_t _cublasOpFromChar(char op) { - // NOLINTNEXTLINE(bugprone-switch-missing-default-case) - switch (op) { - case 'n': - case 'N': - return HIPBLAS_OP_N; - case 't': - case 'T': - return HIPBLAS_OP_T; - case 'c': - case 'C': - return HIPBLAS_OP_C; - } - TORCH_CHECK(false, - "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); -} - -static void _cublasAdjustLdLevel3( - char transa, - char transb, - int64_t m, - int64_t n, - int64_t k, - int64_t* lda, - int64_t* ldb, - int64_t* ldc) { - bool transa_ = ((transa != 'n') && (transa != 'N')); - bool transb_ = ((transb != 'n') && (transb != 'N')); - - // Note: leading dimensions generally are checked that they are > 0 - // and at least as big the result requires (even if the value won't - // be used). - if (n <= 1) - *ldc = std::max(m, 1); - - if (transa_) { - if (m <= 1) - *lda = std::max(k, 1); - } else { - if (k <= 1) - *lda = std::max(m, 1); - } - - if (transb_) { - if (k <= 1) - *ldb = std::max(n, 1); - } else { - if (n <= 1) - *ldb = std::max(k, 1); - } -} - -// Following the pattern of CuSparseDescriptor -// Defined here for now because this is the only place cublas_lt interface is -// used but can be moved to a header once cublas_lt interface is used in -// multiple places. -template -struct HipBlasLtDeleter { - void operator()(T* x) { - if (x != nullptr) { - TORCH_CUDABLAS_CHECK(destructor(x)); - } - } -}; - -template -class HipBlasLtDescriptor { - public: - T* descriptor() const { - return descriptor_.get(); - } - T* descriptor() { - return descriptor_.get(); - } - - protected: - std::unique_ptr> descriptor_; -}; - -class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor< - hipblasLtMatmulDescOpaque_t, - &hipblasLtMatmulDescDestroy> { - public: - HipBlasLtMatmulDescriptor( - hipblasComputeType_t compute_type, - hipDataType scale_type) { - hipblasLtMatmulDesc_t raw_descriptor = nullptr; - TORCH_CUDABLAS_CHECK( - hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); - descriptor_.reset(raw_descriptor); - } - template - inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) { - // NOLINTNEXTLINE(bugprone-sizeof-expression) - TORCH_CUDABLAS_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value))); - } -}; - -class HipBlasLtMatrixLayout : public HipBlasLtDescriptor< - hipblasLtMatrixLayoutOpaque_t, - &hipblasLtMatrixLayoutDestroy> { - public: - HipBlasLtMatrixLayout( - hipDataType type, - uint64_t rows, - uint64_t cols, - int64_t ld, - bool t = false) { - hipblasLtMatrixLayout_t raw_descriptor = nullptr; - TORCH_CUDABLAS_CHECK( - hipblasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld)); - descriptor_.reset(raw_descriptor); - } - template - inline void setAttribute(hipblasLtMatrixLayoutAttribute_t attr, const T value) { - TORCH_CUDABLAS_CHECK(::hipblasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T))); - } -}; - -class HipBlasLtMatmulPreference : public HipBlasLtDescriptor< - hipblasLtMatmulPreferenceOpaque_t, - &hipblasLtMatmulPreferenceDestroy> { - public: - HipBlasLtMatmulPreference() { - hipblasLtMatmulPreference_t raw_descriptor = nullptr; - TORCH_CUDABLAS_CHECK(hipblasLtMatmulPreferenceCreate(&raw_descriptor)); - descriptor_.reset(raw_descriptor); - } - template - inline void setAttribute(hipblasLtMatmulPreferenceAttributes_t attr, const T value) { - TORCH_CUDABLAS_CHECK(::hipblasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T))); - } -}; - -static size_t _parseChosenWorkspaceSize() { - auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE"); - if (!val.has_value()) { - // accept either env var - val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); - } - size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */ - - if (val.has_value()) { - try { - workspace_size = std::stoi(val.value()); - } catch(std::invalid_argument const& e) { - TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,", - " using default workspace size of ", workspace_size, " KiB."); - } catch(std::out_of_range const& e) { - TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,", - " using default workspace size of ", workspace_size, " KiB."); - } - } - return workspace_size * 1024; -} - -static size_t _getWorkspaceSize() { - static size_t workspace_size = _parseChosenWorkspaceSize(); - return workspace_size; -} - -static bool _scaled_mm_is_fnuz() { - auto dprops = at::cuda::getCurrentDeviceProperties(); - std::string device_arch = dprops->gcnArchName; - static const std::vector archs = {"gfx940", "gfx941", "gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; -} - -} // namespace - -// -// copied from aten/src/ATen/native/cuda/Blas.cpp -// -namespace { - -// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 -c10::MaybeOwned inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { - if (resolve_conj && tensor.is_conj()) { - return c10::MaybeOwned::owned(tensor.resolve_conj()); - } else { - return c10::MaybeOwned::borrowed(tensor); - } -} - -c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { - if (tensor.is_non_overlapping_and_dense()) { // common case - transpose_tensor = tensor.is_contiguous(); - return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); - } - IntArrayRef tensor_strides = tensor.strides(); - IntArrayRef tensor_sizes = tensor.sizes(); - if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { - transpose_tensor = false; - return resolve_conj_if_indicated(tensor, !transpose_result); - } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { - transpose_tensor = true; - return resolve_conj_if_indicated(tensor, transpose_result); - } else { - transpose_tensor = true; - return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); - } -} - -c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { - if (tensor.is_non_overlapping_and_dense()) { // common case - transpose_tensor = tensor.is_contiguous(); - return resolve_conj_if_indicated(tensor, true); - } - - IntArrayRef tensor_strides = tensor.strides(); - IntArrayRef tensor_sizes = tensor.sizes(); - if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { - transpose_tensor = false; - return resolve_conj_if_indicated(tensor, true); - } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { - transpose_tensor = true; - return resolve_conj_if_indicated(tensor, true); - } else { - transpose_tensor = true; - return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); - } -} - -struct cublasCommonArgs { - cublasCommonArgs( - const Tensor& mat1, - const Tensor& mat2, - bool swizzle1, - bool swizzle2, - Tensor& c, - const std::optional& scale_a = std::nullopt, - const std::optional& scale_b = std::nullopt, - const std::optional& scale_result = std::nullopt) { - bool transpose_result = false, transpose_a = false, transpose_b = false; - result = prepare_matrix_for_cublas(c, transpose_result); - mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); - matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result); - - // Handle scale tensors if provided - if (scale_a && scale_b) { - // By default since we return in row-major we run the gemm - // as B.T @ A.T, check transpose_result to determine if we flip the scales - scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); - scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); - scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); - scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); - } - - if (scale_result) { - scale_result_ptr = scale_result->data_ptr(); - scale_result_dtype = scale_result->scalar_type(); - } - - // Update transpose flags - if (transpose_result) { - transpose_a = !transpose_a; - transpose_b = !transpose_b; - } - - auto sizes_a = mata->sizes(); - auto sizes_b = matb->sizes(); - - m = sizes_a[transpose_result ? 1 : 0]; - k = sizes_a[transpose_result ? 0 : 1]; - n = sizes_b[transpose_result ? 0 : 1]; - lda = mata->stride((transpose_a == transpose_result) ? 1 : 0); - ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0); - result_ld = result->stride(transpose_result ? 0 : 1); - transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n'; - transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n'; - - mata_is_swizzled = transpose_result ? swizzle2 : swizzle1; - matb_is_swizzled = transpose_result ? swizzle1 : swizzle2; - } - - // Matrix members - char transa, transb; - int64_t m, n, k; - int64_t lda, ldb, result_ld; - c10::MaybeOwned mata, matb, result; - - // Scale members - void* scale_mata_ptr = nullptr; - void* scale_matb_ptr = nullptr; - void* scale_result_ptr = nullptr; - std::optional scale_mata_dtype; - std::optional scale_matb_dtype; - std::optional scale_result_dtype; - - // swizzle members - bool mata_is_swizzled; - bool matb_is_swizzled; -}; - -enum class ScalingType { - TensorWise, - RowWise, - Error -}; - -ScalingType get_scaling_type( - const at::Tensor& scale_a, - const at::Tensor& scale_b, - int64_t dim_m, - int64_t dim_n) { - // Both Per-Tensor and Row-wise scaling expect fp32 tensors - TORCH_CHECK( - scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, - "Both scale_a and scale_b must be float (fp32) tensors."); - - // Check the singluar scale case for per-tensor scaling - if (scale_a.numel() == 1 && scale_b.numel() == 1) { - return ScalingType::TensorWise; - } - - // For non-TensorWise scaling, enforce 2D input tensors - TORCH_CHECK( - scale_a.dim() == 2 && scale_b.dim() == 2, - "For non-TensorWise scaling, scale tensors must be 2-dimensional, " - "but got scale_a.dim()=", - scale_a.dim(), - " and scale_b.dim()=", - scale_b.dim()); - - // Check for RowWise scaling - if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && - scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { -#if defined(HIPBLASLT_VEC_EXT) - TORCH_CHECK( - scale_a.is_contiguous() && scale_b.is_contiguous(), - "Both scale_a and scale_b must be contiguous for RowWise scaling."); - return ScalingType::RowWise; -#else - TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); - return ScalingType::Error; -#endif - } - - // If we reach here, the input doesn't match any valid scaling type - TORCH_CHECK( - false, - "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " - "For RowWise scaling, scale_a should be (", - dim_m, - ", 1) and scale_b should be (1, ", - dim_n, - "). " - "Got scale_a.size()=(", - scale_a.size(0), - ", ", - scale_a.size(1), - ") and ", - "scale_b.size()=(", - scale_b.size(0), - ", ", - scale_b.size(1), - ")"); - - return ScalingType::Error; -} - -} // namespace - -template -inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype), bool mat1_is_swizzled, bool mat2_is_swizzled) { - hipDataType abcType = HIP_R_32F; - hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; - hipDataType scaleType = HIP_R_32F; - if constexpr (std::is_same_v) { - abcType = HIP_R_64F; - computeType = HIPBLAS_COMPUTE_64F; - scaleType = HIP_R_64F; - } else if constexpr (std::is_same_v) { - } else if constexpr (std::is_same_v>) { - abcType = HIP_C_64F; - computeType = HIPBLAS_COMPUTE_64F; - scaleType = HIP_C_64F; - } else if constexpr (std::is_same_v>) { - abcType = HIP_C_32F; - scaleType = HIP_C_32F; - } else if constexpr (std::is_same_v) { - abcType = HIP_R_16F; - } else if constexpr (std::is_same_v) { - abcType = HIP_R_16BF; - } else { - static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented"); - } - - hipblasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); - hipblasOperation_t opa = _cublasOpFromChar(transa); - hipblasOperation_t opb = _cublasOpFromChar(transb); - _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - - HipBlasLtMatmulDescriptor computeDesc(computeType, scaleType); - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); - HipBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == HIPBLAS_OP_T); - HipBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == HIPBLAS_OP_T); - HipBlasLtMatrixLayout Cdesc(abcType, m, n, ldc); -#ifdef HIPBLASLT_HAS_ORDER_COL16 - if (mat1_is_swizzled) { - Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R8); - } - if (mat2_is_swizzled) { - Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R8); - } -#endif - - if (num_batches > 1) { - int num_batches_as_int = static_cast(num_batches); - Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); - Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); - Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); - Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridea); - Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, strideb); - Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec); - } - - hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, epilogue); - - HipBlasLtMatmulPreference preference; - // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind - // setting this to 1M. - size_t workspaceSize = _getWorkspaceSize(); - preference.setAttribute(HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); - - auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); - - hipblasLtMatmulHeuristicResult_t heuristicResult = {}; - int returnedResult = 0; - TORCH_CUDABLAS_CHECK(hipblasLtMatmulAlgoGetHeuristic( - ltHandle, - computeDesc.descriptor(), - Adesc.descriptor(), - Bdesc.descriptor(), - Cdesc.descriptor(), - Cdesc.descriptor(), - preference.descriptor(), - 1, - &heuristicResult, - &returnedResult)); - if (returnedResult == 0) { - TORCH_CUDABLAS_CHECK(HIPBLAS_STATUS_NOT_SUPPORTED); - } - - hipblasStatus_t cublasStatus = hipblasLtMatmul( - ltHandle, - computeDesc.descriptor(), - &alpha, - a, - Adesc.descriptor(), - b, - Bdesc.descriptor(), - &beta, - c, - Cdesc.descriptor(), - c, - Cdesc.descriptor(), - &heuristicResult.algo, - workspace.mutable_data_ptr(), - workspaceSize, - at::hip::getCurrentHIPStreamMasqueradingAsCUDA()); - TORCH_CHECK( - cublasStatus == HIPBLAS_STATUS_SUCCESS, - "CUDA error: ", - at::cuda::blas::_cublasGetErrorEnum(cublasStatus), - " when calling hipblasLtMatmul with transpose_mat1 ", - (opa == HIPBLAS_OP_T), - " transpose_mat2 ", - (opb == HIPBLAS_OP_T), - " m ", - m, - " n ", - n, - " k ", - k, - " lda ", - lda, - " ldb ", - ldb, - " ldc ", - ldc, - " abcType ", - abcType, - " computeType ", - computeType, - " scaleType ", - scaleType); -} - - -template -inline void gemm_hipblaslt(CUDABLAS_GEMM_ARGTYPES(Dtype), bool mat1_is_swizzled, bool mat2_is_swizzled) { - // forward to bgemm implementation but set strides and batches to 0 - bgemm_hipblaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0, mat1_is_swizzled, mat2_is_swizzled); -} - - -Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) { - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - - // NOLINTNEXTLINE(*c-array*) - TensorArg targs[]{{mat1, "mat1", 0}, {mat2, "mat2", 1}}; - checkAllSameGPU(__func__, targs); - - Tensor meta_mat1 = mat1.to("meta"); - Tensor meta_mat2 = mat2.to("meta"); - Tensor meta_result = at::mm(meta_mat1, meta_mat2); - Tensor result = at::empty_like(meta_result, mat1.device()); - at::ScalarType scalar_type = result.scalar_type(); - - cublasCommonArgs args(mat1, mat2, mat1_is_swizzled, mat2_is_swizzled, result); - - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - scalar_type, - "addmm_cuda", - [&] { - using opmath_t = at::opmath_type; - opmath_t alpha_val = opmath_t(1.0); - opmath_t beta_val = opmath_t(0.0); - const scalar_t* mat1_ptr = args.mata->const_data_ptr(); - const scalar_t* mat2_ptr = args.matb->const_data_ptr(); - scalar_t* result_ptr = args.result->mutable_data_ptr(); - gemm_hipblaslt( - args.transa, - args.transb, - args.m, - args.n, - args.k, - alpha_val, - mat1_ptr, - args.lda, - mat2_ptr, - args.ldb, - beta_val, - result_ptr, - args.result_ld, - args.mata_is_swizzled, - args.matb_is_swizzled); - }); - - return result; -} - -void _scaled_gemm( - char transa, - char transb, - int64_t m, - int64_t n, - int64_t k, - const void* mat1_ptr, - const void* mat1_scale_ptr, - int64_t mat1_ld, - ScalarType mat1_dtype, - ScalarType mat1_scale_dtype, - bool mat1_is_swizzled, - const void* mat2_ptr, - const void* mat2_scale_ptr, - int64_t mat2_ld, - ScalarType mat2_dtype, - ScalarType mat2_scale_dtype, - bool mat2_is_swizzled, - const void* bias_ptr, - ScalarType bias_dtype, - void* result_ptr, - const void *result_scale_ptr, - int64_t result_ld, - ScalarType result_dtype, - bool use_rowwise) { - const auto computeType = HIPBLAS_COMPUTE_32F; - const auto scaleType = HIP_R_32F; - const float alpha_val = 1.0; - const float beta_val = 0.0; - HipBlasLtMatmulDescriptor computeDesc(computeType, scaleType); - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); - hipblasLtMatmulDescAttributes_t matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; - hipblasLtMatmulDescAttributes_t matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; -#if defined(HIPBLASLT_VEC_EXT) - if (use_rowwise) { - matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; - matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; - } -#else - // rowwise isn't supported using cublaslt or older hipblaslt - TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); -#endif - computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); - computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); - if (result_scale_ptr != nullptr) { - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); - } - HipBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't'); - HipBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't'); - // Cdesc is unused, beta is 0. But hipblaslt needs this set to something reasonable. - HipBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld); - HipBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld); - if (bias_ptr) { - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); - computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); - } - -#ifdef HIPBLASLT_HAS_ORDER_COL16 - if (mat1_is_swizzled) { - Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R16); - } - if (mat2_is_swizzled) { - Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R16); - } -#endif - - auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(); - size_t workspaceSize = _getWorkspaceSize(); - auto& allocator = *::c10::hip::HIPCachingAllocatorMasqueradingAsCUDA::get(); - auto workspace = allocator.allocate(workspaceSize); - auto workspace_ptr = workspace.mutable_get(); - TORCH_CHECK(workspace_ptr != nullptr, "OOM trying to allocate workspace for cublaslt"); - - HipBlasLtMatmulPreference preference; - preference.setAttribute(HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); - hipblasLtMatmulHeuristicResult_t heuristicResult = {}; - int returnedResult = 0; - hipblasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); - TORCH_CUDABLAS_CHECK(hipblasLtMatmulAlgoGetHeuristic( - ltHandle, - computeDesc.descriptor(), - Adesc.descriptor(), - Bdesc.descriptor(), - Cdesc.descriptor(), - Ddesc.descriptor(), - preference.descriptor(), - 1, - &heuristicResult, - &returnedResult)); - if (returnedResult == 0) { - // hipblaslt might be able to recover by returning all algos - std::vector all_algos; - TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAllAlgos( - ltHandle, - hipblaslt_ext::GemmType::HIPBLASLT_GEMM, - _cublasOpFromChar(transa), - _cublasOpFromChar(transb), - ScalarTypeToCudaDataType(mat1_dtype), - ScalarTypeToCudaDataType(mat2_dtype), - // C is nullptr and beta=0, so set to something reasonable. See above. - //ScalarTypeToCudaDataType(bias_dtype), - ScalarTypeToCudaDataType(result_dtype), - ScalarTypeToCudaDataType(result_dtype), - HIPBLAS_COMPUTE_32F, - all_algos)); - if (all_algos.size() == 0) { - TORCH_CUDABLAS_CHECK(HIPBLAS_STATUS_NOT_SUPPORTED); - } - // pick first valid solution - bool found = false; - for (size_t i = 0; i < all_algos.size(); i++) { - size_t ret_workspace_size = 0; - auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported( - ltHandle, - computeDesc.descriptor(), - &alpha_val, - Adesc.descriptor(), - Bdesc.descriptor(), - &beta_val, - Cdesc.descriptor(), - Ddesc.descriptor(), - all_algos[i].algo, - ret_workspace_size); - if (is_valid_status == HIPBLAS_STATUS_SUCCESS) { - if (ret_workspace_size <= workspaceSize) { - heuristicResult = all_algos[i]; - found = true; - break; - } - } - } - TORCH_CHECK(found, "could not find valid hipblaslt solution"); - } - hipblasStatus_t cublasStatus = hipblasLtMatmul( - ltHandle, - computeDesc.descriptor(), - &alpha_val, - mat1_ptr, - Adesc.descriptor(), - mat2_ptr, - Bdesc.descriptor(), - &beta_val, - result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr - Cdesc.descriptor(), - result_ptr, - Ddesc.descriptor(), - &heuristicResult.algo, - workspace_ptr, - workspaceSize, - stream); - TORCH_CHECK( - cublasStatus == HIPBLAS_STATUS_SUCCESS, - "CUDA error: ", - at::cuda::blas::_cublasGetErrorEnum(cublasStatus), - " when calling hipblasLtMatmul with transpose_mat1 ", - transa, - " transpose_mat2 ", - transb, - " m ", - m, - " n ", - n, - " k ", - k, - " mat1_ld ", - mat1_ld, - " mat2_ld ", - mat2_ld, - " result_ld ", - result_ld, - " computeType ", - computeType, - " scaleType ", - scaleType); - return; -} - -Tensor& -_scaled_mm_out(const Tensor& mat1, const Tensor& mat2, - bool mat1_is_swizzled, - bool mat2_is_swizzled, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - Tensor& out) { - // Check sizes - TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); - TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); - TORCH_CHECK( - mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", - mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - - // Check what type of scaling we are doing based on inputs - ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1)); - TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); - - TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), - "scale_result must be a float scalar"); - TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], - " but got ", bias->numel()); - TORCH_CHECK( - mat1.sizes()[1] % 16 == 0, - "Expected trailing dimension of mat1 to be divisible by 16 ", - "but got mat1 shape: (", - mat1.sizes()[0], - "x", - mat1.sizes()[1], - ")."); - TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x", - mat2.sizes()[1], ") must be divisible by 16"); - // Check types - TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); - TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); - TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); - if (bias) { - TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32"); - TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half, - "Bias must be either Half or BFloat16, but got ", bias->scalar_type()); - TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) || - bias->scalar_type() == ScalarType::BFloat16, - "Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); - TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half, - "Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); - } - { - auto bias_ = bias.value_or(Tensor()); - auto scale_result_ = scale_result.value_or(Tensor()); - - // NOLINTNEXTLINE(*c-array*) - TensorArg targs[]{{out, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2}, - {bias_, "bias", 3}, {scale_a, "scale_a", 4}, {scale_b, "scale_b", 5}, - {scale_result_, "scale_result", 6}}; - checkAllSameGPU(__func__, targs); - } - // Validation checks have passed lets resize the output to actual size - IntArrayRef mat1_sizes = mat1.sizes(); - IntArrayRef mat2_sizes = mat2.sizes(); - at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); - - // If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels - // do not support this case). - if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) { - // `out` was created with `at::empty`. In the case where we are multiplying - // MxK by KxN and K is the zero dim, we need to initialize here to properly - // return a tensor of zeros. - if (mat1_sizes[1] == 0) { - out.zero_(); - } - - return out; - } - - if (scaling_choice == ScalingType::RowWise) { - // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. - Tensor b = mat2; - if (_scaled_mm_is_fnuz()) { - TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz); - } - else { - TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn); - } - // Until more than bf16 is supported. - TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16, - "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); - } - - cublasCommonArgs args(mat1, mat2, mat1_is_swizzled, mat2_is_swizzled, out, scale_a, scale_b, scale_result); - const auto out_dtype_ = args.result->scalar_type(); - TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); - - { - _scaled_gemm( - args.transa, - args.transb, - args.m, - args.n, - args.k, - args.mata->data_ptr(), - args.scale_mata_ptr, - args.lda, - args.mata->scalar_type(), - args.scale_mata_dtype.value(), - args.mata_is_swizzled, - args.matb->data_ptr(), - args.scale_matb_ptr, - args.ldb, - args.matb->scalar_type(), - args.scale_matb_dtype.value(), - args.matb_is_swizzled, - bias ? bias->data_ptr(): nullptr, - bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, - args.result->data_ptr(), - args.scale_result_ptr, - args.result_ld, - out_dtype_, - scaling_choice == ScalingType::RowWise); - } - - return out; -} - -Tensor -swizzle_scaled_mm(const Tensor& mat_a, const Tensor& mat_b, - bool mat1_is_swizzled, - bool mat2_is_swizzled, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype) { - const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); - Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); - return _scaled_mm_out(mat_a, mat_b, mat1_is_swizzled, mat2_is_swizzled, scale_a, scale_b, bias, scale_result, out_dtype, out); -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::swizzle_mm", &swizzle_mm); - m.impl("torchao::swizzle_scaled_mm", &swizzle_scaled_mm); -} -#endif // USE_ROCM diff --git a/torchao/ops.py b/torchao/ops.py index 2fb17e40b5..82de7528ec 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -7,7 +7,7 @@ from typing import Optional import torch -from torch import Tensor, dtype +from torch import Tensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 @@ -39,12 +39,6 @@ lib.define( "to_sparse_semi_structured_cutlass_sm9x_f8(Tensor weight) -> (Tensor, Tensor)" ) -lib.define( - "swizzle_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) -> Tensor" -) -lib.define( - "swizzle_scaled_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None) -> Tensor" -) # Note: we need to add the `torch._C.Tag.needs_fixed_stride_order` tag in order for inductor # to honor the layout constraints for `b` in the two ops below. lib.define( @@ -55,9 +49,6 @@ "mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor", tags=[torch._C.Tag.needs_fixed_stride_order], ) -lib.define( - "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, float q_scale=1.0, int q_zp=0, float k_scale=1.0, int k_zp=0, float v_scale=1.0, int v_zp=0, float a_scale=1.0, int a_zp=0, float o_scale=1.0, int o_zp=0) -> Tensor" -) def register_custom_op(name): @@ -162,94 +153,6 @@ def _( return _in_feats.new_empty((BS, OC)) -def scaled_dot_product_int8( - query: Tensor, - key: Tensor, - value: Tensor, - attn_mask: Optional[Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: float = 0.0, - q_scale: float = 1.0, - q_zp: int = 0, - k_scale: float = 1.0, - k_zp: int = 0, - v_scale: float = 1.0, - v_zp: int = 0, - a_scale: float = 1.0, - a_zp: int = 0, - o_scale: float = 1.0, - o_zp: int = 0, -) -> Tensor: - """ - Quantized SDPA with uint8 inputs and outputs. - - Arguments - query: input query tensor, - key: input key tensor, - value: input value tensor, - attn_mask: attention mask tensor, - dropout_p: dropout probability, - is_causal: causal flag, - scale: scaling factor applied prior to softmax, - q_scale: scale for query from linear quantization, - q_zp: zero point for query from linear quantization, - k_scale: scale for key from linear quantization, - k_zp: zero point of key from linear quantization, - v_scale: zero point for value from linear quantization, - v_zp: zero point of value from linear quantization, - a_scale: scale for attention from softmax quantization, - a_zp: zero point for attention from softmax quantization, - o_scale: scale for output from linear quantization, - o_zp: zero point for output from linear quantization, - - Returns - output of quantized SDPA - """ - return torch.ops.torchao.scaled_dot_product_int8.default( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - q_scale, - q_zp, - k_scale, - k_zp, - v_scale, - v_zp, - a_scale, - a_zp, - o_scale, - o_zp, - ) - - -@register_custom_op("torchao::scaled_dot_product_int8") -def _( - query: Tensor, - key: Tensor, - value: Tensor, - attn_mask: Optional[Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: float = 0.0, - q_scale: float = 1.0, - q_zp: int = 0, - k_scale: float = 1.0, - k_zp: int = 0, - v_scale: float = 1.0, - v_zp: int = 0, - a_scale: float = 1.0, - a_zp: int = 0, - o_scale: float = 1.0, - o_zp: int = 0, -) -> Tensor: - return query - - def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. @@ -826,68 +729,6 @@ def _( ) -def swizzle_mm( - mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool -) -> Tensor: - """ - Similar to torch.mm but Tensor inputs can be SwizzleTensor instances. - - """ - return torch.ops.torchao.swizzle_mm.default( - mat1, mat2, mat1_is_swizzled, mat2_is_swizzled - ) - - -@register_custom_op("torchao::swizzle_mm") -def _( - mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool -) -> Tensor: - return mat1.new_empty(mat1.shape[0], mat2.shape[1]) - - -def swizzle_scaled_mm( - mat1: Tensor, - mat2: Tensor, - mat1_is_swizzled: bool, - mat2_is_swizzled: bool, - scale_a: Tensor, - scale_b: Tensor, - bias: Optional[Tensor], - scale_result: Optional[Tensor], - out_dtype: Optional[dtype], -) -> Tensor: - """ - Similar to torch.mm but Tensor inputs can be SwizzleTensor instances. - - """ - return torch.ops.torchao.swizzle_scaled_mm.default( - mat1, - mat2, - mat1_is_swizzled, - mat2_is_swizzled, - scale_a, - scale_b, - bias, - scale_result, - out_dtype, - ) - - -@register_custom_op("torchao::swizzle_scaled_mm") -def _( - mat1: Tensor, - mat2: Tensor, - mat1_is_swizzled: bool, - mat2_is_swizzled: bool, - scale_a: Tensor, - scale_b: Tensor, - bias: Optional[Tensor], - scale_result: Optional[Tensor], - out_dtype: Optional[dtype], -) -> Tensor: - return mat1.new_empty(mat1.shape[0], mat2.shape[1]) - - @functools.lru_cache() def _get_dtypes(): """TODO: when e8m0 is hardened and major release lets remove uint8 support""" diff --git a/torchao/prototype/inductor/__init__.py b/torchao/prototype/inductor/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/torchao/prototype/inductor/fx_passes/README.md b/torchao/prototype/inductor/fx_passes/README.md deleted file mode 100644 index 9171f508a8..0000000000 --- a/torchao/prototype/inductor/fx_passes/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# Inductor FX Passes - -This directory contains the FX passes of Inductor. FX passes are transformations applied to the FX graph to optimize and modify it for better performance and functionality. - -In TorchAO, you can replace the following customized graph passes of Inductor: -- `pre_grad_custom_pass` -- `joint_custom_pre_pass` -- `joint_custom_post_pass` -- `post_grad_custom_post_pass` -- `post_grad_custom_pre_pass` - -## Directory Structure - -- `int8_sdpa_fusion`: Pattern match for int8 sdpa fusion. - -## Getting Started - -To get started with using the FX passes in TorchAO, you can register and apply them to your FX graph as follows: - -```python -from torch._inductor import config -from torch._inductor.pattern_matcher import PatternMatcherPass - -# Example usage -patterns = PatternMatcherPass() # create a pattern matcher pass -_register_patterns(...) # register your own patterns -config.custom_pass = patterns.apply # define the custom pass with the patterns - -``` - -## Limitations - -For now, we can only register one pass as the custom pass. -In the future, it is better to extend it to a list. diff --git a/torchao/prototype/inductor/fx_passes/__init__.py b/torchao/prototype/inductor/fx_passes/__init__.py deleted file mode 100644 index aae6d5348a..0000000000 --- a/torchao/prototype/inductor/fx_passes/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .int8_sdpa_fusion import _int8_sdpa_init - -__all__ = [ - "_int8_sdpa_init", -] diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py deleted file mode 100644 index a8f181f2db..0000000000 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ /dev/null @@ -1,370 +0,0 @@ -import functools -import itertools - -import torch -from torch._dynamo.utils import counters -from torch._inductor import config -from torch._inductor.fx_passes.post_grad import register_lowering_pattern -from torch._inductor.lowering import lowerings as L -from torch._inductor.lowering import make_fallback -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - KeywordArg, - Match, - PatternMatcherPass, -) - -__all__ = [ - "_int8_sdpa_init", -] - -make_fallback(torch.ops.torchao.scaled_dot_product_int8.default) - -aten = torch.ops.aten -patterns = PatternMatcherPass() - - -def _is_valid_int8_sdpa_pattern(): - def fn(match): - assert all(k in match.kwargs for k in ("query", "key", "value")) - query = match.kwargs["query"].meta["val"] - key = match.kwargs["key"].meta["val"] - value = match.kwargs["value"].meta["val"] - return ( - query.dtype == torch.uint8 - and key.dtype == torch.uint8 - and value.dtype == torch.uint8 - and query.device.type == "cpu" - and key.device == query.device - and value.device == query.device - ) - - return fn - - -def _register_int8_sdpa_pattern(pattern): - @register_lowering_pattern( - pattern, - extra_check=_is_valid_int8_sdpa_pattern(), - ) - def int8_sdpa(match: Match, *args, **kwargs): - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - inv_scale = kwargs["inv_scale"] - attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None - q_scale = kwargs["q_scale"] - q_zp = kwargs["q_zp"] - k_scale = kwargs["k_scale"] - k_zp = kwargs["k_zp"] - v_scale = kwargs["v_scale"] - v_zp = kwargs["v_zp"] - a_scale = kwargs["a_scale"] - a_zp = kwargs["a_zp"] - o_scale = kwargs["o_scale"] - o_zp = kwargs["o_zp"] - counters["inductor"]["int8_fuse_attention"] += 1 - counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) - - trans_query = L[aten.permute.default](query, [0, 2, 1, 3]) - trans_key = L[aten.permute.default](key, [0, 2, 1, 3]) - trans_value = L[aten.permute.default](value, [0, 2, 1, 3]) - output = L[torch.ops.torchao.scaled_dot_product_int8.default]( - trans_query, - trans_key, - trans_value, - attn_mask, - 0.0, # dropout - False, # is_causal - 1.0 / inv_scale, # scale - q_scale, - q_zp, - k_scale, - k_zp, - v_scale, - v_zp, - a_scale, - a_zp, - o_scale, - o_zp, - ) - trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) - return L[aten.clone.default]( - trans_output, memory_format=torch.contiguous_format - ) - - return int8_sdpa - - -def _get_int8_sdpa_qkv_pattern( - is_batch_size_1: bool, has_convert: bool, input_name: str -): - assert input_name in ["query", "key", "value"] - int8_sdpa_qkv_pattern_before_dequant = CallFunction( - aten.permute.default, - KeywordArg(input_name), - Arg(), - ) - if input_name == "key": - # do transpose - int8_sdpa_qkv_pattern_before_dequant = CallFunction( - aten.permute.default, - int8_sdpa_qkv_pattern_before_dequant, - Arg(), - ) - int8_sdpa_qkv_basic_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - int8_sdpa_qkv_pattern_before_dequant, - KeywordArg(input_name[0] + "_scale"), - KeywordArg(input_name[0] + "_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_qkv_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - int8_sdpa_qkv_basic_pattern = CallFunction( - aten.expand.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - if is_batch_size_1: - # pattern is different for bs=1 - return CallFunction( - aten.reshape.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - else: - return CallFunction( - aten.reshape.default, - CallFunction( - aten.clone.default, - int8_sdpa_qkv_basic_pattern, - memory_format=Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_score_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_q_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "query" - ) - int8_sdpa_k_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "key" - ) - int8_sdpa_score_basic_pattern = CallFunction( - aten.reshape.default, - CallFunction( - aten.bmm.default, - int8_sdpa_q_pattern, - int8_sdpa_k_pattern, - ), - Arg(), - ) - if is_reduced_type and not has_mask: - int8_sdpa_score_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_score_basic_pattern, - Arg(), - ) - if has_mask: - return CallFunction( - aten.add.Tensor, - CallFunction( - aten.div.Tensor, - int8_sdpa_score_basic_pattern, - KeywordArg("inv_scale"), - ), - KeywordArg("attn_mask"), - _users=2, - ) - else: - return CallFunction( - aten.mul.Tensor, - int8_sdpa_score_basic_pattern, - Arg(), - _users=2, - ) - - -def _get_int8_sdpa_exp_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_score_pattern = _get_int8_sdpa_score_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - int8_sdpa_exp_basic_pattern = CallFunction( - aten.sub.Tensor, - int8_sdpa_score_pattern, - CallFunction( - aten.amax.default, - int8_sdpa_score_pattern, - Arg(), - Arg(), - ), - ) - if has_mask: - return CallFunction( - aten.exp.default, - int8_sdpa_exp_basic_pattern, - _users=2, - ) - else: - return CallFunction( - aten.exp.default, - CallFunction( - aten.div.Tensor, - int8_sdpa_exp_basic_pattern, - KeywordArg("inv_scale"), - ), - _users=2, - ) - - -def _get_int8_sdpa_attn_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_exp_pattern = _get_int8_sdpa_exp_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - int8_sdpa_div_pattern = CallFunction( - aten.div.Tensor, - int8_sdpa_exp_pattern, - CallFunction( - aten.sum.dim_IntList, - int8_sdpa_exp_pattern, - Arg(), - Arg(), - ), - ) - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - int8_sdpa_div_pattern, - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ) - if is_reduced_type: - if has_mask: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_softmax_pattern, - Arg(), - ) - else: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_div_pattern, - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_softmax_pattern, - Arg(), - ) - return CallFunction( - aten.reshape.default, - CallFunction( - aten.expand.default, - int8_sdpa_softmax_pattern, - Arg(), - ), - Arg(), - ) - - -# Parameters to generate various patterns: -# has_mask: if SDPA has attention mask -# is_batch_size_1: if the batch size is 1 -# is_reduced_type: if autocast is enabled -# has_convert: convert type if dequant out dtype is assigned -def _get_int8_sdpa_final_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_v_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "value" - ) - int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - return CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - CallFunction( - aten.clone.default, - CallFunction( - aten.permute.default, - CallFunction( - aten.reshape.default, - CallFunction( - aten.bmm.default, - int8_sdpa_attn_pattern, - int8_sdpa_v_pattern, - ), - Arg(), - ), - Arg(), - ), - memory_format=Arg(), - ), - KeywordArg("o_scale"), - KeywordArg("o_zp"), - Arg(), - Arg(), - Arg(), - ) - - -def _register_int8_sdpa_lowerings(): - for has_mask, is_batch_size_1, is_reduced_type, has_convert in itertools.product( - [True, False], [True, False], [True, False], [True, False] - ): - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=has_mask, - is_batch_size_1=is_batch_size_1, - is_reduced_type=is_reduced_type, - has_convert=has_convert, - ) - ) - - -@functools.lru_cache(None) -def _int8_sdpa_init(): - _register_int8_sdpa_lowerings() - config.post_grad_custom_pre_pass = patterns.apply diff --git a/torchao/swizzle/__init__.py b/torchao/swizzle/__init__.py deleted file mode 100644 index 7aa001267c..0000000000 --- a/torchao/swizzle/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from .swizzle_tensor import SwizzleTensor - -__all__ = ["SwizzleTensor"] diff --git a/torchao/swizzle/swizzle_tensor.py b/torchao/swizzle/swizzle_tensor.py deleted file mode 100644 index 8ddfd9308a..0000000000 --- a/torchao/swizzle/swizzle_tensor.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torch.utils._pytree import tree_map - - -# copied from float8_utils.py -def _get_min_alignment(size: int, alignment_value: int) -> int: - return (1 + ((size - 1) // alignment_value)) * alignment_value - - -class SwizzleTensor(torch.Tensor): - """ - A Python-only swizzled tensor subclass. - - Intended usage of this abstraction: - Swizzle weight Tensor to avoid LDS use during GEMMs on ROCm hardware. - """ - - def __new__( - cls, - original: torch.Tensor, - shallow: bool = False, - ): - wrapper = torch.empty_like(original, device="meta") - return torch.Tensor._make_subclass(cls, wrapper) - - def __init__(self, original, shallow=False): - if shallow: - return - # assert original.ndim == 2 or original.ndim == 3 # (M, K) or (B, M, K) - assert original.ndim == 2, "SwizzleTensor only supports ndim 2" - assert original.itemsize == 1 or original.itemsize == 2 - kdiv = 32 if original.itemsize == 2 else 64 - lastdim = 8 if original.itemsize == 2 else 16 - if original.ndim == 2: - M, K = original.shape - B = 0 - if original.ndim == 3: - B, M, K = original.shape - alignedM = _get_min_alignment(M, 16) - alignedK = _get_min_alignment(K, kdiv) - paddedM = alignedM - M - paddedK = alignedK - K - x = torch.nn.functional.pad(original, (0, paddedK, 0, paddedM), "constant", 0) - if original.ndim == 2: - x = x.view(alignedM // 16, 16, alignedK // kdiv, 4, lastdim) - x = x.permute(0, 2, 3, 1, 4) - if original.ndim == 3: - x = x.view(B, alignedM // 16, 16, alignedK // kdiv, 4, lastdim) - x = x.permute(0, 1, 3, 4, 2, 5) - self.x = x.contiguous() - self.B = B - self.M = M - self.K = K - self.alignedM = alignedM - self.alignedK = alignedK - self.paddedM = paddedM - self.paddedK = paddedK - self.original_ndim = original.ndim - self.is_transposed = False - - def __repr__(self): - return f"{self.__class__.__name__}(original={self.unswizzle()})" - - def unswizzle(self): - undone = None - if self.original_ndim == 2: - undone = self.x.permute(0, 3, 1, 2, 4).contiguous() - undone = undone.reshape(self.alignedM, self.alignedK) - undone = undone[0 : self.M, 0 : self.K] - undone = undone.reshape(self.M, self.K) - if self.is_transposed: - undone = undone.T - if self.original_ndim == 3: - undone = self.x.permute(0, 1, 4, 2, 3, 5).contiguous() - undone = undone.reshape(self.B, self.alignedM, self.alignedK) - undone = undone[0 : self.B, 0 : self.M, 0 : self.K] - undone = undone.reshape(self.B, self.M, self.K) - return undone - - def as_tensor(self): - # note the transpose because this causes col major hipblaslt op to be TN - if self.original_ndim == 2: - tmp = self.x.reshape(self.alignedM, self.alignedK) - if self.is_transposed: - tmp = tmp.T - return tmp - if self.original_ndim == 3: - tmp = self.x.reshape(self.B, self.alignedM, self.alignedK) - if self.is_transposed: - tmp = tmp.T - return tmp - - def shallow_transpose(self): - shape = ( - (self.M, self.K) if self.original_ndim == 2 else (self.B, self.M, self.K), - ) - new_obj = SwizzleTensor( - torch.empty(*shape, dtype=self.dtype, layout=self.layout, device="meta"), - True, - ) - new_obj.x = self.x - new_obj.B = self.B - new_obj.M = self.M - new_obj.K = self.K - new_obj.alignedM = self.alignedM - new_obj.alignedK = self.alignedK - new_obj.paddedM = self.paddedM - new_obj.paddedK = self.paddedK - new_obj.original_ndim = self.original_ndim - new_obj.is_transposed = not self.is_transposed - return new_obj - - @property - def shape(self): - return torch.Size((self.K, self.M) if self.is_transposed else (self.M, self.K)) - - def stride(self): - return (1, self.K) if self.is_transposed else (self.K, 1) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - # Lazy import to avoid circular dependency - from torchao.swizzle.swizzle_ops import SWIZZLE_OPS_TABLE - - if func in SWIZZLE_OPS_TABLE: - return SWIZZLE_OPS_TABLE[func](func, args, kwargs) - - def unwrap(e): - return e.unswizzle() if isinstance(e, SwizzleTensor) else e - - def wrap(e): - return SwizzleTensor(e) if isinstance(e, torch.Tensor) else e - - return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) - - # Do not force the SwizzleTensor type on the returned tensor - __torch_function__ = torch._C._disabled_torch_function_impl