Skip to content

Commit ff297de

Browse files
syuonibyshiue
authored andcommitted
bitmask v3
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent 0ec7b57 commit ff297de

File tree

1 file changed

+77
-37
lines changed

1 file changed

+77
-37
lines changed

cpp/tensorrt_llm/kernels/logitsBitmask.cu

Lines changed: 77 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -49,57 +49,100 @@ __device__ __nv_bfloat16 negativeInfinity<__nv_bfloat16>()
4949
}
5050

5151
template <typename T, typename PackedT>
52+
__device__ PackedT packedNegativeInfinity()
53+
{
54+
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
55+
T packed[kAlignment];
56+
#pragma unroll
57+
for (int i = 0; i < kAlignment; i++)
58+
{
59+
packed[i] = negativeInfinity<T>();
60+
}
61+
return *reinterpret_cast<PackedT*>(packed);
62+
}
63+
64+
template <typename T, typename PackedT, int32_t kBitsPerThread>
5265
__global__ void __launch_bounds__(kThreadsPerBlock) logitsBitmaskKernel(
5366
T** __restrict__ logits, uint32_t const** __restrict__ bitmask, int32_t vocabSizePadded, int32_t bitmaskSize)
5467
{
5568
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
69+
uint32_t constexpr kPackedMask = (1 << kAlignment) - 1;
70+
5671
int const batchIdx = blockIdx.y;
5772

58-
int const logitsGmemOffset = kThreadsPerBlock * blockIdx.x * kBitsPerMaskElement;
59-
T* logitsGmemPtr = logits[batchIdx] + logitsGmemOffset;
60-
__shared__ T logitsSmem[kThreadsPerBlock * kBitsPerMaskElement];
73+
int const blockOffset = blockIdx.x * kThreadsPerBlock * kBitsPerThread;
74+
T* logitsGmemPtr = logits[batchIdx] + blockOffset;
75+
76+
uint32_t const* bitmaskGmemPtr = bitmask[batchIdx] + blockOffset / kBitsPerMaskElement;
77+
int const bitmaskInnerIdx = threadIdx.x % (kBitsPerMaskElement / kAlignment);
78+
T logitsReg[kAlignment];
6179

6280
#pragma unroll
63-
for (int offset = 0; offset < kThreadsPerBlock * kBitsPerMaskElement; offset += kThreadsPerBlock * kAlignment)
81+
for (int offset = threadIdx.x * kAlignment; offset < kThreadsPerBlock * kBitsPerThread;
82+
offset += kThreadsPerBlock * kAlignment)
6483
{
65-
int localOffset = offset + threadIdx.x * kAlignment;
66-
if (logitsGmemOffset + localOffset >= vocabSizePadded)
84+
if (blockOffset + offset >= vocabSizePadded)
6785
{
6886
break;
6987
}
70-
*reinterpret_cast<PackedT*>(logitsSmem + localOffset)
71-
= *reinterpret_cast<PackedT*>(logitsGmemPtr + localOffset);
72-
}
73-
__syncthreads();
7488

75-
int const bitmaskIdx = kThreadsPerBlock * blockIdx.x + threadIdx.x;
76-
uint32_t const bitmaskVal = bitmask[batchIdx][bitmaskIdx];
89+
uint32_t const bitmaskVal
90+
= (~bitmaskGmemPtr[offset / kBitsPerMaskElement] >> (bitmaskInnerIdx * kAlignment)) & kPackedMask;
7791

78-
#pragma unroll
79-
for (int i = 0; i < kBitsPerMaskElement; ++i)
80-
{
81-
int offset = (i + threadIdx.x) % warpSize;
82-
if (bitmaskIdx * kBitsPerMaskElement + offset >= vocabSizePadded)
92+
if (bitmaskVal == 0)
8393
{
8494
continue;
8595
}
86-
if (!((bitmaskVal >> offset) & 1))
96+
97+
if (bitmaskVal == kPackedMask)
8798
{
88-
logitsSmem[threadIdx.x * kBitsPerMaskElement + offset] = negativeInfinity<T>();
99+
*reinterpret_cast<PackedT*>(logitsGmemPtr + offset) = packedNegativeInfinity<T, PackedT>();
100+
continue;
89101
}
90-
}
91-
__syncthreads();
92102

103+
*reinterpret_cast<PackedT*>(logitsReg) = *reinterpret_cast<PackedT*>(logitsGmemPtr + offset);
93104
#pragma unroll
94-
for (int offset = 0; offset < kThreadsPerBlock * kBitsPerMaskElement; offset += kThreadsPerBlock * kAlignment)
95-
{
96-
int localOffset = offset + threadIdx.x * kAlignment;
97-
if (logitsGmemOffset + localOffset >= vocabSizePadded)
105+
for (int i = 0; i < kAlignment; i++)
98106
{
99-
break;
107+
if (((bitmaskVal >> i) & 1))
108+
{
109+
logitsReg[i] = negativeInfinity<T>();
110+
}
100111
}
101-
*reinterpret_cast<PackedT*>(logitsGmemPtr + localOffset)
102-
= *reinterpret_cast<PackedT*>(logitsSmem + localOffset);
112+
*reinterpret_cast<PackedT*>(logitsGmemPtr + offset) = *reinterpret_cast<PackedT*>(logitsReg);
113+
}
114+
}
115+
116+
template <typename T, typename PackedT>
117+
void logitsBitmaskDispatchToBitsPerThread(
118+
T** logits, uint32_t const** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream)
119+
{
120+
int constexpr kAlignment = sizeof(PackedT) / sizeof(T);
121+
int32_t const numBlocksPerRow = ceilDiv(2048 / kThreadsPerBlock * 128, batchSize);
122+
int32_t const numBitsPerThread = ceilDiv(vocabSizePadded, kThreadsPerBlock * numBlocksPerRow);
123+
int32_t bitmaskSize = ceilDiv(vocabSizePadded, kBitsPerMaskElement);
124+
125+
dim3 const block(kThreadsPerBlock);
126+
127+
if (numBitsPerThread <= 4 && kAlignment <= 4)
128+
{
129+
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 4), batchSize);
130+
logitsBitmaskKernel<T, PackedT, 4><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
131+
}
132+
else if (numBitsPerThread <= 8 && kAlignment <= 8)
133+
{
134+
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 8), batchSize);
135+
logitsBitmaskKernel<T, PackedT, 8><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
136+
}
137+
else if (numBitsPerThread <= 16 && kAlignment <= 16)
138+
{
139+
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 16), batchSize);
140+
logitsBitmaskKernel<T, PackedT, 16><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
141+
}
142+
else
143+
{
144+
dim3 const grid(ceilDiv(vocabSizePadded, kThreadsPerBlock * 32), batchSize);
145+
logitsBitmaskKernel<T, PackedT, 32><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
103146
}
104147
}
105148
} // namespace
@@ -108,25 +151,22 @@ template <typename T>
108151
void invokeLogitsBitmask(
109152
T** logits, uint32_t const** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream)
110153
{
111-
int bitmaskSize = ceilDiv(vocabSizePadded, kBitsPerMaskElement);
112-
dim3 grid(ceilDiv(bitmaskSize, kThreadsPerBlock), batchSize);
113-
dim3 block(kThreadsPerBlock);
114-
154+
// Dispatch to PackedT
115155
if (vocabSizePadded % (sizeof(float4) / sizeof(T)) == 0)
116156
{
117-
logitsBitmaskKernel<T, float4><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
157+
logitsBitmaskDispatchToBitsPerThread<T, float4>(logits, bitmask, batchSize, vocabSizePadded, stream);
118158
}
119159
else if (vocabSizePadded % (sizeof(float2) / sizeof(T)) == 0)
120160
{
121-
logitsBitmaskKernel<T, float2><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
161+
logitsBitmaskDispatchToBitsPerThread<T, float2>(logits, bitmask, batchSize, vocabSizePadded, stream);
122162
}
123163
else if (vocabSizePadded % (sizeof(float) / sizeof(T)) == 0)
124164
{
125-
logitsBitmaskKernel<T, float><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
165+
logitsBitmaskDispatchToBitsPerThread<T, float>(logits, bitmask, batchSize, vocabSizePadded, stream);
126166
}
127167
else
128168
{
129-
logitsBitmaskKernel<T, T><<<grid, block, 0, stream>>>(logits, bitmask, vocabSizePadded, bitmaskSize);
169+
logitsBitmaskDispatchToBitsPerThread<T, T>(logits, bitmask, batchSize, vocabSizePadded, stream);
130170
}
131171
}
132172

0 commit comments

Comments
 (0)