Skip to content

Commit 5d3cda2

Browse files
Improve precision of 3x3 symmetric eigen solver (#714)
Port of 3x3 iterative eigen solver from GeometricTools GTEngine, based on symmetric QR.
1 parent d317933 commit 5d3cda2

File tree

1 file changed

+175
-164
lines changed

1 file changed

+175
-164
lines changed

src/eigen.jl

Lines changed: 175 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -221,188 +221,199 @@ end
221221
return Eigen(vals,vecs)
222222
end
223223

224-
# A small part of the code in the following method was inspired by works of David
225-
# Eberly, Geometric Tools LLC, in code released under the Boost Software
226-
# License (included at the end of this file).
224+
# Port of https://www.geometrictools.com/GTEngine/Include/Mathematics/GteSymmetricEigensolver3x3.h
225+
# released by David Eberly, Geometric Tools, Redmond WA 98052
226+
# under the Boost Software License, Version 1.0 (included at the end of this file)
227+
# The original documentation states
228+
# (see https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf )
229+
# [This] is an implementation of Algorithm 8.2.3 (Symmetric QR Algorithm) described in
230+
# Matrix Computations,2nd edition, by G. H. Golub and C. F. Van Loan, The Johns Hopkins
231+
# University Press, Baltimore MD, Fourth Printing 1993. Algorithm 8.2.1 (Householder
232+
# Tridiagonalization) is used to reduce matrix A to tridiagonal D′. Algorithm 8.2.2
233+
# (Implicit Symmetric QR Step with Wilkinson Shift) is used for the iterative reduction
234+
# from tridiagonal to diagonal. Numerically, we have errors E=RTAR−D. Algorithm 8.2.3
235+
# mentions that one expects |E| is approximately μ|A|, where |M| denotes the Frobenius norm
236+
# of M and where μ is the unit roundoff for the floating-point arithmetic: 2−23 for float,
237+
# which is FLTEPSILON = 1.192092896e-7f, and 2−52 for double, which is
238+
# DBLEPSILON = 2.2204460492503131e-16.
239+
# TODO ensure right-handedness of the eigenvalue matrix
227240
# TODO extend the method to complex hermitian
228241
@inline function _eig(::Size{(3,3)}, A::LinearAlgebra.HermOrSym{T}, permute, scale) where {T <: Real}
229-
S = arithmetic_closure(T)
230-
Sreal = real(S)
231-
232-
@inbounds a11 = convert(Sreal, A.data[1])
233-
@inbounds a22 = convert(Sreal, A.data[5])
234-
@inbounds a33 = convert(Sreal, A.data[9])
235-
if A.uplo == 'U'
236-
@inbounds a12 = convert(S, A.data[4])
237-
@inbounds a13 = convert(S, A.data[7])
238-
@inbounds a23 = convert(S, A.data[8])
239-
else
240-
@inbounds a12 = conj(convert(S, A.data[2]))
241-
@inbounds a13 = conj(convert(S, A.data[3]))
242-
@inbounds a23 = conj(convert(S, A.data[6]))
243-
end
244-
245-
p1 = abs2(a12) + abs2(a13) + abs2(a23)
246-
if (p1 == 0)
247-
# Matrix is diagonal
248-
v1 = SVector(one(S), zero(S), zero(S))
249-
v2 = SVector(zero(S), one(S), zero(S))
250-
v3 = SVector(zero(S), zero(S), one(S) )
251-
252-
if a11 < a22
253-
if a22 < a33
254-
return Eigen(SVector((a11, a22, a33)), hcat(v1,v2,v3))
255-
elseif a33 < a11
256-
return Eigen(SVector((a33, a11, a22)), hcat(v3,v1,v2))
257-
else
258-
return Eigen(SVector((a11, a33, a22)), hcat(v1,v3,v2))
259-
end
260-
else #a22 < a11
261-
if a11 < a33
262-
return Eigen(SVector((a22, a11, a33)), hcat(v2,v1,v3))
263-
elseif a33 < a22
264-
return Eigen(SVector((a33, a22, a11)), hcat(v3,v2,v1))
265-
else
266-
return Eigen(SVector((a22, a33, a11)), hcat(v2,v3,v1))
267-
end
242+
function converged(aggressive, bdiag0, bdiag1, bsuper)
243+
if aggressive
244+
bsuper == 0
245+
else
246+
diag_sum = abs(bdiag0) + abs(bdiag1)
247+
diag_sum + bsuper == diag_sum
268248
end
269249
end
270250

271-
q = (a11 + a22 + a33) / 3
272-
p2 = abs2(a11 - q) + abs2(a22 - q) + abs2(a33 - q) + 2 * p1
273-
p = sqrt(p2 / 6)
274-
invp = inv(p)
275-
b11 = (a11 - q) * invp
276-
b22 = (a22 - q) * invp
277-
b33 = (a33 - q) * invp
278-
b12 = a12 * invp
279-
b13 = a13 * invp
280-
b23 = a23 * invp
281-
B = SMatrix{3,3,S}((b11, conj(b12), conj(b13), b12, b22, conj(b23), b13, b23, b33))
282-
r = real(det(B)) / 2
283-
284-
# In exact arithmetic for a symmetric matrix -1 <= r <= 1
285-
# but computation error can leave it slightly outside this range.
286-
if (r <= -1)
287-
phi = Sreal(pi) / 3
288-
elseif (r >= 1)
289-
phi = zero(Sreal)
290-
else
291-
phi = acos(r) / 3
292-
end
293-
294-
eig3 = q + 2 * p * cos(phi)
295-
eig1 = q + 2 * p * cos(phi + (2*Sreal(pi)/3))
296-
eig2 = 3 * q - eig1 - eig3 # since tr(A) = eig1 + eig2 + eig3
297-
298-
if r > 0 # Helps with conditioning the eigenvector calculation
299-
(eig1, eig3) = (eig3, eig1)
300-
end
301-
302-
# Calculate the first eigenvector
303-
# This should be orthogonal to these three rows of A - eig1*I
304-
# Use all combinations of cross products and choose the "best" one
305-
r₁ = SVector(a11 - eig1, a12, a13)
306-
r₂ = SVector(conj(a12), a22 - eig1, a23)
307-
r₃ = SVector(conj(a13), conj(a23), a33 - eig1)
308-
n₁ = sum(abs2, r₁)
309-
n₂ = sum(abs2, r₂)
310-
n₃ = sum(abs2, r₃)
311-
312-
r₁₂ = r₁ × r₂
313-
r₂₃ = r₂ × r₃
314-
r₃₁ = r₃ × r₁
315-
n₁₂ = sum(abs2, r₁₂)
316-
n₂₃ = sum(abs2, r₂₃)
317-
n₃₁ = sum(abs2, r₃₁)
318-
319-
# we want best angle so we put all norms on same footing
320-
# (cheaper to multiply by third nᵢ rather than divide by the two involved)
321-
if n₁₂ * n₃ > n₂₃ * n₁
322-
if n₁₂ * n₃ > n₃₁ * n₂
323-
eigvec1 = r₁₂ / sqrt(n₁₂)
324-
else
325-
eigvec1 = r₃₁ / sqrt(n₃₁)
326-
end
327-
else
328-
if n₂₃ * n₁ > n₃₁ * n₂
329-
eigvec1 = r₂₃ / sqrt(n₂₃)
251+
function get_cos_sin(u::T,v::T) where {T}
252+
max_abs = max(abs(u), abs(v))
253+
if max_abs > 0
254+
u,v = (u,v) ./ max_abs
255+
len = sqrt(u^2 + v^2)
256+
cs, sn = (u,v) ./ len
257+
if cs > 0
258+
cs = -cs
259+
sn = -sn
260+
end
261+
T(cs), T(sn)
330262
else
331-
eigvec1 = r₃₁ / sqrt(n₃₁)
263+
T(-1), T(0)
332264
end
333265
end
334266

335-
# Calculate the second eigenvector
336-
# This should be orthogonal to the previous eigenvector and the three
337-
# rows of A - eig2*I. However, we need to "solve" the remaining 2x2 subspace
338-
# problem in case the cross products are identically or nearly zero
339-
340-
# The remaing 2x2 subspace is:
341-
@inbounds if abs(eigvec1[1]) < abs(eigvec1[2]) # safe to set one component to zero, depending on this
342-
orthogonal1 = SVector(-eigvec1[3], zero(S), eigvec1[1]) / sqrt(abs2(eigvec1[1]) + abs2(eigvec1[3]))
343-
else
344-
orthogonal1 = SVector(zero(S), eigvec1[3], -eigvec1[2]) / sqrt(abs2(eigvec1[2]) + abs2(eigvec1[3]))
267+
function _sortperm3(v)
268+
local perm = SVector(1,2,3)
269+
# unrolled bubble-sort
270+
(v[perm[1]] > v[perm[2]]) && (perm = SVector(perm[2], perm[1], perm[3]))
271+
(v[perm[2]] > v[perm[3]]) && (perm = SVector(perm[1], perm[3], perm[2]))
272+
(v[perm[1]] > v[perm[2]]) && (perm = SVector(perm[2], perm[1], perm[3]))
273+
perm
345274
end
346-
orthogonal2 = eigvec1 × orthogonal1
347-
348-
# The projected 2x2 eigenvalue problem is C x = 0 where C is the projection
349-
# of (A - eig2*I) onto the subspace {orthogonal1, orthogonal2}
350-
@inbounds a_orth1_1 = a11 * orthogonal1[1] + a12 * orthogonal1[2] + a13 * orthogonal1[3]
351-
@inbounds a_orth1_2 = conj(a12) * orthogonal1[1] + a22 * orthogonal1[2] + a23 * orthogonal1[3]
352-
@inbounds a_orth1_3 = conj(a13) * orthogonal1[1] + conj(a23) * orthogonal1[2] + a33 * orthogonal1[3]
353-
354-
@inbounds a_orth2_1 = a11 * orthogonal2[1] + a12 * orthogonal2[2] + a13 * orthogonal2[3]
355-
@inbounds a_orth2_2 = conj(a12) * orthogonal2[1] + a22 * orthogonal2[2] + a23 * orthogonal2[3]
356-
@inbounds a_orth2_3 = conj(a13) * orthogonal2[1] + conj(a23) * orthogonal2[2] + a33 * orthogonal2[3]
357-
358-
@inbounds c11 = conj(orthogonal1[1])*a_orth1_1 + conj(orthogonal1[2])*a_orth1_2 + conj(orthogonal1[3])*a_orth1_3 - eig2
359-
@inbounds c12 = conj(orthogonal1[1])*a_orth2_1 + conj(orthogonal1[2])*a_orth2_2 + conj(orthogonal1[3])*a_orth2_3
360-
@inbounds c22 = conj(orthogonal2[1])*a_orth2_1 + conj(orthogonal2[2])*a_orth2_2 + conj(orthogonal2[3])*a_orth2_3 - eig2
361-
362-
# Solve this robustly (some values might be small or zero)
363-
c11² = abs2(c11)
364-
c12² = abs2(c12)
365-
c22² = abs2(c22)
366-
if c11² >= c22²
367-
if c11² > 0 || c12² > 0
368-
if c11² >= c12²
369-
tmp = c12 / c11 # TODO check for compex input
370-
p2 = inv(sqrt(1 + abs2(tmp)))
371-
p1 = tmp * p2
372-
else
373-
tmp = c11 / c12 # TODO check for compex input
374-
p1 = inv(sqrt(1 + abs2(tmp)))
375-
p2 = tmp * p1
275+
276+
# Givens reflections
277+
update0(Q, c, s) = Q * @SMatrix [c 0 -s; s 0 c; 0 1 0]
278+
update1(Q, c, s) = Q * @SMatrix [0 1 0; c 0 s; -s 0 c]
279+
# Householder reflections
280+
update2(Q, c, s) = Q * @SMatrix [c s 0; s -c 0; 0 0 1]
281+
update3(Q, c, s) = Q * @SMatrix [1 0 0; 0 c s; 0 s -c]
282+
283+
is_rotation = false
284+
285+
# If `aggressive` is `true`, the iterations occur until a superdiagonal
286+
# entry is exactly zero, otherwise they occur until it is effectively zero
287+
# compared to the magnitude of its diagonal neighbors. Generally the non-
288+
# aggressive convergence is acceptable.
289+
#
290+
# Even with `aggressive = true` this method is faster than the one it
291+
# replaces and in order to keep the old interface, aggressive is set to true
292+
aggressive = true
293+
294+
# the input is symmetric, so we only consider the unique elements:
295+
a00, a01, a02, a11, a12, a22 = A[1,1], A[1,2], A[1,3], A[2,2], A[2,3], A[3,3]
296+
297+
# Compute the Householder reflection H and B = H * A * H where b02 = 0
298+
299+
c,s = get_cos_sin(a12, -a02)
300+
301+
Q = @SMatrix [c s 0; s -c 0; 0 0 1]
302+
303+
term0 = c * a00 + s * a01
304+
term1 = c * a01 + s * a11
305+
b00 = c * term0 + s * term1
306+
b01 = s * term0 - c * term1
307+
term0 = s * a00 - c * a01
308+
term1 = s * a01 - c * a11
309+
b11 = s * term0 - c * term1
310+
b12 = s * a02 - c * a12
311+
b22 = a22
312+
313+
# Givens reflections, B' = G^T * B * G, preserve tridiagonal matrices
314+
max_iteration = 2 * (1 + precision(T) - exponent(floatmin(T)))
315+
316+
if abs(b12) <= abs(b01)
317+
saveB00, saveB01, saveB11 = b00, b01, b11
318+
for iteration in 1:max_iteration
319+
# compute the Givens reflection
320+
c2, s2 = get_cos_sin((b00 - b11) / 2, b01)
321+
s = sqrt((1 - c2) / 2)
322+
c = s2 / 2s
323+
324+
# update Q by the Givens reflection
325+
Q = update0(Q, c, s)
326+
is_rotation = !is_rotation
327+
328+
# update B ← Q^T * B * Q, ensuring that b02 is zero and |b12| has
329+
# strictly decreased
330+
saveB00, saveB01, saveB11 = b00, b01, b11
331+
term0 = c * saveB00 + s * saveB01
332+
term1 = c * saveB01 + s * saveB11
333+
b00 = c * term0 + s * term1
334+
b11 = b22
335+
term0 = c * saveB01 - s * saveB00
336+
term1 = c * saveB11 - s * saveB01
337+
b22 = c * term1 - s * term0
338+
b01 = s * b12
339+
b12 = c * b12
340+
341+
if converged(aggressive, b00, b11, b01)
342+
# compute the Householder reflection
343+
c2, s2 = get_cos_sin((b00 - b11) / 2, b01)
344+
s = sqrt((1 - c2) / 2)
345+
c = s2 / 2s
346+
347+
# update Q by the Householder reflection
348+
Q = update2(Q, c, s)
349+
is_rotation = !is_rotation
350+
351+
# update D = Q^T * B * Q
352+
saveB00, saveB01, saveB11 = b00, b01, b11
353+
term0 = c * saveB00 + s * saveB01
354+
term1 = c * saveB01 + s * saveB11
355+
b00 = c * term0 + s * term1
356+
term0 = s * saveB00 - c * saveB01
357+
term1 = s * saveB01 - c * saveB11
358+
b11 = s * term0 - c * term1
359+
break
376360
end
377-
eigvec2 = p1*orthogonal1 - p2*orthogonal2
378-
else # c11 == 0 && c12 == 0 && c22 == 0 (smaller than c11)
379-
eigvec2 = orthogonal1
380361
end
381362
else
382-
if c22² >= c12²
383-
tmp = c12 / c22 # TODO check for compex input
384-
p1 = inv(sqrt(1 + abs2(tmp)))
385-
p2 = tmp * p1
386-
else
387-
tmp = c22 / c12 # TODO check for compex input
388-
p2 = inv(sqrt(1 + abs2(tmp)))
389-
p1 = tmp * p2
363+
saveB11, saveB12, saveB22 = b11, b12, b22
364+
for iteration in 1:max_iteration
365+
# compute the Givens reflection
366+
c2, s2 = get_cos_sin((b22 - b11) / 2, b12)
367+
s = sqrt((1 - c2) / 2)
368+
c = s2 / 2s
369+
370+
# update Q by the Givens reflection
371+
Q = update1(Q, c, s)
372+
is_rotation = !is_rotation
373+
374+
# update B ← Q^T * B * Q ensuring that b02 is zero and |b12| has
375+
# strictly decreased.
376+
saveB11, saveB12, saveB22 = b11, b12, b22
377+
378+
term0 = c * saveB22 + s * saveB12
379+
term1 = c * saveB12 + s * saveB11
380+
b22 = c * term0 + s * term1
381+
b11 = b00
382+
term0 = c * saveB12 - s * saveB22
383+
term1 = c * saveB11 - s * saveB12
384+
b00 = c * term1 - s * term0
385+
b12 = s * b01
386+
b01 = c * b01
387+
388+
if converged(aggressive, b11, b22, b12)
389+
# compute the Householder reflection
390+
c2, s2 = get_cos_sin((b11 - b22) / 2, b12)
391+
s = sqrt((1 - c2) / 2)
392+
c = s2 / 2s
393+
394+
# update Q by the Householder reflection
395+
Q = update3(Q, c, s)
396+
is_rotation = !is_rotation
397+
398+
# update D = Q^T * B * Q
399+
saveB11, saveB12, saveB22 = b11, b12, b22
400+
term0 = c * saveB11 + s * saveB12
401+
term1 = c * saveB12 + s * saveB22
402+
b11 = c * term0 + s * term1
403+
term0 = s * saveB11 - c * saveB12
404+
term1 = s * saveB12 - c * saveB22
405+
b22 = s * term0 - c * term1
406+
break
407+
end
390408
end
391-
eigvec2 = p1*orthogonal1 - p2*orthogonal2
392409
end
393410

394-
# The third eigenvector is a simple cross product of the other two
395-
eigvec3 = eigvec1 × eigvec2 # should be normalized already
396-
397-
# Sort them back to the original ordering, if necessary
398-
if r > 0
399-
(eig1, eig3) = (eig3, eig1)
400-
(eigvec1, eigvec3) = (eigvec3, eigvec1)
401-
end
402-
403-
return Eigen(SVector(eig1, eig2, eig3), hcat(eigvec1, eigvec2, eigvec3))
411+
evals = @SVector [b00, b11, b22]
412+
perm = _sortperm3(evals)
413+
Eigen(evals[perm], Q[:,perm])
404414
end
405415

416+
406417
@inline function eigen(A::StaticMatrix; permute::Bool=true, scale::Bool=true)
407418
_eig(Size(A), A, permute, scale)
408419
end

0 commit comments

Comments
 (0)