Skip to content

Commit 01d31dd

Browse files
authored
Allow specifying initial step-size (#434)
1 parent 4e84b48 commit 01d31dd

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

src/trajectory.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -746,12 +746,24 @@ function A(h, z, ϵ)
746746
return z′, H′
747747
end
748748

749-
"Find a good initial leap-frog step-size via heuristic search."
749+
"""
750+
find_good_stepsize(h::Hamiltonian, θ::AbstractVector; initial_step_size = 1//10, max_n_iters::Int=100)
751+
find_good_stepsize(rng::AbstractRNG, h::Hamiltonian, θ::AbstractVector; initial_step_size = 1//10, max_n_iters::Int=100)
752+
753+
Find a good initial leap-frog step-size via heuristic search.
754+
755+
- `initial_step_size`: Custom initial step size, default as 1//10
756+
- `max_n_iters`: Maximum number of iteration for searching a good step-size, default as 100
757+
"""
750758
function find_good_stepsize(
751-
rng::AbstractRNG, h::Hamiltonian, θ::AbstractVector{T}; max_n_iters::Int=100
759+
rng::AbstractRNG,
760+
h::Hamiltonian,
761+
θ::AbstractVector{T};
762+
initial_step_size=1//10,
763+
max_n_iters::Int=100,
752764
) where {T<:Real}
753765
# Initialize searching parameters
754-
ϵ′ = ϵ = T(1//10)
766+
ϵ′ = ϵ = T(initial_step_size)
755767
# minimal, crossing, maximal log accept ratio
756768
log_a_min = 2 * T(loghalf)
757769
log_a_cross = T(loghalf)
@@ -815,9 +827,18 @@ function find_good_stepsize(
815827
end
816828

817829
function find_good_stepsize(
818-
h::Hamiltonian, θ::AbstractVector{<:AbstractFloat}; max_n_iters::Int=100
830+
h::Hamiltonian,
831+
θ::AbstractVector{<:AbstractFloat};
832+
initial_step_size=1//10,
833+
max_n_iters::Int=100,
819834
)
820-
return find_good_stepsize(Random.default_rng(), h, θ; max_n_iters=max_n_iters)
835+
return find_good_stepsize(
836+
Random.default_rng(),
837+
h,
838+
θ;
839+
initial_step_size=initial_step_size,
840+
max_n_iters=max_n_iters,
841+
)
821842
end
822843

823844
"Perform MH acceptance based on energy, i.e. negative log probability."

0 commit comments

Comments
 (0)