Skip to content

Commit d214d57

Browse files
jw3126Sacha0
authored andcommitted
use TypeArithmetic trait in cumsum! implementation (#21666)
* use TypeArithmetic trait in cumsum! implementation
1 parent 25f241c commit d214d57

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

base/multidimensional.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -574,13 +574,18 @@ function accumulate_pairwise(op, v::AbstractVector{T}) where T
574574
end
575575

576576
function cumsum!(out, v::AbstractVector, axis::Integer=1)
577-
# for types prone to numerical stability issues, we want
578-
# accumulate_pairwise.
579-
axis == 1 ? accumulate_pairwise!(+, out, v) : copy!(out,v)
577+
# we dispatch on the possibility of numerical stability issues
578+
_cumsum!(out, v, axis, TypeArithmetic(eltype(out)))
580579
end
581580

582-
function cumsum!(out, v::AbstractVector{<:Integer}, axis::Integer=1)
583-
axis == 1 ? accumulate!(+, out, v) : copy!(out,v)
581+
function _cumsum!(out, v, axis, ::ArithmeticRounds)
582+
axis == 1 ? accumulate_pairwise!(+, out, v) : copy!(out, v)
583+
end
584+
function _cumsum!(out, v, axis, ::ArithmeticUnknown)
585+
_cumsum!(out, v, axis, ArithmeticRounds())
586+
end
587+
function _cumsum!(out, v, axis, ::TypeArithmetic)
588+
axis == 1 ? accumulate!(+, out, v) : copy!(out, v)
584589
end
585590

586591
"""

test/arrayops.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,6 +2055,28 @@ end
20552055
@test accumulate(op, [10 20 30], 2) == [10 op(10, 20) op(op(10, 20), 30)] == [10 40 110]
20562056
end
20572057

2058+
struct F21666{T <: Base.TypeArithmetic}
2059+
x::Float32
2060+
end
2061+
2062+
@testset "Exactness of cumsum # 21666" begin
2063+
# test that cumsum uses more stable algorithm
2064+
# for types with unknown/rounding arithmetic
2065+
Base.TypeArithmetic(::Type{F21666{T}}) where {T} = T
2066+
Base.:+(x::F, y::F) where {F <: F21666} = F(x.x + y.x)
2067+
Base.convert(::Type{Float64}, x::F21666) = Float64(x.x)
2068+
# we make v pretty large, because stable algorithm may have a large base case
2069+
v = zeros(300); v[1] = 2; v[200:end] = eps(Float32)
2070+
2071+
f_rounds = Float64.(cumsum(F21666{Base.ArithmeticRounds}.(v)))
2072+
f_unknown = Float64.(cumsum(F21666{Base.ArithmeticUnknown}.(v)))
2073+
f_truth = cumsum(v)
2074+
f_inexact = Float64.(accumulate(+, Float32.(v)))
2075+
@test f_rounds == f_unknown
2076+
@test f_rounds != f_inexact
2077+
@test norm(f_truth - f_rounds) < norm(f_truth - f_inexact)
2078+
end
2079+
20582080
@testset "zeros and ones" begin
20592081
@test ones([1,2], Float64, (2,3)) == ones(2,3)
20602082
@test ones(2) == ones(Int, 2) == ones([2,3], Float32, 2) == [1,1]

0 commit comments

Comments
 (0)