Skip to content

Add Riemannian manifold HMC #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This modularity means that different HMC variants can be easily constructed by c
- Unit metric: `UnitEuclideanMetric(dim)`
- Diagonal metric: `DiagEuclideanMetric(dim)`
- Dense metric: `DenseEuclideanMetric(dim)`
- Dense Riemannian metric: `DenseRiemannianMetric(size, G, ∂G∂θ)`

where `dim` is the dimensionality of the sampling space.

Expand Down
21 changes: 19 additions & 2 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@ module AdvancedHMC

using Statistics: mean, var, middle
using LinearAlgebra:
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
Symmetric,
UpperTriangular,
mul!,
ldiv!,
dot,
I,
diag,
cholesky,
UniformScaling,
logdet,
tr,
eigen,
diagm
using StatsFuns: logaddexp, logsumexp, loghalf
using Random: Random, AbstractRNG
using ProgressMeter: ProgressMeter
Expand Down Expand Up @@ -40,7 +52,7 @@ struct GaussianKinetic <: AbstractKinetic end
export GaussianKinetic

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric

include("hamiltonian.jl")
export Hamiltonian
Expand All @@ -50,6 +62,11 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
include("riemannian/integrator.jl")
export GeneralizedLeapfrog

include("riemannian/metric.jl")
export IdentityMap, SoftAbsMap, DenseRiemannianMetric

include("riemannian/hamiltonian.jl")

include("trajectory.jl")
export Trajectory,
HMCKernel,
Expand Down
298 changes: 33 additions & 265 deletions src/riemannian/hamiltonian.jl
Original file line number Diff line number Diff line change
@@ -1,257 +1,16 @@
using Random

### integrator.jl

import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step
using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size

"""
$(TYPEDEF)

Generalized leapfrog integrator with fixed step size `ϵ`.

# Fields

$(TYPEDFIELDS)
"""
struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
"Step size."
ϵ::T
n::Int
end
function Base.show(io::IO, l::GeneralizedLeapfrog)
return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))")
end

# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ
function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T}
dv = ∂H∂θ(h, θ, r)
return return_cache ? (dv, nothing) : dv
end

# TODO Make sure vectorization works
# TODO Check if tempering is valid
function step(
lf::GeneralizedLeapfrog{T},
h::Hamiltonian,
z::P,
n_steps::Int=1;
fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0
full_trajectory::Val{FullTraj}=Val(false),
) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj}
n_steps = abs(n_steps) # to support `n_steps < 0` cases

ϵ = fwd ? step_size(lf) : -step_size(lf)
ϵ = ϵ'

res = if FullTraj
Vector{P}(undef, n_steps)
else
z
end

for i in 1:n_steps
θ_init, r_init = z.θ, z.r
# Tempering
#r = temper(lf, r, (i=i, is_half=true), n_steps)
#! Eq (16) of Girolami & Calderhead (2011)
r_half = copy(r_init)
local cache
for j in 1:(lf.n)
# Reuse cache for the first iteration
if j == 1
(; value, gradient) = z.ℓπ
elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged)
retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true)
(; value, gradient) = retval
else # reuse cache
(; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache)
end
r_half = r_init - ϵ / 2 * gradient
# println("r_half: ", r_half)
end
#! Eq (17) of Girolami & Calderhead (2011)
θ_full = copy(θ_init)
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
for j in 1:(lf.n)
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
# println("θ_full :", θ_full)
end
#! Eq (18) of Girolami & Calderhead (2011)
(; value, gradient) = ∂H∂θ(h, θ_full, r_half)
r_full = r_half - ϵ / 2 * gradient
# println("r_full: ", r_full)
# Tempering
#r = temper(lf, r, (i=i, is_half=false), n_steps)
# Create a new phase point by caching the logdensity and gradient
z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient))
# Update result
if FullTraj
res[i] = z
else
res = z
end
if !isfinite(z)
# Remove undef
if FullTraj
res = res[isassigned.(Ref(res), 1:n_steps)]
end
break
end
# @assert false
end
return res
end

# TODO Make the order of θ and r consistent with neg_energy
∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ)
∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r)

### hamiltonian.jl

import AdvancedHMC: refresh, phasepoint
using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric

# To change L180 of hamiltonian.jl
function phasepoint(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
θ::AbstractVecOrMat{T},
h::Hamiltonian,
) where {T<:Real}
return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ))
end

# To change L191 of hamiltonian.jl
function refresh(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
::FullMomentumRefreshment,
h::Hamiltonian,
z::PhasePoint,
)
return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ))
end

# To change L215 of hamiltonian.jl
function refresh(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
ref::PartialMomentumRefreshment,
h::Hamiltonian,
z::PhasePoint,
)
return phasepoint(
h,
z.θ,
ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ),
)
end

### metric.jl

import AdvancedHMC: _rand
using AdvancedHMC: AbstractMetric
using LinearAlgebra: eigen, cholesky, Symmetric

