Skip to content

Add relativistic kinetic #445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.8.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdaptiveRejectionSampling = "c75e803d-635f-53bd-ab7d-544e482d8c75"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -33,6 +34,7 @@ AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
[compat]
ADTypes = "1"
AbstractMCMC = "5.6"
AdaptiveRejectionSampling = "0.2.1"
ArgCheck = "1, 2"
ComponentArrays = "0.15"
CUDA = "3, 4, 5"
Expand Down
7 changes: 7 additions & 0 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ using AbstractMCMC: AbstractMCMC, LogDensityModel

import StatsBase: sample

using AdaptiveRejectionSampling: RejectionSampler, run_sampler!

const DEFAULT_FLOAT_TYPE = typeof(float(0))

include("utilities.jl")
Expand Down Expand Up @@ -63,6 +65,11 @@ export Trajectory,
MultinomialTS,
find_good_stepsize

include("relativistic/hamiltonian.jl")
export RelativisticKinetic, DimensionwiseRelativisticKinetic

include("relativistic/metric.jl")

# Useful defaults

@deprecate find_good_eps find_good_stepsize
Expand Down
71 changes: 71 additions & 0 deletions src/relativistic/hamiltonian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
abstract type AbstractRelativisticKinetic{T} <: AbstractKinetic end

struct RelativisticKinetic{T} <: AbstractRelativisticKinetic{T}
"Mass"
m::T
"Speed of light"
c::T
end

function relativistic_mass(kinetic::RelativisticKinetic, r, r′=r)
return kinetic.m * sqrt(dot(r, r′) / (kinetic.m^2 * kinetic.c^2) + 1)
end
function relativistic_energy(kinetic::RelativisticKinetic, r, r′=r)
return sum(kinetic.c^2 * relativistic_mass(kinetic, r, r′))
end

struct DimensionwiseRelativisticKinetic{T} <: AbstractRelativisticKinetic{T}
"Mass"
m::T
"Speed of light"
c::T
end

function relativistic_mass(kinetic::DimensionwiseRelativisticKinetic, r, r′=r)
return kinetic.m .* sqrt.(r .* r′ ./ (kinetic.m .^ 2 .* kinetic.c .^ 2) .+ 1)
end
function relativistic_energy(kinetic::DimensionwiseRelativisticKinetic, r, r′=r)
return sum(kinetic.c .^ 2 .* relativistic_mass(kinetic, r, r′))
end

function ∂H∂r(
h::Hamiltonian{<:UnitEuclideanMetric,<:AbstractRelativisticKinetic}, r::AbstractVecOrMat
)
mass = relativistic_mass(h.kinetic, r)
return r ./ mass
end
function ∂H∂r(
h::Hamiltonian{<:DiagEuclideanMetric,<:AbstractRelativisticKinetic}, r::AbstractVecOrMat
)
r = h.metric.sqrtM⁻¹ .* r
mass = relativistic_mass(h.kinetic, r)
red_term = r ./ mass # red part of (15)
return h.metric.sqrtM⁻¹ .* red_term # (15)
end
function ∂H∂r(
h::Hamiltonian{<:DenseEuclideanMetric,<:AbstractRelativisticKinetic},
r::AbstractVecOrMat,
)
r = h.metric.cholM⁻¹ * r
mass = relativistic_mass(h.kinetic, r)
red_term = r ./ mass
return h.metric.cholM⁻¹' * red_term
end

function neg_energy(
h::Hamiltonian{<:UnitEuclideanMetric,<:AbstractRelativisticKinetic}, r::T, θ::T
) where {T<:AbstractVector}
return -relativistic_energy(h.kinetic, r)
end
function neg_energy(
h::Hamiltonian{<:DiagEuclideanMetric,<:AbstractRelativisticKinetic}, r::T, θ::T
) where {T<:AbstractVector}
r = h.metric.sqrtM⁻¹ .* r
return -relativistic_energy(h.kinetic, r)
end
function neg_energy(
h::Hamiltonian{<:DenseEuclideanMetric,<:AbstractRelativisticKinetic}, r::T, θ::T
) where {T<:AbstractVector}
r = h.metric.cholM⁻¹ * r
return -relativistic_energy(h.kinetic, r)
end
75 changes: 75 additions & 0 deletions src/relativistic/metric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
function rand_angles(rng::AbstractRNG, dim)
return rand(rng, dim - 1) .* vcat(fill(π, dim - 2), 2 * π)
end

"Special case of `polar2spherical` with dimension equal 2"
polar2cartesian(θ, d) = d * [cos(θ), sin(θ)]

