Skip to content

Commit ff15d93

Browse files
xiaomengyfacebook-github-bot
authored andcommitted
Improve numerical stability of GroupNorm (pytorch#54921)
Summary: Pull Request resolved: pytorch#54921 Improve numerical stability of GroupNorm Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GroupNorm" Reviewed By: ngimel Differential Revision: D27414438 fbshipit-source-id: 815517240ca5ea3e2beb77ced3bd862e9c83d445
1 parent 095cd6a commit ff15d93

File tree

10 files changed

+319
-92
lines changed

10 files changed

+319
-92
lines changed

aten/src/ATen/native/SharedReduceOps.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,15 @@ struct WelfordData {
8080
scalar_t m2;
8181
index_t n;
8282
combine_t nf;
83-
C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
84-
C10_DEVICE WelfordData(scalar_t mean, scalar_t m2, index_t n, combine_t nf) : mean(mean), m2(m2), n(n), nf(nf) {}
83+
84+
C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
85+
86+
C10_HOST_DEVICE WelfordData(
87+
scalar_t mean,
88+
scalar_t m2,
89+
index_t n,
90+
combine_t nf)
91+
: mean(mean), m2(m2), n(n), nf(nf) {}
8592
};
8693

8794

@@ -145,7 +152,7 @@ struct WelfordOps {
145152
};
146153
}
147154
#endif
148-
WelfordOps(index_t correction, bool take_sqrt)
155+
C10_HOST_DEVICE WelfordOps(index_t correction, bool take_sqrt)
149156
: correction(correction), take_sqrt(take_sqrt) {}
150157
};
151158

aten/src/ATen/native/cpu/SumKernel.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
#include <ATen/Dispatch.h>
2-
#include <ATen/native/TensorIterator.h>
31
#include <ATen/native/ReduceOps.h>
4-
#include <ATen/native/cpu/Reduce.h>
5-
#include <c10/util/llvmMathExtras.h>
62

73
#include <algorithm>
84

5+
#include <ATen/Dispatch.h>
6+
#include <ATen/native/TensorIterator.h>
7+
#include <ATen/native/cpu/Reduce.h>
8+
#include <ATen/native/cpu/utils.h>
99

