Skip to content

Commit 8a525f1

Browse files
Cholesky numerical stability: inverse transform (#356)
* Improve numerical stability of Cholesky invlink * Simplify further * Use logcosh Co-authored-by: David Widmann <[email protected]> * Swap over one more occurrence of logcosh * Update Stan documentation link * Simplify loop in _logabsdetjac_inv_chol * Apply same fix in _inv_link_chol_lkj_rrule * Minor performance optimisation * Change broadcasting to map Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: David Widmann <[email protected]>
1 parent c420ff0 commit 8a525f1

File tree

2 files changed

+19
-27
lines changed

2 files changed

+19
-27
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.15.5"
3+
version = "0.15.6"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/bijectors/corr.jl

+18-26
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
CorrBijector <: Bijector
33
44
A bijector implementation of Stan's parametrization method for Correlation matrix:
5-
https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
5+
https://mc-stan.org/docs/reference-manual/transforms.html#correlation-matrix-transform.section
66
77
Basically, a unconstrained strictly upper triangular matrix `y` is transformed to
88
a correlation matrix by following readable but not that efficient form:
@@ -348,13 +348,12 @@ function _inv_link_chol_lkj(Y::AbstractMatrix)
348348
T = float(eltype(W))
349349
logJ = zero(T)
350350

351-
idx = 1
352351
@inbounds for j in 1:K
353352
log_remainder = zero(T) # log of proportion of unit vector remaining
354353
for i in 1:(j - 1)
355354
z = tanh(Y[i, j])
356355
W[i, j] = z * exp(log_remainder)
357-
log_remainder += log1p(-z^2) / 2
356+
log_remainder -= LogExpFunctions.logcosh(Y[i, j])
358357
logJ += log_remainder
359358
end
360359
logJ += log_remainder
@@ -375,15 +374,18 @@ function _inv_link_chol_lkj(y::AbstractVector)
375374
T = float(eltype(W))
376375
logJ = zero(T)
377376

377+
z_vec = map(tanh, y)
378+
lc_vec = map(LogExpFunctions.logcosh, y)
379+
378380
idx = 1
379381
@inbounds for j in 1:K
380382
log_remainder = zero(T) # log of proportion of unit vector remaining
381383
for i in 1:(j - 1)
382-
z = tanh(y[idx])
383-
idx += 1
384+
z = z_vec[idx]
384385
W[i, j] = z * exp(log_remainder)
385-
log_remainder += log1p(-z^2) / 2
386+
log_remainder -= lc_vec[idx]
386387
logJ += log_remainder
388+
idx += 1
387389
end
388390
logJ += log_remainder
389391
W[j, j] = exp(log_remainder)
@@ -404,18 +406,19 @@ function _inv_link_chol_lkj_rrule(y::AbstractVector)
404406
T = typeof(log(one(eltype(W))))
405407
logJ = zero(T)
406408

407-
z_vec = tanh.(y)
409+
z_vec = map(tanh, y)
410+
lc_vec = map(LogExpFunctions.logcosh, y)
408411

409412
idx = 1
410413
W[1, 1] = 1
411414
@inbounds for j in 2:K
412415
log_remainder = zero(T) # log of proportion of unit vector remaining
413416
for i in 1:(j - 1)
414417
z = z_vec[idx]
415-
idx += 1
416418
W[i, j] = z * exp(log_remainder)
417-
log_remainder += log1p(-z^2) / 2
419+
log_remainder -= lc_vec[idx]
418420
logJ += log_remainder
421+
idx += 1
419422
end
420423
logJ += log_remainder
421424
W[j, j] = exp(log_remainder)
@@ -461,13 +464,8 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix)
461464
K = LinearAlgebra.checksquare(Y)
462465

463466
result = float(zero(eltype(Y)))
464-
for j in 2:K, i in 1:(j - 1)
465-
@inbounds abs_y_i_j = abs(Y[i, j])
466-
result +=
467-
(K - i + 1) * (
468-
IrrationalConstants.logtwo -
469-
(abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j))
470-
)
467+
@inbounds for j in 2:K, i in 1:(j - 1)
468+
result -= (K - i + 1) * LogExpFunctions.logcosh(Y[i, j])
471469
end
472470
return result
473471
end
@@ -477,13 +475,8 @@ function _logabsdetjac_inv_corr(y::AbstractVector)
477475

478476
result = float(zero(eltype(y)))
479477
for (i, y_i) in enumerate(y)
480-
abs_y_i = abs(y_i)
481478
row_idx = vec_to_triu1_row_index(i)
482-
result +=
483-
(K - row_idx + 1) * (
484-
IrrationalConstants.logtwo -
485-
(abs_y_i + LogExpFunctions.log1pexp(-2 * abs_y_i))
486-
)
479+
result -= (K - row_idx + 1) * LogExpFunctions.logcosh(y_i)
487480
end
488481
return result
489482
end
@@ -496,10 +489,9 @@ function _logabsdetjac_inv_chol(y::AbstractVector)
496489
@inbounds for j in 2:K
497490
tmp = zero(result)
498491
for _ in 1:(j - 1)
499-
z = tanh(y[idx])
500-
logz = log(1 - z^2)
501-
result += logz + (tmp / 2)
502-
tmp += logz
492+
logcoshy = LogExpFunctions.logcosh(y[idx])
493+
tmp -= logcoshy
494+
result += tmp - logcoshy
503495
idx += 1
504496
end
505497
end

0 commit comments

Comments
 (0)