From 419c7d96ad084f1de64cdd3b12e70893bc9ff6ea Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Sun, 27 Apr 2025 21:46:09 +0800 Subject: [PATCH 1/8] Add riemannian manifold HMC --- Project.toml | 6 ++ src/AdvancedHMC.jl | 13 +++- src/riemannian/hamiltonian.jl | 126 ++++++++++++++++++++++++++++++++++ src/riemannian/metric.jl | 112 ++++++++++++++++++++++++++++++ test/riemannian.jl | 63 +++++++++++++++++ 5 files changed, 318 insertions(+), 2 deletions(-) create mode 100644 src/riemannian/hamiltonian.jl create mode 100644 src/riemannian/metric.jl create mode 100644 test/riemannian.jl diff --git a/Project.toml b/Project.toml index c0f46d18..2bd520a2 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.7.1" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" @@ -15,6 +16,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +VecTargets = "8a639fad-7908-4fe4-8003-906e9297f002" [weakdeps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -23,6 +25,9 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +[sources] +VecTargets = {rev = "main", url = "https://github.com/chalk-lab/VecTargets.jl"} + [extensions] AdvancedHMCADTypesExt = "ADTypes" AdvancedHMCComponentArraysExt = "ComponentArrays" @@ -37,6 +42,7 @@ ArgCheck = "1, 2" ComponentArrays = "0.15" CUDA = "3, 4, 5" DocStringExtensions = "0.8, 0.9" +ForwardDiff = "0.10.38" LinearAlgebra = "<0.1, 1" LogDensityProblems = "2" LogDensityProblemsAD = "1" diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 51cf5f57..a6b3e4c0 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -2,7 +2,7 @@ module AdvancedHMC using Statistics: mean, var, middle using LinearAlgebra: - Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling + Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, diagm, cholesky, UniformScaling, logdet, tr using StatsFuns: logaddexp, logsumexp, loghalf using Random: Random, AbstractRNG using ProgressMeter: ProgressMeter @@ -19,8 +19,12 @@ using LogDensityProblemsAD: LogDensityProblemsAD using AbstractMCMC: AbstractMCMC, LogDensityModel +using VecTargets: VecTargets + import StatsBase: sample +using ForwardDiff: ForwardDiff + const DEFAULT_FLOAT_TYPE = typeof(float(0)) include("utilities.jl") @@ -40,7 +44,7 @@ struct GaussianKinetic <: AbstractKinetic end export GaussianKinetic include("metric.jl") -export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric +export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric include("hamiltonian.jl") export Hamiltonian @@ -50,6 +54,11 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog include("riemannian/integrator.jl") export GeneralizedLeapfrog +include("riemannian/metric.jl") +export IdentityMap, SoftAbsMap, DenseRiemannianMetric + +include("riemannian/hamiltonian.jl") + include("trajectory.jl") export Trajectory, HMCKernel, diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl new file mode 100644 index 00000000..2b939698 --- /dev/null +++ b/src/riemannian/hamiltonian.jl @@ -0,0 +1,126 @@ +#! Eq (14) of Girolami & Calderhead (2011) +function ∂H∂r( + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, θ::AbstractVecOrMat, r::AbstractVecOrMat +) + H = h.metric.G(θ) + G = h.metric.map(H) + return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't +end + +function ∂H∂θ( + h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, + θ::AbstractVecOrMat{T}, + r::AbstractVecOrMat{T}, +) where {T} + ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) + G = h.metric.map(h.metric.G(θ)) + invG = inv(G) + ∂G∂θ = h.metric.∂G∂θ(θ) + d = length(∂ℓπ∂θ) + return DualValue( + ℓπ, + #! Eq (15) of Girolami & Calderhead (2011) + -mapreduce(vcat, 1:d) do i + ∂G∂θᵢ = ∂G∂θ[:, :, i] + ∂ℓπ∂θ[i] - 1 / 2 * tr(invG * ∂G∂θᵢ) + 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r + # Gr = G \ r + # ∂ℓπ∂θ[i] - 1 / 2 * tr(G \ ∂G∂θᵢ) + 1 / 2 * Gr' * ∂G∂θᵢ * Gr + # 1 / 2 * tr(invG * ∂G∂θᵢ) + # 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r + end, + ) +end + +# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29 +#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative" +dsoftabsdλ(α, λ) = coth(α * λ) + λ * α * -csch(λ * α)^2 + +#! J as defined in middle of the right column of Page 3 of Betancourt (2012) +function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} + d = length(λ) + J = Matrix{T}(undef, d, d) + for i in 1:d, j in 1:d + J[i, j] = if (λ[i] == λ[j]) + dsoftabsdλ(α, λ[i]) + else + ((λ[i] * coth(α * λ[i]) - λ[j] * coth(α * λ[j])) / (λ[i] - λ[j])) + end + end + return J +end + +function ∂H∂θ( + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, + θ::AbstractVecOrMat{T}, + r::AbstractVecOrMat{T}, +) where {T} + return ∂H∂θ_cache(h, θ, r) +end +function ∂H∂θ_cache( + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, + θ::AbstractVecOrMat{T}, + r::AbstractVecOrMat{T}; + return_cache=false, + cache=nothing, +) where {T} + # Terms that only dependent on θ can be cached in θ-unchanged loops + if isnothing(cache) + ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) + H = h.metric.G(θ) + ∂H∂θ = h.metric.∂G∂θ(θ) + + G, Q, λ, softabsλ = softabs(H, h.metric.map.α) + + R = diagm(1 ./ softabsλ) + + # softabsΛ = diagm(softabsλ) + # M = inv(softabsΛ) * Q' * r + # M = R * Q' * r # equiv to above but avoid inv + + J = make_J(λ, h.metric.map.α) + + #! Based on the two equations from the right column of Page 3 of Betancourt (2012) + term_1_cached = Q * (R .* J) * Q' + else + ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache + end + d = length(∂ℓπ∂θ) + D = diagm((Q' * r) ./ softabsλ) + term_2_cached = Q * D * J * D * Q' + g = + isdiag ? + -(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ)) : + -mapreduce(vcat, 1:d) do i + ∂H∂θᵢ = ∂H∂θ[:, :, i] + # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) + # NOTE Some further optimization can be done here: cache the 1st product all together + ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly + end + + dv = DualValue(ℓπ, g) + return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv +end + +# QUES Do we want to change everything to position dependent by default? +# Add θ to ∂H∂r for DenseRiemannianMetric +function phasepoint( + h::Hamiltonian{<:DenseRiemannianMetric}, + θ::T, + r::T; + ℓπ=∂H∂θ(h, θ), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), +) where {T<:AbstractVecOrMat} + return PhasePoint(θ, r, ℓπ, ℓκ) +end + +#! Eq (13) of Girolami & Calderhead (2011) +function neg_energy( + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T +) where {T<:AbstractVecOrMat} + G = h.metric.map(h.metric.G(θ)) + D = size(G, 1) + # Need to consider the normalizing term as it is no longer same for different θs + logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined + mul!(h.metric._temp, inv(G), r) + return -logZ - dot(r, h.metric._temp) / 2 +end diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl new file mode 100644 index 00000000..e79ad7f9 --- /dev/null +++ b/src/riemannian/metric.jl @@ -0,0 +1,112 @@ +abstract type AbstractRiemannianMetric <: AbstractMetric end + +abstract type AbstractHessianMap end + +struct IdentityMap <: AbstractHessianMap end + +(::IdentityMap)(x) = x + +struct SoftAbsMap{T} <: AbstractHessianMap + α::T +end + +function softabs(X, α=20.0) + F = eigen(X) # ReverseDiff cannot diff through `eigen` + Q = hcat(F.vectors) + λ = F.values + softabsλ = λ .* coth.(α * λ) + return Q * diagm(softabsλ) * Q', Q, λ, softabsλ +end + +(map::SoftAbsMap)(x) = softabs(x, map.α)[1] + +# TODO Register softabs with ReverseDiff +#! The definition of SoftAbs from Page 3 of Betancourt (2012) +struct DenseRiemannianMetric{ + T, + TM<:AbstractHessianMap, + A<:Union{Tuple{Int},Tuple{Int,Int}}, + AV<:AbstractVecOrMat{T}, + TG, + T∂G∂θ, +} <: AbstractRiemannianMetric + size::A + G::TG # TODO store G⁻¹ here instead + ∂G∂θ::T∂G∂θ + map::TM + _temp::AV +end + +# TODO Make dense mass matrix support matrix-mode parallel +function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) + _temp = Vector{Float64}(undef, first(size)) + return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) +end + +# Convenient constructor +function DenseRiemannianMetric(size, ℓπ, initial_θ, λ, map = IdentityMap()) + _Hfunc = VecTargets.gen_hess(x -> -ℓπ(x), initial_θ) # x -> (value, gradient, hessian) + Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug + + fstabilize = H -> H + λ * I + Gfunc = x -> begin + H = fstabilize(Hfunc(x)[3]) + all(isfinite, H) ? H : diagm(ones(length(x))) + end + _∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize) + ∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x)) + + _temp = Vector{Float64}(undef, first(size)) + + return DenseRiemannianMetric(size, Gfunc, ∂G∂θfunc, map, _temp) +end + +function gen_hess_fwd(func, x::AbstractVector) + function hess(x::AbstractVector) + return nothing, nothing, ForwardDiff.hessian(func, x) + end + return hess +end + +#= possible integrate DI for AD-independent fisher information metric +function gen_∂G∂θ_rev(Vfunc, x; f=identity) + _Hfunc = VecTargets.gen_hess(Vfunc, ReverseDiff.track.(x)) + Hfunc = x -> _Hfunc(x)[3] + # QUES What's the best output format of this function? + return x -> ReverseDiff.jacobian(x -> f(Hfunc(x)), x) # default output shape [∂H∂x₁; ∂H∂x₂; ...] +end +=# + +# Fisher information metric +function gen_∂G∂θ_fwd(Vfunc, x; f=identity) + _Hfunc = gen_hess_fwd(Vfunc, x) + Hfunc = x -> _Hfunc(x)[3] + # QUES What's the best output format of this function? + cfg = ForwardDiff.JacobianConfig(Hfunc, x) + d = length(x) + out = zeros(eltype(x), d^2, d) + return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) + return out # default output shape [∂H∂x₁; ∂H∂x₂; ...] +end + +function reshape_∂G∂θ(H) + d = size(H, 2) + return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3) +end + +Base.size(e::DenseRiemannianMetric) = e.size +Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] +Base.show(io::IO, drm::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric") + +function rand_momentum( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + metric::DenseRiemannianMetric{T}, + kinetic, + θ::AbstractVecOrMat, +) where {T} + r = _randn(rng, T, size(metric)...) + G⁻¹ = inv(metric.map(metric.G(θ))) + chol = cholesky(Symmetric(G⁻¹)) + ldiv!(chol.U, r) + return r +end diff --git a/test/riemannian.jl b/test/riemannian.jl new file mode 100644 index 00000000..27839d40 --- /dev/null +++ b/test/riemannian.jl @@ -0,0 +1,63 @@ +using ReTest, Random +using AdvancedHMC, ForwardDiff, AbstractMCMC +using LinearAlgebra + +@testset "Multi variate Normal with Riemannian HMC" begin + # Set the number of samples to draw and warmup iterations + n_samples = 2_000 + rng = MersenneTwister(1110) + initial_θ = rand(rng, D) + λ = 1e-2 + # Define a Hamiltonian system + metric = DenseRiemannianMetric((D,), ℓπ, initial_θ, λ) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∇ℓπ) + + # Define a leapfrog solver, with the initial step size chosen heuristically + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 6) + + # Define an HMC sampler with the following components + # - multinomial sampling scheme, + # - generalised No-U-Turn criteria, and + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) + + # Run the sampler to draw samples from the specified Gaussian, where + # - `samples` will store the samples + # - `stats` will store diagnostic statistics for each sample + samples, stats = sample( + rng, hamiltonian, kernel, initial_θ, n_samples; progress=true + ) + @test length(samples) == n_samples + @test length(stats) == n_samples +end + +@testset "Multi variate Normal with Riemannian HMC softabs metric" begin + # Set the number of samples to draw and warmup iterations + n_samples = 2_000 + rng = MersenneTwister(1110) + initial_θ = rand(rng, D) + + # Define a Hamiltonian system + metric = DenseRiemannianMetric((D,), ℓπ, initial_θ, λSoftAbsMap(20.0)) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∇ℓπ) + + # Define a leapfrog solver, with the initial step size chosen heuristically + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 6) + + # Define an HMC sampler with the following components + # - multinomial sampling scheme, + # - generalised No-U-Turn criteria, and + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) + + # Run the sampler to draw samples from the specified Gaussian, where + # - `samples` will store the samples + # - `stats` will store diagnostic statistics for each sample + samples, stats = sample( + rng, hamiltonian, kernel, initial_θ, n_samples; progress=true + ) + @test length(samples) == n_samples + @test length(stats) == n_samples +end \ No newline at end of file From 5fdc20df49ed16e5a80cc9f5fdb243432bf9ba8a Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Sun, 27 Apr 2025 22:12:37 +0800 Subject: [PATCH 2/8] format --- src/AdvancedHMC.jl | 13 ++++++++++++- src/riemannian/hamiltonian.jl | 21 ++++++++++++--------- src/riemannian/metric.jl | 6 ++++-- test/riemannian.jl | 10 +++------- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index c969c5bc..dff715c6 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -2,7 +2,18 @@ module AdvancedHMC using Statistics: mean, var, middle using LinearAlgebra: - Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, diagm, cholesky, UniformScaling, logdet, tr + Symmetric, + UpperTriangular, + mul!, + ldiv!, + dot, + I, + diag, + diagm, + cholesky, + UniformScaling, + logdet, + tr using StatsFuns: logaddexp, logsumexp, loghalf using Random: Random, AbstractRNG using ProgressMeter: ProgressMeter diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 2b939698..9c0fdb17 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,6 +1,8 @@ #! Eq (14) of Girolami & Calderhead (2011) function ∂H∂r( - h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, θ::AbstractVecOrMat, r::AbstractVecOrMat + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, + θ::AbstractVecOrMat, + r::AbstractVecOrMat, ) H = h.metric.G(θ) G = h.metric.map(H) @@ -87,15 +89,16 @@ function ∂H∂θ_cache( d = length(∂ℓπ∂θ) D = diagm((Q' * r) ./ softabsλ) term_2_cached = Q * D * J * D * Q' - g = - isdiag ? - -(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ)) : + g = if isdiag + -(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ)) + else -mapreduce(vcat, 1:d) do i - ∂H∂θᵢ = ∂H∂θ[:, :, i] - # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) - # NOTE Some further optimization can be done here: cache the 1st product all together - ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly - end + ∂H∂θᵢ = ∂H∂θ[:, :, i] + # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) + # NOTE Some further optimization can be done here: cache the 1st product all together + ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly + end + end dv = DualValue(ℓπ, g) return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index e79ad7f9..e05385cf 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -44,7 +44,7 @@ function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) end # Convenient constructor -function DenseRiemannianMetric(size, ℓπ, initial_θ, λ, map = IdentityMap()) +function DenseRiemannianMetric(size, ℓπ, initial_θ, λ, map=IdentityMap()) _Hfunc = VecTargets.gen_hess(x -> -ℓπ(x), initial_θ) # x -> (value, gradient, hessian) Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug @@ -96,7 +96,9 @@ end Base.size(e::DenseRiemannianMetric) = e.size Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] -Base.show(io::IO, drm::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric") +function Base.show(io::IO, drm::DenseRiemannianMetric) + return print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric") +end function rand_momentum( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, diff --git a/test/riemannian.jl b/test/riemannian.jl index 27839d40..eb8795d0 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -25,9 +25,7 @@ using LinearAlgebra # Run the sampler to draw samples from the specified Gaussian, where # - `samples` will store the samples # - `stats` will store diagnostic statistics for each sample - samples, stats = sample( - rng, hamiltonian, kernel, initial_θ, n_samples; progress=true - ) + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true) @test length(samples) == n_samples @test length(stats) == n_samples end @@ -55,9 +53,7 @@ end # Run the sampler to draw samples from the specified Gaussian, where # - `samples` will store the samples # - `stats` will store diagnostic statistics for each sample - samples, stats = sample( - rng, hamiltonian, kernel, initial_θ, n_samples; progress=true - ) + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true) @test length(samples) == n_samples @test length(stats) == n_samples -end \ No newline at end of file +end From 75c2380f0a985d65a1ccd22d60fc8e0d0a9983c8 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Sun, 27 Apr 2025 22:32:56 +0800 Subject: [PATCH 3/8] Fix source --- Project.toml | 7 ++++--- src/riemannian/hamiltonian.jl | 10 +++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 2bd520a2..d9153332 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" VecTargets = "8a639fad-7908-4fe4-8003-906e9297f002" +[sources] +VecTargets = {url = "https://github.com/chalk-lab/VecTargets.jl", rev = "main"} + + [weakdeps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -25,9 +29,6 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -[sources] -VecTargets = {rev = "main", url = "https://github.com/chalk-lab/VecTargets.jl"} - [extensions] AdvancedHMCADTypesExt = "ADTypes" AdvancedHMCComponentArraysExt = "ComponentArrays" diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 9c0fdb17..166d1b7b 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -93,11 +93,11 @@ function ∂H∂θ_cache( -(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ)) else -mapreduce(vcat, 1:d) do i - ∂H∂θᵢ = ∂H∂θ[:, :, i] - # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) - # NOTE Some further optimization can be done here: cache the 1st product all together - ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly - end + ∂H∂θᵢ = ∂H∂θ[:, :, i] + # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) + # NOTE Some further optimization can be done here: cache the 1st product all together + ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly + end end dv = DualValue(ℓπ, g) From 8a31df4eb277d204198b33d1972baef4713bf1ae Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Tue, 29 Apr 2025 06:03:58 +0800 Subject: [PATCH 4/8] Dont hard dep on MCMCLogDensityProblems --- Project.toml | 7 ------ src/AdvancedHMC.jl | 4 --- src/riemannian/metric.jl | 51 ------------------------------------- test/riemannian.jl | 54 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 50 insertions(+), 66 deletions(-) diff --git a/Project.toml b/Project.toml index d9153332..c0f46d18 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.7.1" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" @@ -16,11 +15,6 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -VecTargets = "8a639fad-7908-4fe4-8003-906e9297f002" - -[sources] -VecTargets = {url = "https://github.com/chalk-lab/VecTargets.jl", rev = "main"} - [weakdeps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -43,7 +37,6 @@ ArgCheck = "1, 2" ComponentArrays = "0.15" CUDA = "3, 4, 5" DocStringExtensions = "0.8, 0.9" -ForwardDiff = "0.10.38" LinearAlgebra = "<0.1, 1" LogDensityProblems = "2" LogDensityProblemsAD = "1" diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index dff715c6..e4ee65a5 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -30,12 +30,8 @@ using LogDensityProblemsAD: LogDensityProblemsAD using AbstractMCMC: AbstractMCMC, LogDensityModel -using VecTargets: VecTargets - import StatsBase: sample -using ForwardDiff: ForwardDiff - const DEFAULT_FLOAT_TYPE = typeof(float(0)) include("utilities.jl") diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index e05385cf..41d11127 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -43,57 +43,6 @@ function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) end -# Convenient constructor -function DenseRiemannianMetric(size, ℓπ, initial_θ, λ, map=IdentityMap()) - _Hfunc = VecTargets.gen_hess(x -> -ℓπ(x), initial_θ) # x -> (value, gradient, hessian) - Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug - - fstabilize = H -> H + λ * I - Gfunc = x -> begin - H = fstabilize(Hfunc(x)[3]) - all(isfinite, H) ? H : diagm(ones(length(x))) - end - _∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize) - ∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x)) - - _temp = Vector{Float64}(undef, first(size)) - - return DenseRiemannianMetric(size, Gfunc, ∂G∂θfunc, map, _temp) -end - -function gen_hess_fwd(func, x::AbstractVector) - function hess(x::AbstractVector) - return nothing, nothing, ForwardDiff.hessian(func, x) - end - return hess -end - -#= possible integrate DI for AD-independent fisher information metric -function gen_∂G∂θ_rev(Vfunc, x; f=identity) - _Hfunc = VecTargets.gen_hess(Vfunc, ReverseDiff.track.(x)) - Hfunc = x -> _Hfunc(x)[3] - # QUES What's the best output format of this function? - return x -> ReverseDiff.jacobian(x -> f(Hfunc(x)), x) # default output shape [∂H∂x₁; ∂H∂x₂; ...] -end -=# - -# Fisher information metric -function gen_∂G∂θ_fwd(Vfunc, x; f=identity) - _Hfunc = gen_hess_fwd(Vfunc, x) - Hfunc = x -> _Hfunc(x)[3] - # QUES What's the best output format of this function? - cfg = ForwardDiff.JacobianConfig(Hfunc, x) - d = length(x) - out = zeros(eltype(x), d^2, d) - return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) - return out # default output shape [∂H∂x₁; ∂H∂x₂; ...] -end - -function reshape_∂G∂θ(H) - d = size(H, 2) - return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3) -end - Base.size(e::DenseRiemannianMetric) = e.size Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] function Base.show(io::IO, drm::DenseRiemannianMetric) diff --git a/test/riemannian.jl b/test/riemannian.jl index eb8795d0..eb609574 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -2,16 +2,60 @@ using ReTest, Random using AdvancedHMC, ForwardDiff, AbstractMCMC using LinearAlgebra +using Pkg +Pkg.develop(; url="https://github.com/chalk-lab/MCMCLogDensityProblems.jl") +using MCMCLogDensityProblems + +# Fisher information metric +function gen_∂G∂θ_fwd(Vfunc, x; f=identity) + _Hfunc = gen_hess_fwd(Vfunc, x) + Hfunc = x -> _Hfunc(x)[3] + # QUES What's the best output format of this function? + cfg = ForwardDiff.JacobianConfig(Hfunc, x) + d = length(x) + out = zeros(eltype(x), d^2, d) + return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) + return out # default output shape [∂H∂x₁; ∂H∂x₂; ...] +end + +function gen_hess_fwd(func, x::AbstractVector) + function hess(x::AbstractVector) + return nothing, nothing, ForwardDiff.hessian(func, x) + end + return hess +end + +function reshape_∂G∂θ(H) + d = size(H, 2) + return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3) +end + +function prepare_sample(ℓπ, initial_θ, λ) + _Hfunc = MCMCLogDensityProblems.gen_hess(x -> -ℓπ(x), initial_θ) # x -> (value, gradient, hessian) + Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug + + fstabilize = H -> H + λ * I + Gfunc = x -> begin + H = fstabilize(Hfunc(x)[3]) + all(isfinite, H) ? H : diagm(ones(length(x))) + end + _∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize) + ∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x)) + + return Gfunc, ∂G∂θfunc +end + @testset "Multi variate Normal with Riemannian HMC" begin # Set the number of samples to draw and warmup iterations n_samples = 2_000 rng = MersenneTwister(1110) initial_θ = rand(rng, D) λ = 1e-2 + G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) # Define a Hamiltonian system - metric = DenseRiemannianMetric((D,), ℓπ, initial_θ, λ) + metric = DenseRiemannianMetric((D,), G, ∂G∂θ) kinetic = GaussianKinetic() - hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∇ℓπ) + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) # Define a leapfrog solver, with the initial step size chosen heuristically initial_ϵ = 0.01 @@ -35,11 +79,13 @@ end n_samples = 2_000 rng = MersenneTwister(1110) initial_θ = rand(rng, D) + λ = 1e-2 + G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) # Define a Hamiltonian system - metric = DenseRiemannianMetric((D,), ℓπ, initial_θ, λSoftAbsMap(20.0)) + metric = DenseRiemannianMetric((D,), G, ∂G∂θ, λSoftAbsMap(20.0)) kinetic = GaussianKinetic() - hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∇ℓπ) + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) # Define a leapfrog solver, with the initial step size chosen heuristically initial_ϵ = 0.01 From 9bdec6ff1211eb4741824648c8a16e667b06ce36 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Tue, 29 Apr 2025 06:41:32 +0800 Subject: [PATCH 5/8] Fix JET detected issues --- src/AdvancedHMC.jl | 3 ++- src/riemannian/hamiltonian.jl | 5 +---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index e4ee65a5..33a1bf92 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -13,7 +13,8 @@ using LinearAlgebra: cholesky, UniformScaling, logdet, - tr + tr, + eigen using StatsFuns: logaddexp, logsumexp, loghalf using Random: Random, AbstractRNG using ProgressMeter: ProgressMeter diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 166d1b7b..f8acc797 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -89,16 +89,13 @@ function ∂H∂θ_cache( d = length(∂ℓπ∂θ) D = diagm((Q' * r) ./ softabsλ) term_2_cached = Q * D * J * D * Q' - g = if isdiag - -(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ)) - else + g = -mapreduce(vcat, 1:d) do i ∂H∂θᵢ = ∂H∂θ[:, :, i] # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) # NOTE Some further optimization can be done here: cache the 1st product all together ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly end - end dv = DualValue(ℓπ, g) return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv From 574d864c983b08b529748edd11ea0616ca3e088b Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Fri, 2 May 2025 07:09:41 +0800 Subject: [PATCH 6/8] Delete for now --- src/riemannian/hamiltonian.jl | 126 ---------------------------------- 1 file changed, 126 deletions(-) delete mode 100644 src/riemannian/hamiltonian.jl diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl deleted file mode 100644 index f8acc797..00000000 --- a/src/riemannian/hamiltonian.jl +++ /dev/null @@ -1,126 +0,0 @@ -#! Eq (14) of Girolami & Calderhead (2011) -function ∂H∂r( - h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, - θ::AbstractVecOrMat, - r::AbstractVecOrMat, -) - H = h.metric.G(θ) - G = h.metric.map(H) - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't -end - -function ∂H∂θ( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, - θ::AbstractVecOrMat{T}, - r::AbstractVecOrMat{T}, -) where {T} - ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) - G = h.metric.map(h.metric.G(θ)) - invG = inv(G) - ∂G∂θ = h.metric.∂G∂θ(θ) - d = length(∂ℓπ∂θ) - return DualValue( - ℓπ, - #! Eq (15) of Girolami & Calderhead (2011) - -mapreduce(vcat, 1:d) do i - ∂G∂θᵢ = ∂G∂θ[:, :, i] - ∂ℓπ∂θ[i] - 1 / 2 * tr(invG * ∂G∂θᵢ) + 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r - # Gr = G \ r - # ∂ℓπ∂θ[i] - 1 / 2 * tr(G \ ∂G∂θᵢ) + 1 / 2 * Gr' * ∂G∂θᵢ * Gr - # 1 / 2 * tr(invG * ∂G∂θᵢ) - # 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r - end, - ) -end - -# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29 -#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative" -dsoftabsdλ(α, λ) = coth(α * λ) + λ * α * -csch(λ * α)^2 - -#! J as defined in middle of the right column of Page 3 of Betancourt (2012) -function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} - d = length(λ) - J = Matrix{T}(undef, d, d) - for i in 1:d, j in 1:d - J[i, j] = if (λ[i] == λ[j]) - dsoftabsdλ(α, λ[i]) - else - ((λ[i] * coth(α * λ[i]) - λ[j] * coth(α * λ[j])) / (λ[i] - λ[j])) - end - end - return J -end - -function ∂H∂θ( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, - θ::AbstractVecOrMat{T}, - r::AbstractVecOrMat{T}, -) where {T} - return ∂H∂θ_cache(h, θ, r) -end -function ∂H∂θ_cache( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, - θ::AbstractVecOrMat{T}, - r::AbstractVecOrMat{T}; - return_cache=false, - cache=nothing, -) where {T} - # Terms that only dependent on θ can be cached in θ-unchanged loops - if isnothing(cache) - ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) - H = h.metric.G(θ) - ∂H∂θ = h.metric.∂G∂θ(θ) - - G, Q, λ, softabsλ = softabs(H, h.metric.map.α) - - R = diagm(1 ./ softabsλ) - - # softabsΛ = diagm(softabsλ) - # M = inv(softabsΛ) * Q' * r - # M = R * Q' * r # equiv to above but avoid inv - - J = make_J(λ, h.metric.map.α) - - #! Based on the two equations from the right column of Page 3 of Betancourt (2012) - term_1_cached = Q * (R .* J) * Q' - else - ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache - end - d = length(∂ℓπ∂θ) - D = diagm((Q' * r) ./ softabsλ) - term_2_cached = Q * D * J * D * Q' - g = - -mapreduce(vcat, 1:d) do i - ∂H∂θᵢ = ∂H∂θ[:, :, i] - # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) - # NOTE Some further optimization can be done here: cache the 1st product all together - ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly - end - - dv = DualValue(ℓπ, g) - return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv -end - -# QUES Do we want to change everything to position dependent by default? -# Add θ to ∂H∂r for DenseRiemannianMetric -function phasepoint( - h::Hamiltonian{<:DenseRiemannianMetric}, - θ::T, - r::T; - ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), -) where {T<:AbstractVecOrMat} - return PhasePoint(θ, r, ℓπ, ℓκ) -end - -#! Eq (13) of Girolami & Calderhead (2011) -function neg_energy( - h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T -) where {T<:AbstractVecOrMat} - G = h.metric.map(h.metric.G(θ)) - D = size(G, 1) - # Need to consider the normalizing term as it is no longer same for different θs - logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined - mul!(h.metric._temp, inv(G), r) - return -logZ - dot(r, h.metric._temp) / 2 -end From 32fc1b4f22b17c961656b86fc511003a4d2b404f Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Fri, 2 May 2025 07:13:21 +0800 Subject: [PATCH 7/8] Use git mv to preserve history --- research/src/riemannian_hmc.jl => src/riemannian/hamiltonian.jl | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename research/src/riemannian_hmc.jl => src/riemannian/hamiltonian.jl (100%) diff --git a/research/src/riemannian_hmc.jl b/src/riemannian/hamiltonian.jl similarity index 100% rename from research/src/riemannian_hmc.jl rename to src/riemannian/hamiltonian.jl From 081cf7fa54fbdd142e64a97e168318b900678df2 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Fri, 2 May 2025 07:14:37 +0800 Subject: [PATCH 8/8] Real change here --- src/riemannian/hamiltonian.jl | 298 ++++------------------------------ 1 file changed, 33 insertions(+), 265 deletions(-) diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index feddb411..f8acc797 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,257 +1,16 @@ -using Random - -### integrator.jl - -import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step -using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size - -""" -$(TYPEDEF) - -Generalized leapfrog integrator with fixed step size `ϵ`. - -# Fields - -$(TYPEDFIELDS) -""" -struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} - "Step size." - ϵ::T - n::Int -end -function Base.show(io::IO, l::GeneralizedLeapfrog) - return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))") -end - -# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T} - dv = ∂H∂θ(h, θ, r) - return return_cache ? (dv, nothing) : dv -end - -# TODO Make sure vectorization works -# TODO Check if tempering is valid -function step( - lf::GeneralizedLeapfrog{T}, - h::Hamiltonian, - z::P, - n_steps::Int=1; - fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 - full_trajectory::Val{FullTraj}=Val(false), -) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} - n_steps = abs(n_steps) # to support `n_steps < 0` cases - - ϵ = fwd ? step_size(lf) : -step_size(lf) - ϵ = ϵ' - - res = if FullTraj - Vector{P}(undef, n_steps) - else - z - end - - for i in 1:n_steps - θ_init, r_init = z.θ, z.r - # Tempering - #r = temper(lf, r, (i=i, is_half=true), n_steps) - #! Eq (16) of Girolami & Calderhead (2011) - r_half = copy(r_init) - local cache - for j in 1:(lf.n) - # Reuse cache for the first iteration - if j == 1 - (; value, gradient) = z.ℓπ - elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) - retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) - (; value, gradient) = retval - else # reuse cache - (; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) - end - r_half = r_init - ϵ / 2 * gradient - # println("r_half: ", r_half) - end - #! Eq (17) of Girolami & Calderhead (2011) - θ_full = copy(θ_init) - term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop - for j in 1:(lf.n) - θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) - # println("θ_full :", θ_full) - end - #! Eq (18) of Girolami & Calderhead (2011) - (; value, gradient) = ∂H∂θ(h, θ_full, r_half) - r_full = r_half - ϵ / 2 * gradient - # println("r_full: ", r_full) - # Tempering - #r = temper(lf, r, (i=i, is_half=false), n_steps) - # Create a new phase point by caching the logdensity and gradient - z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) - # Update result - if FullTraj - res[i] = z - else - res = z - end - if !isfinite(z) - # Remove undef - if FullTraj - res = res[isassigned.(Ref(res), 1:n_steps)] - end - break - end - # @assert false - end - return res -end - -# TODO Make the order of θ and r consistent with neg_energy -∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) -∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) - -### hamiltonian.jl - -import AdvancedHMC: refresh, phasepoint -using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric - -# To change L180 of hamiltonian.jl -function phasepoint( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - θ::AbstractVecOrMat{T}, - h::Hamiltonian, -) where {T<:Real} - return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) -end - -# To change L191 of hamiltonian.jl -function refresh( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - ::FullMomentumRefreshment, - h::Hamiltonian, - z::PhasePoint, -) - return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ)) -end - -# To change L215 of hamiltonian.jl -function refresh( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - ref::PartialMomentumRefreshment, - h::Hamiltonian, - z::PhasePoint, -) - return phasepoint( - h, - z.θ, - ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ), - ) -end - -### metric.jl - -import AdvancedHMC: _rand -using AdvancedHMC: AbstractMetric -using LinearAlgebra: eigen, cholesky, Symmetric - -abstract type AbstractRiemannianMetric <: AbstractMetric end - -abstract type AbstractHessianMap end - -struct IdentityMap <: AbstractHessianMap end - -(::IdentityMap)(x) = x - -struct SoftAbsMap{T} <: AbstractHessianMap - α::T -end - -# TODO Register softabs with ReverseDiff -#! The definition of SoftAbs from Page 3 of Betancourt (2012) -function softabs(X, α=20.0) - F = eigen(X) # ReverseDiff cannot diff through `eigen` - Q = hcat(F.vectors) - λ = F.values - softabsλ = λ .* coth.(α * λ) - return Q * diagm(softabsλ) * Q', Q, λ, softabsλ -end - -(map::SoftAbsMap)(x) = softabs(x, map.α)[1] - -struct DenseRiemannianMetric{ - T, - TM<:AbstractHessianMap, - A<:Union{Tuple{Int},Tuple{Int,Int}}, - AV<:AbstractVecOrMat{T}, - TG, - T∂G∂θ, -} <: AbstractRiemannianMetric - size::A - G::TG # TODO store G⁻¹ here instead - ∂G∂θ::T∂G∂θ - map::TM - _temp::AV -end - -# TODO Make dense mass matrix support matrix-mode parallel -function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} - _temp = Vector{Float64}(undef, size[1]) - return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) -end -# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D)) -# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D) -# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz))) -# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz) - -# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹) - -Base.size(e::DenseRiemannianMetric) = e.size -Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] -Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") - -function rand_momentum( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - metric::DenseRiemannianMetric{T}, - kinetic, +#! Eq (14) of Girolami & Calderhead (2011) +function ∂H∂r( + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, θ::AbstractVecOrMat, -) where {T} - r = _randn(rng, T, size(metric)...) - G⁻¹ = inv(metric.map(metric.G(θ))) - chol = cholesky(Symmetric(G⁻¹)) - ldiv!(chol.U, r) - return r -end - -### hamiltonian.jl - -import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r -using LinearAlgebra: logabsdet, tr - -# QUES Do we want to change everything to position dependent by default? -# Add θ to ∂H∂r for DenseRiemannianMetric -function phasepoint( - h::Hamiltonian{<:DenseRiemannianMetric}, - θ::T, - r::T; - ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), -) where {T<:AbstractVecOrMat} - return PhasePoint(θ, r, ℓπ, ℓκ) -end - -# Negative kinetic energy -#! Eq (13) of Girolami & Calderhead (2011) -function neg_energy( - h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T -) where {T<:AbstractVecOrMat} - G = h.metric.map(h.metric.G(θ)) - D = size(G, 1) - # Need to consider the normalizing term as it is no longer same for different θs - logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined - mul!(h.metric._temp, inv(G), r) - return -logZ - dot(r, h.metric._temp) / 2 + r::AbstractVecOrMat, +) + H = h.metric.G(θ) + G = h.metric.map(H) + return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't end -# QUES L31 of hamiltonian.jl now reads a bit weird (semantically) function ∂H∂θ( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}, ) where {T} @@ -293,14 +52,14 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} end function ∂H∂θ( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}, ) where {T} return ∂H∂θ_cache(h, θ, r) end function ∂H∂θ_cache( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}; return_cache=false, @@ -342,17 +101,26 @@ function ∂H∂θ_cache( return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv end -#! Eq (14) of Girolami & Calderhead (2011) -function ∂H∂r( - h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat -) - H = h.metric.G(θ) - # if !all(isfinite, H) - # println("θ: ", θ) - # println("H: ", H) - # end - G = h.metric.map(H) - # return inv(G) * r - # println("G \ r: ", G \ r) - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't +# QUES Do we want to change everything to position dependent by default? +# Add θ to ∂H∂r for DenseRiemannianMetric +function phasepoint( + h::Hamiltonian{<:DenseRiemannianMetric}, + θ::T, + r::T; + ℓπ=∂H∂θ(h, θ), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), +) where {T<:AbstractVecOrMat} + return PhasePoint(θ, r, ℓπ, ℓκ) +end + +#! Eq (13) of Girolami & Calderhead (2011) +function neg_energy( + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T +) where {T<:AbstractVecOrMat} + G = h.metric.map(h.metric.G(θ)) + D = size(G, 1) + # Need to consider the normalizing term as it is no longer same for different θs + logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined + mul!(h.metric._temp, inv(G), r) + return -logZ - dot(r, h.metric._temp) / 2 end