Skip to content

Commit 134646f

Browse files
authored
Optimize multiplication for Normed (#213)
This adds `wrapping_mul`, `saturating_mul` and `checked_mul` binary operations. However, this does not specialize them for `Fixed` and does not change `*` for `Fixed`. This replaces most of Normed's implementation of multiplication with integer operations. This improves the speed in many cases and the accuracy in some cases.
1 parent 5794adf commit 134646f

File tree

4 files changed

+108
-4
lines changed

4 files changed

+108
-4
lines changed

src/FixedPointNumbers.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ wrapping_neg(x::X) where {X <: FixedPoint} = X(-x.i, 0)
206206
wrapping_abs(x::X) where {X <: FixedPoint} = X(abs(x.i), 0)
207207
wrapping_add(x::X, y::X) where {X <: FixedPoint} = X(x.i + y.i, 0)
208208
wrapping_sub(x::X, y::X) where {X <: FixedPoint} = X(x.i - y.i, 0)
209+
wrapping_mul(x::X, y::X) where {X <: FixedPoint} = (float(x) * float(y)) % X
209210

210211
# saturating arithmetic
211212
saturating_neg(x::X) where {X <: FixedPoint} = X(~min(x.i - true, x.i), 0)
@@ -222,6 +223,7 @@ saturating_sub(x::X, y::X) where {X <: FixedPoint} =
222223
X(x.i - ifelse(x.i < 0, min(y.i, x.i - typemin(x.i)), max(y.i, x.i - typemax(x.i))), 0)
223224
saturating_sub(x::X, y::X) where {X <: FixedPoint{<:Unsigned}} = X(x.i - min(x.i, y.i), 0)
224225

226+
saturating_mul(x::X, y::X) where {X <: FixedPoint} = clamp(float(x) * float(y), X)
225227

226228
# checked arithmetic
227229
checked_neg(x::X) where {X <: FixedPoint} = checked_sub(zero(X), x)
@@ -241,6 +243,11 @@ function checked_sub(x::X, y::X) where {X <: FixedPoint}
241243
f && throw_overflowerror(:-, x, y)
242244
z
243245
end
246+
function checked_mul(x::X, y::X) where {X <: FixedPoint}
247+
z = float(x) * float(y)
248+
typemin(X) - eps(X)/2 <= z < typemax(X) + eps(X)/2 || throw_overflowerror(:*, x, y)
249+
z % X
250+
end
244251

245252
# default arithmetic
246253
const DEFAULT_ARITHMETIC = :wrapping
@@ -251,7 +258,7 @@ for (op, name) in ((:-, :neg), (:abs, :abs))
251258
$op(x::X) where {X <: FixedPoint} = $f(x)
252259
end
253260
end
254-
for (op, name) in ((:+, :add), (:-, :sub))
261+
for (op, name) in ((:+, :add), (:-, :sub), (:*, :mul))
255262
f = Symbol(DEFAULT_ARITHMETIC, :_, name)
256263
@eval begin
257264
$op(x::X, y::X) where {X <: FixedPoint} = $f(x, y)

src/normed.jl

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ function rem(x::Float64, ::Type{N}) where {f, N <: Normed{UInt64,f}}
127127
reinterpret(N, r << UInt8(f - 53) - unsigned(signed(r) >> 0x35))
128128
end
129129

130-
131130
function (::Type{T})(x::Normed) where {T <: AbstractFloat}
132131
# The following optimization for constant division may cause rounding errors.
133132
# y = reinterpret(x)*(one(rawtype(x))/convert(T, rawone(x)))
@@ -248,8 +247,45 @@ Base.BigFloat(x::Normed) = reinterpret(x) / BigFloat(rawone(x))
248247

249248
Base.Rational(x::Normed) = reinterpret(x)//rawone(x)
250249

251-
# unchecked arithmetic
252-
*(x::T, y::T) where {T <: Normed} = convert(T,convert(floattype(T), x)*convert(floattype(T), y))
250+
# Division by `2^f-1` with RoundNearest. The result would be in the lower half bits.
251+
div_2fm1(x::T, ::Val{f}) where {T, f} = (x + (T(1)<<(f - 1) - 0x1)) ÷ (T(1) << f - 0x1)
252+
div_2fm1(x::T, ::Val{1}) where T = x
253+
div_2fm1(x::UInt16, ::Val{8}) = (((x + 0x80) >> 0x8) + x + 0x80) >> 0x8
254+
div_2fm1(x::UInt32, ::Val{16}) = (((x + 0x8000) >> 0x10) + x + 0x8000) >> 0x10
255+
div_2fm1(x::UInt64, ::Val{32}) = (((x + 0x80000000) >> 0x20) + x + 0x80000000) >> 0x20
256+
div_2fm1(x::UInt128, ::Val{64}) = (((x + 0x8000000000000000) >> 0x40) + x + 0x8000000000000000) >> 0x40
257+
258+
# wrapping arithmetic
259+
function wrapping_mul(x::N, y::N) where {T <: Union{UInt8,UInt16,UInt32,UInt64}, f, N <: Normed{T,f}}
260+
z = widemul(x.i, y.i)
261+
N(div_2fm1(z, Val(Int(f))) % T, 0)
262+
end
263+
264+
# saturating arithmetic
265+
function saturating_mul(x::N, y::N) where {T <: Union{UInt8,UInt16,UInt32,UInt64}, f, N <: Normed{T,f}}
266+
f == bitwidth(T) && return wrapping_mul(x, y)
267+
z = min(widemul(x.i, y.i), widemul(typemax(N).i, rawone(N)))
268+
N(div_2fm1(z, Val(Int(f))) % T, 0)
269+
end
270+
271+
# checked arithmetic
272+
function checked_mul(x::N, y::N) where {N <: Normed}
273+
z = float(x) * float(y)
274+
z < typemax(N) + eps(N)/2 || throw_overflowerror(:*, x, y)
275+
z % N
276+
end
277+
function checked_mul(x::N, y::N) where {T <: Union{UInt8,UInt16,UInt32,UInt64}, f, N <: Normed{T,f}}
278+
f == bitwidth(T) && return wrapping_mul(x, y)
279+
z = widemul(x.i, y.i)
280+
m = widemul(typemax(N).i, rawone(N)) + (rawone(N) >> 0x1)
281+
z < m || throw_overflowerror(:*, x, y)
282+
N(div_2fm1(z, Val(Int(f))) % T, 0)
283+
end
284+
285+
# TODO: decide the default arithmetic for `Normed` mul
286+
# Override the default arithmetic with `checked` for backward compatibility
287+
*(x::N, y::N) where {N <: Normed} = checked_mul(x, y)
288+
253289
/(x::T, y::T) where {T <: Normed} = convert(T,convert(floattype(T), x)/convert(floattype(T), y))
254290

255291
# Functions

test/fixed.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,40 @@ end
373373
end
374374
end
375375

376+
@testset "mul" begin
377+
for F in target(Fixed; ex = :thin)
378+
@test wrapping_mul(typemax(F), zero(F)) === zero(F)
379+
@test saturating_mul(typemax(F), zero(F)) === zero(F)
380+
@test checked_mul(typemax(F), zero(F)) === zero(F)
381+
382+
# FIXME: Both the rhs and lhs of the following tests may be inaccurate due to `rem`
383+
F === Fixed{Int128,127} && continue
384+
385+
@test wrapping_mul(F(-1), typemax(F)) === -typemax(F)
386+
@test saturating_mul(F(-1), typemax(F)) === -typemax(F)
387+
@test checked_mul(F(-1), typemax(F)) === -typemax(F)
388+
389+
@test wrapping_mul(typemin(F), typemax(F)) === big(typemin(F)) * big(typemax(F)) % F
390+
if typemin(F) != -1
391+
@test saturating_mul(typemin(F), typemax(F)) === typemin(F)
392+
@test_throws OverflowError checked_mul(typemin(F), typemax(F))
393+
end
394+
395+
@test wrapping_mul(typemin(F), typemin(F)) === big(typemin(F))^2 % F
396+
@test saturating_mul(typemin(F), typemin(F)) === typemax(F)
397+
@test_throws OverflowError checked_mul(typemin(F), typemin(F))
398+
end
399+
for F in target(Fixed, :i8; ex = :thin)
400+
xs = typemin(F):eps(F):typemax(F)
401+
xys = ((x, y) for x in xs, y in xs)
402+
fmul(x, y) = float(x) * float(y) # note that precision(Float32) < 32
403+
@test all(((x, y),) -> wrapping_mul(x, y) === fmul(x, y) % F, xys)
404+
@test all(((x, y),) -> saturating_mul(x, y) === clamp(fmul(x, y), F), xys)
405+
@test all(((x, y),) -> !(typemin(F) <= fmul(x, y) <= typemax(F)) ||
406+
wrapping_mul(x, y) === checked_mul(x, y), xys)
407+
end
408+
end
409+
376410
@testset "rounding" begin
377411
for sym in (:i8, :i16, :i32, :i64)
378412
T = symbol_to_inttype(Fixed, sym)

test/normed.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,33 @@ end
394394
end
395395
end
396396

397+
@testset "mul" begin
398+
for N in target(Normed; ex = :thin)
399+
@test wrapping_mul(typemax(N), zero(N)) === zero(N)
400+
@test saturating_mul(typemax(N), zero(N)) === zero(N)
401+
@test checked_mul(typemax(N), zero(N)) === zero(N)
402+
403+
@test wrapping_mul(one(N), typemax(N)) === typemax(N)
404+
@test saturating_mul(one(N), typemax(N)) === typemax(N)
405+
@test checked_mul(one(N), typemax(N)) === typemax(N)
406+
407+
@test wrapping_mul(typemax(N), typemax(N)) === big(typemax(N))^2 % N
408+
@test saturating_mul(typemax(N), typemax(N)) === typemax(N)
409+
if typemax(N) != 1
410+
@test_throws OverflowError checked_mul(typemax(N), typemax(N))
411+
end
412+
end
413+
for N in target(Normed, :i8; ex = :thin)
414+
xs = typemin(N):eps(N):typemax(N)
415+
xys = ((x, y) for x in xs, y in xs)
416+
fmul(x, y) = float(x) * float(y) # note that precision(Float32) < 32
417+
@test all(((x, y),) -> wrapping_mul(x, y) === fmul(x, y) % N, xys)
418+
@test all(((x, y),) -> saturating_mul(x, y) === clamp(fmul(x, y), N), xys)
419+
@test all(((x, y),) -> !(typemin(N) <= fmul(x, y) <= typemax(N)) ||
420+
wrapping_mul(x, y) === checked_mul(x, y), xys)
421+
end
422+
end
423+
397424
@testset "div/fld1" begin
398425
@test div(reinterpret(N0f8, 0x10), reinterpret(N0f8, 0x02)) == fld(reinterpret(N0f8, 0x10), reinterpret(N0f8, 0x02)) == 8
399426
@test div(reinterpret(N0f8, 0x0f), reinterpret(N0f8, 0x02)) == fld(reinterpret(N0f8, 0x0f), reinterpret(N0f8, 0x02)) == 7

0 commit comments

Comments
 (0)