Skip to content

Commit a782be8

Browse files
committed
Improve accuracy of conversions to floating-point numbers
1 parent ee5bd54 commit a782be8

File tree

2 files changed

+173
-18
lines changed

2 files changed

+173
-18
lines changed

src/normed.jl

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,135 @@ rem(x::Float16, ::Type{T}) where {T <: Normed} = rem(Float32(x), T) # avoid ove
105105

106106
float(x::Normed) = convert(floattype(x), x)
107107

108-
Base.BigFloat(x::Normed) = reinterpret(x)*(1/BigFloat(rawone(x)))
108+
macro f32(x::Float64) # just for hexadecimal floating-point literals
109+
:(Float32($x))
110+
end
111+
macro exp2(n)
112+
:(_exp2(Val($(esc(n)))))
113+
end
114+
@generated _exp2(::Val{N}) where {N} = (x = exp2(N); :($x))
115+
116+
# for Julia v1.0, which does not fold `div_float` before inlining
117+
@generated inv_rawone(::T) where {T <: Normed} = (x = 1.0 / rawone(T); :($x))
118+
109119
function (::Type{T})(x::Normed) where {T <: AbstractFloat}
110-
y = reinterpret(x)*(one(rawtype(x))/convert(T, rawone(x)))
120+
# The following optimization for constant division may cause rounding errors.
121+
# y = reinterpret(x)*(one(rawtype(x))/convert(T, rawone(x)))
122+
# Therefore, we use a simple form here.
123+
# If you prefer speed over accuracy, consider using `scaledual` instead.
124+
y = reinterpret(x) / convert(promote_type(T, floattype(x)), rawone(x))
111125
convert(T, y) # needed for types like Float16 which promote arithmetic to Float32
112126
end
127+
128+
function Base.Float16(x::Normed{Ti,f}) where {Ti <: Union{UInt8, UInt16, UInt32}, f}
129+
f == 1 ? Float16(x.i) : Float16(Float32(x))
130+
end
131+
function Base.Float16(x::Normed{Ti,f}) where {Ti <: Union{UInt64, UInt128}, f}
132+
f == 1 ? Float16(x.i) : Float16(Float64(x))
133+
end
134+
135+
function Base.Float32(x::Normed{UInt8,f}) where f
136+
f == 1 && return Float32(x.i)
137+
f == 2 && return Float32(Int32(x.i) * 0x101) * @f32(0x550055p-32)
138+
f == 3 && return Float32(Int32(x.i) * 0x00b) * @f32(0xd4c77bp-30)
139+
f == 4 && return Float32(Int32(x.i) * 0x101) * @f32(0x110011p-32)
140+
f == 5 && return Float32(Int32(x.i) * 0x003) * @f32(0xb02c0bp-30)
141+
f == 6 && return Float32(Int32(x.i) * 0x049) * @f32(0xe40039p-36)
142+
f == 7 && return Float32(Int32(x.i) * 0x01f) * @f32(0x852b5fp-35)
143+
f == 8 && return Float32(Int32(x.i) * 0x155) * @f32(0xc0f0fdp-40)
144+
0.0f0
145+
end
146+
function Base.Float32(x::Normed{UInt16,f}) where {f}
147+
f32 = Float32(x.i)
148+
f == 1 && return f32
149+
f == 2 && return f32 * @f32(0x55p-8) + f32 * @f32(0x555555p-32)
150+
f == 3 && return f32 * @f32(0x49p-9) + f32 * @f32(0x249249p-33)
151+
f == 4 && return f32 * @f32(0x11p-8) + f32 * @f32(0x111111p-32)
152+
f == 5 && return f32 * @f32(0x21p-10) + f32 * @f32(0x108421p-35)
153+
f == 6 && return f32 * @f32(0x41p-12) + f32 * @f32(0x041041p-36)
154+
f == 7 && return f32 * @f32(0x81p-14) + f32 * @f32(0x204081p-42)
155+
f == 16 && return f32 * @f32(0x01p-16) + f32 * @f32(0x010001p-48)
156+
Float32(x.i / rawone(x))
157+
end
158+
function Base.Float32(x::Normed{UInt32,f})::Float32 where f
159+
f == 1 && return Float32(x.i)
160+
i32 = unsafe_trunc(Int32, x.i)
161+
if f == 32
162+
rh, rl = Float32(i32>>>16), Float32((i32&0xFFFF)<<8 | (i32>>>24))
163+
return muladd(rh, @f32(0x1p-16), rl * @f32(0x1p-40))
164+
elseif f >= 25
165+
rh, rl = Float32(i32>>>16),Float32(((i32&0xFFFF)<<14) + (i32>>>(f-14)))
166+
return muladd(rh, Float32(@exp2(16-f)), rl * Float32(@exp2(-14-f)))
167+
end
168+
# FIXME: avoid the branch in native x86_64 (non-SIMD) codes
169+
m = ifelse(i32 < 0, 0x1p32 * inv_rawone(x), 0.0)
170+
Float32(muladd(Float64(i32), inv_rawone(x), m))
171+
end
172+
function Base.Float32(x::Normed{Ti,f}) where {Ti <: Union{UInt64, UInt128}, f}
173+
f == 1 ? Float32(x.i) : Float32(Float64(x))
174+
end
175+
176+
function Base.Float64(x::Normed{Ti,f}) where {Ti <: Union{UInt8, UInt16}, f}
177+
Float64(Normed{UInt32,f}(x))
178+
end
179+
function Base.Float64(x::Normed{UInt32,f}) where f
180+
f64 = Float64(x.i)
181+
f == 1 && return f64
182+
f == 2 && return (f64 * 0x040001) * 0x15555000015555p-72
183+
f == 3 && return (f64 * 0x108421) * 0x11b6db76924929p-75
184+
f == 4 && return (f64 * 0x010101) * 0x11000011000011p-72
185+
f == 5 && return (f64 * 0x108421) * 0x04000002000001p-75
186+
f == 6 && return (f64 * 0x09dfb1) * 0x1a56b8e38e6d91p-78
187+
f == 7 && return (f64 * 0x000899) * 0x0f01480001e029p-70
188+
f == 8 && return (f64 * 0x0a5a5b) * 0x18d300000018d3p-80
189+
f == 9 && return (f64 * 0x001001) * 0x080381c8e3f201p-72
190+
f == 10 && return (f64 * 0x100001) * 0x04010000000401p-80
191+
f == 11 && return (f64 * 0x000009) * 0x0e3aaae3955639p-66
192+
f == 12 && return (f64 * 0x0a8055) * 0x186246e46e4cfdp-84
193+
f == 13 && return (f64 * 0x002001) * 0x10000004000001p-78
194+
f == 14 && return (f64 * 0x03400d) * 0x13b13b14ec4ec5p-84
195+
f == 15 && return (f64 * 0x000259) * 0x06d0c5a4f3a5e9p-75
196+
f == 16 && return (f64 * 0x011111) * 0x00f000ff00fff1p-80
197+
f == 18 && return (f64 * 0x0b06d1) * 0x17377445dd1231p-90
198+
f == 19 && return (f64 * 0x080001) * 0x00004000000001p-76
199+
f == 20 && return (f64 * 0x000101) * 0x0ff010ef10ff01p-80
200+
f == 21 && return (f64 * 0x004001) * 0x01fff8101fc001p-84
201+
f == 22 && return (f64 * 0x002945) * 0x18d0000000018dp-88
202+
f == 23 && return (f64 * 0x044819) * 0x07794a23729429p-92
203+
f == 27 && return (f64 * 0x000a21) * 0x0006518c7df9e1p-81
204+
f == 28 && return (f64 * 0x00000d) * 0x13b13b14ec4ec5p-84
205+
f == 30 && return (f64 * 0x001041) * 0x00fc003f03ffc1p-90
206+
f == 32 && return (f64 * 0x010101) * 0x00ff0000ffff01p-96
207+
f64 / rawone(x)
208+
end
209+
function Base.Float64(x::Normed{UInt64,f}) where f
210+
f == 1 && return Float64(x.i)
211+
if f >= 53
212+
rh = Float64(unsafe_trunc(Int64, x.i >> 16)) * @exp2(16-f) # upper 48 bits
213+
rl = Float64(unsafe_trunc(Int32, x.i&0xFFFF)) * @exp2(-f) # lower 16 bits
214+
return rh + muladd(rh, @exp2(-f), rl)
215+
end
216+
x.i / rawone(x)
217+
end
218+
function Base.Float64(x::Normed{UInt128,f}) where f
219+
f == 1 && return Float64(x.i)
220+
ih, il = unsafe_trunc(Int64, x.i>>64), unsafe_trunc(Int64, x.i)
221+
rh = Float64(ih>>>16) * @exp2(f <= 53 ? 80 : 80 - f) # upper 48 bits
222+
km = @exp2(f <= 53 ? 48 : 48 - f) # for middle 32 bits
223+
rm = Float64(unsafe_trunc(Int32, ih&0xFFFF)) * (0x1p16 * km) +
224+
Float64(unsafe_trunc(Int32, il>>>48)) * km
225+
rl = Float64(il&0xFFFFFFFFFFFF) * @exp2(f <= 53 ? 0 : -f) # lower 48 bits
226+
if f <= 53
227+
return (rh + (rm + rl)) / rawone(x)
228+
elseif f < 76
229+
return rh + (rm + muladd(rh, @exp2(-f), rl))
230+
else
231+
return rh + (rm + rl)
232+
end
233+
end
234+
235+
Base.BigFloat(x::Normed) = reinterpret(x)*(1/BigFloat(rawone(x)))
236+
113237
Base.Bool(x::Normed) = x == zero(x) ? false : true
114238
Base.Integer(x::Normed) = convert(Integer, x*1.0)
115239
(::Type{T})(x::Normed) where {T <: Integer} = convert(T, x*(1/oneunit(T)))

