Skip to content

Commit ea5c9cb

Browse files
authored
Make normalize work for Numbers (#49342)
1 parent 99c0dad commit ea5c9cb

File tree

3 files changed

+52
-29
lines changed

3 files changed

+52
-29
lines changed

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,21 +1804,18 @@ function normalize!(a::AbstractArray, p::Real=2)
18041804
__normalize!(a, nrm)
18051805
end
18061806

1807-
@inline function __normalize!(a::AbstractArray, nrm::Real)
1807+
@inline function __normalize!(a::AbstractArray, nrm)
18081808
# The largest positive floating point number whose inverse is less than infinity
18091809
δ = inv(prevfloat(typemax(nrm)))
1810-
18111810
if nrm δ # Safe to multiply with inverse
18121811
invnrm = inv(nrm)
18131812
rmul!(a, invnrm)
1814-
18151813
else # scale elements to avoid overflow
18161814
εδ = eps(one(nrm))/δ
18171815
rmul!(a, εδ)
18181816
rmul!(a, inv(nrm*εδ))
18191817
end
1820-
1821-
a
1818+
return a
18221819
end
18231820

18241821
"""

stdlib/LinearAlgebra/test/generic.jl

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ using .Main.Quaternions
1212
isdefined(Main, :OffsetArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl"))
1313
using .Main.OffsetArrays
1414

15+
isdefined(Main, :DualNumbers) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "DualNumbers.jl"))
16+
using .Main.DualNumbers
1517

1618
Random.seed!(123)
1719

@@ -78,30 +80,7 @@ n = 5 # should be odd
7880
end
7981

8082
@testset "det with nonstandard Number type" begin
81-
struct MyDual{T<:Real} <: Real
82-
val::T
83-
eps::T
84-
end
85-
Base.:+(x::MyDual, y::MyDual) = MyDual(x.val + y.val, x.eps + y.eps)
86-
Base.:*(x::MyDual, y::MyDual) = MyDual(x.val * y.val, x.eps * y.val + y.eps * x.val)
87-
Base.:/(x::MyDual, y::MyDual) = x.val / y.val
88-
Base.:(==)(x::MyDual, y::MyDual) = x.val == y.val && x.eps == y.eps
89-
Base.zero(::MyDual{T}) where {T} = MyDual(zero(T), zero(T))
90-
Base.zero(::Type{MyDual{T}}) where {T} = MyDual(zero(T), zero(T))
91-
Base.one(::MyDual{T}) where {T} = MyDual(one(T), zero(T))
92-
Base.one(::Type{MyDual{T}}) where {T} = MyDual(one(T), zero(T))
93-
# the following line is required for BigFloat, IDK why it doesn't work via
94-
# promote_rule like for all other types
95-
Base.promote_type(::Type{MyDual{BigFloat}}, ::Type{BigFloat}) = MyDual{BigFloat}
96-
Base.promote_rule(::Type{MyDual{T}}, ::Type{S}) where {T,S<:Real} =
97-
MyDual{promote_type(T, S)}
98-
Base.promote_rule(::Type{MyDual{T}}, ::Type{MyDual{S}}) where {T,S} =
99-
MyDual{promote_type(T, S)}
100-
Base.convert(::Type{MyDual{T}}, x::MyDual) where {T} =
101-
MyDual(convert(T, x.val), convert(T, x.eps))
102-
if elty <: Real
103-
@test det(triu(MyDual.(A, zero(A)))) isa MyDual
104-
end
83+
elty <: Real && @test det(Dual.(triu(A), zero(A))) isa Dual
10584
end
10685
end
10786

@@ -390,6 +369,7 @@ end
390369
[1.0 2.0 3.0; 4.0 5.0 6.0], # 2-dim
391370
rand(1,2,3), # higher dims
392371
rand(1,2,3,4),
372+
Dual.(randn(2,3), randn(2,3)),
393373
OffsetArray([-1,0], (-2,)) # no index 1
394374
)
395375
@test normalize(arr) == normalize!(copy(arr))

test/testhelpers/DualNumbers.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
module DualNumbers
4+
5+
export Dual
6+
7+
# Dual numbers type with minimal interface
8+
# example of a (real) number type that subtypes Number, but not Real.
9+
# Can be used to test generic linear algebra functions.
10+
11+
struct Dual{T<:Real} <: Number
12+
val::T
13+
eps::T
14+
end
15+
Base.:+(x::Dual, y::Dual) = Dual(x.val + y.val, x.eps + y.eps)
16+
Base.:-(x::Dual, y::Dual) = Dual(x.val - y.val, x.eps - y.eps)
17+
Base.:*(x::Dual, y::Dual) = Dual(x.val * y.val, x.eps * y.val + y.eps * x.val)
18+
Base.:*(x::Number, y::Dual) = Dual(x*y.val, x*y.eps)
19+
Base.:*(x::Dual, y::Number) = Dual(x.val*y, x.eps*y)
20+
Base.:/(x::Dual, y::Dual) = Dual(x.val / y.val, (x.eps*y.val - x.val*y.eps)/(y.val*y.val))
21+
22+
Base.:(==)(x::Dual, y::Dual) = x.val == y.val && x.eps == y.eps
23+
24+
Base.promote_rule(::Type{Dual{T}}, ::Type{T}) where {T} = Dual{T}
25+
Base.promote_rule(::Type{Dual{T}}, ::Type{S}) where {T,S<:Real} = Dual{promote_type(T, S)}
26+
Base.promote_rule(::Type{Dual{T}}, ::Type{Dual{S}}) where {T,S} = Dual{promote_type(T, S)}
27+
28+
Base.convert(::Type{Dual{T}}, x::Dual{T}) where {T} = x
29+
Base.convert(::Type{Dual{T}}, x::Dual) where {T} = Dual(convert(T, x.val), convert(T, x.eps))
30+
Base.convert(::Type{Dual{T}}, x::Real) where {T} = Dual(convert(T, x), zero(T))
31+
32+
Base.float(x::Dual) = Dual(float(x.val), float(x.eps))
33+
# the following two methods are needed for normalize (to check for potential overflow)
34+
Base.typemax(x::Dual) = Dual(typemax(x.val), zero(x.eps))
35+
Base.prevfloat(x::Dual{<:AbstractFloat}) = prevfloat(x.val)
36+
37+
Base.abs2(x::Dual) = x*x
38+
Base.abs(x::Dual) = sqrt(abs2(x))
39+
Base.sqrt(x::Dual) = Dual(sqrt(x.val), x.eps/(2sqrt(x.val)))
40+
41+
Base.isless(x::Dual, y::Dual) = x.val < y.val
42+
Base.isless(x::Real, y::Dual) = x < y.val
43+
Base.isinf(x::Dual) = isinf(x.val) & isfinite(x.eps)
44+
Base.real(x::Dual) = x # since we curently only consider Dual{<:Real}
45+
46+
end # module

0 commit comments

Comments
 (0)