Skip to content

Commit 1786532

Browse files
authored
RFC: Introduce AdjointRotation to avoid subtyping AbsMat (#46233)
1 parent 639a4ff commit 1786532

File tree

2 files changed

+44
-35
lines changed

2 files changed

+44
-35
lines changed

stdlib/LinearAlgebra/src/givens.jl

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,29 @@
33
# givensAlgorithm functions are derived from LAPACK, see below
44

55
abstract type AbstractRotation{T} end
6+
struct AdjointRotation{T,S<:AbstractRotation{T}} <: AbstractRotation{T}
7+
R::S
8+
end
69

710
transpose(R::AbstractRotation) = error("transpose not implemented for $(typeof(R)). Consider using adjoint instead of transpose.")
811

912
function (*)(R::AbstractRotation{T}, A::AbstractVecOrMat{S}) where {T,S}
1013
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
1114
lmul!(convert(AbstractRotation{TS}, R), copy_similar(A, TS))
1215
end
13-
(*)(A::AbstractVector, adjR::Adjoint{<:Any,<:AbstractRotation}) = _absvecormat_mul_adjrot(A, adjR)
14-
(*)(A::AbstractMatrix, adjR::Adjoint{<:Any,<:AbstractRotation}) = _absvecormat_mul_adjrot(A, adjR)
15-
function _absvecormat_mul_adjrot(A::AbstractVecOrMat{T}, adjR::Adjoint{<:Any,<:AbstractRotation{S}}) where {T,S}
16-
R = adjR.parent
16+
function (*)(adjR::AdjointRotation{T}, A::AbstractVecOrMat{S}) where {T,S}
17+
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
18+
lmul!(convert(AbstractRotation{TS}, adjR.R)', copy_similar(A, TS))
19+
end
20+
(*)(A::AbstractVector, adjR::AdjointRotation) = _absvecormat_mul_adjrot(A, adjR)
21+
(*)(A::AbstractMatrix, adjR::AdjointRotation) = _absvecormat_mul_adjrot(A, adjR)
22+
function _absvecormat_mul_adjrot(A::AbstractVecOrMat{T}, adjR::AdjointRotation{S}) where {T,S}
1723
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
18-
rmul!(TS.(A), convert(AbstractRotation{TS}, R)')
24+
rmul!(copy_similar(A, TS), convert(AbstractRotation{TS}, adjR.R)')
1925
end
2026
function(*)(A::AbstractMatrix{T}, R::AbstractRotation{S}) where {T,S}
2127
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
22-
rmul!(TS.(A), convert(AbstractRotation{TS}, R))
28+
rmul!(copy_similar(A, TS), convert(AbstractRotation{TS}, R))
2329
end
2430

2531
"""
@@ -55,12 +61,11 @@ AbstractRotation{T}(G::Givens) where {T} = Givens{T}(G)
5561
AbstractRotation{T}(R::Rotation) where {T} = Rotation{T}(R)
5662

5763
adjoint(G::Givens) = Givens(G.i1, G.i2, G.c', -G.s)
58-
adjoint(R::Rotation) = Adjoint(R)
59-
function Base.copy(aG::Adjoint{<:Any,<:Givens})
60-
G = aG.parent
61-
return Givens(G.i1, G.i2, conj(G.c), -G.s)
62-
end
63-
Base.copy(aR::Adjoint{<:Any,Rotation{T}}) where {T} = Rotation{T}(reverse!([r' for r in aR.parent.rotations]))
64+
adjoint(R::AbstractRotation) = AdjointRotation(R)
65+
adjoint(adjR::AdjointRotation) = adjR.R
66+
67+
Base.copy(aR::AdjointRotation{T,Rotation{T}}) where {T} =
68+
Rotation{T}([r' for r in Iterators.reverse(aR.R.rotations)])
6469

6570
floatmin2(::Type{Float32}) = reinterpret(Float32, 0x26000000)
6671
floatmin2(::Type{Float64}) = reinterpret(Float64, 0x21a0000000000000)
@@ -291,7 +296,7 @@ function givens(f::T, g::T, i1::Integer, i2::Integer) where T
291296
c, s, r = givensAlgorithm(f, g)
292297
if i1 > i2
293298
s = -conj(s)
294-
i1,i2 = i2,i1
299+
i1, i2 = i2, i1
295300
end
296301
Givens(i1, i2, c, s), r
297302
end
@@ -329,9 +334,7 @@ B[i2] = 0
329334
330335
See also [`LinearAlgebra.Givens`](@ref).
331336
"""
332-
givens(x::AbstractVector, i1::Integer, i2::Integer) =
333-
givens(x[i1], x[i2], i1, i2)
334-
337+
givens(x::AbstractVector, i1::Integer, i2::Integer) = givens(x[i1], x[i2], i1, i2)
335338

336339
function getindex(G::Givens, i::Integer, j::Integer)
337340
if i == j
@@ -386,23 +389,24 @@ function lmul!(R::Rotation, A::AbstractMatrix)
386389
end
387390
return A
388391
end
389-
function rmul!(A::AbstractMatrix, adjR::Adjoint{<:Any,<:Rotation})
390-
R = adjR.parent
392+
function rmul!(A::AbstractMatrix, R::Rotation)
393+
@inbounds for i = 1:length(R.rotations)
394+
rmul!(A, R.rotations[i])
395+
end
396+
return A
397+
end
398+
function lmul!(adjR::AdjointRotation{<:Any,<:Rotation}, A::AbstractMatrix)
399+
R = adjR.R
400+
@inbounds for i = 1:length(R.rotations)
401+
lmul!(adjoint(R.rotations[i]), A)
402+
end
403+
return A
404+
end
405+
function rmul!(A::AbstractMatrix, adjR::AdjointRotation{<:Any,<:Rotation})
406+
R = adjR.R
391407
@inbounds for i = 1:length(R.rotations)
392408
rmul!(A, adjoint(R.rotations[i]))
393409
end
394410
return A
395411
end
396-
*(G1::Givens{T}, G2::Givens{T}) where {T} = Rotation(push!(push!(Givens{T}[], G2), G1))
397-
398-
# TODO: None of the following disambiguation methods are great. They should perhaps
399-
# instead be MethodErrors, or revised.
400-
#
401-
# disambiguation methods: *(Adj/Trans of AbsVec or AbsMat, Adj of AbstractRotation)
402-
*(A::Adjoint{<:Any,<:AbstractVector}, B::Adjoint{<:Any,<:AbstractRotation}) = copy(A) * B
403-
*(A::Adjoint{<:Any,<:AbstractMatrix}, B::Adjoint{<:Any,<:AbstractRotation}) = copy(A) * B
404-
*(A::Transpose{<:Any,<:AbstractVector}, B::Adjoint{<:Any,<:AbstractRotation}) = copy(A) * B
405-
*(A::Transpose{<:Any,<:AbstractMatrix}, B::Adjoint{<:Any,<:AbstractRotation}) = copy(A) * B
406-
# disambiguation methods: *(Diag/AbsTri, Adj of AbstractRotation)
407-
*(A::Diagonal, B::Adjoint{<:Any,<:AbstractRotation}) = A * copy(B)
408-
*(A::AbstractTriangular, B::Adjoint{<:Any,<:AbstractRotation}) = A * copy(B)
412+
*(G1::Givens{T}, G2::Givens{T}) where {T} = Rotation([G2, G1])

stdlib/LinearAlgebra/test/givens.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module TestGivens
44

55
using Test, LinearAlgebra, Random
6-
using LinearAlgebra: rmul!, lmul!, Givens
6+
using LinearAlgebra: Givens, Rotation
77

88
# Test givens rotations
99
@testset for elty in (Float32, Float64, ComplexF32, ComplexF64)
@@ -14,7 +14,7 @@ using LinearAlgebra: rmul!, lmul!, Givens
1414
end
1515
@testset for A in (raw_A, view(raw_A, 1:10, 1:10))
1616
Ac = copy(A)
17-
R = LinearAlgebra.Rotation(LinearAlgebra.Givens{elty}[])
17+
R = Rotation(Givens{elty}[])
1818
for j = 1:8
1919
for i = j+2:10
2020
G, _ = givens(A, j+1, i, j)
@@ -25,14 +25,19 @@ using LinearAlgebra: rmul!, lmul!, Givens
2525
@test lmul!(G,Matrix{elty}(I, 10, 10)) == [G[i,j] for i=1:10,j=1:10]
2626

2727
@testset "transposes" begin
28-
@test G'*G*Matrix(elty(1)I, 10, 10) Matrix(I, 10, 10)
28+
@test (@inferred G'*G)*Matrix(elty(1)I, 10, 10) Matrix(I, 10, 10)
2929
@test (G*Matrix(elty(1)I, 10, 10))*G' Matrix(I, 10, 10)
30-
@test copy(R')*(R*Matrix(elty(1)I, 10, 10)) Matrix(I, 10, 10)
30+
@test (@inferred copy(R'))*(R*Matrix(elty(1)I, 10, 10)) Matrix(I, 10, 10)
3131
@test_throws ErrorException transpose(G)
3232
@test_throws ErrorException transpose(R)
3333
end
3434
end
3535
end
36+
@test (R')' === R
37+
@test R * A (A' * R')' lmul!(R, copy(A))
38+
@test A * R (R' * A')' rmul!(copy(A), R)
39+
@test R' * A lmul!(R', copy(A))
40+
@test A * R' rmul!(copy(A), R')
3641
@test_throws ArgumentError givens(A, 3, 3, 2)
3742
@test_throws ArgumentError givens(one(elty),zero(elty),2,2)
3843
G, _ = givens(one(elty),zero(elty),11,12)

0 commit comments

Comments
 (0)