# ref: https://en.wikipedia.org/wiki/N-sphere#Spherical_coordinates
function polar2spherical(θs, d)
cos_lst, sin_lst = cos.(θs), sin.(θs)
suffixed_cos_lst = vcat(cos_lst, 1) # [cos(θ[1]), cos(θ[2]), ..., cos(θ[d-1]), 1]
prefixed_cumprod_sin_lst = vcat(1, cumprod(sin_lst)) # [1, sin(θ[1]), sin(θ[1]) * sin(θ[2]), ..., sin(θ[1]) * ... * sin(θ[d-1])]
return d * prefixed_cumprod_sin_lst .* suffixed_cos_lst
end

momentum_mode(m, c) = sqrt((1 / c^2 + sqrt(1 / c^2 + 4 * m^2)) / 2) # mode of the momentum distribution

function rand_momentum(
rng::AbstractRNG,
metric::UnitEuclideanMetric{T},
kinetic::RelativisticKinetic{T},
::AbstractVecOrMat,
) where {T}
densityfunc = x -> exp(-relativistic_energy(kinetic, [x])) * x
mm = momentum_mode(kinetic.m, kinetic.c)
sampler = RejectionSampler(densityfunc, (0.0, Inf), (mm / 2, mm * 2); max_segments=5)
sz = size(metric)
θs = rand_angles(rng, prod(sz))
d = only(run_sampler!(rng, sampler, 1))
r = polar2spherical(θs, d * rand(rng, [-1, +1])) # TODO Double check if +/- is needed
r = reshape(r, sz)
return r
end

# TODO Support AbstractVector{<:AbstractRNG}
# FIXME Unit-test this using slice sampler or HMC sampler
function rand_momentum(
rng::AbstractRNG,
metric::UnitEuclideanMetric{T},
kinetic::DimensionwiseRelativisticKinetic{T},
::AbstractVecOrMat,
) where {T}
h_temp = Hamiltonian(metric, kinetic, identity, identity)
densityfunc = x -> exp(neg_energy(h_temp, [x], [x]))
sampler = RejectionSampler(densityfunc, (-Inf, Inf); max_segments=5)
sz = size(metric)
r = run_sampler!(rng, sampler, prod(sz))
r = reshape(r, sz)
return r
end

# TODO Support AbstractVector{<:AbstractRNG}
function rand_momentum(
rng::AbstractRNG,
metric::DiagEuclideanMetric{T},
kinetic::AbstractRelativisticKinetic{T},
θ::AbstractVecOrMat,
) where {T}
r = rand_momentum(rng, UnitEuclideanMetric(size(metric)), kinetic, θ)
# p' = A p where A = sqrtM
r ./= metric.sqrtM⁻¹
return r
end
# TODO Support AbstractVector{<:AbstractRNG}
function rand_momentum(
rng::AbstractRNG,
metric::DenseEuclideanMetric{T},
kinetic::AbstractRelativisticKinetic{T},
θ::AbstractVecOrMat,
) where {T}
r = rand_momentum(rng, UnitEuclideanMetric(size(metric)), kinetic, θ)
# p' = A p where A = cholM
ldiv!(metric.cholM⁻¹, r)
return r
end
26 changes: 26 additions & 0 deletions test/relativistic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using ReTest, Random, AdvancedHMC

@testset "Relativistic kinetic construction" begin
f = x -> dot(x, x)
g = x -> 2x
metric = UnitEuclideanMetric(10)
h = Hamiltonian(metric, RelativisticKinetic(1.0, 1.0), f, g)
@test h.kinetic isa RelativisticKinetic
end

@testset "Sampling with relativistic kinetic" begin
n_samples = 2_000
rng = MersenneTwister(1110)
initial_θ = rand(D)
metric = DiagEuclideanMetric(D)
for kineticT in [RelativisticKinetic, DimensionwiseRelativisticKinetic]
kinetic = kineticT(1.0, 1.0)
h = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ)
initial_ϵ = find_good_stepsize(h, initial_θ)
integrator = Leapfrog(initial_ϵ)
kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8)))
samples, stats = sample(rng, h, kernel, initial_θ, n_samples; progress=true)
@test length(samples) == n_samples
@test length(stats) == n_samples
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ if GROUP == "All" || GROUP == "AdvancedHMC"
include("abstractmcmc.jl")
include("mcmcchains.jl")
include("constructors.jl")
include("relativistic.jl")

Comonicon.@main function runtests(patterns...; dry::Bool=false)
return retest(patterns...; dry=dry, verbose=Inf)
Expand Down
Loading