diff --git a/src/OverflowContexts.jl b/src/OverflowContexts.jl index c7d10f8..10326d8 100644 --- a/src/OverflowContexts.jl +++ b/src/OverflowContexts.jl @@ -2,10 +2,12 @@ module OverflowContexts include("macros.jl") include("base_ext.jl") +include("base_ext_sat.jl") include("abstractarraymath_ext.jl") -export @default_checked, @default_unchecked, @checked, @unchecked, +export @default_checked, @default_unchecked, @default_saturating, @checked, @unchecked, @saturating, + checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs, unchecked_neg, unchecked_add, unchecked_sub, unchecked_mul, unchecked_negsub, unchecked_pow, unchecked_abs, - checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs + saturating_neg, saturating_add, saturating_sub, saturating_mul, saturating_pow, saturating_negsub, saturating_abs end # module diff --git a/src/base_ext.jl b/src/base_ext.jl index e478b7d..1552f23 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -1,11 +1,13 @@ -using Base: promote, afoldl, @_inline_meta +using Base: BitInteger, promote, afoldl, @_inline_meta import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs +using Base.Checked: mul_with_overflow if VERSION ≥ v"1.11-alpha" + import Base: power_by_squaring import Base.Checked: checked_pow else - using Base: BitInteger, throw_domerr_powbysq, to_power_type - using Base.Checked: mul_with_overflow, throw_overflowerr_binaryop + using Base: throw_domerr_powbysq, to_power_type + using Base.Checked: throw_overflowerr_binaryop end # The Base methods have unchecked semantics, so just pass through @@ -22,6 +24,10 @@ checked_add(a, b, c, xs...) = @checked (@_inline_meta; afoldl(+, (+)((+)(a, b), checked_sub(a, b, c, xs...) = @checked (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...)) checked_mul(a, b, c, xs...) = @checked (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...)) +saturating_add(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(+, (+)((+)(a, b), c), xs...)) +saturating_sub(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...)) +saturating_mul(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...)) + # promote unmatched number types to same type checked_add(x::Number, y::Number) = checked_add(promote(x, y)...) @@ -29,6 +35,11 @@ checked_sub(x::Number, y::Number) = checked_sub(promote(x, y)...) checked_mul(x::Number, y::Number) = checked_mul(promote(x, y)...) checked_pow(x::Number, y::Number) = checked_pow(promote(x, y)...) +saturating_add(x::Number, y::Number) = saturating_add(promote(x, y)...) +saturating_sub(x::Number, y::Number) = saturating_sub(promote(x, y)...) +saturating_mul(x::Number, y::Number) = saturating_mul(promote(x, y)...) +saturating_pow(x::Number, y::Number) = saturating_pow(promote(x, y)...) + # fallback to `unchecked_` for `Number` types that don't have more specific `checked_` methods checked_neg(x::T) where T <: Number = unchecked_neg(x) @@ -38,6 +49,13 @@ checked_mul(x::T, y::T) where T <: Number = unchecked_mul(x, y) checked_pow(x::T, y::T) where T <: Number = unchecked_pow(x, y) checked_abs(x::T) where T <: Number = unchecked_abs(x) +saturating_neg(x::T) where T <: Number = unchecked_neg(x) +saturating_add(x::T, y::T) where T <: Number = unchecked_add(x, y) +saturating_sub(x::T, y::T) where T <: Number = unchecked_sub(x, y) +saturating_mul(x::T, y::T) where T <: Number = unchecked_mul(x, y) +saturating_pow(x::T, y::T) where T <: Number = unchecked_pow(x, y) +saturating_abs(x::T) where T <: Number = unchecked_abs(x) + # fallback to `unchecked_` for non-`Number` types checked_neg(x) = unchecked_neg(x) @@ -51,50 +69,38 @@ checked_abs(x) = unchecked_abs(x) if VERSION < v"1.11" # Base.Checked only gained checked powers in 1.11 -function checked_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} - @_inline_meta - z, b = pow_with_overflow(x, y) - b && throw_overflowerr_binaryop(:^, x, y) - z -end +checked_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} = + power_by_squaring(x_, p; mul = checked_mul) -function pow_with_overflow(x_, p::Integer) +# Base.@assume_effects :terminates_locally # present in Julia 1.11 code, but only supported from 1.8 on +function power_by_squaring(x_, p::Integer; mul=*) x = to_power_type(x_) if p == 1 - return (copy(x), false) + return copy(x) elseif p == 0 - return (one(x), false) + return one(x) elseif p == 2 - return mul_with_overflow(x, x) + return mul(x, x) elseif p < 0 - isone(x) && return (copy(x), false) - isone(-x) && return (iseven(p) ? one(x) : copy(x), false) + isone(x) && return copy(x) + isone(-x) && return iseven(p) ? one(x) : copy(x) throw_domerr_powbysq(x, p) end t = trailing_zeros(p) + 1 p >>= t - b = false while (t -= 1) > 0 - x, b1 = mul_with_overflow(x, x) - b |= b1 + x = mul(x, x) end y = x while p > 0 t = trailing_zeros(p) + 1 p >>= t while (t -= 1) >= 0 - x, b1 = mul_with_overflow(x, x) - b |= b1 + x = mul(x, x) end - y, b1 = mul_with_overflow(y, x) - b |= b1 + y = mul(y, x) end - return y, b -end -pow_with_overflow(x::Bool, p::Unsigned) = ((p==0) | x, false) -function pow_with_overflow(x::Bool, p::Integer) - p < 0 && !x && throw_domerr_powbysq(x, p) - return (p==0) | x, false + return y end end diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl new file mode 100644 index 0000000..102bcd2 --- /dev/null +++ b/src/base_ext_sat.jl @@ -0,0 +1,70 @@ +import Base: BitInteger +import Base.Checked: mul_with_overflow + +if VERSION ≤ v"1.11-alpha" + import Base: power_by_squaring +end + +# saturating implementations +const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} + +saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) + +if VERSION ≥ v"1.5" + using Base: llvmcall + + # These intrinsics were added in LLVM 8, which was first supported with Julia 1.5 + @generated function saturating_add(x::T, y::T) where T <: BitInteger + llvm_su = T <: Signed ? "s" : "u" + llvm_t = "i" * string(8sizeof(T)) + llvm_intrinsic = "llvm.$(llvm_su)add.sat.$llvm_t" + :(ccall($llvm_intrinsic, llvmcall, $T, ($T, $T), x, y)) + end + + @generated function saturating_sub(x::T, y::T) where T <: BitInteger + llvm_su = T <: Signed ? "s" : "u" + llvm_t = "i" * string(8sizeof(T)) + llvm_intrinsic = "llvm.$(llvm_su)sub.sat.$llvm_t" + :(ccall($llvm_intrinsic, llvmcall, $T, ($T, $T), x, y)) + end + +else + import Base.Checked: add_with_overflow, sub_with_overflow + + function saturating_add(x::T, y::T) where T <: BitInteger + result, overflow_flag = add_with_overflow(x, y) + if overflow_flag + return sign(x) > 0 ? + typemax(T) : + typemin(T) + end + return result + end + + function saturating_sub(x::T, y::T) where T <: BitInteger + result, overflow_flag = sub_with_overflow(x, y) + if overflow_flag + return y > x ? + typemin(T) : + typemax(T) + end + return result + end +end + +function saturating_mul(x::T, y::T) where T <: BitInteger + result, overflow_flag = mul_with_overflow(x, y) + return overflow_flag ? + (sign(x) == sign(y) ? + typemax(T) : + typemin(T)) : + result +end + +saturating_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} = + power_by_squaring(x_, p; mul = saturating_mul) + +function saturating_abs(x::T) where T <: SignedBitInteger + result = flipsign(x, x) + return result < 0 ? typemax(T) : result +end diff --git a/src/macros.jl b/src/macros.jl index 45af0a5..3820edf 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -50,6 +50,41 @@ macro default_unchecked() end end +""" + @default_saturating + +Redirect default integer math to saturating operators for the current module. Only works at top-level. +""" +macro default_saturating() + quote + if !isdefined(@__MODULE__, :__OverflowContextDefaultSet) + any(Base.isbindingresolved.(Ref(@__MODULE__), op_method_symbols)) && + error("A default context may only be set before any reference to the affected methods (+, -, *, ^, abs) in the target module.") + else + @warn "A previous default was set for this module. Previously defined methods in this module will be recompiled with this new default." + end + (@__MODULE__).eval(:(-(x) = OverflowContexts.saturating_neg(x))) + (@__MODULE__).eval(:(+(x...) = OverflowContexts.saturating_add(x...))) + (@__MODULE__).eval(:(-(x...) = OverflowContexts.saturating_sub(x...))) + (@__MODULE__).eval(:(*(x...) = OverflowContexts.saturating_mul(x...))) + (@__MODULE__).eval(:(^(x...) = OverflowContexts.saturating_pow(x...))) + (@__MODULE__).eval(:(abs(x) = OverflowContexts.saturating_abs(x))) + (@__MODULE__).eval(:(__OverflowContextDefaultSet = true)) + nothing + end +end + +""" + @checked expr + +Perform all integer operations in `expr` using overflow-checked arithmetic. +""" +macro checked(expr) + isa(expr, Expr) || return expr + expr = copy(expr) + return esc(replace_op!(expr, op_checked)) +end + """ @unchecked expr @@ -62,14 +97,14 @@ macro unchecked(expr) end """ - @checked expr + @saturating expr -Perform all integer operations in `expr` using overflow-checked arithmetic. +Perform all integer operations in `expr` using saturating arithmetic. """ -macro checked(expr) +macro saturating(expr) isa(expr, Expr) || return expr expr = copy(expr) - return esc(replace_op!(expr, op_checked)) + return esc(replace_op!(expr, op_saturating)) end const op_checked = Dict( @@ -92,6 +127,16 @@ const op_unchecked = Dict( :abs => :(unchecked_abs) ) +const op_saturating = Dict( + Symbol("unary-") => :(saturating_neg), + Symbol("ambig-") => :(saturating_negsub), + :+ => :(saturating_add), + :- => :(saturating_sub), + :* => :(saturating_mul), + :^ => :(saturating_pow), + :abs => :(saturating_abs) +) + const broadcast_op_map = Dict( :.+ => :+, :.- => :-, @@ -115,6 +160,8 @@ unchecked_negsub(x) = unchecked_neg(x) unchecked_negsub(x, y) = unchecked_sub(x, y) checked_negsub(x) = checked_neg(x) checked_negsub(x, y) = checked_sub(x, y) +saturating_negsub(x) = saturating_neg(x) +saturating_negsub(x, y) = saturating_sub(x, y) # copied from CheckedArithmetic.jl and modified it function replace_op!(expr::Expr, op_map::Dict) @@ -182,7 +229,7 @@ function replace_op!(expr::Expr, op_map::Dict) elseif isexpr(expr, :.) # broadcast function op = expr.args[1] expr.args[1] = get(op_map, op, op) - elseif !isexpr(expr, :macrocall) || expr.args[1] ∉ (Symbol("@checked"), Symbol("@unchecked")) + elseif !isexpr(expr, :macrocall) || expr.args[1] ∉ (Symbol("@checked"), Symbol("@unchecked"), Symbol("@saturating")) for a in expr.args if isa(a, Expr) replace_op!(a, op_map) diff --git a/test/runtests.jl b/test/runtests.jl index af57507..e818c60 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,27 @@ end @test @unchecked(abs(typemin(Int))) == typemin(Int) end +@testset "saturating expressions" begin + @test @saturating(-typemin(Int)) == typemax(Int) + @test @saturating(-UInt(1)) == typemin(UInt) + + @test @saturating(typemax(Int) + 1) == typemax(Int) + @test @saturating(typemax(UInt) + 1) == typemax(UInt) + + @test @saturating(typemin(Int) - 1) == typemin(Int) + @test @saturating(typemin(UInt) - 1) == typemin(UInt) + + @test @saturating(typemax(Int) * 2) == typemax(Int) + @test @saturating(typemin(Int) * 2) == typemin(Int) + @test @saturating(typemax(UInt) * 2) == typemax(UInt) + + @test @saturating(typemax(Int) ^ 2) == typemax(Int) + @test @saturating(typemin(Int) ^ 2) == typemax(Int) + @test @saturating(typemax(UInt) ^ 2) == typemax(UInt) + + @test @saturating(abs(typemin(Int))) == typemax(Int) +end + @testset "juxtaposed multiplication works" begin @test_throws OverflowError @checked 2typemax(Int) @test_throws OverflowError @checked 2typemin(Int) @@ -52,6 +73,9 @@ end @test @unchecked(2typemax(Int)) == -2 @test @unchecked(2typemin(Int)) == 0 @test @unchecked(2typemax(UInt)) == typemax(UInt) - 1 + @test @saturating(2typemax(Int)) == typemax(Int) + @test @saturating(2typemin(Int)) == typemin(Int) + @test @saturating(2typemax(UInt)) == typemax(UInt) end @testset "exhaustive checks over 16 bit math" begin @@ -59,39 +83,39 @@ end if T <: Signed @testset "$T negation" begin for i ∈ typemin(T) + T(1):typemax(T) - @test @checked(-i) == @unchecked(-i) == -i + @test @checked(-i) == @unchecked(-i) == @saturating(-i) == -i end end end @testset "$T addition" begin for i ∈ typemin(T):typemax(T) - T(1) - @test @checked(i + T(1)) == @unchecked(i + T(1)) == i + T(1) + @test @checked(i + T(1)) == @unchecked(i + T(1)) == @saturating(i + T(1)) == i + T(1) end end @testset "$T subtraction" begin for i ∈ typemin(T) + T(1):typemax(T) - @test @checked(i - T(1)) == @unchecked(i - T(1)) == i - T(1) + @test @checked(i - T(1)) == @unchecked(i - T(1)) == @saturating(i - T(1)) == i - T(1) end end @testset "$T multiplication" begin for i ∈ typemin(T) ÷ T(2):typemax(T) ÷ T(2) - @test @checked(2i) == @unchecked(2i) == 2i + @test @checked(2i) == @unchecked(2i) == @saturating(2i) == 2i end end @testset "$T power" begin if T <: Signed for i ∈ ceil(T, -√(typemax(T))):floor(T, √(typemax(T))) - @test @checked(i ^ 2) == @unchecked(i ^ 2) == i ^ 2 + @test @checked(i ^ 2) == @unchecked(i ^ 2) == @saturating(i ^ 2) == i ^ 2 end else for i ∈ T(0):floor(T, √(typemax(T))) - @test @checked(i ^ 2) == @unchecked(i ^ 2) == i ^ 2 + @test @checked(i ^ 2) == @unchecked(i ^ 2) == @saturating(i ^ 2) == i ^ 2 end end end @testset "$T abs" begin for i ∈ typemin(T) + T(1):typemax(T) - @test @checked(abs(i)) == @unchecked(abs(i)) == abs(i) + @test @checked(abs(i)) == @unchecked(abs(i)) == @saturating(abs(i)) == abs(i) end end end @@ -100,38 +124,55 @@ end @testset "lowest-level macro takes priority" begin @checked begin @test @unchecked(typemax(Int) + 1) == typemin(Int) + @test @saturating(typemax(Int) + 1) == typemax(Int) end @unchecked begin @test_throws OverflowError @checked typemax(Int) + 1 + @test @saturating(typemax(Int) + 1) == typemax(Int) + end + @saturating begin + @test @unchecked(typemax(Int) + 1) == typemin(Int) + @test_throws OverflowError @checked typemax(Int) + 1 end end @testset "literals passthrough" begin @test @checked(-1) == -1 @test @unchecked(-1) == -1 + @test @saturating(-1) == -1 end @testset "non-integer math still works" begin @test @checked(-1.0) == -1 @test @unchecked(-1.0) == -1 + @test @saturating(-1.0) == -1 @test @checked(1.0 + 3.0) == 4.0 @test @unchecked(1.0 + 3.0) == 4.0 + @test @saturating(1.0 + 3.0) == 4.0 @test @checked(1 + 3.0) == 4.0 @test @unchecked(1 + 3.0) == 4.0 + @test @saturating(1 + 3.0) == 4.0 @test @checked(1.0 - 3.0) == -2.0 @test @unchecked(1.0 - 3.0) == -2.0 + @test @saturating(1.0 - 3.0) == -2.0 @test @checked(1 - 3.0) == -2.0 @test @unchecked(1 - 3.0) == -2.0 + @test @saturating(1 - 3.0) == -2.0 @test @checked(1.0 * 3.0) == 3.0 @test @unchecked(1.0 * 3.0) == 3.0 + @test @saturating(1.0 * 3.0) == 3.0 @test @checked(1 * 3.0) == 3.0 @test @unchecked(1 * 3.0) == 3.0 + @test @saturating(1 * 3.0) == 3.0 @test @checked(1.0 ^ 3.0) == 1.0 @test @unchecked(1.0 ^ 3.0) == 1.0 + @test @saturating(1.0 ^ 3.0) == 1.0 @test @checked(1 ^ 3.0) == 1.0 @test @unchecked(1 ^ 3.0) == 1.0 + @test @saturating(1 ^ 3.0) == 1.0 @test @checked(abs(-1.0)) == 1.0 @test @unchecked(abs(-1.0)) == 1.0 + @test @saturating(abs(-1.0)) == 1.0 end @testset "symbol replacement" begin @@ -141,11 +182,17 @@ end expr = @macroexpand @unchecked foldl(+, []) @test expr.args[2] == :unchecked_add + expr = @macroexpand @saturating foldl(+, []) + @test expr.args[2] == :saturating_add + expr = @macroexpand @checked foldl(-, []) @test expr.args[2] == :checked_negsub expr = @macroexpand @unchecked foldl(-, []) @test expr.args[2] == :unchecked_negsub + + expr = @macroexpand @saturating foldl(-, []) + @test expr.args[2] == :saturating_negsub expr = @macroexpand @checked foldl(*, []) @test expr.args[2] == :checked_mul @@ -153,42 +200,63 @@ end expr = @macroexpand @unchecked foldl(*, []) @test expr.args[2] == :unchecked_mul + expr = @macroexpand @saturating foldl(*, []) + @test expr.args[2] == :saturating_mul + expr = @macroexpand @checked foldl(^, []) @test expr.args[2] == :checked_pow expr = @macroexpand @unchecked foldl(^, []) @test expr.args[2] == :unchecked_pow + expr = @macroexpand @saturating foldl(^, []) + @test expr.args[2] == :saturating_pow + expr = @macroexpand @checked foldl(:abs, []) @test expr.args[2] == :checked_abs expr = @macroexpand @unchecked foldl(:abs, []) @test expr.args[2] == :unchecked_abs + + expr = @macroexpand @saturating foldl(:abs, []) + @test expr.args[2] == :saturating_abs end @testset "negsub helper methods dispatch correctly" begin + @test checked_negsub(1) == -1 + @test checked_negsub(1, 2) == 1 - 2 @test unchecked_negsub(1) == -1 @test unchecked_negsub(1, 2) == 1 - 2 + @test saturating_negsub(1) == -1 + @test saturating_negsub(1, 2) == 1 - 2 end @testset "assignment operators" begin a = typemax(Int) @test_throws OverflowError @checked a += 1 + @saturating a += 1 + @test a == typemax(Int) @unchecked a += 1 @test a == typemin(Int) a = typemin(Int) @test_throws OverflowError @checked a -= 1 + @saturating a -= 1 + @test a == typemin(Int) @unchecked a -= 1 @test a == typemax(Int) a = typemax(Int) @test_throws OverflowError @checked a *= 2 + @saturating a *= 2 + @test a == typemax(Int) @unchecked a *= 2 @test a == -2 a = typemax(Int) @test_throws OverflowError @checked a ^= 2 + @saturating a ^= 2 + @test a == typemax(Int) @unchecked a ^= 2 @test a == 1 end @@ -198,11 +266,19 @@ end checkminus(x, y) = x - y end +@saturating begin + satplus(x, y) = x + y + satminus(x, y) = x - y +end + @testset "rewrite inside block body" begin @test checkplus(0x10, 0x20) === 0x30 @test_throws OverflowError checkplus(0xf0, 0x20) @test checkminus(0x30, 0x20) === 0x10 @test_throws OverflowError checkminus(0x20, 0x30) + + @test satplus(0xf0, 0x20) === 0xff + @test satminus(0x20, 0x30) === 0x00 end module CheckedModule @@ -215,6 +291,12 @@ module CheckedModule @default_unchecked testfunc() = @test typemax(Int) + 1 == typemin(Int) end + + module NestedSaturatingModule + using OverflowContexts, Test + @default_saturating + testfunc() = @test typemax(Int) + 1 == typemax(Int) + end end module UncheckedModule @@ -227,19 +309,49 @@ module UncheckedModule @default_checked testfunc() = @test_throws OverflowError typemax(Int) + 1 end + + module NestedSaturatingModule + using OverflowContexts, Test + @default_saturating + testfunc() = @test typemax(Int) + 1 == typemax(Int) + end +end + +module SaturatingModule + using OverflowContexts, Test + @default_saturating + testfunc() = @test typemax(Int) + 1 == typemax(Int) + + module NestedCheckedModule + using OverflowContexts, Test + @default_checked + testfunc() = @test_throws OverflowError typemax(Int) + 1 + end + + module NestedUncheckedModule + using OverflowContexts, Test + @default_unchecked + testfunc() = @test typemax(Int) + 1 == typemin(Int) + end end @testset "module-specific contexts" begin CheckedModule.testfunc() CheckedModule.NestedUncheckedModule.testfunc() + CheckedModule.NestedSaturatingModule.testfunc() UncheckedModule.testfunc() UncheckedModule.NestedCheckedModule.testfunc() + UncheckedModule.NestedSaturatingModule.testfunc() + SaturatingModule.testfunc() + SaturatingModule.NestedCheckedModule.testfunc() + SaturatingModule.NestedUncheckedModule.testfunc() end @testset "default methods error if Base symbol already resolved" begin x = 1 + 1 @test_throws ErrorException @default_checked @test_throws ErrorException @default_unchecked + @test_throws ErrorException @default_saturating (@__MODULE__).eval(:( module BadCheckedModule @@ -254,6 +366,13 @@ end x = 1 + 1 @test_throws ErrorException @default_unchecked end)) + + (@__MODULE__).eval(:( + module BadSaturatingModule + using OverflowContexts, Test + x = 1 + 1 + @test_throws ErrorException @default_saturating + end)) end @testset "default methods warn if default is changed" begin @@ -273,8 +392,9 @@ end end @testset "ensure pow methods don't promote on the power" begin - @test typeof(@unchecked 3 ^ UInt(4)) == Int @test typeof(@checked 3 ^ UInt(4)) == Int + @test typeof(@unchecked 3 ^ UInt(4)) == Int + @test typeof(@saturating 3 ^ UInt(4)) == Int end @testset "multiargument methods" begin