Skip to content

[RFC] Move arithmetic functions into submodule FixedPointArithmetic #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 31 additions & 199 deletions src/FixedPointNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ import Base: ==, <, <=, -, +, *, /, ~, isapprox,

import Random: Random, AbstractRNG, SamplerType, rand!

import Base.Checked: checked_neg, checked_abs, checked_add, checked_sub, checked_mul,
checked_div, checked_fld, checked_cld, checked_rem, checked_mod

using Base: @pure

"""
Expand All @@ -35,14 +32,11 @@ export
# "special" typealiases
# Q and N typealiases are exported in separate source files
# Functions
scaledual,
wrapping_neg, wrapping_abs, wrapping_add, wrapping_sub, wrapping_mul,
wrapping_div, wrapping_fld, wrapping_cld, wrapping_rem, wrapping_mod,
saturating_neg, saturating_abs, saturating_add, saturating_sub, saturating_mul,
saturating_div, saturating_fld, saturating_cld, saturating_rem, saturating_mod,
wrapping_fdiv, saturating_fdiv, checked_fdiv
scaledual

include("utilities.jl")
using .Utilities
import .Utilities: floattype, rawone, nbitsfrac, rawtype, signbits, nbitsint, scaledual

# reinterpretation
reinterpret(x::FixedPoint) = x.i
Expand All @@ -57,18 +51,6 @@ rawtype(::Type{X}) where {T, X <: FixedPoint{T}} = T
signbits(::Type{X}) where {T, X <: FixedPoint{T}} = T <: Unsigned ? 0 : 1
nbitsint(::Type{X}) where {X <: FixedPoint} = bitwidth(X) - nbitsfrac(X) - signbits(X)

# construction using the (approximate) intended value, i.e., N0f8
*(x::Real, ::Type{X}) where {X <: FixedPoint} = _convert(X, x)
wrapping_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = x % X
saturating_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = clamp(x, X)
checked_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = _convert(X, x)

# type modulus
rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)
wrapping_rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)
saturating_rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)
checked_rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)

# constructor-style conversions
(::Type{X})(x::X) where {X <: FixedPoint} = x
(::Type{X})(x::Number) where {X <: FixedPoint} = _convert(X, x)
Expand Down Expand Up @@ -139,9 +121,6 @@ zero(::Type{X}) where {X <: FixedPoint} = X(zero(rawtype(X)), 0)
oneunit(::Type{X}) where {X <: FixedPoint} = X(rawone(X), 0)
one(::Type{X}) where {X <: FixedPoint} = oneunit(X)

# for Julia v1.0, which does not fold `div_float` before inlining
inv_rawone(x) = (@generated) ? (y = 1.0 / rawone(x); :($y)) : 1.0 / rawone(x)

# traits
eps(::Type{X}) where {X <: FixedPoint} = X(oneunit(rawtype(X)), 0)
typemax(::Type{T}) where {T <: FixedPoint} = T(typemax(rawtype(T)), 0)
Expand Down Expand Up @@ -192,164 +171,12 @@ RGB{Float32}

`RGB` itself is not a subtype of `AbstractFloat`, but unlike `RGB{N0f8}` operations with `RGB{Float32}` are not subject to integer overflow.
"""
floattype(::Type{T}) where {T <: AbstractFloat} = T # fallback (we want a MethodError if no method producing AbstractFloat is defined)
floattype(::Type{T}) where {T <: Union{ShortInts, Bool}} = Float32
floattype(::Type{T}) where {T <: Integer} = Float64
floattype(::Type{T}) where {T <: LongInts} = BigFloat
floattype(::Type{T}) where {I <: Integer, T <: Rational{I}} = typeof(zero(I)/oneunit(I))
floattype(::Type{<:AbstractIrrational}) = Float64
floattype(::Type{X}) where {T <: ShortInts, X <: FixedPoint{T}} = Float32
floattype(::Type{X}) where {T <: Integer, X <: FixedPoint{T}} = Float64
floattype(::Type{X}) where {T <: LongInts, X <: FixedPoint{T}} = BigFloat

