From 1c85ef73d4342b96abd89ce272d28508726fc80d Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Sun, 11 May 2025 23:33:19 +0800 Subject: [PATCH] Add relativistic kinetic --- Project.toml | 2 + src/AdvancedHMC.jl | 7 +++ src/relativistic/hamiltonian.jl | 71 +++++++++++++++++++++++++++++++ src/relativistic/metric.jl | 75 +++++++++++++++++++++++++++++++++ test/relativistic.jl | 26 ++++++++++++ test/runtests.jl | 1 + 6 files changed, 182 insertions(+) create mode 100644 src/relativistic/hamiltonian.jl create mode 100644 src/relativistic/metric.jl create mode 100644 test/relativistic.jl diff --git a/Project.toml b/Project.toml index bae9e5dd..78173853 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.8.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AdaptiveRejectionSampling = "c75e803d-635f-53bd-ab7d-544e482d8c75" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -33,6 +34,7 @@ AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq" [compat] ADTypes = "1" AbstractMCMC = "5.6" +AdaptiveRejectionSampling = "0.2.1" ArgCheck = "1, 2" ComponentArrays = "0.15" CUDA = "3, 4, 5" diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index b25710d5..e4cfc9d8 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -21,6 +21,8 @@ using AbstractMCMC: AbstractMCMC, LogDensityModel import StatsBase: sample +using AdaptiveRejectionSampling: RejectionSampler, run_sampler! + const DEFAULT_FLOAT_TYPE = typeof(float(0)) include("utilities.jl") @@ -63,6 +65,11 @@ export Trajectory, MultinomialTS, find_good_stepsize +include("relativistic/hamiltonian.jl") +export RelativisticKinetic, DimensionwiseRelativisticKinetic + +include("relativistic/metric.jl") + # Useful defaults @deprecate find_good_eps find_good_stepsize diff --git a/src/relativistic/hamiltonian.jl b/src/relativistic/hamiltonian.jl new file mode 100644 index 00000000..1a05344f --- /dev/null +++ b/src/relativistic/hamiltonian.jl @@ -0,0 +1,71 @@ +abstract type AbstractRelativisticKinetic{T} <: AbstractKinetic end + +struct RelativisticKinetic{T} <: AbstractRelativisticKinetic{T} + "Mass" + m::T + "Speed of light" + c::T +end + +function relativistic_mass(kinetic::RelativisticKinetic, r, r′=r) + return kinetic.m * sqrt(dot(r, r′) / (kinetic.m^2 * kinetic.c^2) + 1) +end +function relativistic_energy(kinetic::RelativisticKinetic, r, r′=r) + return sum(kinetic.c^2 * relativistic_mass(kinetic, r, r′)) +end + +struct DimensionwiseRelativisticKinetic{T} <: AbstractRelativisticKinetic{T} + "Mass" + m::T + "Speed of light" + c::T +end + +function relativistic_mass(kinetic::DimensionwiseRelativisticKinetic, r, r′=r) + return kinetic.m .* sqrt.(r .* r′ ./ (kinetic.m .^ 2 .* kinetic.c .^ 2) .+ 1) +end +function relativistic_energy(kinetic::DimensionwiseRelativisticKinetic, r, r′=r) + return sum(kinetic.c .^ 2 .* relativistic_mass(kinetic, r, r′)) +end + +function ∂H∂r( + h::Hamiltonian{<:UnitEuclideanMetric,<:AbstractRelativisticKinetic}, r::AbstractVecOrMat +) + mass = relativistic_mass(h.kinetic, r) + return r ./ mass +end +function ∂H∂r( + h::Hamiltonian{<:DiagEuclideanMetric,<:AbstractRelativisticKinetic}, r::AbstractVecOrMat +) + r = h.metric.sqrtM⁻¹ .* r + mass = relativistic_mass(h.kinetic, r) + red_term = r ./ mass # red part of (15) + return h.metric.sqrtM⁻¹ .* red_term # (15) +end +function ∂H∂r( + h::Hamiltonian{<:DenseEuclideanMetric,<:AbstractRelativisticKinetic}, + r::AbstractVecOrMat, +) + r = h.metric.cholM⁻¹ * r + mass = relativistic_mass(h.kinetic, r) + red_term = r ./ mass + return h.metric.cholM⁻¹' * red_term +end + +function neg_energy( + h::Hamiltonian{<:UnitEuclideanMetric,<:AbstractRelativisticKinetic}, r::T, θ::T +) where {T<:AbstractVector} + return -relativistic_energy(h.kinetic, r) +end +function neg_energy( + h::Hamiltonian{<:DiagEuclideanMetric,<:AbstractRelativisticKinetic}, r::T, θ::T +) where {T<:AbstractVector} + r = h.metric.sqrtM⁻¹ .* r + return -relativistic_energy(h.kinetic, r) +end +function neg_energy( + h::Hamiltonian{<:DenseEuclideanMetric,<:AbstractRelativisticKinetic}, r::T, θ::T +) where {T<:AbstractVector} + r = h.metric.cholM⁻¹ * r + return -relativistic_energy(h.kinetic, r) +end diff --git a/src/relativistic/metric.jl b/src/relativistic/metric.jl new file mode 100644 index 00000000..cb27719d --- /dev/null +++ b/src/relativistic/metric.jl @@ -0,0 +1,75 @@ +function rand_angles(rng::AbstractRNG, dim) + return rand(rng, dim - 1) .* vcat(fill(π, dim - 2), 2 * π) +end + +"Special case of `polar2spherical` with dimension equal 2" +polar2cartesian(θ, d) = d * [cos(θ), sin(θ)] + +# ref: https://en.wikipedia.org/wiki/N-sphere#Spherical_coordinates +function polar2spherical(θs, d) + cos_lst, sin_lst = cos.(θs), sin.(θs) + suffixed_cos_lst = vcat(cos_lst, 1) # [cos(θ[1]), cos(θ[2]), ..., cos(θ[d-1]), 1] + prefixed_cumprod_sin_lst = vcat(1, cumprod(sin_lst)) # [1, sin(θ[1]), sin(θ[1]) * sin(θ[2]), ..., sin(θ[1]) * ... * sin(θ[d-1])] + return d * prefixed_cumprod_sin_lst .* suffixed_cos_lst +end + +momentum_mode(m, c) = sqrt((1 / c^2 + sqrt(1 / c^2 + 4 * m^2)) / 2) # mode of the momentum distribution + +function rand_momentum( + rng::AbstractRNG, + metric::UnitEuclideanMetric{T}, + kinetic::RelativisticKinetic{T}, + ::AbstractVecOrMat, +) where {T} + densityfunc = x -> exp(-relativistic_energy(kinetic, [x])) * x + mm = momentum_mode(kinetic.m, kinetic.c) + sampler = RejectionSampler(densityfunc, (0.0, Inf), (mm / 2, mm * 2); max_segments=5) + sz = size(metric) + θs = rand_angles(rng, prod(sz)) + d = only(run_sampler!(rng, sampler, 1)) + r = polar2spherical(θs, d * rand(rng, [-1, +1])) # TODO Double check if +/- is needed + r = reshape(r, sz) + return r +end + +# TODO Support AbstractVector{<:AbstractRNG} +# FIXME Unit-test this using slice sampler or HMC sampler +function rand_momentum( + rng::AbstractRNG, + metric::UnitEuclideanMetric{T}, + kinetic::DimensionwiseRelativisticKinetic{T}, + ::AbstractVecOrMat, +) where {T} + h_temp = Hamiltonian(metric, kinetic, identity, identity) + densityfunc = x -> exp(neg_energy(h_temp, [x], [x])) + sampler = RejectionSampler(densityfunc, (-Inf, Inf); max_segments=5) + sz = size(metric) + r = run_sampler!(rng, sampler, prod(sz)) + r = reshape(r, sz) + return r +end + +# TODO Support AbstractVector{<:AbstractRNG} +function rand_momentum( + rng::AbstractRNG, + metric::DiagEuclideanMetric{T}, + kinetic::AbstractRelativisticKinetic{T}, + θ::AbstractVecOrMat, +) where {T} + r = rand_momentum(rng, UnitEuclideanMetric(size(metric)), kinetic, θ) + # p' = A p where A = sqrtM + r ./= metric.sqrtM⁻¹ + return r +end +# TODO Support AbstractVector{<:AbstractRNG} +function rand_momentum( + rng::AbstractRNG, + metric::DenseEuclideanMetric{T}, + kinetic::AbstractRelativisticKinetic{T}, + θ::AbstractVecOrMat, +) where {T} + r = rand_momentum(rng, UnitEuclideanMetric(size(metric)), kinetic, θ) + # p' = A p where A = cholM + ldiv!(metric.cholM⁻¹, r) + return r +end diff --git a/test/relativistic.jl b/test/relativistic.jl new file mode 100644 index 00000000..50c31fc5 --- /dev/null +++ b/test/relativistic.jl @@ -0,0 +1,26 @@ +using ReTest, Random, AdvancedHMC + +@testset "Relativistic kinetic construction" begin + f = x -> dot(x, x) + g = x -> 2x + metric = UnitEuclideanMetric(10) + h = Hamiltonian(metric, RelativisticKinetic(1.0, 1.0), f, g) + @test h.kinetic isa RelativisticKinetic +end + +@testset "Sampling with relativistic kinetic" begin + n_samples = 2_000 + rng = MersenneTwister(1110) + initial_θ = rand(D) + metric = DiagEuclideanMetric(D) + for kineticT in [RelativisticKinetic, DimensionwiseRelativisticKinetic] + kinetic = kineticT(1.0, 1.0) + h = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + initial_ϵ = find_good_stepsize(h, initial_θ) + integrator = Leapfrog(initial_ϵ) + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) + samples, stats = sample(rng, h, kernel, initial_θ, n_samples; progress=true) + @test length(samples) == n_samples + @test length(stats) == n_samples + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 642f76d3..c5fbcbbf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,6 +32,7 @@ if GROUP == "All" || GROUP == "AdvancedHMC" include("abstractmcmc.jl") include("mcmcchains.jl") include("constructors.jl") + include("relativistic.jl") Comonicon.@main function runtests(patterns...; dry::Bool=false) return retest(patterns...; dry=dry, verbose=Inf)