Skip to content

Commit 7cfeae1

Browse files
authored
std.crypto.onetimeauth.ghash: faster GHASH on modern CPUs (ziglang#13566)
* std.crypto.onetimeauth.ghash: faster GHASH on modern CPUs Carryless multiplication was slow on older Intel CPUs, justifying the need for using Karatsuba multiplication. This is not the case any more; using 4 multiplications to multiply two 128-bit numbers is actually faster than 3 multiplications + shifts and additions. This is also true on aarch64. Keep using Karatsuba only when targeting x86 (granted, this is a bit of a brutal shortcut, we should really list all the CPU models that had a slow clmul instruction). Also remove useless agg_2 treshold and restore the ability to precompute only H and H^2 in ReleaseSmall. Finally, avoid using u256. Using 128-bit registers is actually faster. * Use a switch, add some comments
1 parent 58d9004 commit 7cfeae1

File tree

1 file changed

+124
-74
lines changed

1 file changed

+124
-74
lines changed

lib/std/crypto/ghash.zig

Lines changed: 124 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,19 @@ pub const Ghash = struct {
1818
pub const mac_length = 16;
1919
pub const key_length = 16;
2020

21-
const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 4;
22-
const agg_2_treshold = 5;
21+
const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 2;
2322
const agg_4_treshold = 22;
2423
const agg_8_treshold = 84;
2524
const agg_16_treshold = 328;
2625

26+
// Before the Haswell architecture, the carryless multiplication instruction was
27+
// extremely slow. Even with 128-bit operands, using Karatsuba multiplication was
28+
// thus faster than a schoolbook multiplication.
29+
// This is no longer the case -- Modern CPUs, including ARM-based ones, have a fast
30+
// carryless multiplication instruction; using 4 multiplications is now faster than
31+
// 3 multiplications with extra shifts and additions.
32+
const mul_algorithm = if (builtin.cpu.arch == .x86) .karatsuba else .schoolbook;
33+
2734
hx: [pc_count]Precomp,
2835
acc: u128 = 0,
2936

@@ -43,10 +50,10 @@ pub const Ghash = struct {
4350
var hx: [pc_count]Precomp = undefined;
4451
hx[0] = h;
4552
hx[1] = gcmReduce(clsq128(hx[0])); // h^2
46-
hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
47-
hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2
4853

4954
if (builtin.mode != .ReleaseSmall) {
55+
hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
56+
hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2
5057
if (block_count >= agg_8_treshold) {
5158
hx[4] = gcmReduce(clmul128(hx[3], h)); // h^5
5259
hx[5] = gcmReduce(clsq128(hx[2])); // h^6 = h^3^2
@@ -69,47 +76,71 @@ pub const Ghash = struct {
6976
return Ghash.initForBlockCount(key, math.maxInt(usize));
7077
}
7178

72-
const Selector = enum { lo, hi };
79+
const Selector = enum { lo, hi, hi_lo };
7380

7481
// Carryless multiplication of two 64-bit integers for x86_64.
7582
inline fn clmulPclmul(x: u128, y: u128, comptime half: Selector) u128 {
76-
if (half == .hi) {
77-
const product = asm (
78-
\\ vpclmulqdq $0x11, %[x], %[y], %[out]
79-
: [out] "=x" (-> @Vector(2, u64)),
80-
: [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
81-
[y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
82-
);
83-
return @bitCast(u128, product);
84-
} else {
85-
const product = asm (
86-
\\ vpclmulqdq $0x00, %[x], %[y], %[out]
87-
: [out] "=x" (-> @Vector(2, u64)),
88-
: [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
89-
[y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
90-
);
91-
return @bitCast(u128, product);
83+
switch (half) {
84+
.hi => {
85+
const product = asm (
86+
\\ vpclmulqdq $0x11, %[x], %[y], %[out]
87+
: [out] "=x" (-> @Vector(2, u64)),
88+
: [x] "x" (@bitCast(@Vector(2, u64), x)),
89+
[y] "x" (@bitCast(@Vector(2, u64), y)),
90+
);
91+
return @bitCast(u128, product);
92+
},
93+
.lo => {
94+
const product = asm (
95+
\\ vpclmulqdq $0x00, %[x], %[y], %[out]
96+
: [out] "=x" (-> @Vector(2, u64)),
97+
: [x] "x" (@bitCast(@Vector(2, u64), x)),
98+
[y] "x" (@bitCast(@Vector(2, u64), y)),
99+
);
100+
return @bitCast(u128, product);
101+
},
102+
.hi_lo => {
103+
const product = asm (
104+
\\ vpclmulqdq $0x10, %[x], %[y], %[out]
105+
: [out] "=x" (-> @Vector(2, u64)),
106+
: [x] "x" (@bitCast(@Vector(2, u64), x)),
107+
[y] "x" (@bitCast(@Vector(2, u64), y)),
108+
);
109+
return @bitCast(u128, product);
110+
},
92111
}
93112
}
94113

95114
// Carryless multiplication of two 64-bit integers for ARM crypto.
96115
inline fn clmulPmull(x: u128, y: u128, comptime half: Selector) u128 {
97-
if (half == .hi) {
98-
const product = asm (
99-
\\ pmull2 %[out].1q, %[x].2d, %[y].2d
100-
: [out] "=w" (-> @Vector(2, u64)),
101-
: [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
102-
[y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
103-
);
104-
return @bitCast(u128, product);
105-
} else {
106-
const product = asm (
107-
\\ pmull %[out].1q, %[x].1d, %[y].1d
108-
: [out] "=w" (-> @Vector(2, u64)),
109-
: [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
110-
[y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
111-
);
112-
return @bitCast(u128, product);
116+
switch (half) {
117+
.hi => {
118+
const product = asm (
119+
\\ pmull2 %[out].1q, %[x].2d, %[y].2d
120+
: [out] "=w" (-> @Vector(2, u64)),
121+
: [x] "w" (@bitCast(@Vector(2, u64), x)),
122+
[y] "w" (@bitCast(@Vector(2, u64), y)),
123+
);
124+
return @bitCast(u128, product);
125+
},
126+
.lo => {
127+
const product = asm (
128+
\\ pmull %[out].1q, %[x].1d, %[y].1d
129+
: [out] "=w" (-> @Vector(2, u64)),
130+
: [x] "w" (@bitCast(@Vector(2, u64), x)),
131+
[y] "w" (@bitCast(@Vector(2, u64), y)),
132+
);
133+
return @bitCast(u128, product);
134+
},
135+
.hi_lo => {
136+
const product = asm (
137+
\\ pmull %[out].1q, %[x].1d, %[y].1d
138+
: [out] "=w" (-> @Vector(2, u64)),
139+
: [x] "w" (@bitCast(@Vector(2, u64), x >> 64)),
140+
[y] "w" (@bitCast(@Vector(2, u64), y)),
141+
);
142+
return @bitCast(u128, product);
143+
},
113144
}
114145
}
115146

@@ -144,38 +175,63 @@ pub const Ghash = struct {
144175
(z3 & 0x88888888888888888888888888888888) ^ extra;
145176
}
146177

178+
const I256 = struct {
179+
hi: u128,
180+
lo: u128,
181+
mid: u128,
182+
};
183+
184+
inline fn xor256(x: *I256, y: I256) void {
185+
x.* = I256{
186+
.hi = x.hi ^ y.hi,
187+
.lo = x.lo ^ y.lo,
188+
.mid = x.mid ^ y.mid,
189+
};
190+
}
191+
147192
// Square a 128-bit integer in GF(2^128).
148-
fn clsq128(x: u128) u256 {
149-
const lo = @truncate(u64, x);
150-
const hi = @truncate(u64, x >> 64);
151-
const mid = lo ^ hi;
152-
const r_lo = clmul(x, x, .lo);
153-
const r_hi = clmul(x, x, .hi);
154-
const r_mid = clmul(mid, mid, .lo) ^ r_lo ^ r_hi;
155-
return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
193+
fn clsq128(x: u128) I256 {
194+
return .{
195+
.hi = clmul(x, x, .hi),
196+
.lo = clmul(x, x, .lo),
197+
.mid = 0,
198+
};
156199
}
157200

158201
// Multiply two 128-bit integers in GF(2^128).
159-
inline fn clmul128(x: u128, y: u128) u256 {
160-
const x_hi = @truncate(u64, x >> 64);
161-
const y_hi = @truncate(u64, y >> 64);
162-
const r_lo = clmul(x, y, .lo);
163-
const r_hi = clmul(x, y, .hi);
164-
const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
165-
return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
202+
inline fn clmul128(x: u128, y: u128) I256 {
203+
if (mul_algorithm == .karatsuba) {
204+
const x_hi = @truncate(u64, x >> 64);
205+
const y_hi = @truncate(u64, y >> 64);
206+
const r_lo = clmul(x, y, .lo);
207+
const r_hi = clmul(x, y, .hi);
208+
const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
209+
return .{
210+
.hi = r_hi,
211+
.lo = r_lo,
212+
.mid = r_mid,
213+
};
214+
} else {
215+
return .{
216+
.hi = clmul(x, y, .hi),
217+
.lo = clmul(x, y, .lo),
218+
.mid = clmul(x, y, .hi_lo) ^ clmul(y, x, .hi_lo),
219+
};
220+
}
166221
}
167222

168223
// Reduce a 256-bit representative of a polynomial modulo the irreducible polynomial x^128 + x^127 + x^126 + x^121 + 1.
169224
// This is done *without reversing the bits*, using Shay Gueron's black magic demysticated here:
170225
// https://blog.quarkslab.com/reversing-a-finite-field-multiplication-optimization.html
171-
inline fn gcmReduce(x: u256) u128 {
226+
inline fn gcmReduce(x: I256) u128 {
227+
const hi = x.hi ^ (x.mid >> 64);
228+
const lo = x.lo ^ (x.mid << 64);
172229
const p64 = (((1 << 121) | (1 << 126) | (1 << 127)) >> 64);
173-
const lo = @truncate(u128, x);
174230
const a = clmul(lo, p64, .lo);
175231
const b = ((lo << 64) | (lo >> 64)) ^ a;
176232
const c = clmul(b, p64, .lo);
177233
const d = ((b << 64) | (b >> 64)) ^ c;
178-
return d ^ @truncate(u128, x >> 128);
234+
return d ^ hi;
179235
}
180236

181237
const has_pclmul = std.Target.x86.featureSetHas(builtin.cpu.features, .pclmul);
@@ -202,7 +258,7 @@ pub const Ghash = struct {
202258
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[15 - 0]);
203259
comptime var j = 1;
204260
inline while (j < 16) : (j += 1) {
205-
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]);
261+
xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]));
206262
}
207263
acc = gcmReduce(u);
208264
}
@@ -212,7 +268,7 @@ pub const Ghash = struct {
212268
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[7 - 0]);
213269
comptime var j = 1;
214270
inline while (j < 8) : (j += 1) {
215-
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]);
271+
xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]));
216272
}
217273
acc = gcmReduce(u);
218274
}
@@ -222,31 +278,25 @@ pub const Ghash = struct {
222278
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[3 - 0]);
223279
comptime var j = 1;
224280
inline while (j < 4) : (j += 1) {
225-
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]);
281+
xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]));
226282
}
227283
acc = gcmReduce(u);
228284
}
229-
} else if (msg.len >= agg_2_treshold * block_length) {
230-
// 2-blocks aggregated reduction
231-
while (i + 32 <= msg.len) : (i += 32) {
232-
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
233-
comptime var j = 1;
234-
inline while (j < 2) : (j += 1) {
235-
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]);
236-
}
237-
acc = gcmReduce(u);
285+
}
286+
// 2-blocks aggregated reduction
287+
while (i + 32 <= msg.len) : (i += 32) {
288+
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
289+
comptime var j = 1;
290+
inline while (j < 2) : (j += 1) {
291+
xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]));
238292
}
293+
acc = gcmReduce(u);
239294
}
240295
// remaining blocks
241296
if (i < msg.len) {
242-
const n = (msg.len - i) / 16;
243-
var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[n - 1 - 0]);
244-
var j: usize = 1;
245-
while (j < n) : (j += 1) {
246-
u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[n - 1 - j]);
247-
}
248-
i += n * 16;
297+
const u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[0]);
249298
acc = gcmReduce(u);
299+
i += 16;
250300
}
251301
assert(i == msg.len);
252302
st.acc = acc;

0 commit comments

Comments
 (0)