20
20
│ 1 │ m │ 0.824853 │
21
21
```
22
22
"""
23
- struct ESS <: InferenceAlgorithm end
23
+ struct ESS <: AbstractSampler end
24
+
25
+ DynamicPPL. initialsampler (:: ESS ) = DynamicPPL. SampleFromPrior ()
26
+ update_sample_kwargs (:: ESS , :: Integer , kwargs) = kwargs
27
+ get_adtype (:: ESS ) = nothing
28
+ requires_unconstrained_space (:: ESS ) = false
24
29
25
30
# always accept in the first step
26
- function DynamicPPL. initialstep (
27
- rng:: AbstractRNG , model:: Model , spl:: Sampler{<:ESS} , vi:: AbstractVarInfo ; kwargs...
28
- )
31
+ function AbstractMCMC. step (rng:: AbstractRNG , ldf:: LogDensityFunction , spl:: ESS ; kwargs... )
32
+ vi = ldf. varinfo
29
33
for vn in keys (vi)
30
34
dist = getdist (vi, vn)
31
35
EllipticalSliceSampling. isgaussian (typeof (dist)) ||
32
36
error (" ESS only supports Gaussian prior distributions" )
33
37
end
34
- return Transition (model, vi), vi
38
+ return Transition (ldf . model, vi), vi
35
39
end
36
40
37
41
function AbstractMCMC. step (
38
- rng:: AbstractRNG , model :: Model , spl:: Sampler{<: ESS} , vi:: AbstractVarInfo ; kwargs...
42
+ rng:: AbstractRNG , ldf :: LogDensityFunction , spl:: ESS , vi:: AbstractVarInfo ; kwargs...
39
43
)
40
44
# obtain previous sample
41
45
f = vi[:]
@@ -45,14 +49,13 @@ function AbstractMCMC.step(
45
49
oldstate = EllipticalSliceSampling. ESSState (f, getlogp (vi), nothing )
46
50
47
51
# compute next state
52
+ # Note: `f_loglikelihood` effectively calculates the log-likelihood (not
53
+ # log-joint, despite the use of `LDP.logdensity`) because `tilde_assume` is
54
+ # overloaded on `SamplingContext(rng, ESS(), ...)` below.
55
+ f_loglikelihood = Base. Fix1 (LogDensityProblems. logdensity, ldf)
48
56
sample, state = AbstractMCMC. step (
49
57
rng,
50
- EllipticalSliceSampling. ESSModel (
51
- ESSPrior (model, spl, vi),
52
- DynamicPPL. LogDensityFunction (
53
- model, vi, DynamicPPL. SamplingContext (spl, DynamicPPL. DefaultContext ())
54
- ),
55
- ),
58
+ EllipticalSliceSampling. ESSModel (ESSPrior (ldf. model, spl, vi), f_loglikelihood),
56
59
EllipticalSliceSampling. ESS (),
57
60
oldstate,
58
61
)
@@ -61,67 +64,57 @@ function AbstractMCMC.step(
61
64
vi = DynamicPPL. unflatten (vi, sample)
62
65
vi = setlogp!! (vi, state. loglikelihood)
63
66
64
- return Transition (model, vi), vi
67
+ return Transition (ldf . model, vi), vi
65
68
end
66
69
67
70
# Prior distribution of considered random variable
68
- struct ESSPrior{M<: Model ,S <: Sampler{<:ESS} , V<: AbstractVarInfo ,T}
71
+ struct ESSPrior{M<: Model ,V<: AbstractVarInfo ,T}
69
72
model:: M
70
- sampler:: S
73
+ sampler:: ESS
71
74
varinfo:: V
72
75
μ:: T
73
76
74
- function ESSPrior {M,S, V} (
75
- model:: M , sampler:: S , varinfo:: V
76
- ) where {M<: Model ,S <: Sampler{<:ESS} , V<: AbstractVarInfo }
77
+ function ESSPrior {M,V} (
78
+ model:: M , sampler:: ESS , varinfo:: V
79
+ ) where {M<: Model ,V<: AbstractVarInfo }
77
80
vns = keys (varinfo)
78
81
μ = mapreduce (vcat, vns) do vn
79
82
dist = getdist (varinfo, vn)
80
83
EllipticalSliceSampling. isgaussian (typeof (dist)) ||
81
84
error (" [ESS] only supports Gaussian prior distributions" )
82
85
DynamicPPL. tovec (mean (dist))
83
86
end
84
- return new {M,S, V,typeof(μ)} (model, sampler, varinfo, μ)
87
+ return new {M,V,typeof(μ)} (model, sampler, varinfo, μ)
85
88
end
86
89
end
87
90
88
- function ESSPrior (model:: Model , sampler:: Sampler{<: ESS} , varinfo:: AbstractVarInfo )
89
- return ESSPrior {typeof(model),typeof(sampler),typeof( varinfo)} (model, sampler, varinfo)
91
+ function ESSPrior (model:: Model , sampler:: ESS , varinfo:: AbstractVarInfo )
92
+ return ESSPrior {typeof(model),typeof(varinfo)} (model, sampler, varinfo)
90
93
end
91
94
92
95
# Ensure that the prior is a Gaussian distribution (checked in the constructor)
93
96
EllipticalSliceSampling. isgaussian (:: Type{<:ESSPrior} ) = true
94
97
95
98
# Only define out-of-place sampling
96
99
function Base. rand (rng:: Random.AbstractRNG , p:: ESSPrior )
97
- sampler = p . sampler
98
- varinfo = p . varinfo
99
- # TODO : Surely there's a better way of doing this now that we have `SamplingContext`?
100
- vns = keys (varinfo)
101
- for vn in vns
102
- set_flag! (varinfo, vn, " del" )
100
+ # TODO (penelopeysm): This is ugly -- need to set 'del' flag because
101
+ # otherwise DynamicPPL.SampleWithPrior will just use the existing
102
+ # parameters in the varinfo. In general SampleWithPrior etc. need to be
103
+ # reworked.
104
+ for vn in keys (p . varinfo)
105
+ set_flag! (p . varinfo, vn, " del" )
103
106
end
104
- p. model (rng, varinfo, sampler)
105
- return varinfo [:]
107
+ _, vi = DynamicPPL . evaluate!! ( p. model, p . varinfo, SamplingContext (rng, p . sampler) )
108
+ return vi [:]
106
109
end
107
110
108
111
# Mean of prior distribution
109
112
Distributions. mean (p:: ESSPrior ) = p. μ
110
113
111
- # Evaluate log-likelihood of proposals
112
- const ESSLogLikelihood{M<: Model ,S<: Sampler{<:ESS} ,V<: AbstractVarInfo } =
113
- DynamicPPL. LogDensityFunction{M,V,<: DynamicPPL.SamplingContext{<:S} ,AD} where {AD}
114
-
115
- (ℓ:: ESSLogLikelihood )(f:: AbstractVector ) = LogDensityProblems. logdensity (ℓ, f)
116
-
117
114
function DynamicPPL. tilde_assume (
118
- rng:: Random.AbstractRNG , :: DefaultContext , :: Sampler{<: ESS} , right, vn, vi
115
+ rng:: Random.AbstractRNG , :: DefaultContext , :: ESS , right, vn, vi
119
116
)
120
117
return DynamicPPL. tilde_assume (
121
118
rng, LikelihoodContext (), SampleFromPrior (), right, vn, vi
122
119
)
123
120
end
124
-
125
- function DynamicPPL. tilde_observe (ctx:: DefaultContext , :: Sampler{<:ESS} , right, left, vi)
126
- return DynamicPPL. tilde_observe (ctx, SampleFromPrior (), right, left, vi)
127
- end
0 commit comments