abstract type AbstractRiemannianMetric <: AbstractMetric end

abstract type AbstractHessianMap end

struct IdentityMap <: AbstractHessianMap end

(::IdentityMap)(x) = x

struct SoftAbsMap{T} <: AbstractHessianMap
α::T
end

# TODO Register softabs with ReverseDiff
#! The definition of SoftAbs from Page 3 of Betancourt (2012)
function softabs(X, α=20.0)
F = eigen(X) # ReverseDiff cannot diff through `eigen`
Q = hcat(F.vectors)
λ = F.values
softabsλ = λ .* coth.(α * λ)
return Q * diagm(softabsλ) * Q', Q, λ, softabsλ
end

(map::SoftAbsMap)(x) = softabs(x, map.α)[1]

struct DenseRiemannianMetric{
T,
TM<:AbstractHessianMap,
A<:Union{Tuple{Int},Tuple{Int,Int}},
AV<:AbstractVecOrMat{T},
TG,
T∂G∂θ,
} <: AbstractRiemannianMetric
size::A
G::TG # TODO store G⁻¹ here instead
∂G∂θ::T∂G∂θ
map::TM
_temp::AV
end

# TODO Make dense mass matrix support matrix-mode parallel
function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat}
_temp = Vector{Float64}(undef, size[1])
return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp)
end
# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D))
# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D)
# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz)))
# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz)

# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)

Base.size(e::DenseRiemannianMetric) = e.size
Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim]
Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)")

function rand_momentum(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::DenseRiemannianMetric{T},
kinetic,
#! Eq (14) of Girolami & Calderhead (2011)
function ∂H∂r(
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic},
θ::AbstractVecOrMat,
) where {T}
r = _randn(rng, T, size(metric)...)
G⁻¹ = inv(metric.map(metric.G(θ)))
chol = cholesky(Symmetric(G⁻¹))
ldiv!(chol.U, r)
return r
end

### hamiltonian.jl

import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r
using LinearAlgebra: logabsdet, tr

# QUES Do we want to change everything to position dependent by default?
# Add θ to ∂H∂r for DenseRiemannianMetric
function phasepoint(
h::Hamiltonian{<:DenseRiemannianMetric},
θ::T,
r::T;
ℓπ=∂H∂θ(h, θ),
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
) where {T<:AbstractVecOrMat}
return PhasePoint(θ, r, ℓπ, ℓκ)
end

# Negative kinetic energy
#! Eq (13) of Girolami & Calderhead (2011)
function neg_energy(
h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T
) where {T<:AbstractVecOrMat}
G = h.metric.map(h.metric.G(θ))
D = size(G, 1)
# Need to consider the normalizing term as it is no longer same for different θs
logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined
mul!(h.metric._temp, inv(G), r)
return -logZ - dot(r, h.metric._temp) / 2
r::AbstractVecOrMat,
)
H = h.metric.G(θ)
G = h.metric.map(H)
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
end

# QUES L31 of hamiltonian.jl now reads a bit weird (semantically)
function ∂H∂θ(
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}},
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic},
θ::AbstractVecOrMat{T},
r::AbstractVecOrMat{T},
) where {T}
Expand Down Expand Up @@ -293,14 +52,14 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat}
end

function ∂H∂θ(
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
θ::AbstractVecOrMat{T},
r::AbstractVecOrMat{T},
) where {T}
return ∂H∂θ_cache(h, θ, r)
end
function ∂H∂θ_cache(
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
θ::AbstractVecOrMat{T},
r::AbstractVecOrMat{T};
return_cache=false,
Expand Down Expand Up @@ -342,17 +101,26 @@ function ∂H∂θ_cache(
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv
end

#! Eq (14) of Girolami & Calderhead (2011)
function ∂H∂r(
h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat
)
H = h.metric.G(θ)
# if !all(isfinite, H)
# println("θ: ", θ)
# println("H: ", H)
# end
G = h.metric.map(H)
# return inv(G) * r
# println("G \ r: ", G \ r)
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
# QUES Do we want to change everything to position dependent by default?
# Add θ to ∂H∂r for DenseRiemannianMetric
function phasepoint(
h::Hamiltonian{<:DenseRiemannianMetric},
θ::T,
r::T;
ℓπ=∂H∂θ(h, θ),
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
) where {T<:AbstractVecOrMat}
return PhasePoint(θ, r, ℓπ, ℓκ)
end

#! Eq (13) of Girolami & Calderhead (2011)
function neg_energy(
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T
) where {T<:AbstractVecOrMat}
G = h.metric.map(h.metric.G(θ))
D = size(G, 1)
# Need to consider the normalizing term as it is no longer same for different θs
logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined
mul!(h.metric._temp, inv(G), r)
return -logZ - dot(r, h.metric._temp) / 2
end
Loading
Loading