Skip to content

Fix UnitVector implementation #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
Flux = "0.9"
LogDensityProblems = "^0.9.0"
julia = "^1"

Expand Down
8 changes: 4 additions & 4 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,18 @@ Return a transformation that transforms consecutive groups of real numbers to a
julia> t = as((asℝ₊, UnitVector(3)));

julia> dimension(t)
3
4

julia> transform(t, zeros(dimension(t)))
(1.0, [0.0, 0.0, 1.0])
(1.0, [1.0, 0.0, 0.0])

julia> t2 = as((σ = asℝ₊, u = UnitVector(3)));

julia> dimension(t2)
3
4

julia> transform(t2, zeros(dimension(t2)))
(σ = 1.0, u = [0.0, 0.0, 1.0])
(σ = 1.0, u = [1.0, 0.0, 0.0])
```
"""
as(transformations::NTransforms) = TransformTuple(transformations)
Expand Down
29 changes: 12 additions & 17 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Inverse of [`l2_remainder_transform`](@ref) in `x` and `y`.
"""
UnitVector(n)

Transform `n-1` real numbers to a unit vector of length `n`, under the
Transform `n` real numbers to a unit vector of length `n`, under the
Euclidean norm.
"""
@calltrans struct UnitVector <: VectorTransform
Expand All @@ -46,35 +46,30 @@ Euclidean norm.
end
end

dimension(t::UnitVector) = t.n - 1
dimension(t::UnitVector) = t.n

function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index)
@unpack n = t
T = extended_eltype(x)
r = one(T)
y = Vector{T}(undef, n)
index′ = index + n
vx = view(x, index:(index′ - 1))
nx² = sum(abs2, vx)
y = nx² > 0 ? vx ./ √nx² : [one(T); zeros(T, n - 1)]
ℓ = logjac_zero(flag, T)
@inbounds for i in 1:(n - 1)
xi = x[index]
index += 1
y[i], r, ℓi = l2_remainder_transform(flag, xi, r)
ℓ += ℓi
if !(flag isa NoLogJac)
ℓ -= nx² / 2
end
y[end] = √r
y, ℓ, index
y, ℓ, index′
end

inverse_eltype(t::UnitVector, y::AbstractVector) = extended_eltype(y)

function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
@unpack n = t
@argcheck length(y) == n
r = one(eltype(y))
@inbounds for yi in axes(y, 1)[1:(end-1)]
x[index], r = l2_remainder_inverse(y[yi], r)
index += 1
end
index
index′ = index + n
setindex!(x, y, index:(index′ - 1))
index′
end


Expand Down
27 changes: 18 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end
@testset "to unit vector" begin
@testset "dimension checks" begin
U = UnitVector(3)
x = zeros(3) # incorrect
x = zeros(2) # incorrect
@test_throws ArgumentError U(x)
@test_throws ArgumentError transform(U, x)
@test_throws ArgumentError transform_and_logjac(U, x)
Expand All @@ -94,10 +94,18 @@ end
@testset "consistency checks" begin
for K in 1:10
t = UnitVector(K)
@test dimension(t) == K - 1
@test dimension(t) == K
if K > 1
test_transformation(t, y -> sum(abs2, y) ≈ 1,
vec_y = y -> y[1:(end-1)])
test_transformation(t, y -> sum(abs2, y) ≈ 1;
test_inverse = false, test_logjac = false)
x = randn(K)
y = transform(t, x)
x2 = @inferred inverse(t, y)
@test x2 ≈ y
ι = inverse(t)
@test y ≈ ι(y)
@test transform_and_logjac(t, x)[2] ≈ -sum(abs2, x) ./ 2
@test transform(t, zeros(K)) ≈ [1; zeros(K-1)]
end
end
end
Expand Down Expand Up @@ -242,7 +250,8 @@ end
x = randn(dimension(tn))
y = @inferred transform(tn, x)
@test y isa NamedTuple{(:a,:b,:c)}
@test inverse(tn, y) ≈ x
x2 = inverse(tn, y)
@test inverse(tn, transform(tn, x2)) ≈ x2
index = 0
ljacc = 0.0
for (i, t) in enumerate((t1, t2, t3))
Expand Down Expand Up @@ -284,7 +293,7 @@ end
x = randn(dimension(tt))
y = tt(x)
x′ = inverse(tt, y)
@test x ≈ x′
@test inverse(tt, transform(tt, x′)) ≈ x′
end
end

Expand Down Expand Up @@ -385,11 +394,11 @@ end
u = UnitVector(3), L = CorrCholeskyFactor(4),
δ = as((asℝ₋, as𝕀))))
function f(θ)
@unpack μ, σ, β, α, δ = θ
-(abs2(μ) + abs2(σ) + abs2(β) + α + δ[1] + δ[2])
@unpack μ, σ, β, α, u, δ = θ
-(abs2(μ) + abs2(σ) + abs2(β) + α + sum(u) + δ[1] + δ[2])
end
P = TransformedLogDensity(t, f)
x = zeros(dimension(t))
x = randn(dimension(t)) .* 1e-5
v = logdensity(P, x)
g = ForwardDiff.gradient(x -> logdensity(P, x), x)

Expand Down
13 changes: 8 additions & 5 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ automatic differentiation.
`test_inverse` determines whether the inverse is tested.
"""
function test_transformation(t::AbstractTransform, is_valid_y;
vec_y = identity, N = 1000, test_inverse = true)
vec_y = identity, N = 1000,
test_inverse = true, test_logjac = true)
for _ in 1:N
x = t isa ScalarTransform ? randn() : randn(dimension(t))
if t isa ScalarTransform
Expand All @@ -37,10 +38,12 @@ function test_transformation(t::AbstractTransform, is_valid_y;
@test t(x) == y # callable
y2, lj = @inferred transform_and_logjac(t, x)
@test y2 == y
if t isa ScalarTransform
@test lj ≈ AD_logjac(t, x)
else
@test lj ≈ AD_logjac(t, x, vec_y)
if test_logjac
if t isa ScalarTransform
@test lj ≈ AD_logjac(t, x)
else
@test lj ≈ AD_logjac(t, x, vec_y)
end
end
if test_inverse
x2 = inverse(t, y)
Expand Down