test/normed.jl

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ end
112112

113113
@test reinterpret(U(zero(Tf))) == 0x0
114114

115-
# TODO: fix issue #129
116-
# input_typemax = Tf(typemax(U))
117-
input_typemax = Tf(BigFloat(typemax(T)) / r)
115+
input_typemax = Tf(typemax(U))
118116
if isinf(input_typemax)
119117
@test reinterpret(U(floatmax(Tf))) >= round(T, floatmax(Tf))
120118
else
@@ -136,6 +134,46 @@ end
136134
@test N0f32(Float32(0x0.800001p-32)) == eps(N0f32)
137135
end
138136

137+
@testset "conversions to float" begin
138+
x = N0f8(0.3)
139+
for T in (Float16, Float32, Float64, BigFloat)
140+
y = convert(T, x)
141+
@test isa(y, T)
142+
end
143+
144+
for Tf in (Float16, Float32, Float64)
145+
@testset "$Tf(::Normed{$Ti})" for Ti in (UInt8, UInt16)
146+
@testset "$Tf(::Normed{$Ti,$f})" for f = 1:(sizeof(Ti)*8)
147+
T = Normed{Ti,f}
148+
float_err = 0.0
149+
for i = typemin(Ti):typemax(Ti)
150+
f_expected = Tf(i / BigFloat(FixedPointNumbers.rawone(T)))
151+
isinf(f_expected) && break # for Float16(::Normed{UInt16,1})
152+
f_actual = Tf(reinterpret(T, i))
153+
float_err += abs(f_actual - f_expected)
154+
end
155+
@test float_err == 0.0
156+
end
157+
end
158+
@testset "$Tf(::Normed{$Ti})" for Ti in (UInt32, UInt64, UInt128)
159+
@testset "$Tf(::Normed{$Ti,$f})" for f = 1:(sizeof(Ti)*8)
160+
T = Normed{Ti,f}
161+
error_count = 0
162+
for i in vcat(Ti(0x00):Ti(0xFF), (typemax(Ti)-0xFF):typemax(Ti))
163+
f_expected = Tf(i / BigFloat(FixedPointNumbers.rawone(T)))
164+
isinf(f_expected) && break # for Float16() and Float32()
165+
f_actual = Tf(reinterpret(T, i))
166+
f_actual == f_expected && continue
167+
f_actual == prevfloat(f_expected) && continue
168+
f_actual == nextfloat(f_expected) && continue
169+
error_count += 1
170+
end
171+
@test error_count == 0
172+
end
173+
end
174+
end
175+
end
176+
139177
@testset "modulus" begin
140178
@test N0f8(0.2) % N0f8 === N0f8(0.2)
141179
@test N2f14(1.2) % N0f16 === N0f16(0.20002)
@@ -313,25 +351,26 @@ end
313351
af8 = reinterpret(N0f8, a)
314352
b = 0.5
315353

