2
2
CorrBijector <: Bijector
3
3
4
4
A bijector implementation of Stan's parametrization method for Correlation matrix:
5
- https://mc-stan.org/docs/2_23/ reference-manual/correlation-matrix-transform-section.html
5
+ https://mc-stan.org/docs/reference-manual/transforms.html# correlation-matrix-transform.section
6
6
7
7
Basically, a unconstrained strictly upper triangular matrix `y` is transformed to
8
8
a correlation matrix by following readable but not that efficient form:
@@ -348,13 +348,12 @@ function _inv_link_chol_lkj(Y::AbstractMatrix)
348
348
T = float (eltype (W))
349
349
logJ = zero (T)
350
350
351
- idx = 1
352
351
@inbounds for j in 1 : K
353
352
log_remainder = zero (T) # log of proportion of unit vector remaining
354
353
for i in 1 : (j - 1 )
355
354
z = tanh (Y[i, j])
356
355
W[i, j] = z * exp (log_remainder)
357
- log_remainder += log1p ( - z ^ 2 ) / 2
356
+ log_remainder -= LogExpFunctions . logcosh (Y[i, j])
358
357
logJ += log_remainder
359
358
end
360
359
logJ += log_remainder
@@ -375,15 +374,18 @@ function _inv_link_chol_lkj(y::AbstractVector)
375
374
T = float (eltype (W))
376
375
logJ = zero (T)
377
376
377
+ z_vec = map (tanh, y)
378
+ lc_vec = map (LogExpFunctions. logcosh, y)
379
+
378
380
idx = 1
379
381
@inbounds for j in 1 : K
380
382
log_remainder = zero (T) # log of proportion of unit vector remaining
381
383
for i in 1 : (j - 1 )
382
- z = tanh (y[idx])
383
- idx += 1
384
+ z = z_vec[idx]
384
385
W[i, j] = z * exp (log_remainder)
385
- log_remainder += log1p ( - z ^ 2 ) / 2
386
+ log_remainder -= lc_vec[idx]
386
387
logJ += log_remainder
388
+ idx += 1
387
389
end
388
390
logJ += log_remainder
389
391
W[j, j] = exp (log_remainder)
@@ -404,18 +406,19 @@ function _inv_link_chol_lkj_rrule(y::AbstractVector)
404
406
T = typeof (log (one (eltype (W))))
405
407
logJ = zero (T)
406
408
407
- z_vec = tanh .(y)
409
+ z_vec = map (tanh, y)
410
+ lc_vec = map (LogExpFunctions. logcosh, y)
408
411
409
412
idx = 1
410
413
W[1 , 1 ] = 1
411
414
@inbounds for j in 2 : K
412
415
log_remainder = zero (T) # log of proportion of unit vector remaining
413
416
for i in 1 : (j - 1 )
414
417
z = z_vec[idx]
415
- idx += 1
416
418
W[i, j] = z * exp (log_remainder)
417
- log_remainder += log1p ( - z ^ 2 ) / 2
419
+ log_remainder -= lc_vec[idx]
418
420
logJ += log_remainder
421
+ idx += 1
419
422
end
420
423
logJ += log_remainder
421
424
W[j, j] = exp (log_remainder)
@@ -461,13 +464,8 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix)
461
464
K = LinearAlgebra. checksquare (Y)
462
465
463
466
result = float (zero (eltype (Y)))
464
- for j in 2 : K, i in 1 : (j - 1 )
465
- @inbounds abs_y_i_j = abs (Y[i, j])
466
- result +=
467
- (K - i + 1 ) * (
468
- IrrationalConstants. logtwo -
469
- (abs_y_i_j + LogExpFunctions. log1pexp (- 2 * abs_y_i_j))
470
- )
467
+ @inbounds for j in 2 : K, i in 1 : (j - 1 )
468
+ result -= (K - i + 1 ) * LogExpFunctions. logcosh (Y[i, j])
471
469
end
472
470
return result
473
471
end
@@ -477,13 +475,8 @@ function _logabsdetjac_inv_corr(y::AbstractVector)
477
475
478
476
result = float (zero (eltype (y)))
479
477
for (i, y_i) in enumerate (y)
480
- abs_y_i = abs (y_i)
481
478
row_idx = vec_to_triu1_row_index (i)
482
- result +=
483
- (K - row_idx + 1 ) * (
484
- IrrationalConstants. logtwo -
485
- (abs_y_i + LogExpFunctions. log1pexp (- 2 * abs_y_i))
486
- )
479
+ result -= (K - row_idx + 1 ) * LogExpFunctions. logcosh (y_i)
487
480
end
488
481
return result
489
482
end
@@ -496,10 +489,9 @@ function _logabsdetjac_inv_chol(y::AbstractVector)
496
489
@inbounds for j in 2 : K
497
490
tmp = zero (result)
498
491
for _ in 1 : (j - 1 )
499
- z = tanh (y[idx])
500
- logz = log (1 - z^ 2 )
501
- result += logz + (tmp / 2 )
502
- tmp += logz
492
+ logcoshy = LogExpFunctions. logcosh (y[idx])
493
+ tmp -= logcoshy
494
+ result += tmp - logcoshy
503
495
idx += 1
504
496
end
505
497
end
0 commit comments