1
1
/*
2
- * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ * Copyright (c) 2024-2025 , NVIDIA CORPORATION. All rights reserved.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
@@ -49,57 +49,100 @@ __device__ __nv_bfloat16 negativeInfinity<__nv_bfloat16>()
49
49
}
50
50
51
51
template <typename T, typename PackedT>
52
+ __device__ PackedT packedNegativeInfinity ()
53
+ {
54
+ int constexpr kAlignment = sizeof (PackedT) / sizeof (T);
55
+ T packed[kAlignment ];
56
+ #pragma unroll
57
+ for (int i = 0 ; i < kAlignment ; i++)
58
+ {
59
+ packed[i] = negativeInfinity<T>();
60
+ }
61
+ return *reinterpret_cast <PackedT*>(packed);
62
+ }
63
+
64
+ template <typename T, typename PackedT, int32_t kBitsPerThread >
52
65
__global__ void __launch_bounds__ (kThreadsPerBlock ) logitsBitmaskKernel(
53
66
T** __restrict__ logits, uint32_t const ** __restrict__ bitmask, int32_t vocabSizePadded, int32_t bitmaskSize)
54
67
{
55
68
int constexpr kAlignment = sizeof (PackedT) / sizeof (T);
69
+ uint32_t constexpr kPackedMask = (1 << kAlignment ) - 1 ;
70
+
56
71
int const batchIdx = blockIdx .y ;
57
72
58
- int const logitsGmemOffset = kThreadsPerBlock * blockIdx .x * kBitsPerMaskElement ;
59
- T* logitsGmemPtr = logits[batchIdx] + logitsGmemOffset;
60
- __shared__ T logitsSmem[kThreadsPerBlock * kBitsPerMaskElement ];
73
+ int const blockOffset = blockIdx .x * kThreadsPerBlock * kBitsPerThread ;
74
+ T* logitsGmemPtr = logits[batchIdx] + blockOffset;
75
+
76
+ uint32_t const * bitmaskGmemPtr = bitmask[batchIdx] + blockOffset / kBitsPerMaskElement ;
77
+ int const bitmaskInnerIdx = threadIdx .x % (kBitsPerMaskElement / kAlignment );
78
+ T logitsReg[kAlignment ];
61
79
62
80
#pragma unroll
63
- for (int offset = 0 ; offset < kThreadsPerBlock * kBitsPerMaskElement ; offset += kThreadsPerBlock * kAlignment )
81
+ for (int offset = threadIdx .x * kAlignment ; offset < kThreadsPerBlock * kBitsPerThread ;
82
+ offset += kThreadsPerBlock * kAlignment )
64
83
{
65
- int localOffset = offset + threadIdx .x * kAlignment ;
66
- if (logitsGmemOffset + localOffset >= vocabSizePadded)
84
+ if (blockOffset + offset >= vocabSizePadded)
67
85
{
68
86
break ;
69
87
}
70
- *reinterpret_cast <PackedT*>(logitsSmem + localOffset)
71
- = *reinterpret_cast <PackedT*>(logitsGmemPtr + localOffset);
72
- }
73
- __syncthreads ();
74
88
75
- int const bitmaskIdx = kThreadsPerBlock * blockIdx . x + threadIdx . x ;
76
- uint32_t const bitmaskVal = bitmask[batchIdx][bitmaskIdx] ;
89
+ uint32_t const bitmaskVal
90
+ = (~bitmaskGmemPtr[offset / kBitsPerMaskElement ] >> (bitmaskInnerIdx * kAlignment )) & kPackedMask ;
77
91
78
- #pragma unroll
79
- for (int i = 0 ; i < kBitsPerMaskElement ; ++i)
80
- {
81
- int offset = (i + threadIdx .x ) % warpSize ;
82
- if (bitmaskIdx * kBitsPerMaskElement + offset >= vocabSizePadded)
92
+ if (bitmaskVal == 0 )
83
93
{
84
94
continue ;
85
95
}
86
- if (!((bitmaskVal >> offset) & 1 ))
96
+
97
+ if (bitmaskVal == kPackedMask )
87
98
{
88
- logitsSmem[threadIdx .x * kBitsPerMaskElement + offset] = negativeInfinity<T>();
99
+ *reinterpret_cast <PackedT*>(logitsGmemPtr + offset) = packedNegativeInfinity<T, PackedT>();
100
+ continue ;
89
101
}
90
- }
91
- __syncthreads ();
92
102
103
+ *reinterpret_cast <PackedT*>(logitsReg) = *reinterpret_cast <PackedT*>(logitsGmemPtr + offset);
93
104
#pragma unroll
94
- for (int offset = 0 ; offset < kThreadsPerBlock * kBitsPerMaskElement ; offset += kThreadsPerBlock * kAlignment )
95
- {
96
- int localOffset = offset + threadIdx .x * kAlignment ;
97
- if (logitsGmemOffset + localOffset >= vocabSizePadded)
105
+ for (int i = 0 ; i < kAlignment ; i++)
98
106
{
99
- break ;
107
+ if (((bitmaskVal >> i) & 1 ))
108
+ {
109
+ logitsReg[i] = negativeInfinity<T>();
110
+ }
100
111
}
101
- *reinterpret_cast <PackedT*>(logitsGmemPtr + localOffset)
102
- = *reinterpret_cast <PackedT*>(logitsSmem + localOffset);
112
+ *reinterpret_cast <PackedT*>(logitsGmemPtr + offset) = *reinterpret_cast <PackedT*>(logitsReg);
113
+ }
114
+ }
115
+
116
+ template <typename T, typename PackedT>
117
+ void logitsBitmaskDispatchToBitsPerThread (
118
+ T** logits, uint32_t const ** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream)
119
+ {
120
+ int constexpr kAlignment = sizeof (PackedT) / sizeof (T);
121
+ int32_t const numBlocksPerRow = ceilDiv (2048 / kThreadsPerBlock * 128 , batchSize);
122
+ int32_t const numBitsPerThread = ceilDiv (vocabSizePadded, kThreadsPerBlock * numBlocksPerRow);
123
+ int32_t bitmaskSize = ceilDiv (vocabSizePadded, kBitsPerMaskElement );
124
+
125
+ dim3 const block (kThreadsPerBlock );
126
+
127
+ if (numBitsPerThread <= 4 && kAlignment <= 4 )
128
+ {
129
+ dim3 const grid (ceilDiv (vocabSizePadded, kThreadsPerBlock * 4 ), batchSize);
130
+ logitsBitmaskKernel<T, PackedT, 4 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize);
131
+ }
132
+ else if (numBitsPerThread <= 8 && kAlignment <= 8 )
133
+ {
134
+ dim3 const grid (ceilDiv (vocabSizePadded, kThreadsPerBlock * 8 ), batchSize);
135
+ logitsBitmaskKernel<T, PackedT, 8 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize);
136
+ }
137
+ else if (numBitsPerThread <= 16 && kAlignment <= 16 )
138
+ {
139
+ dim3 const grid (ceilDiv (vocabSizePadded, kThreadsPerBlock * 16 ), batchSize);
140
+ logitsBitmaskKernel<T, PackedT, 16 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize);
141
+ }
142
+ else
143
+ {
144
+ dim3 const grid (ceilDiv (vocabSizePadded, kThreadsPerBlock * 32 ), batchSize);
145
+ logitsBitmaskKernel<T, PackedT, 32 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize);
103
146
}
104
147
}
105
148
} // namespace
@@ -108,25 +151,22 @@ template <typename T>
108
151
void invokeLogitsBitmask (
109
152
T** logits, uint32_t const ** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream)
110
153
{
111
- int bitmaskSize = ceilDiv (vocabSizePadded, kBitsPerMaskElement );
112
- dim3 grid (ceilDiv (bitmaskSize, kThreadsPerBlock ), batchSize);
113
- dim3 block (kThreadsPerBlock );
114
-
154
+ // Dispatch to PackedT
115
155
if (vocabSizePadded % (sizeof (float4 ) / sizeof (T)) == 0 )
116
156
{
117
- logitsBitmaskKernel <T, float4 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize );
157
+ logitsBitmaskDispatchToBitsPerThread <T, float4 >(logits, bitmask, batchSize, vocabSizePadded, stream );
118
158
}
119
159
else if (vocabSizePadded % (sizeof (float2 ) / sizeof (T)) == 0 )
120
160
{
121
- logitsBitmaskKernel <T, float2 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize );
161
+ logitsBitmaskDispatchToBitsPerThread <T, float2 >(logits, bitmask, batchSize, vocabSizePadded, stream );
122
162
}
123
163
else if (vocabSizePadded % (sizeof (float ) / sizeof (T)) == 0 )
124
164
{
125
- logitsBitmaskKernel <T, float ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize );
165
+ logitsBitmaskDispatchToBitsPerThread <T, float >(logits, bitmask, batchSize, vocabSizePadded, stream );
126
166
}
127
167
else
128
168
{
129
- logitsBitmaskKernel <T, T><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize );
169
+ logitsBitmaskDispatchToBitsPerThread <T, T>(logits, bitmask, batchSize, vocabSizePadded, stream );
130
170
}
131
171
}
132
172
0 commit comments