354+
# LHSs of the following `@test`s with `af8` can be slightly more accurate
316355
bd, eld = scaledual(b, af8[1])
317-
@test b*af8[1] == bd*eld
356+
@test b*af8[1] bd*eld rtol=1e-15
318357
bd, ad = scaledual(b, af8)
319-
@test b*af8 == bd*ad
358+
@test b*af8 bd*ad rtol=1e-15
320359

321360
bd, eld = scaledual(b, a[1])
322361
@test b*a[1] == bd*eld
323362
bd, ad = scaledual(b, a)
324363
@test b*a == bd*ad
325364

326365
bd, eld = scaledual(Float64, af8[1])
327-
@test 1.0*af8[1] == bd*eld
366+
@test 1.0*af8[1] bd*eld rtol=1e-15
328367
bd, ad = scaledual(Float64, af8)
329-
@test 1.0*af8 == bd*ad
368+
@test 1.0*af8 bd*ad rtol=1e-15
330369

331370
bd, eld = scaledual(Float64, a[1])
332371
@test 1.0*a[1] == bd*eld
333372
bd, ad = scaledual(Float64, a)
334-
@test 1.0*a == bd*ad
373+
@test 1.0*a == bd*ad
335374
end
336375

337376
@testset "reductions" begin
@@ -345,14 +384,6 @@ end
345384
@test prod(a, dims=1) == [acmp]
346385
end
347386

348-
@testset "convert" begin
349-
x = N0f8(0.3)
350-
for T in (Float16, Float32, Float64, BigFloat)
351-
y = convert(T, x)
352-
@test isa(y, T)
353-
end
354-
end
355-
356387
@testset "rand" begin
357388
for T in (Normed{UInt8,8}, Normed{UInt8,6},
358389
Normed{UInt16,16}, Normed{UInt16,14},

0 commit comments

Comments
 (0)