-
Notifications
You must be signed in to change notification settings - Fork 3k
/
Copy pathcuda_utils.cuh
286 lines (251 loc) · 11 KB
/
cuda_utils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
// Utilities for use in __device__ code
#ifndef CUDA_UTILS_CUH
#define CUDA_UTILS_CUH
#include "cuda_common.h"
// ----------------------------------------------------------------------------
// Packed128 data structure that forces the compiler to use 128-bit loads/stores
// in GPUs that support (the LDG.128 and STS.128 instructions)
// This is a bit similar to the use of float4 in the case of 32-bit floats, but
// supports arbitrary precision.
template<class ElementType>
struct alignas(16) Packed128 {
Packed128() = default;
__device__ explicit Packed128(int4 bits) {
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&payload, &bits, sizeof(bits));
}
__device__ static Packed128 constant(ElementType value) {
Packed128 result;
for(int k = 0; k < size; ++k) {
result.payload[k] = value;
}
return result;
}
__device__ static Packed128 zeros() {
return constant(0.f);
}
__device__ static Packed128 ones() {
return constant(1.f);
}
__device__ ElementType& operator[](int index) {
return payload[index];
}
__device__ const ElementType& operator[](int index) const {
return payload[index];
}
__device__ int4 get_bits() const {
int4 bits;
static_assert(sizeof(bits) == sizeof(payload), "Size mismatch.");
memcpy(&bits, &payload, sizeof(bits));
return bits;
}
static constexpr const size_t size = sizeof(int4) / sizeof(ElementType);
ElementType payload[size];
};
// load a Packed128 from an aligned memory address
template<class ElementType>
__device__ Packed128<ElementType> load128(const ElementType* address) {
return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};
}
// load a Packed128 from an aligned memory address with streaming cache hint
template<class ElementType>
__device__ Packed128<ElementType> load128cs(const ElementType* address) {
return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};
}
// store a Packed128 to an aligned memory address
template<class ElementType>
__device__ void store128(ElementType* target, Packed128<ElementType> value) {
*reinterpret_cast<int4*>(target) = value.get_bits();
}
// store a Packed128 to an aligned memory address with streaming cache hint
template<class ElementType>
__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__stcs(reinterpret_cast<int4*>(target), value.get_bits());
}
// store a Packed128 to an aligned memory address while caching in L2 but bypassing L1
template<class ElementType>
__device__ void store128cg(ElementType* target, Packed128<ElementType> value) {
__stcg(reinterpret_cast<int4*>(target), value.get_bits());
}
// short-form typedefs
typedef Packed128<float> f128;
typedef Packed128<floatX> x128;
// ----------------------------------------------------------------------------
// DType support
// enumerator to indentify the datatype of a tensor.
enum class DType : uint8_t {
FP32, FP16, BF16
};
// Given a datatype enum, returns the underlying number of bytes
// for a scalar of that type
size_t sizeof_dtype(DType type) {
switch (type) {
case DType::FP32:
return sizeof(float);
case DType::FP16:
return sizeof(half);
case DType::BF16:
return sizeof(nv_bfloat16);
default: // handle or get compiler warning
fprintf(stderr, "Unknown datatype\n");
exit(EXIT_FAILURE);
}
}
DType dtype_of(float* f) { return DType::FP32; }
DType dtype_of(nv_bfloat16 * f) { return DType::BF16; }
DType dtype_of(half * f) { return DType::FP16; }
// ----------------------------------------------------------------------------
// Copy, cast functions
// device functions and the kernel to cast data between types
template<typename Td, typename Ts>
__device__ Td cast_value(Ts val);
template<>
__device__ float cast_value<float, float>(float val) {
return val;
}
template<>
__device__ float cast_value<float, half>(half val) {
return __half2float(val);
}
template<>
__device__ float cast_value<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}
template<typename Td, typename Ts>
__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t stride_dst, ptrdiff_t stride_src) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// need to try grid stride looping for more perf later
if (idx < n) {
dst[idx + stride_dst * blockIdx.y] = cast_value<Td, Ts>(src[idx + stride_src * blockIdx.y]);
}
}
// ----------------------------------------------------------------------------
// Warp/Block communication primitives
// warp-level reduction for summing values
__device__ inline float warpReduceSum(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
}
return val;
}
// warp-level reduction for finding the maximum value
__device__ inline float warpReduceMax(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, offset));
}
return val;
}
// requires all 32 threads in the warp to be active, but should work for any block size
// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes
// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end
// but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1
using reduction_func_t = float (*) (float);
template<reduction_func_t warp_reduction>
__device__ inline float blockReduce(float val, bool final_sync=false, float out_of_bounds=0.0f) {
// two reductions of up to 1024 threads:
// 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)
__shared__ float shared_val[WARP_SIZE];
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
float warp_val = warp_reduction(val);
if (lane_id == 0) { shared_val[warp_id] = warp_val; }
__syncthreads();
warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds;
float block_val = warp_reduction(warp_val);
if (final_sync) {
__syncthreads(); // only needed in loops when effectively reusing shared memory etc.
}
return block_val;
}
// Performs a _deterministic_ sum reduction. determinism is achieved by requiring that only
// a single block be used.
template<class Float>
__global__ void global_sum_single_block_kernel(float* result, const Float* values, size_t count) {
assert(gridDim.x == 1); // only a single block!
float thread_sum = 0;
for(size_t index = threadIdx.x; index < count; index += blockDim.x) {
thread_sum += (float)values[index];
}
float reduction = blockReduce<warpReduceSum>(thread_sum, true);
if(threadIdx.x == 0) {
*result = reduction;
}
}
template<class Float>
void global_sum_deterministic(float* result, const Float* values, int count, cudaStream_t stream) {
global_sum_single_block_kernel<<<1, 1024, 0, stream>>>(result, values, count);
cudaCheck(cudaGetLastError());
}
// ----------------------------------------------------------------------------
// memory management
// allocate memory, preferrably on the device
// returns a status code. 0 = OK, 1 = fell back to managed memory
int cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, int line) {
// try to allocate
cudaError_t err = cudaMalloc(out, bytes);
if(err == cudaErrorMemoryAllocation) {
// if we OOM, fallback to a managed allocation. slower but at least won't crash.
cudaGetLastError(); // reset the error before the next API call
cudaCheck_(cudaMallocManaged(out, bytes), file, line);
cudaCheck_(cudaMemAdvise(*out, bytes, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId), file, line);
return 1;
} else {
cudaCheck_(err, file, line);
return 0;
}
}
#define cudaMallocConditionallyManaged(out, bytes)\
(cudaMallocConditionallyManaged((void**)out, bytes, __FILE__, __LINE__))
// ----------------------------------------------------------------------------
// Random Number Generation used in Stochastic Rounding
// SquirrelNoise5 - Squirrel's Raw Noise utilities (version 5)
// This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU
// todo - possibly overkill and we don't need such high quality random numbers? (tbd)
// http://eiserloh.net/noise/SquirrelNoise5.hpp
__device__ __host__ constexpr unsigned int SquirrelNoise5(unsigned int positionX, unsigned int seed)
{
constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111
constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111
constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011
constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB; // 10110111100111110011101010111011
constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5; // 00011011010101101100010011110101
unsigned int mangledBits = positionX;
mangledBits *= SQ5_BIT_NOISE1;
mangledBits += seed;
mangledBits ^= (mangledBits >> 9);
mangledBits += SQ5_BIT_NOISE2;
mangledBits ^= (mangledBits >> 11);
mangledBits *= SQ5_BIT_NOISE3;
mangledBits ^= (mangledBits >> 13);
mangledBits += SQ5_BIT_NOISE4;
mangledBits ^= (mangledBits >> 15);
mangledBits *= SQ5_BIT_NOISE5;
mangledBits ^= (mangledBits >> 17);
return mangledBits;
}
__device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed)
{
constexpr unsigned int PRIME_NUMBER = 198491317u; // Large prime number with non-boring bits
unsigned int x = static_cast<unsigned int>(indexX);
unsigned int y = static_cast<unsigned int>(indexY);
return SquirrelNoise5(x + (PRIME_NUMBER * y), seed);
}
// stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift)
__device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) {
// todo - is this stochastic rounding *too good*? can we cut any corners?
// makes sure each thread gets a different random number
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed);
unsigned int threshold = random & 0xFFFF;
unsigned int float_bits = __float_as_uint(in);
unsigned int rounded_bits = float_bits & 0x0000FFFF;
float_bits = (rounded_bits > threshold) ? (float_bits | 0xFFFF) : (float_bits & ~0xFFFF);
*out = __float2bfloat16_rn(__uint_as_float(float_bits));
}
__device__ __forceinline__ void stochastic_rounding(float in, half *out, unsigned int random) {
*out = (float)in; // todo - implement this...
}
__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) {
*out = in; // dummy function for when floatX is float (FP32 mode)
}
#endif