# Non-Real types
floattype(::Type{Complex{T}}) where T = Complex{floattype(T)}
floattype(::Type{Base.TwicePrecision{Float64}}) = Float64 # wider would be nice, but hardware support is paramount
floattype(::Type{Base.TwicePrecision{T}}) where T<:Union{Float16,Float32} = widen(T)

float(x::FixedPoint) = convert(floattype(x), x)

# wrapping arithmetic
wrapping_neg(x::X) where {X <: FixedPoint} = X(-x.i, 0)
wrapping_abs(x::X) where {X <: FixedPoint} = X(abs(x.i), 0)
wrapping_add(x::X, y::X) where {X <: FixedPoint} = X(x.i + y.i, 0)
wrapping_sub(x::X, y::X) where {X <: FixedPoint} = X(x.i - y.i, 0)
wrapping_mul(x::X, y::X) where {X <: FixedPoint} = (float(x) * float(y)) % X
function wrapping_fdiv(x::X, y::X) where {X <: FixedPoint}
z = floattype(X)(x.i) / floattype(X)(y.i)
isfinite(z) ? z % X : zero(X)
end
function wrapping_div(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
z = round(floattype(X)(x.i) / floattype(X)(y.i), r)
isfinite(z) || return zero(T)
if T <: Unsigned
_unsafe_trunc(T, z)
else
z > typemax(T) ? typemin(T) : _unsafe_trunc(T, z)
end
end
wrapping_fld(x::X, y::X) where {X <: FixedPoint} = wrapping_div(x, y, RoundDown)
wrapping_cld(x::X, y::X) where {X <: FixedPoint} = wrapping_div(x, y, RoundUp)
wrapping_rem(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}} =
X(x.i - wrapping_div(x, y, r) * y.i, 0)
wrapping_mod(x::X, y::X) where {X <: FixedPoint} = wrapping_rem(x, y, RoundDown)

# saturating arithmetic
saturating_neg(x::X) where {X <: FixedPoint} = X(~min(x.i - true, x.i), 0)
saturating_neg(x::X) where {X <: FixedPoint{<:Unsigned}} = zero(X)

saturating_abs(x::X) where {X <: FixedPoint} =
X(ifelse(signbit(abs(x.i)), typemax(x.i), abs(x.i)), 0)

saturating_add(x::X, y::X) where {X <: FixedPoint} =
X(x.i + ifelse(x.i < 0, max(y.i, typemin(x.i) - x.i), min(y.i, typemax(x.i) - x.i)), 0)
saturating_add(x::X, y::X) where {X <: FixedPoint{<:Unsigned}} = X(x.i + min(~x.i, y.i), 0)

saturating_sub(x::X, y::X) where {X <: FixedPoint} =
X(x.i - ifelse(x.i < 0, min(y.i, x.i - typemin(x.i)), max(y.i, x.i - typemax(x.i))), 0)
saturating_sub(x::X, y::X) where {X <: FixedPoint{<:Unsigned}} = X(x.i - min(x.i, y.i), 0)

saturating_mul(x::X, y::X) where {X <: FixedPoint} = clamp(float(x) * float(y), X)

saturating_fdiv(x::X, y::X) where {X <: FixedPoint} =
clamp(floattype(X)(x.i) / floattype(X)(y.i), X)