1010
namespace at {
1111
namespace native {
@@ -48,17 +48,6 @@ void accumulate_result(char * C10_RESTRICT data, int64_t stride, int64_t index,
4848
}
4949
}
5050

51-
int64_t ceil_log2(int64_t x) {
52-
if (x <= 2) {
53-
return 1;
54-
}
55-
56-
auto ux = static_cast<uint64_t>(x);
57-
// Last set bit is floor(log2(x)), floor + 1 is ceil
58-
// except when x is an exact powers of 2, so subtract 1 first
59-
return static_cast<int64_t>(llvm::findLastSet(ux - 1)) + 1;
60-
}
61-
6251
/** Simultaneously sum over n rows at once
6352
6453
This algorithm calculates the sum without loss of precision over large axes. It
@@ -101,7 +90,7 @@ std::array<scalar_t, nrows> multi_row_sum(
10190
constexpr int64_t num_levels = 4;
10291

10392
const int64_t level_power =
104-
std::max(int64_t(4), ceil_log2(size) / num_levels);
93+
std::max(int64_t(4), utils::CeilLog2(size) / num_levels);
10594
const int64_t level_step = (1 << level_power);
10695
const int64_t level_mask = level_step - 1;
10796

aten/src/ATen/native/cpu/group_norm_kernel.cpp

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ATen/CPUApplyUtils.h>
99
#include <ATen/Dispatch.h>
1010
#include <ATen/cpu/vec/vec.h>
11+
#include <ATen/native/cpu/moments_utils.h>
1112

1213
namespace at {
1314
namespace native {
@@ -38,47 +39,33 @@ void GroupNormKernelImplInternal(
3839
T* Y_data = Y.data_ptr<T>();
3940
T* mean_data = mean.data_ptr<T>();
4041
T* rstd_data = rstd.data_ptr<T>();
41-
const T s = T(1) / static_cast<T>(D * HxW);
4242
const bool gamma_null = (gamma_data == nullptr);
4343
const bool beta_null = beta_data == nullptr;
44+
const int64_t inner_size = D * HxW;
4445

4546
at::parallel_for(0, N * G, 1, [&](int64_t start, int64_t end) {
46-
constexpr int64_t K = vec::Vectorized<T>::size();
47-
const int64_t inner_size = D * HxW / K * K;
48-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
49-
std::array<T, K> mean_arr;
50-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
51-
std::array<T, K> rstd_arr;
5247
for (int64_t i = start; i < end; ++i) {
53-
const T* X_ptr = X_data + i * D * HxW;
54-
vec::Vectorized<T> mean_vec(0);
55-
vec::Vectorized<T> rstd_vec(0);
56-
for (int64_t j = 0; j < inner_size; j += K) {
57-
const vec::Vectorized<T> x_vec = vec::Vectorized<T>::loadu(X_ptr + j);
58-
mean_vec = mean_vec + x_vec;
59-
rstd_vec = rstd_vec + x_vec * x_vec;
60-
}
61-
mean_vec.store(mean_arr.data());
62-
rstd_vec.store(rstd_arr.data());
63-
T mean_val = std::accumulate(mean_arr.cbegin(), mean_arr.cend(), T(0));
64-
T rstd_val = std::accumulate(rstd_arr.cbegin(), rstd_arr.cend(), T(0));
65-
for (int64_t j = inner_size; j < D * HxW; ++j) {
66-
mean_val += X_ptr[j];
67-
rstd_val += X_ptr[j] * X_ptr[j];
68-
}
69-
mean_val *= s;
70-
rstd_val = std::max(rstd_val * s - mean_val * mean_val, T(0));
71-
rstd_val = T(1) / std::sqrt(rstd_val + eps);
72-
73-
const int64_t g = i % G;
74-
for (int64_t j = 0; j < D; ++j) {
75-
const int64_t c = g * D + j;
76-
const T scale = rstd_val * (gamma_null ? T(1) : gamma_data[c]);
77-
const T bias = -scale * mean_val + (beta_null ? T(0) : beta_data[c]);
78-
X_ptr = X_data + (i * D + j) * HxW;
79-
T* Y_ptr = Y_data + (i * D + j) * HxW;
80-
for (int64_t k = 0; k < HxW; ++k) {
81-
Y_ptr[k] = scale * X_ptr[k] + bias;
48+
const T* X_ptr = X_data + i * inner_size;
49+
T mean_val;
50+
T rstd_val;
51+
std::tie(mean_val, rstd_val) = utils::RowwiseMoments(X_ptr, inner_size);
52+
rstd_val = T(1) / std::sqrt(std::max(rstd_val, T(0)) + eps);
53+
if (gamma_null && beta_null) {
54+
T* Y_ptr = Y_data + i * inner_size;
55+
for (int j = 0; j < inner_size; ++j) {
56+
Y_ptr[j] = (X_ptr[j] - mean_val) * rstd_val;
57+
}
58+
} else {
59+
const int64_t g = i % G;
60+
for (int64_t j = 0; j < D; ++j) {
61+
const int64_t c = g * D + j;
62+
const T scale = rstd_val * (gamma_null ? T(1) : gamma_data[c]);
63+
const T bias = -scale * mean_val + (beta_null ? T(0) : beta_data[c]);
64+
X_ptr = X_data + (i * D + j) * HxW;
65+
T* Y_ptr = Y_data + (i * D + j) * HxW;
66+
for (int64_t k = 0; k < HxW; ++k) {
67+
Y_ptr[k] = scale * X_ptr[k] + bias;
68+
}
8269
}
8370
}
8471
mean_data[i] = mean_val;
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#pragma once
2+
3+
#include <array>
4+
#include <cstring>
5+
#include <numeric>
6+
#include <utility>
7+
#include <vector>
8+
9+
#include <ATen/cpu/vec/vec.h>
10+
#include <ATen/native/cpu/utils.h>
11+
#include <c10/util/SmallVector.h>
12+
13+
namespace at {
14+
namespace native {
15+
namespace utils {
16+
17+
constexpr int64_t kChunkSize = 16;
18+
19+
template <typename T>
20+
void AddMoments(
21+
int64_t m0_add,
22+
const T& m1_add,
23+
const T& m2_add,
24+
int64_t& m0,
25+
T& m1,
26+
T& m2) {
27+
const int64_t n = m0 + m0_add;
28+
const T c = n == 0 ? 0 : static_cast<T>(m0_add) / static_cast<T>(n);
29+
const T delta = m1_add - m1;
30+
m1 += c * delta;
31+
m2 += m2_add + delta * delta * c * static_cast<T>(m0);
32+
m0 = n;
33+
}
34+
35+
template <typename T>
36+
void AddMomentsVec(
37+
int64_t m0_add,
38+
const vec::Vectorized<T>& m1_add,
39+
const vec::Vectorized<T>& m2_add,
40+
int64_t& m0,
41+
vec::Vectorized<T>& m1,
42+
vec::Vectorized<T>& m2) {
43+
using Vec = vec::Vectorized<T>;
44+
const int64_t n = m0 + m0_add;
45+
const T c = n == 0 ? 0 : static_cast<T>(m0_add) / static_cast<T>(n);
46+
const Vec c_vec(c);
47+
const Vec delta = m1_add - m1;
48+
m1 += c_vec * delta;
49+
m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
50+
m0 = n;
51+
}
52+
53+
// Compute rowwise moments by Welford algorithm and cascade sum to improve
54+
// numerical stability.
55+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
56+
// https://en.wikipedia.org/wiki/Pairwise_summation
57+
template <typename T, int64_t kMaxDepth>
58+
std::pair<T, T> RowwiseMomentsImpl(const T* X, int64_t N) {
59+
using Vec = vec::Vectorized<T>;
60+
61+
constexpr int64_t kVecSize = Vec::size();
62+
const int64_t n = N / kVecSize;
63+
const int64_t m = divup(n, kChunkSize);
64+
const int64_t depth = CeilLog2(m);
65+
66+
const Vec kZeroVec(T(0));
67+
c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
68+
c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
69+
c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);
70+
71+
for (int64_t i = 0; i < m; ++i) {
72+
const T* X_ptr = X + i * kChunkSize * kVecSize;
73+
const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
74+
Vec m1_vec(0);
75+
Vec m2_vec(0);
76+
for (int64_t j = 0; j < m0; ++j) {
77+
const Vec x_vec = Vec::loadu(X_ptr + j * kVecSize);
78+
const Vec delta_vec = x_vec - m1_vec;
79+
const Vec c_vec = Vec(T(1) / static_cast<T>(j + 1));
80+
m1_vec += delta_vec * c_vec;
81+
m2_vec += delta_vec * (x_vec - m1_vec);
82+
}
83+
AddMomentsVec(m0, m1_vec, m2_vec, m0_stk[0], m1_stk[0], m2_stk[0]);
84+
int64_t mask = i + 1;
85+
for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
86+
AddMomentsVec(
87+
m0_stk[j - 1],
88+
m1_stk[j - 1],
89+
m2_stk[j - 1],
90+
m0_stk[j],
91+
m1_stk[j],
92+
m2_stk[j]);
93+
m0_stk[j - 1] = 0;
94+
m1_stk[j - 1] = kZeroVec;
95+
m2_stk[j - 1] = kZeroVec;
96+
mask >>= 1;
97+
}
98+
}
99+
for (int64_t i = 1; i < depth; ++i) {
100+
AddMomentsVec(
101+
m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
102+
}
103+
104+
std::array<T, kVecSize> m1_arr{};
105+
std::array<T, kVecSize> m2_arr{};
106+
m1_stk[0].store(m1_arr.data());
107+
m2_stk[0].store(m2_arr.data());
108+
109+
int64_t m0 = 0;
110+
T m1 = 0;
111+
T m2 = 0;
112+
for (int64_t i = n * kVecSize; i < N; ++i) {
113+
const T delta = X[i] - m1;
114+
++m0;
115+
m1 += delta / static_cast<T>(m0);
116+
m2 += delta * (X[i] - m1);
117+
}
118+
for (int64_t i = 0; i < kVecSize; ++i) {
119+
AddMoments(n, m1_arr[i], m2_arr[i], m0, m1, m2);
120+
}
121+
122+
return std::make_pair(m1, m2 / static_cast<T>(N));
123+
}
124+
125+
template <typename T>
126+
std::pair<T, T> RowwiseMoments(const T* X, int64_t N) {
127+
using Vec = vec::Vectorized<T>;
128+
constexpr int64_t kVecSize = Vec::size();
129+
const int64_t n = N / kVecSize;
130+
const int64_t m = divup(n, kChunkSize);
131+
const int64_t depth = CeilLog2(m);
132+
if (depth <= 4) {
133+
return RowwiseMomentsImpl<T, 4>(X, N);
134+
} else if (depth <= 8) {
135+
return RowwiseMomentsImpl<T, 8>(X, N);
136+
} else if (depth <= 16) {
137+
return RowwiseMomentsImpl<T, 16>(X, N);
138+
} else if (depth <= 32) {
139+
return RowwiseMomentsImpl<T, 32>(X, N);
140+
} else {
141+
return RowwiseMomentsImpl<T, 64>(X, N);
142+
}
143+
}
144+
145+
} // namespace utils
146+
} // namespace native
147+
} // namespace at

aten/src/ATen/native/cpu/utils.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33
#include <ATen/cpu/vec/vec.h>
44
#include <c10/util/llvmMathExtras.h>
55

6-
namespace at { namespace native { namespace {
6+
namespace at {
7+
namespace native {
8+
9+
namespace {
710

811
template <typename T>
912
inline T data_index_init(T offset) {
1013
return offset;
1114
}
1215

1316
template <typename T, typename... Args>
14-
inline T data_index_init(T offset, T &x, const T &X, Args &&... args) {
17+
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
1518
offset = data_index_init(offset, std::forward<Args>(args)...);
1619
x = offset % X;
1720
return offset / X;
@@ -22,7 +25,7 @@ inline bool data_index_step() {
2225
}
2326

2427
template <typename T, typename... Args>
25-
inline bool data_index_step(T &x, const T &X, Args &&... args) {
28+
inline bool data_index_step(T& x, const T& X, Args&&... args) {
2629
if (data_index_step(std::forward<Args>(args)...)) {
2730
x = ((x + 1) == X) ? 0 : (x + 1);
2831
return x == 0;
@@ -47,4 +50,4 @@ T CeilLog2(const T& x) {
4750
} // namespace utils
4851

4952
} // namespace native
50-
} // namespace at// namespace at::native::<anonymous>
53+
} // namespace at

aten/src/ATen/native/cuda/block_reduce.cuh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
#include <thrust/tuple.h>
4+
5+
#include <ATen/native/SharedReduceOps.h>
36
#include <ATen/cuda/DeviceUtils.cuh>
47

58
namespace at {
@@ -45,6 +48,34 @@ __inline__ __device__ T BlockReduceSum(T val, T* shared) {
4548
return val;
4649
}
4750

51+
template <typename T, class ReduceOp>
52+
__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
53+
#pragma unroll
54+
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
55+
val = op.combine(val, op.warp_shfl_down(val, offset));
56+
}
57+
return val;
58+
}
59+
60+
template <typename T, class ReduceOp>
61+
__inline__ __device__ T
62+
BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
63+
const int lid = threadIdx.x % C10_WARP_SIZE;
64+
const int wid = threadIdx.x / C10_WARP_SIZE;
65+
val = WarpReduce(val, op);
66+
__syncthreads();
67+
if (lid == 0) {
68+
shared[wid] = val;
69+
}
70+
__syncthreads();
71+
val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid]
72+
: identity_element;
73+
if (wid == 0) {
74+
val = WarpReduce(val, op);
75+
}
76+
return val;
77+
}
78+
4879
} // namespace cuda_utils
4980
} // namespace native
5081
} // namespace at

0 commit comments

Comments
 (0)