Skip to content

Commit 103b885

Browse files
author
expikr
authored
math.hypot: fix incorrect over/underflow behavior (#19472)
1 parent b2588de commit 103b885

File tree

3 files changed

+113
-144
lines changed

3 files changed

+113
-144
lines changed

lib/std/math.zig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub const floatTrueMin = @import("math/float.zig").floatTrueMin;
5252
pub const floatMin = @import("math/float.zig").floatMin;
5353
pub const floatMax = @import("math/float.zig").floatMax;
5454
pub const floatEps = @import("math/float.zig").floatEps;
55+
pub const floatEpsAt = @import("math/float.zig").floatEpsAt;
5556
pub const inf = @import("math/float.zig").inf;
5657
pub const nan = @import("math/float.zig").nan;
5758
pub const snan = @import("math/float.zig").snan;

lib/std/math/float.zig

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ pub inline fn floatEps(comptime T: type) T {
9494
return reconstructFloat(T, -floatFractionalBits(T), mantissaOne(T));
9595
}
9696

97+
/// Returns the local epsilon of floating point type T.
98+
pub inline fn floatEpsAt(comptime T: type, x: T) T {
99+
switch (@typeInfo(T)) {
100+
.Float => |F| {
101+
const U: type = @Type(.{ .Int = .{ .signedness = .unsigned, .bits = F.bits } });
102+
const u: U = @bitCast(x);
103+
const y: T = @bitCast(u ^ 1);
104+
return @abs(x - y);
105+
},
106+
else => @compileError("floatEpsAt only supports floats"),
107+
}
108+
}
109+
97110
/// Returns the value inf for floating point type T.
98111
pub inline fn inf(comptime T: type) T {
99112
return reconstructFloat(T, floatExponentMax(T) + 1, mantissaOne(T));

lib/std/math/hypot.zig

Lines changed: 99 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,176 +1,131 @@
1-
// Ported from musl, which is licensed under the MIT license:
2-
// https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
3-
//
4-
// https://git.musl-libc.org/cgit/musl/tree/src/math/hypotf.c
5-
// https://git.musl-libc.org/cgit/musl/tree/src/math/hypot.c
6-
71
const std = @import("../std.zig");
82
const math = std.math;
93
const expect = std.testing.expect;
10-
const maxInt = std.math.maxInt;
4+
const isNan = math.isNan;
5+
const isInf = math.isInf;
6+
const inf = math.inf;
7+
const nan = math.nan;
8+
const floatEpsAt = math.floatEpsAt;
9+
const floatEps = math.floatEps;
10+
const floatMin = math.floatMin;
11+
const floatMax = math.floatMax;
1112

1213
/// Returns sqrt(x * x + y * y), avoiding unnecessary overflow and underflow.
1314
///
1415
/// Special Cases:
1516
///
1617
/// | x | y | hypot |
1718
/// |-------|-------|-------|
18-
/// | +inf | num | +inf |
19-
/// | num | +-inf | +inf |
20-
/// | nan | any | nan |
21-
/// | any | nan | nan |
19+
/// | +-inf | any | +inf |
20+
/// | any | +-inf | +inf |
21+
/// | nan | fin | nan |
22+
/// | fin | nan | nan |
2223
pub fn hypot(x: anytype, y: anytype) @TypeOf(x, y) {
2324
const T = @TypeOf(x, y);
24-
return switch (T) {
25-
f32 => hypot32(x, y),
26-
f64 => hypot64(x, y),
25+
switch (@typeInfo(T)) {
26+
.Float => {},
27+
.ComptimeFloat => return @sqrt(x * x + y * y),
2728
else => @compileError("hypot not implemented for " ++ @typeName(T)),
28-
};
29-
}
30-
31-
fn hypot32(x: f32, y: f32) f32 {
32-
var ux = @as(u32, @bitCast(x));
33-
var uy = @as(u32, @bitCast(y));
34-
35-
ux &= maxInt(u32) >> 1;
36-
uy &= maxInt(u32) >> 1;
37-
if (ux < uy) {
38-
const tmp = ux;
39-
ux = uy;
40-
uy = tmp;
4129
}
42-
43-
var xx = @as(f32, @bitCast(ux));
44-
var yy = @as(f32, @bitCast(uy));
45-
if (uy == 0xFF << 23) {
46-
return yy;
47-
}
48-
if (ux >= 0xFF << 23 or uy == 0 or ux - uy >= (25 << 23)) {
49-
return xx + yy;
50-
}
51-
52-
var z: f32 = 1.0;
53-
if (ux >= (0x7F + 60) << 23) {
54-
z = 0x1.0p90;
55-
xx *= 0x1.0p-90;
56-
yy *= 0x1.0p-90;
57-
} else if (uy < (0x7F - 60) << 23) {
58-
z = 0x1.0p-90;
59-
xx *= 0x1.0p-90;
60-
yy *= 0x1.0p-90;
30+
const lower = @sqrt(floatMin(T));
31+
const upper = @sqrt(floatMax(T) / 2);
32+
const incre = @sqrt(floatEps(T) / 2);
33+
const scale = floatEpsAt(T, incre);
34+
const hypfn = if (emulateFma(T)) hypotUnfused else hypotFused;
35+
var major: T = x;
36+
var minor: T = y;
37+
if (isInf(major) or isInf(minor)) return inf(T);
38+
if (isNan(major) or isNan(minor)) return nan(T);
39+
if (T == f16) return @floatCast(@sqrt(@mulAdd(f32, x, x, @as(f32, y) * y)));
40+
if (T == f32) return @floatCast(@sqrt(@mulAdd(f64, x, x, @as(f64, y) * y)));
41+
major = @abs(major);
42+
minor = @abs(minor);
43+
if (minor > major) {
44+
const tempo = major;
45+
major = minor;
46+
minor = tempo;
6147
}
62-
63-
return z * @sqrt(@as(f32, @floatCast(@as(f64, x) * x + @as(f64, y) * y)));
48+
if (major * incre >= minor) return major;
49+
if (major > upper) return hypfn(T, major * scale, minor * scale) / scale;
50+
if (minor < lower) return hypfn(T, major / scale, minor / scale) * scale;
51+
return hypfn(T, major, minor);
6452
}
6553

66-
fn sq(hi: *f64, lo: *f64, x: f64) void {
67-
const split: f64 = 0x1.0p27 + 1.0;
68-
const xc = x * split;
69-
const xh = x - xc + xc;
70-
const xl = x - xh;
71-
hi.* = x * x;
72-
lo.* = xh * xh - hi.* + 2 * xh * xl + xl * xl;
54+
inline fn emulateFma(comptime T: type) bool {
55+
// If @mulAdd lowers to the software implementation,
56+
// hypotUnfused should be used in place of hypotFused.
57+
// This takes an educated guess, but ideally we should
58+
// properly detect at comptime when that fallback will
59+
// occur.
60+
return (T == f128 or T == f80);
7361
}
7462

75-
fn hypot64(x: f64, y: f64) f64 {
76-
var ux = @as(u64, @bitCast(x));
77-
var uy = @as(u64, @bitCast(y));
78-
79-
ux &= maxInt(u64) >> 1;
80-
uy &= maxInt(u64) >> 1;
81-
if (ux < uy) {
82-
const tmp = ux;
83-
ux = uy;
84-
uy = tmp;
85-
}
86-
87-
const ex = ux >> 52;
88-
const ey = uy >> 52;
89-
var xx = @as(f64, @bitCast(ux));
90-
var yy = @as(f64, @bitCast(uy));
91-
92-
// hypot(inf, nan) == inf
93-
if (ey == 0x7FF) {
94-
return yy;
95-
}
96-
if (ex == 0x7FF or uy == 0) {
97-
return xx;
98-
}
99-
100-
// hypot(x, y) ~= x + y * y / x / 2 with inexact for small y/x
101-
if (ex - ey > 64) {
102-
return xx + yy;
103-
}
63+
inline fn hypotFused(comptime F: type, x: F, y: F) F {
64+
const r = @sqrt(@mulAdd(F, x, x, y * y));
65+
const rr = r * r;
66+
const xx = x * x;
67+
const z = @mulAdd(F, -y, y, rr - xx) + @mulAdd(F, r, r, -rr) - @mulAdd(F, x, x, -xx);
68+
return r - z / (2 * r);
69+
}
10470

105-
var z: f64 = 1;
106-
if (ex > 0x3FF + 510) {
107-
z = 0x1.0p700;
108-
xx *= 0x1.0p-700;
109-
yy *= 0x1.0p-700;
110-
} else if (ey < 0x3FF - 450) {
111-
z = 0x1.0p-700;
112-
xx *= 0x1.0p700;
113-
yy *= 0x1.0p700;
71+
inline fn hypotUnfused(comptime F: type, x: F, y: F) F {
72+
const r = @sqrt(x * x + y * y);
73+
if (r <= 2 * y) { // 30deg or steeper
74+
const dx = r - y;
75+
const z = x * (2 * dx - x) + (dx - 2 * (x - y)) * dx;
76+
return r - z / (2 * r);
77+
} else { // shallower than 30 deg
78+
const dy = r - x;
79+
const z = 2 * dy * (x - 2 * y) + (4 * dy - y) * y + dy * dy;
80+
return r - z / (2 * r);
11481
}
115-
116-
var hx: f64 = undefined;
117-
var lx: f64 = undefined;
118-
var hy: f64 = undefined;
119-
var ly: f64 = undefined;
120-
121-
sq(&hx, &lx, x);
122-
sq(&hy, &ly, y);
123-
124-
return z * @sqrt(ly + lx + hy + hx);
12582
}
12683

84+
const hypot_test_cases = .{
85+
.{ 0.0, -1.2, 1.2 },
86+
.{ 0.2, -0.34, 0.3944616584663203993612799816649560759946493601889826495362 },
87+
.{ 0.8923, 2.636890, 2.7837722899152509525110650481670176852603253522923737962880 },
88+
.{ 1.5, 5.25, 5.4600824169603887033229768686452745953332522619323580787836 },
89+
.{ 37.45, 159.835, 164.16372840856167640478217141034363907565754072954443805164 },
90+
.{ 89.123, 382.028905, 392.28687638576315875933966414927490685367196874260165618371 },
91+
.{ 123123.234375, 529428.707813, 543556.88524707706887251269205923830745438413088753096759371 },
92+
};
93+
12794
test hypot {
128-
const x32: f32 = 0.0;
129-
const y32: f32 = -1.2;
130-
const x64: f64 = 0.0;
131-
const y64: f64 = -1.2;
132-
try expect(hypot(x32, y32) == hypot32(0.0, -1.2));
133-
try expect(hypot(x64, y64) == hypot64(0.0, -1.2));
95+
try expect(hypot(0.3, 0.4) == 0.5);
13496
}
13597

136-
test hypot32 {
137-
const epsilon = 0.000001;
138-
139-
try expect(math.approxEqAbs(f32, hypot32(0.0, -1.2), 1.2, epsilon));
140-
try expect(math.approxEqAbs(f32, hypot32(0.2, -0.34), 0.394462, epsilon));
141-
try expect(math.approxEqAbs(f32, hypot32(0.8923, 2.636890), 2.783772, epsilon));
142-
try expect(math.approxEqAbs(f32, hypot32(1.5, 5.25), 5.460083, epsilon));
143-
try expect(math.approxEqAbs(f32, hypot32(37.45, 159.835), 164.163742, epsilon));
144-
try expect(math.approxEqAbs(f32, hypot32(89.123, 382.028905), 392.286865, epsilon));
145-
try expect(math.approxEqAbs(f32, hypot32(123123.234375, 529428.707813), 543556.875, epsilon));
98+
test "hypot.correct" {
99+
inline for (.{ f16, f32, f64, f128 }) |T| {
100+
inline for (hypot_test_cases) |v| {
101+
const a: T, const b: T, const c: T = v;
102+
try expect(math.approxEqRel(T, hypot(a, b), c, @sqrt(floatEps(T))));
103+
}
104+
}
146105
}
147106

148-
test hypot64 {
149-
const epsilon = 0.000001;
150-
151-
try expect(math.approxEqAbs(f64, hypot64(0.0, -1.2), 1.2, epsilon));
152-
try expect(math.approxEqAbs(f64, hypot64(0.2, -0.34), 0.394462, epsilon));
153-
try expect(math.approxEqAbs(f64, hypot64(0.8923, 2.636890), 2.783772, epsilon));
154-
try expect(math.approxEqAbs(f64, hypot64(1.5, 5.25), 5.460082, epsilon));
155-
try expect(math.approxEqAbs(f64, hypot64(37.45, 159.835), 164.163728, epsilon));
156-
try expect(math.approxEqAbs(f64, hypot64(89.123, 382.028905), 392.286876, epsilon));
157-
try expect(math.approxEqAbs(f64, hypot64(123123.234375, 529428.707813), 543556.885247, epsilon));
107+
test "hypot.precise" {
108+
inline for (.{ f16, f32, f64 }) |T| { // f128 seems to be 5 ulp
109+
inline for (hypot_test_cases) |v| {
110+
const a: T, const b: T, const c: T = v;
111+
try expect(math.approxEqRel(T, hypot(a, b), c, floatEps(T)));
112+
}
113+
}
158114
}
159115

160-
test "hypot32.special" {
161-
try expect(math.isPositiveInf(hypot32(math.inf(f32), 0.0)));
162-
try expect(math.isPositiveInf(hypot32(-math.inf(f32), 0.0)));
163-
try expect(math.isPositiveInf(hypot32(0.0, math.inf(f32))));
164-
try expect(math.isPositiveInf(hypot32(0.0, -math.inf(f32))));
165-
try expect(math.isNan(hypot32(math.nan(f32), 0.0)));
166-
try expect(math.isNan(hypot32(0.0, math.nan(f32))));
167-
}
116+
test "hypot.special" {
117+
inline for (.{ f16, f32, f64, f128 }) |T| {
118+
try expect(math.isNan(hypot(nan(T), 0.0)));
119+
try expect(math.isNan(hypot(0.0, nan(T))));
120+
121+
try expect(math.isPositiveInf(hypot(inf(T), 0.0)));
122+
try expect(math.isPositiveInf(hypot(0.0, inf(T))));
123+
try expect(math.isPositiveInf(hypot(inf(T), nan(T))));
124+
try expect(math.isPositiveInf(hypot(nan(T), inf(T))));
168125

169-
test "hypot64.special" {
170-
try expect(math.isPositiveInf(hypot64(math.inf(f64), 0.0)));
171-
try expect(math.isPositiveInf(hypot64(-math.inf(f64), 0.0)));
172-
try expect(math.isPositiveInf(hypot64(0.0, math.inf(f64))));
173-
try expect(math.isPositiveInf(hypot64(0.0, -math.inf(f64))));
174-
try expect(math.isNan(hypot64(math.nan(f64), 0.0)));
175-
try expect(math.isNan(hypot64(0.0, math.nan(f64))));
126+
try expect(math.isPositiveInf(hypot(-inf(T), 0.0)));
127+
try expect(math.isPositiveInf(hypot(0.0, -inf(T))));
128+
try expect(math.isPositiveInf(hypot(-inf(T), nan(T))));
129+
try expect(math.isPositiveInf(hypot(nan(T), -inf(T))));
130+
}
176131
}

0 commit comments

Comments
 (0)