From 7883d497f430c03182598705f6e63ac858e57300 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Mon, 5 May 2025 23:57:36 +0800 Subject: [PATCH 1/2] Add Riemannian manifold HMC --- docs/src/api.md | 1 + research/tests/runtests.jl | 2 - src/AdvancedHMC.jl | 21 ++- src/riemannian/hamiltonian.jl | 298 ++++------------------------------ src/riemannian/metric.jl | 63 +++++++ src/sampler.jl | 6 +- src/trajectory.jl | 7 +- test/Project.toml | 2 + test/demo.jl | 10 +- test/integrator.jl | 5 +- test/riemannian.jl | 120 ++++++++++++-- test/trajectory.jl | 65 +++----- 12 files changed, 268 insertions(+), 332 deletions(-) create mode 100644 src/riemannian/metric.jl diff --git a/docs/src/api.md b/docs/src/api.md index 54b5939d..a1c488fb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,6 +8,7 @@ This modularity means that different HMC variants can be easily constructed by c - Unit metric: `UnitEuclideanMetric(dim)` - Diagonal metric: `DiagEuclideanMetric(dim)` - Dense metric: `DenseEuclideanMetric(dim)` + - Dense Riemannian metric: `DenseRiemannianMetric(size, G, ∂G∂θ)` where `dim` is the dimensionality of the sampling space. diff --git a/research/tests/runtests.jl b/research/tests/runtests.jl index da95548d..0633bc59 100644 --- a/research/tests/runtests.jl +++ b/research/tests/runtests.jl @@ -5,11 +5,9 @@ Pkg.add(; url="https://github.com/chalk-lab/MCMCLogDensityProblems.jl.git"); # include the source code for experimental HMC include("../src/relativistic_hmc.jl") -include("../src/riemannian_hmc.jl") # include the tests for experimental HMC include("relativistic_hmc.jl") -include("riemannian_hmc.jl") Comonicon.@main function runtests(patterns...; dry::Bool=false) return retest(patterns...; dry=dry, verbose=Inf) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index b25710d5..41d934e6 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -2,7 +2,19 @@ module AdvancedHMC using Statistics: mean, var, middle using LinearAlgebra: - Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling + Symmetric, + UpperTriangular, + mul!, + ldiv!, + dot, + I, + diag, + cholesky, + UniformScaling, + logdet, + tr, + eigen, + diagm using StatsFuns: logaddexp, logsumexp, loghalf using Random: Random, AbstractRNG using ProgressMeter: ProgressMeter @@ -40,7 +52,7 @@ struct GaussianKinetic <: AbstractKinetic end export GaussianKinetic include("metric.jl") -export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric +export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric include("hamiltonian.jl") export Hamiltonian @@ -50,6 +62,11 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog include("riemannian/integrator.jl") export GeneralizedLeapfrog +include("riemannian/metric.jl") +export IdentityMap, SoftAbsMap, DenseRiemannianMetric + +include("riemannian/hamiltonian.jl") + include("trajectory.jl") export Trajectory, HMCKernel, diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index feddb411..f8acc797 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,257 +1,16 @@ -using Random - -### integrator.jl - -import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step -using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size - -""" -$(TYPEDEF) - -Generalized leapfrog integrator with fixed step size `ϵ`. - -# Fields - -$(TYPEDFIELDS) -""" -struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} - "Step size." - ϵ::T - n::Int -end -function Base.show(io::IO, l::GeneralizedLeapfrog) - return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))") -end - -# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T} - dv = ∂H∂θ(h, θ, r) - return return_cache ? (dv, nothing) : dv -end - -# TODO Make sure vectorization works -# TODO Check if tempering is valid -function step( - lf::GeneralizedLeapfrog{T}, - h::Hamiltonian, - z::P, - n_steps::Int=1; - fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 - full_trajectory::Val{FullTraj}=Val(false), -) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} - n_steps = abs(n_steps) # to support `n_steps < 0` cases - - ϵ = fwd ? step_size(lf) : -step_size(lf) - ϵ = ϵ' - - res = if FullTraj - Vector{P}(undef, n_steps) - else - z - end - - for i in 1:n_steps - θ_init, r_init = z.θ, z.r - # Tempering - #r = temper(lf, r, (i=i, is_half=true), n_steps) - #! Eq (16) of Girolami & Calderhead (2011) - r_half = copy(r_init) - local cache - for j in 1:(lf.n) - # Reuse cache for the first iteration - if j == 1 - (; value, gradient) = z.ℓπ - elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) - retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) - (; value, gradient) = retval - else # reuse cache - (; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) - end - r_half = r_init - ϵ / 2 * gradient - # println("r_half: ", r_half) - end - #! Eq (17) of Girolami & Calderhead (2011) - θ_full = copy(θ_init) - term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop - for j in 1:(lf.n) - θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) - # println("θ_full :", θ_full) - end - #! Eq (18) of Girolami & Calderhead (2011) - (; value, gradient) = ∂H∂θ(h, θ_full, r_half) - r_full = r_half - ϵ / 2 * gradient - # println("r_full: ", r_full) - # Tempering - #r = temper(lf, r, (i=i, is_half=false), n_steps) - # Create a new phase point by caching the logdensity and gradient - z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) - # Update result - if FullTraj - res[i] = z - else - res = z - end - if !isfinite(z) - # Remove undef - if FullTraj - res = res[isassigned.(Ref(res), 1:n_steps)] - end - break - end - # @assert false - end - return res -end - -# TODO Make the order of θ and r consistent with neg_energy -∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) -∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) - -### hamiltonian.jl - -import AdvancedHMC: refresh, phasepoint -using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric - -# To change L180 of hamiltonian.jl -function phasepoint( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - θ::AbstractVecOrMat{T}, - h::Hamiltonian, -) where {T<:Real} - return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) -end - -# To change L191 of hamiltonian.jl -function refresh( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - ::FullMomentumRefreshment, - h::Hamiltonian, - z::PhasePoint, -) - return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ)) -end - -# To change L215 of hamiltonian.jl -function refresh( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - ref::PartialMomentumRefreshment, - h::Hamiltonian, - z::PhasePoint, -) - return phasepoint( - h, - z.θ, - ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ), - ) -end - -### metric.jl - -import AdvancedHMC: _rand -using AdvancedHMC: AbstractMetric -using LinearAlgebra: eigen, cholesky, Symmetric - -abstract type AbstractRiemannianMetric <: AbstractMetric end - -abstract type AbstractHessianMap end - -struct IdentityMap <: AbstractHessianMap end - -(::IdentityMap)(x) = x - -struct SoftAbsMap{T} <: AbstractHessianMap - α::T -end - -# TODO Register softabs with ReverseDiff -#! The definition of SoftAbs from Page 3 of Betancourt (2012) -function softabs(X, α=20.0) - F = eigen(X) # ReverseDiff cannot diff through `eigen` - Q = hcat(F.vectors) - λ = F.values - softabsλ = λ .* coth.(α * λ) - return Q * diagm(softabsλ) * Q', Q, λ, softabsλ -end - -(map::SoftAbsMap)(x) = softabs(x, map.α)[1] - -struct DenseRiemannianMetric{ - T, - TM<:AbstractHessianMap, - A<:Union{Tuple{Int},Tuple{Int,Int}}, - AV<:AbstractVecOrMat{T}, - TG, - T∂G∂θ, -} <: AbstractRiemannianMetric - size::A - G::TG # TODO store G⁻¹ here instead - ∂G∂θ::T∂G∂θ - map::TM - _temp::AV -end - -# TODO Make dense mass matrix support matrix-mode parallel -function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} - _temp = Vector{Float64}(undef, size[1]) - return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) -end -# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D)) -# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D) -# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz))) -# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz) - -# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹) - -Base.size(e::DenseRiemannianMetric) = e.size -Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] -Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") - -function rand_momentum( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - metric::DenseRiemannianMetric{T}, - kinetic, +#! Eq (14) of Girolami & Calderhead (2011) +function ∂H∂r( + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, θ::AbstractVecOrMat, -) where {T} - r = _randn(rng, T, size(metric)...) - G⁻¹ = inv(metric.map(metric.G(θ))) - chol = cholesky(Symmetric(G⁻¹)) - ldiv!(chol.U, r) - return r -end - -### hamiltonian.jl - -import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r -using LinearAlgebra: logabsdet, tr - -# QUES Do we want to change everything to position dependent by default? -# Add θ to ∂H∂r for DenseRiemannianMetric -function phasepoint( - h::Hamiltonian{<:DenseRiemannianMetric}, - θ::T, - r::T; - ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), -) where {T<:AbstractVecOrMat} - return PhasePoint(θ, r, ℓπ, ℓκ) -end - -# Negative kinetic energy -#! Eq (13) of Girolami & Calderhead (2011) -function neg_energy( - h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T -) where {T<:AbstractVecOrMat} - G = h.metric.map(h.metric.G(θ)) - D = size(G, 1) - # Need to consider the normalizing term as it is no longer same for different θs - logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined - mul!(h.metric._temp, inv(G), r) - return -logZ - dot(r, h.metric._temp) / 2 + r::AbstractVecOrMat, +) + H = h.metric.G(θ) + G = h.metric.map(H) + return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't end -# QUES L31 of hamiltonian.jl now reads a bit weird (semantically) function ∂H∂θ( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}, ) where {T} @@ -293,14 +52,14 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} end function ∂H∂θ( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}, ) where {T} return ∂H∂θ_cache(h, θ, r) end function ∂H∂θ_cache( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}; return_cache=false, @@ -342,17 +101,26 @@ function ∂H∂θ_cache( return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv end -#! Eq (14) of Girolami & Calderhead (2011) -function ∂H∂r( - h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat -) - H = h.metric.G(θ) - # if !all(isfinite, H) - # println("θ: ", θ) - # println("H: ", H) - # end - G = h.metric.map(H) - # return inv(G) * r - # println("G \ r: ", G \ r) - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't +# QUES Do we want to change everything to position dependent by default? +# Add θ to ∂H∂r for DenseRiemannianMetric +function phasepoint( + h::Hamiltonian{<:DenseRiemannianMetric}, + θ::T, + r::T; + ℓπ=∂H∂θ(h, θ), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), +) where {T<:AbstractVecOrMat} + return PhasePoint(θ, r, ℓπ, ℓκ) +end + +#! Eq (13) of Girolami & Calderhead (2011) +function neg_energy( + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T +) where {T<:AbstractVecOrMat} + G = h.metric.map(h.metric.G(θ)) + D = size(G, 1) + # Need to consider the normalizing term as it is no longer same for different θs + logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined + mul!(h.metric._temp, inv(G), r) + return -logZ - dot(r, h.metric._temp) / 2 end diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl new file mode 100644 index 00000000..41d11127 --- /dev/null +++ b/src/riemannian/metric.jl @@ -0,0 +1,63 @@ +abstract type AbstractRiemannianMetric <: AbstractMetric end + +abstract type AbstractHessianMap end + +struct IdentityMap <: AbstractHessianMap end + +(::IdentityMap)(x) = x + +struct SoftAbsMap{T} <: AbstractHessianMap + α::T +end + +function softabs(X, α=20.0) + F = eigen(X) # ReverseDiff cannot diff through `eigen` + Q = hcat(F.vectors) + λ = F.values + softabsλ = λ .* coth.(α * λ) + return Q * diagm(softabsλ) * Q', Q, λ, softabsλ +end + +(map::SoftAbsMap)(x) = softabs(x, map.α)[1] + +# TODO Register softabs with ReverseDiff +#! The definition of SoftAbs from Page 3 of Betancourt (2012) +struct DenseRiemannianMetric{ + T, + TM<:AbstractHessianMap, + A<:Union{Tuple{Int},Tuple{Int,Int}}, + AV<:AbstractVecOrMat{T}, + TG, + T∂G∂θ, +} <: AbstractRiemannianMetric + size::A + G::TG # TODO store G⁻¹ here instead + ∂G∂θ::T∂G∂θ + map::TM + _temp::AV +end + +# TODO Make dense mass matrix support matrix-mode parallel +function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) + _temp = Vector{Float64}(undef, first(size)) + return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) +end + +Base.size(e::DenseRiemannianMetric) = e.size +Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] +function Base.show(io::IO, drm::DenseRiemannianMetric) + return print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric") +end + +function rand_momentum( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + metric::DenseRiemannianMetric{T}, + kinetic, + θ::AbstractVecOrMat, +) where {T} + r = _randn(rng, T, size(metric)...) + G⁻¹ = inv(metric.map(metric.G(θ))) + chol = cholesky(Symmetric(G⁻¹)) + ldiv!(chol.U, r) + return r +end diff --git a/src/sampler.jl b/src/sampler.jl index c0a42681..e0138819 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -117,7 +117,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=pm_next!, + (pm_next!)::Function=(pm_next!), ) return sample( Random.default_rng(), @@ -130,7 +130,7 @@ function sample( drop_warmup=drop_warmup, verbose=verbose, progress=progress, - (pm_next!)=pm_next!, + (pm_next!)=(pm_next!), ) end @@ -168,7 +168,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=pm_next!, + (pm_next!)::Function=(pm_next!), ) where {T<:AbstractVecOrMat{<:AbstractFloat}} @assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase." # Prepare containers to store sampling results diff --git a/src/trajectory.jl b/src/trajectory.jl index a7680760..aa8c90ca 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -133,8 +133,9 @@ $(TYPEDEF) Slice sampler for the starting single leaf tree. Slice variable is initialized. """ -SliceTS(rng::AbstractRNG, z0::PhasePoint) = +function SliceTS(rng::AbstractRNG, z0::PhasePoint) SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) +end """ $(TYPEDEF) @@ -278,7 +279,7 @@ function transition( hamiltonian_energy=H, hamiltonian_energy_error=H - H0, # check numerical error in proposed phase point. - numerical_error=!all(isfinite, H′), + numerical_error=(!all(isfinite, H′)), ), stat(τ.integrator), ) @@ -717,7 +718,7 @@ function transition( ( n_steps=tree.nα, is_accept=true, - acceptance_rate=tree.sum_α / tree.nα, + acceptance_rate=(tree.sum_α / tree.nα), log_density=zcand.ℓπ.value, hamiltonian_energy=H, hamiltonian_energy_error=H - H0, diff --git a/test/Project.toml b/test/Project.toml index f3821481..3e2a793b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,12 +7,14 @@ Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MCMCLogDensityProblems = "8a639fad-7908-4fe4-8003-906e9297f002" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" diff --git a/test/demo.jl b/test/demo.jl index 98315daa..c9010a7f 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -10,8 +10,9 @@ using LinearAlgebra, ADTypes LogDensityProblems.logdensity(p::DemoProblem, θ) = logpdf(MvNormal(zeros(p.dim), I), θ) LogDensityProblems.dimension(p::DemoProblem) = p.dim - LogDensityProblems.capabilities(::Type{DemoProblem}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{DemoProblem}) = LogDensityProblems.LogDensityOrder{ + 0 + }() # Choose parameter dimensionality and initial parameter value D = 10 @@ -66,8 +67,9 @@ end return -((1 - p.μ) / p.σ)^2 end LogDensityProblems.dimension(::DemoProblemComponentArrays) = 2 - LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = LogDensityProblems.LogDensityOrder{ + 0 + }() ℓπ = DemoProblemComponentArrays() diff --git a/test/integrator.jl b/test/integrator.jl index b9eb1407..f5a3dbea 100644 --- a/test/integrator.jl +++ b/test/integrator.jl @@ -112,8 +112,9 @@ using Statistics: mean LogDensityProblems.logdensity(::NegU, x) = -dot(x, x) / 2 LogDensityProblems.dimension(d::NegU) = d.dim - LogDensityProblems.capabilities(::Type{NegU}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{NegU}) = LogDensityProblems.LogDensityOrder{ + 0 + }() negU = NegU(1) diff --git a/test/riemannian.jl b/test/riemannian.jl index 67b1cad0..0cfcb823 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -1,28 +1,63 @@ -using ReTest, AdvancedHMC - -include("../src/riemannian_hmc.jl") -include("../src/riemannian_hmc_utility.jl") - +using ReTest, Random +using AdvancedHMC, ForwardDiff, AbstractMCMC +using LinearAlgebra +using MCMCLogDensityProblems using FiniteDiff: finite_difference_gradient, finite_difference_hessian, finite_difference_jacobian -using Distributions: MvNormal -using AdvancedHMC: neg_energy, energy +using AdvancedHMC: neg_energy, energy, ∂H∂θ, ∂H∂r + +# Fisher information metric +function gen_∂G∂θ_fwd(Vfunc, x; f=identity) + _Hfunc = gen_hess_fwd(Vfunc, x) + Hfunc = x -> _Hfunc(x)[3] + # QUES What's the best output format of this function? + cfg = ForwardDiff.JacobianConfig(Hfunc, x) + d = length(x) + out = zeros(eltype(x), d^2, d) + return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) + return out # default output shape [∂H∂x₁; ∂H∂x₂; ...] +end + +function gen_hess_fwd(func, x::AbstractVector) + function hess(x::AbstractVector) + return nothing, nothing, ForwardDiff.hessian(func, x) + end + return hess +end + +function reshape_∂G∂θ(H) + d = size(H, 2) + return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3) +end -# Taken from https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/test/finitedifftests.jl -δ(a, b) = maximum(abs.(a - b)) +function prepare_sample(ℓπ, initial_θ, λ) + Vfunc = x -> -ℓπ(x) + _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, initial_θ) # x -> (value, gradient, hessian) + Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug -@testset "Riemannian" begin - hps = (; λ=1e-2, α=20.0, ϵ=0.1, n=6, L=8) + fstabilize = H -> H + λ * I + Gfunc = x -> begin + H = fstabilize(Hfunc(x)[3]) + all(isfinite, H) ? H : diagm(ones(length(x))) + end + _∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize) + ∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x)) + + return Vfunc, Hfunc, Gfunc, ∂G∂θfunc +end +@testset "Constructors tests" begin + δ(a, b) = maximum(abs.(a - b)) @testset "$(nameof(typeof(target)))" for target in [HighDimGaussian(2), Funnel()] rng = MersenneTwister(1110) + λ = 1e-2 θ₀ = rand(rng, dim(target)) ℓπ = MCMCLogDensityProblems.gen_logpdf(target) ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀) - Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample_target(hps, θ₀, ℓπ) + Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample(ℓπ, θ₀, λ) D = dim(target) # ==2 for this test x = zeros(D) # randn(rng, D) @@ -36,7 +71,7 @@ using AdvancedHMC: neg_energy, energy end @testset "$(nameof(typeof(hessmap)))" for hessmap in - [IdentityMap(), SoftAbsMap(hps.α)] + [IdentityMap(), SoftAbsMap(20.0)] metric = DenseRiemannianMetric((D,), Gfunc, ∂G∂θfunc, hessmap) kinetic = GaussianKinetic() hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) @@ -67,3 +102,62 @@ using AdvancedHMC: neg_energy, energy end end end + +@testset "Multi variate Normal with Riemannian HMC" begin + # Set the number of samples to draw and warmup iterations + n_samples = 2_000 + rng = MersenneTwister(1110) + initial_θ = rand(rng, D) + λ = 1e-2 + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + # Define a Hamiltonian system + metric = DenseRiemannianMetric((D,), G, ∂G∂θ) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + # Define a leapfrog solver, with the initial step size chosen heuristically + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 6) + + # Define an HMC sampler with the following components + # - multinomial sampling scheme, + # - generalised No-U-Turn criteria, and + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) + + # Run the sampler to draw samples from the specified Gaussian, where + # - `samples` will store the samples + # - `stats` will store diagnostic statistics for each sample + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true) + @test length(samples) == n_samples + @test length(stats) == n_samples +end + +@testset "Multi variate Normal with Riemannian HMC softabs metric" begin + # Set the number of samples to draw and warmup iterations + n_samples = 2_000 + rng = MersenneTwister(1110) + initial_θ = rand(rng, D) + λ = 1e-2 + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + + # Define a Hamiltonian system + metric = DenseRiemannianMetric((D,), G, ∂G∂θ, λSoftAbsMap(20.0)) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + # Define a leapfrog solver, with the initial step size chosen heuristically + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 6) + + # Define an HMC sampler with the following components + # - multinomial sampling scheme, + # - generalised No-U-Turn criteria, and + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) + + # Run the sampler to draw samples from the specified Gaussian, where + # - `samples` will store the samples + # - `stats` will store diagnostic statistics for each sample + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true) + @test length(samples) == n_samples + @test length(stats) == n_samples +end diff --git a/test/trajectory.jl b/test/trajectory.jl index 403fd446..4bf0ac4d 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -257,46 +257,35 @@ end traj_r = hcat(map(z -> z.r, traj_z)...) rho = cumsum(traj_r; dims=2) - ts_hand_isturn_fwd = - hand_isturn.( - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) - ts_ahmc_isturn_fwd = - ahmc_isturn.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_fwd = hand_isturn.( + Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) + ) + ts_ahmc_isturn_fwd = ahmc_isturn.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_hand_isturn_generalised_fwd = - hand_isturn_generalised.( - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) - ts_ahmc_isturn_generalised_fwd = - ahmc_isturn_generalised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_generalised_fwd = hand_isturn_generalised.( + Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) + ) + ts_ahmc_isturn_generalised_fwd = ahmc_isturn_generalised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_ahmc_isturn_strictgeneralised_fwd = - ahmc_isturn_strictgeneralised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_ahmc_isturn_strictgeneralised_fwd = ahmc_isturn_strictgeneralised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) check_subtree_u_turns.( Ref(h), Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)] From 1fcdd0982625ee4e04416fcf237c82d22a1870ac Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Sat, 10 May 2025 17:57:47 +0800 Subject: [PATCH 2/2] Format --- src/trajectory.jl | 2 +- test/demo.jl | 10 +++---- test/integrator.jl | 5 ++-- test/trajectory.jl | 65 +++++++++++++++++++++++++++------------------- 4 files changed, 45 insertions(+), 37 deletions(-) diff --git a/src/trajectory.jl b/src/trajectory.jl index aa8c90ca..2764993d 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -134,7 +134,7 @@ Slice sampler for the starting single leaf tree. Slice variable is initialized. """ function SliceTS(rng::AbstractRNG, z0::PhasePoint) - SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) + return SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) end """ diff --git a/test/demo.jl b/test/demo.jl index c9010a7f..98315daa 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -10,9 +10,8 @@ using LinearAlgebra, ADTypes LogDensityProblems.logdensity(p::DemoProblem, θ) = logpdf(MvNormal(zeros(p.dim), I), θ) LogDensityProblems.dimension(p::DemoProblem) = p.dim - LogDensityProblems.capabilities(::Type{DemoProblem}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{DemoProblem}) = + LogDensityProblems.LogDensityOrder{0}() # Choose parameter dimensionality and initial parameter value D = 10 @@ -67,9 +66,8 @@ end return -((1 - p.μ) / p.σ)^2 end LogDensityProblems.dimension(::DemoProblemComponentArrays) = 2 - LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = + LogDensityProblems.LogDensityOrder{0}() ℓπ = DemoProblemComponentArrays() diff --git a/test/integrator.jl b/test/integrator.jl index f5a3dbea..b9eb1407 100644 --- a/test/integrator.jl +++ b/test/integrator.jl @@ -112,9 +112,8 @@ using Statistics: mean LogDensityProblems.logdensity(::NegU, x) = -dot(x, x) / 2 LogDensityProblems.dimension(d::NegU) = d.dim - LogDensityProblems.capabilities(::Type{NegU}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{NegU}) = + LogDensityProblems.LogDensityOrder{0}() negU = NegU(1) diff --git a/test/trajectory.jl b/test/trajectory.jl index 4bf0ac4d..403fd446 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -257,35 +257,46 @@ end traj_r = hcat(map(z -> z.r, traj_z)...) rho = cumsum(traj_r; dims=2) - ts_hand_isturn_fwd = hand_isturn.( - Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) - ) - ts_ahmc_isturn_fwd = ahmc_isturn.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_fwd = + hand_isturn.( + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) + ts_ahmc_isturn_fwd = + ahmc_isturn.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_hand_isturn_generalised_fwd = hand_isturn_generalised.( - Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) - ) - ts_ahmc_isturn_generalised_fwd = ahmc_isturn_generalised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_generalised_fwd = + hand_isturn_generalised.( + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) + ts_ahmc_isturn_generalised_fwd = + ahmc_isturn_generalised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_ahmc_isturn_strictgeneralised_fwd = ahmc_isturn_strictgeneralised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_ahmc_isturn_strictgeneralised_fwd = + ahmc_isturn_strictgeneralised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) check_subtree_u_turns.( Ref(h), Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)]