-
Notifications
You must be signed in to change notification settings - Fork 226
sample with LogDensityFunction: part 2 - ess.jl
+ mh.jl
#2590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: py/ldf-hmc
Are you sure you want to change the base?
Conversation
Turing.jl documentation for PR #2590 is available at: |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## py/ldf-hmc #2590 +/- ##
==============================================
- Coverage 42.32% 36.19% -6.13%
==============================================
Files 22 22
Lines 1498 1478 -20
==============================================
- Hits 634 535 -99
- Misses 864 943 +79 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
84e8b7e
to
393234b
Compare
fa0f068
to
22c2e9a
Compare
ess.jl
ess.jl
+ mh.jl
ctx = if ldf.context isa SamplingContext | ||
ldf.context | ||
else | ||
SamplingContext(rng, spl) | ||
SamplingContext(rng, spl, ldf.context) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise the existing context won't be obeyed
# Some of the proposals require working in unconstrained space. | ||
transform_maybe(proposal::AMH.Proposal) = proposal | ||
function transform_maybe(proposal::AMH.RandomWalkProposal) | ||
return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal)) | ||
end | ||
|
||
function MH(model::Model; proposal_type=AMH.StaticProposal) | ||
priors = DynamicPPL.extract_priors(model) | ||
props = Tuple([proposal_type(prop) for prop in values(priors)]) | ||
vars = Tuple(map(Symbol, collect(keys(priors)))) | ||
priors = map(transform_maybe, NamedTuple{vars}(props)) | ||
return AMH.MetropolisHastings(priors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this code was broken for a long time and nobody bothered to update it. the idea would be that this would return an AMH.MetropolisHastings
which then needed to be wrapped in an ExternalSampler
. however the external sampler would break (see tests below) I just removed it because it wasn't documented and wasn't working.
# s6 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) | ||
# c6 = sample(gdemo_default, s6, N) | ||
|
||
# NOTE: Broken because MH doesn't really follow the `logdensity` interface, but calls | ||
# it with `NamedTuple` instead of `AbstractVector`. | ||
# s7 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.StaticProposal)) | ||
# c7 = sample(gdemo_default, s7, N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these were the corresponding tests
@testset "MH link/invlink" begin | ||
vi_base = DynamicPPL.VarInfo(gdemo_default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests covered by requires_unconstrained_space
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0]; atol=0.3) | ||
@test mean(chain[:a]) ≈ 0.0 atol = 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test was annoyingly difficult to get right. with the gdemo model i had to bump this up to 50000 samples to get it to work reasonably often. I replaced it with a simpler model for which it's obvious when it's being sampled from the prior (a = 0) rather than the posterior (a = 2.5).
See #2555
Pending #2588