Skip to content

Commit 5b24ceb

Browse files
authored
Don't get stuck in an infinite loop if HMC can't find an initial point (#2392)
* Error after 1000 attempts at finding initial parameters * Add a test * Fix missing import * Bump Project.toml
1 parent 470f447 commit 5b24ceb

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/mcmc/hmc.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,11 @@ function DynamicPPL.initialstep(
181181
if init_attempt_count == 10
182182
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
183183
end
184+
if init_attempt_count == 1000
185+
error(
186+
"failed to find valid initial parameters in $(init_attempt_count) tries. This may indicate an error with the model or AD backend; please open an issue at https://github.com/TuringLang/Turing.jl/issues",
187+
)
188+
end
184189

185190
# NOTE: This will sample in the unconstrained space.
186191
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))

test/mcmc/hmc.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import Random
1515
using StableRNGs: StableRNG
1616
using StatsFuns: logistic
1717
import Mooncake
18-
using Test: @test, @test_logs, @testset
18+
using Test: @test, @test_logs, @testset, @test_throws
1919
using Turing
2020

2121
@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends
@@ -272,6 +272,15 @@ using Turing
272272
end
273273
end
274274

275+
@testset "error for impossible model" begin
276+
@model function demo_impossible()
277+
x ~ Normal()
278+
Turing.@addlogprob! -Inf
279+
end
280+
281+
@test_throws ErrorException sample(demo_impossible(), NUTS(; adtype=adbackend), 5)
282+
end
283+
275284
@testset "(partially) issue: #2095" begin
276285
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
277286
xs = Vector{TV}(undef, 2)

0 commit comments

Comments
 (0)