From 1eefc6cc9e4ab5c3612867facfe63753df69de78 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 00:42:01 -0800 Subject: [PATCH 01/14] Reimplement UnitVector --- src/special_arrays.jl | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index 0fdf9dc..465322e 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -35,7 +35,7 @@ Inverse of [`l2_remainder_transform`](@ref) in `x` and `y`. """ UnitVector(n) -Transform `n-1` real numbers to a unit vector of length `n`, under the +Transform `n` real numbers to a unit vector of length `n`, under the Euclidean norm. """ @calltrans struct UnitVector <: VectorTransform @@ -46,22 +46,20 @@ Euclidean norm. end end -dimension(t::UnitVector) = t.n - 1 +dimension(t::UnitVector) = t.n function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index) @unpack n = t T = extended_eltype(x) - r = one(T) - y = Vector{T}(undef, n) + index′ = index + n + vx = view(x, index:(index′ - 1)) + nx² = sum(abs2, vx) + y = vx ./ √nx² ℓ = logjac_zero(flag, T) - @inbounds for i in 1:(n - 1) - xi = x[index] - index += 1 - y[i], r, ℓi = l2_remainder_transform(flag, xi, r) - ℓ += ℓi + if !(flag isa NoLogJac) + ℓ -= nx² / 2 end - y[end] = √r - y, ℓ, index + y, ℓ, index′ end inverse_eltype(t::UnitVector, y::AbstractVector) = extended_eltype(y) @@ -69,12 +67,9 @@ inverse_eltype(t::UnitVector, y::AbstractVector) = extended_eltype(y) function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector) @unpack n = t @argcheck length(y) == n - r = one(eltype(y)) - @inbounds for yi in axes(y, 1)[1:(end-1)] - x[index], r = l2_remainder_inverse(y[yi], r) - index += 1 - end - index + index′ = index + n + setindex!(x, y, index:(index′ - 1)) + index′ end From 9fc49ac4593fd2f568ab2ed775f3644681e38295 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 00:42:33 -0800 Subject: [PATCH 02/14] Add ability to opt out of logjac test --- test/utilities.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/utilities.jl b/test/utilities.jl index 0a83699..a6284b4 100644 --- a/test/utilities.jl +++ b/test/utilities.jl @@ -22,7 +22,8 @@ automatic differentiation. `test_inverse` determines whether the inverse is tested. """ function test_transformation(t::AbstractTransform, is_valid_y; - vec_y = identity, N = 1000, test_inverse = true) + vec_y = identity, N = 1000, + test_inverse = true, test_logjac = true) for _ in 1:N x = t isa ScalarTransform ? randn() : randn(dimension(t)) if t isa ScalarTransform @@ -37,10 +38,12 @@ function test_transformation(t::AbstractTransform, is_valid_y; @test t(x) == y # callable y2, lj = @inferred transform_and_logjac(t, x) @test y2 == y - if t isa ScalarTransform - @test lj ≈ AD_logjac(t, x) - else - @test lj ≈ AD_logjac(t, x, vec_y) + if test_logjac + if t isa ScalarTransform + @test lj ≈ AD_logjac(t, x) + else + @test lj ≈ AD_logjac(t, x, vec_y) + end end if test_inverse x2 = inverse(t, y) From a0494f6cf91e0a9d9a3601848247ce16375892ba Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 00:43:12 -0800 Subject: [PATCH 03/14] Update consistency test --- test/runtests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9e63bbc..a9acd1c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -85,7 +85,7 @@ end @testset "to unit vector" begin @testset "dimension checks" begin U = UnitVector(3) - x = zeros(3) # incorrect + x = zeros(2) # incorrect @test_throws ArgumentError U(x) @test_throws ArgumentError transform(U, x) @test_throws ArgumentError transform_and_logjac(U, x) @@ -94,10 +94,10 @@ end @testset "consistency checks" begin for K in 1:10 t = UnitVector(K) - @test dimension(t) == K - 1 + @test dimension(t) == K if K > 1 - test_transformation(t, y -> sum(abs2, y) ≈ 1, - vec_y = y -> y[1:(end-1)]) + test_transformation(t, y -> sum(abs2, y) ≈ 1; + test_inverse = false, test_logjac = false) end end end From 7058f468febcca9537ea1ba71510a33dd8a6b8dd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 00:43:30 -0800 Subject: [PATCH 04/14] Test right inverse and logjac --- test/runtests.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index a9acd1c..14f683e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -98,6 +98,13 @@ end if K > 1 test_transformation(t, y -> sum(abs2, y) ≈ 1; test_inverse = false, test_logjac = false) + x = randn(K) + y = transform(t, x) + x2 = @inferred inverse(t, y) + @test x2 ≈ y + ι = inverse(t) + @test y ≈ ι(y) + @test transform_and_logjac(t, x)[2] ≈ -sum(abs2, x) ./ 2 end end end From a989b63a460aa664a237993b010d187ccc444fb7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 00:44:25 -0800 Subject: [PATCH 05/14] Test right inverse instead of left --- test/runtests.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 14f683e..1c21e5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -249,7 +249,8 @@ end x = randn(dimension(tn)) y = @inferred transform(tn, x) @test y isa NamedTuple{(:a,:b,:c)} - @test inverse(tn, y) ≈ x + x2 = inverse(tn, y) + @test inverse(tn, transform(tn, x2)) ≈ x2 index = 0 ljacc = 0.0 for (i, t) in enumerate((t1, t2, t3)) @@ -291,7 +292,7 @@ end x = randn(dimension(tt)) y = tt(x) x′ = inverse(tt, y) - @test x ≈ x′ + @test inverse(tt, transform(tt, x′)) ≈ x′ end end From 750471803cb0784258d2556976aa9e22baa23831 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 00:45:08 -0800 Subject: [PATCH 06/14] Don't initialize at origin --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1c21e5f..971cddf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -503,7 +503,7 @@ end t = UnitVector(3) d = dimension(t) - x = [zeros(d), zeros(d)] + x = [randn(d), randn(d)] @test t.(x) == map(t, x) end From e5da2ec39f7732b02640f1f40d77e366c3367e20 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 01:28:40 -0800 Subject: [PATCH 07/14] Set upper bound on Flux for Tracker Flux removed Tracker as a dependency in v0.10.0 --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index d36ddf2..478f3ef 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] +Flux = "0.9" LogDensityProblems = "^0.9.0" julia = "^1" From 7ca7efe7ee13e26cdc48ad6fba9780f3995142d0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 18:43:53 -0800 Subject: [PATCH 08/14] Don't initialize to 0 --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 971cddf..16c7d08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -397,7 +397,7 @@ end -(abs2(μ) + abs2(σ) + abs2(β) + α + δ[1] + δ[2]) end P = TransformedLogDensity(t, f) - x = zeros(dimension(t)) + x = randn(dimension(t)) v = logdensity(P, x) g = ForwardDiff.gradient(x -> logdensity(P, x), x) From 982dd130604f0d43680a3648a126b63d7a5fde8f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 19:58:22 -0800 Subject: [PATCH 09/14] Transform origin to valid unit vector --- src/special_arrays.jl | 2 +- test/runtests.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index 465322e..c0c9b74 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -54,7 +54,7 @@ function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, inde index′ = index + n vx = view(x, index:(index′ - 1)) nx² = sum(abs2, vx) - y = vx ./ √nx² + y = nx² > 0 ? vx ./ √nx² : [zeros(T, n - 1); one(T)] ℓ = logjac_zero(flag, T) if !(flag isa NoLogJac) ℓ -= nx² / 2 diff --git a/test/runtests.jl b/test/runtests.jl index 16c7d08..cb5672d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -105,6 +105,7 @@ end ι = inverse(t) @test y ≈ ι(y) @test transform_and_logjac(t, x)[2] ≈ -sum(abs2, x) ./ 2 + @test transform(t, zeros(K)) ≈ [zeros(K-1); 1] end end end From d5a21c6f1818b819fdd406b199c11b1ac0ab1b40 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 19:58:51 -0800 Subject: [PATCH 10/14] Revert "Don't initialize to 0" This reverts commit 7ca7efe7ee13e26cdc48ad6fba9780f3995142d0. --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index cb5672d..3a9c8ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -398,7 +398,7 @@ end -(abs2(μ) + abs2(σ) + abs2(β) + α + δ[1] + δ[2]) end P = TransformedLogDensity(t, f) - x = randn(dimension(t)) + x = zeros(dimension(t)) v = logdensity(P, x) g = ForwardDiff.gradient(x -> logdensity(P, x), x) From 2eccfae7ed68ea698520a88b058fdad5bc30ce0e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 19:59:21 -0800 Subject: [PATCH 11/14] Revert "Don't initialize at origin" This reverts commit 750471803cb0784258d2556976aa9e22baa23831. --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3a9c8ce..e00078f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -504,7 +504,7 @@ end t = UnitVector(3) d = dimension(t) - x = [randn(d), randn(d)] + x = [zeros(d), zeros(d)] @test t.(x) == map(t, x) end From 08674f269305fbed9e1429fbfa4612b208373287 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 20:04:54 -0800 Subject: [PATCH 12/14] Update docstring with dimension --- src/aggregation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aggregation.jl b/src/aggregation.jl index 59912fa..a34babf 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -131,7 +131,7 @@ Return a transformation that transforms consecutive groups of real numbers to a julia> t = as((asℝ₊, UnitVector(3))); julia> dimension(t) -3 +4 julia> transform(t, zeros(dimension(t))) (1.0, [0.0, 0.0, 1.0]) @@ -139,7 +139,7 @@ julia> transform(t, zeros(dimension(t))) julia> t2 = as((σ = asℝ₊, u = UnitVector(3))); julia> dimension(t2) -3 +4 julia> transform(t2, zeros(dimension(t2))) (σ = 1.0, u = [0.0, 0.0, 1.0]) From 9a1c54f29b6ef037df742156a8144433a6a9be13 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 7 Dec 2019 21:15:38 -0800 Subject: [PATCH 13/14] Change default unit vector convention This is slightly more standard. e.g. for the 1- and 3-spheres (related to 2D and 3D rotation groups), [1,0] and [1,0,0,0] correspond to identity elements. --- src/aggregation.jl | 4 ++-- src/special_arrays.jl | 2 +- test/runtests.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/aggregation.jl b/src/aggregation.jl index a34babf..5c4e23b 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -134,7 +134,7 @@ julia> dimension(t) 4 julia> transform(t, zeros(dimension(t))) -(1.0, [0.0, 0.0, 1.0]) +(1.0, [1.0, 0.0, 0.0]) julia> t2 = as((σ = asℝ₊, u = UnitVector(3))); @@ -142,7 +142,7 @@ julia> dimension(t2) 4 julia> transform(t2, zeros(dimension(t2))) -(σ = 1.0, u = [0.0, 0.0, 1.0]) +(σ = 1.0, u = [1.0, 0.0, 0.0]) ``` """ as(transformations::NTransforms) = TransformTuple(transformations) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index c0c9b74..0f7a35c 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -54,7 +54,7 @@ function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, inde index′ = index + n vx = view(x, index:(index′ - 1)) nx² = sum(abs2, vx) - y = nx² > 0 ? vx ./ √nx² : [zeros(T, n - 1); one(T)] + y = nx² > 0 ? vx ./ √nx² : [one(T); zeros(T, n - 1)] ℓ = logjac_zero(flag, T) if !(flag isa NoLogJac) ℓ -= nx² / 2 diff --git a/test/runtests.jl b/test/runtests.jl index e00078f..b810394 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -105,7 +105,7 @@ end ι = inverse(t) @test y ≈ ι(y) @test transform_and_logjac(t, x)[2] ≈ -sum(abs2, x) ./ 2 - @test transform(t, zeros(K)) ≈ [zeros(K-1); 1] + @test transform(t, zeros(K)) ≈ [1; zeros(K-1)] end end end From 809c8dd6fe64b75e26c873627af2a498556d76bf Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 8 Dec 2019 01:22:38 -0800 Subject: [PATCH 14/14] Add function of unit vector to AD test And don't evaluate at exact origin where reverse-mode diffs are undefined. --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index b810394..0920e11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -394,11 +394,11 @@ end u = UnitVector(3), L = CorrCholeskyFactor(4), δ = as((asℝ₋, as𝕀)))) function f(θ) - @unpack μ, σ, β, α, δ = θ - -(abs2(μ) + abs2(σ) + abs2(β) + α + δ[1] + δ[2]) + @unpack μ, σ, β, α, u, δ = θ + -(abs2(μ) + abs2(σ) + abs2(β) + α + sum(u) + δ[1] + δ[2]) end P = TransformedLogDensity(t, f) - x = zeros(dimension(t)) + x = randn(dimension(t)) .* 1e-5 v = logdensity(P, x) g = ForwardDiff.gradient(x -> logdensity(P, x), x)