function saturating_div(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
z = round(floattype(X)(x.i) / floattype(X)(y.i), r)
isnan(z) && return zero(T)
if T <: Unsigned
isfinite(z) ? _unsafe_trunc(T, z) : typemax(T)
else
_unsafe_trunc(T, clamp(z, typemin(T), typemax(T)))
end
end
saturating_fld(x::X, y::X) where {X <: FixedPoint} = saturating_div(x, y, RoundDown)
saturating_cld(x::X, y::X) where {X <: FixedPoint} = saturating_div(x, y, RoundUp)
function saturating_rem(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
T <: Unsigned && r isa RoundingMode{:Up} && return zero(X)
X(x.i - saturating_div(x, y, r) * y.i, 0)
end
saturating_mod(x::X, y::X) where {X <: FixedPoint} = saturating_rem(x, y, RoundDown)

# checked arithmetic
checked_neg(x::X) where {X <: FixedPoint} = checked_sub(zero(X), x)
function checked_abs(x::X) where {X <: FixedPoint}
abs(x.i) >= 0 || throw_overflowerror_abs(x)
X(abs(x.i), 0)
end
function checked_add(x::X, y::X) where {X <: FixedPoint}
r, f = Base.Checked.add_with_overflow(x.i, y.i)
z = X(r, 0) # store first
f && throw_overflowerror(:+, x, y)
z
end
function checked_sub(x::X, y::X) where {X <: FixedPoint}
r, f = Base.Checked.sub_with_overflow(x.i, y.i)
z = X(r, 0) # store first
f && throw_overflowerror(:-, x, y)
z
end
function checked_mul(x::X, y::X) where {X <: FixedPoint}
z = float(x) * float(y)
typemin(X) - eps(X)/2 <= z < typemax(X) + eps(X)/2 || throw_overflowerror(:*, x, y)
z % X
end
function checked_fdiv(x::X, y::X) where {T, X <: FixedPoint{T}}
y === zero(X) && throw(DivideError())
z = floattype(X)(x.i) / floattype(X)(y.i)
if T <: Unsigned
z < typemax(X) + eps(X)/2 || throw_overflowerror(:/, x, y)
else
typemin(X) - eps(X)/2 <= z < typemax(X) + eps(X)/2 || throw_overflowerror(:/, x, y)
end
z % X
end
function checked_div(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
y === zero(X) && throw(DivideError())
z = round(floattype(X)(x.i) / floattype(X)(y.i), r)
if T <: Signed
z <= typemax(T) || throw_overflowerror_div(r, x, y)
end
_unsafe_trunc(T, z)
end
checked_fld(x::X, y::X) where {X <: FixedPoint} = checked_div(x, y, RoundDown)
checked_cld(x::X, y::X) where {X <: FixedPoint} = checked_div(x, y, RoundUp)
function checked_rem(x::X, y::X, r::RoundingMode = RoundToZero) where {T, X <: FixedPoint{T}}
y === zero(X) && throw(DivideError())
fx, fy = floattype(X)(x.i), floattype(X)(y.i)
z = fx - round(fx / fy, r) * fy
if T <: Unsigned && r isa RoundingMode{:Up}
z >= zero(z) || throw_overflowerror_rem(r, x, y)
end
X(_unsafe_trunc(T, z), 0)
end
checked_mod(x::X, y::X) where {X <: FixedPoint} = checked_rem(x, y, RoundDown)

# default arithmetic
const DEFAULT_ARITHMETIC = :wrapping

for (op, name) in ((:-, :neg), (:abs, :abs))
f = Symbol(DEFAULT_ARITHMETIC, :_, name)
@eval begin
$op(x::X) where {X <: FixedPoint} = $f(x)
end
end
for (op, name) in ((:+, :add), (:-, :sub), (:*, :mul))
f = Symbol(DEFAULT_ARITHMETIC, :_, name)
@eval begin
$op(x::X, y::X) where {X <: FixedPoint} = $f(x, y)
end
end
# force checked arithmetic
/(x::X, y::X) where {X <: FixedPoint} = checked_fdiv(x, y)
div(x::X, y::X, r::RoundingMode = RoundToZero) where {X <: FixedPoint} = checked_div(x, y, r)
fld(x::X, y::X) where {X <: FixedPoint} = checked_div(x, y, RoundDown)
cld(x::X, y::X) where {X <: FixedPoint} = checked_div(x, y, RoundUp)
rem(x::X, y::X) where {X <: FixedPoint} = checked_rem(x, y, RoundToZero)
rem(x::X, y::X, ::RoundingMode{:Down}) where {X <: FixedPoint} = checked_rem(x, y, RoundDown)
rem(x::X, y::X, ::RoundingMode{:Up}) where {X <: FixedPoint} = checked_rem(x, y, RoundUp)
mod(x::X, y::X) where {X <: FixedPoint} = checked_rem(x, y, RoundDown)

function minmax(x::X, y::X) where {X <: FixedPoint}
a, b = minmax(reinterpret(x), reinterpret(y))
X(a,0), X(b,0)
Expand Down Expand Up @@ -518,6 +345,34 @@ include("normed.jl")
include("deprecations.jl")
const UF = (N0f8, N6f10, N4f12, N2f14, N0f16)

include("arithmetic/arithmetic.jl")
using .FixedPointArithmetic
# re-export
for name in names(FixedPointArithmetic.Wrapping)
startswith(string(name), "wrapping") || continue
@eval export $name
end
for name in names(FixedPointArithmetic.Saturating)
startswith(string(name), "saturating") || continue
@eval export $name
end
for name in names(FixedPointArithmetic.Checked)
startswith(string(name), "checked") || continue
@eval export $name
end

# construction using the (approximate) intended value, i.e., N0f8
*(x::Real, ::Type{X}) where {X <: FixedPoint} = _convert(X, x)
Wrapping.wrapping_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = x % X
Saturating.saturating_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = clamp(x, X)
Checked.checked_mul(x::Real, ::Type{X}) where {X <: FixedPoint} = _convert(X, x)

# type modulus
rem(x::Real, ::Type{X}) where {X <: FixedPoint} = _rem(x, X)
Wrapping.wrapping_rem(x::Real, ::Type{X}) where {X<:FixedPoint} = _rem(x, X)
Saturating.saturating_rem(x::Real, ::Type{X}) where {X<:FixedPoint} = _rem(x, X)
Checked.checked_rem(x::Real, ::Type{X}) where {X<:FixedPoint} = _rem(x, X)

# Promotions
promote_rule(::Type{X}, ::Type{Tf}) where {X <: FixedPoint, Tf <: AbstractFloat} =
promote_type(floattype(X), Tf)
Expand Down Expand Up @@ -585,29 +440,6 @@ scaledual(::Type{Tdual}, x::AbstractArray{T}) where {Tdual, T <: FixedPoint} =
throw(ArgumentError(String(take!(io))))
end

@noinline function throw_overflowerror(op::Symbol, @nospecialize(x), @nospecialize(y))
io = IOBuffer()
print(io, x, ' ', op, ' ', y, " overflowed for type ")
showtype(io, typeof(x))
throw(OverflowError(String(take!(io))))
end
@noinline function throw_overflowerror_abs(@nospecialize(x))
io = IOBuffer()
print(io, "abs(", x, ") overflowed for type ")
showtype(io, typeof(x))
throw(OverflowError(String(take!(io))))
end
@noinline function throw_overflowerror_div(r::RoundingMode, @nospecialize(x), @nospecialize(y))
io = IOBuffer()
op = r === RoundUp ? "cld(" : r === RoundDown ? "fld(" : "div("
print(io, op, x, ", ", y, ") overflowed for type ", rawtype(x))
throw(OverflowError(String(take!(io))))
end
@noinline function throw_overflowerror_rem(r::RoundingMode, @nospecialize(x), @nospecialize(y))
io = IOBuffer()
print(io, "rem(", x, ", ", y, ", ", r, ") overflowed for type ", typeof(x))
throw(OverflowError(String(take!(io))))
end

function Random.rand(r::AbstractRNG, ::SamplerType{X}) where X <: FixedPoint
X(rand(r, rawtype(X)), 0)
Expand Down
Loading
Loading