Skip to content

Commit cba7ed7

Browse files
committed
Add demo_assume_literal_observe + rename demo_assume_observe_literal -> demo_assume_multivariate_observe_literal
1 parent a53e37f commit cba7ed7

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

src/test_utils/models.jl

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,28 +323,28 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)})
323323
return [@varname(s), @varname(m)]
324324
end
325325

326-
@model function demo_assume_observe_literal()
327-
# `assume` and literal `observe`
326+
@model function demo_assume_multivariate_observe_literal()
327+
# multivariate `assume` and literal `observe`
328328
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
329329
m ~ MvNormal(zeros(2), Diagonal(s))
330330
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))
331331

332332
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
333333
end
334-
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
334+
function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m)
335335
s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
336336
m_dist = MvNormal(zeros(2), Diagonal(s))
337337
return logpdf(s_dist, s) + logpdf(m_dist, m)
338338
end
339-
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
339+
function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m)
340340
return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0])
341341
end
342342
function logprior_true_with_logabsdet_jacobian(
343-
model::Model{typeof(demo_assume_observe_literal)}, s, m
343+
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
344344
)
345345
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
346346
end
347-
function varnames(model::Model{typeof(demo_assume_observe_literal)})
347+
function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)})
348348
return [@varname(s), @varname(m)]
349349
end
350350

@@ -377,6 +377,30 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)})
377377
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
378378
end
379379

380+
@model function demo_assume_literal_observe()
381+
# univariate `assume` and literal `observe`
382+
s ~ InverseGamma(2, 3)
383+
m ~ Normal(0, sqrt(s))
384+
1.5 ~ Normal(m, sqrt(s))
385+
2.0 ~ Normal(m, sqrt(s))
386+
387+
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
388+
end
389+
function logprior_true(model::Model{typeof(demo_assume_literal_observe)}, s, m)
390+
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
391+
end
392+
function loglikelihood_true(model::Model{typeof(demo_assume_literal_observe)}, s, m)
393+
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
394+
end
395+
function logprior_true_with_logabsdet_jacobian(
396+
model::Model{typeof(demo_assume_literal_observe)}, s, m
397+
)
398+
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
399+
end
400+
function varnames(model::Model{typeof(demo_assume_literal_observe)})
401+
return [@varname(s), @varname(m)]
402+
end
403+
380404
@model function demo_assume_literal_dot_observe()
381405
# `assume` and literal `dot_observe`
382406
s ~ InverseGamma(2, 3)
@@ -575,7 +599,8 @@ const DemoModels = Union{
575599
Model{typeof(demo_dot_assume_observe_index)},
576600
Model{typeof(demo_assume_dot_observe)},
577601
Model{typeof(demo_assume_literal_dot_observe)},
578-
Model{typeof(demo_assume_observe_literal)},
602+
Model{typeof(demo_assume_literal_observe)},
603+
Model{typeof(demo_assume_multivariate_observe_literal)},
579604
Model{typeof(demo_dot_assume_observe_index_literal)},
580605
Model{typeof(demo_assume_submodel_observe_index_literal)},
581606
Model{typeof(demo_dot_assume_observe_submodel)},
@@ -585,7 +610,9 @@ const DemoModels = Union{
585610
}
586611

587612
const UnivariateAssumeDemoModels = Union{
588-
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
613+
Model{typeof(demo_assume_dot_observe)},
614+
Model{typeof(demo_assume_literal_dot_observe)}
615+
Model{typeof(demo_assume_literal_observe)}
589616
}
590617
function posterior_mean(model::UnivariateAssumeDemoModels)
591618
return (s=49 / 24, m=7 / 6)
@@ -609,7 +636,7 @@ const MultivariateAssumeDemoModels = Union{
609636
Model{typeof(demo_assume_index_observe)},
610637
Model{typeof(demo_assume_multivariate_observe)},
611638
Model{typeof(demo_dot_assume_observe_index)},
612-
Model{typeof(demo_assume_observe_literal)},
639+
Model{typeof(demo_assume_multivariate_observe_literal)},
613640
Model{typeof(demo_dot_assume_observe_index_literal)},
614641
Model{typeof(demo_assume_submodel_observe_index_literal)},
615642
Model{typeof(demo_dot_assume_observe_submodel)},
@@ -759,9 +786,10 @@ const DEMO_MODELS = (
759786
demo_assume_multivariate_observe(),
760787
demo_dot_assume_observe_index(),
761788
demo_assume_dot_observe(),
762-
demo_assume_observe_literal(),
789+
demo_assume_multivariate_observe_literal(),
763790
demo_dot_assume_observe_index_literal(),
764791
demo_assume_literal_dot_observe(),
792+
demo_assume_literal_observe(),
765793
demo_assume_submodel_observe_index_literal(),
766794
demo_dot_assume_observe_submodel(),
767795
demo_dot_assume_dot_observe_matrix(),

0 commit comments

Comments
 (0)