Skip to content

Commit 681e472

Browse files
Replace instances of @submodel with to_submodel() (#751)
* Replace remaining instances of @SubModel * Implement tilde_assume!! for Sampleable / pointwise contexts * Fix typos in test * Bump patch version * Re-add some minimal tests for deprecated @SubModel * Format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 8972b98 commit 681e472

8 files changed

+102
-51
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.32.0"
3+
version = "0.32.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/contexts.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ adds the `Prefix` to all parameters.
244244
This context is useful in nested models to ensure that the names of the parameters are
245245
unique.
246246
247-
See also: [`@submodel`](@ref)
247+
See also: [`to_submodel`](@ref)
248248
"""
249249
struct PrefixContext{Prefix,C} <: AbstractContext
250250
context::C

src/pointwise_logdensities.jl

+20
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,26 @@ function _pointwise_tilde_observe(
146146
end
147147
end
148148

149+
# Note on submodels (penelopeysm)
150+
#
151+
# We don't need to overload tilde_observe!! for Sampleables (yet), because it
152+
# is currently not possible to evaluate a model with a Sampleable on the RHS
153+
# of an observe statement.
154+
#
155+
# Note that calling tilde_assume!! on a Sampleable does not necessarily imply
156+
# that there are no observe statements inside the Sampleable. There could well
157+
# be likelihood terms in there, which must be included in the returned logp.
158+
# See e.g. the `demo_dot_assume_observe_submodel` demo model.
159+
#
160+
# This is handled by passing the same context to rand_like!!, which figures out
161+
# which terms to include using the context, and also mutates the context and vi
162+
# appropriately. Thus, we don't need to check against _include_prior(context)
163+
# here.
164+
function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi)
165+
value, vi = DynamicPPL.rand_like!!(right, context, vi)
166+
return value, vi
167+
end
168+
149169
function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi)
150170
!_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi))
151171
value, logp, vi = tilde_assume(context.context, right, vn, vi)

src/test_utils/models.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ end
437437

438438
@model function demo_assume_submodel_observe_index_literal()
439439
# Submodel prior
440-
@submodel s, m = _prior_dot_assume()
440+
priors ~ to_submodel(_prior_dot_assume(), false)
441+
s, m = priors
441442
1.5 ~ Normal(m[1], sqrt(s[1]))
442443
2.0 ~ Normal(m[2], sqrt(s[2]))
443444

@@ -462,7 +463,7 @@ function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal
462463
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
463464
end
464465

465-
@model function _likelihood_mltivariate_observe(s, m, x)
466+
@model function _likelihood_multivariate_observe(s, m, x)
466467
return x ~ MvNormal(m, Diagonal(s))
467468
end
468469

@@ -475,7 +476,9 @@ end
475476
m .~ Normal.(0, sqrt.(s))
476477

477478
# Submodel likelihood
478-
@submodel _likelihood_mltivariate_observe(s, m, x)
479+
# With to_submodel, we have to have a left-hand side variable to
480+
# capture the result, so we just use a dummy variable
481+
_ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x))
479482

480483
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
481484
end

test/compiler.jl

+11-41
Original file line numberDiff line numberDiff line change
@@ -382,34 +382,13 @@ module Issue537 end
382382
@test demo2()() == 42
383383
end
384384

385-
@testset "@submodel is deprecated" begin
386-
@model inner() = x ~ Normal()
387-
@model outer() = @submodel x = inner()
388-
@test_logs(
389-
(
390-
:warn,
391-
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
392-
),
393-
outer()()
394-
)
395-
396-
@model outer_with_prefix() = @submodel prefix = "sub" x = inner()
397-
@test_logs(
398-
(
399-
:warn,
400-
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
401-
),
402-
outer_with_prefix()()
403-
)
404-
end
405-
406-
@testset "submodel" begin
385+
@testset "to_submodel" begin
407386
# No prefix, 1 level.
408387
@model function demo1(x)
409388
return x ~ Normal()
410389
end
411390
@model function demo2(x, y)
412-
@submodel demo1(x)
391+
_ignore ~ to_submodel(demo1(x), false)
413392
return y ~ Uniform()
414393
end
415394
# No observation.
@@ -441,7 +420,7 @@ module Issue537 end
441420

442421
# Check values makes sense.
443422
@model function demo3(x, y)
444-
@submodel demo1(x)
423+
_ignore ~ to_submodel(demo1(x), false)
445424
return y ~ Normal(x)
446425
end
447426
m = demo3(1000.0, missing)
@@ -453,12 +432,10 @@ module Issue537 end
453432
x ~ Normal()
454433
return x
455434
end
456-
457435
@model function demo_useval(x, y)
458-
@submodel prefix = "sub1" x1 = demo_return(x)
459-
@submodel prefix = "sub2" x2 = demo_return(y)
460-
461-
return z ~ Normal(x1 + x2 + 100, 1.0)
436+
sub1 ~ to_submodel(demo_return(x))
437+
sub2 ~ to_submodel(demo_return(y))
438+
return z ~ Normal(sub1 + sub2 + 100, 1.0)
462439
end
463440
m = demo_useval(missing, missing)
464441
vi = VarInfo(m)
@@ -472,21 +449,18 @@ module Issue537 end
472449
@model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV}
473450
η ~ MvNormal(zeros(num_steps), I)
474451
δ = sqrt(1 - α^2)
475-
476452
x = TV(undef, num_steps)
477453
x[1] = η[1]
478454
@inbounds for t in 2:num_steps
479455
x[t] = @. α * x[t - 1] + δ * η[t]
480456
end
481-
482457
return @. μ + σ * x
483458
end
484459

485460
@model function demo(y)
486461
α ~ Uniform()
487462
μ ~ Normal()
488463
σ ~ truncated(Normal(), 0, Inf)
489-
490464
num_steps = length(y[1])
491465
num_obs = length(y)
492466
@inbounds for i in 1:num_obs
@@ -613,14 +587,11 @@ module Issue537 end
613587
@model demo() = x ~ Normal()
614588
retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext())
615589

616-
# Return-value when using `@submodel`
590+
# Return-value when using `to_submodel`
617591
@model inner() = x ~ Normal()
618-
# Without assignment.
619-
@model outer() = @submodel inner()
620-
@test outer()() isa Real
621-
622-
# With assignment.
623-
@model outer() = @submodel x = inner()
592+
@model function outer()
593+
return _ignore ~ to_submodel(inner())
594+
end
624595
@test outer()() isa Real
625596

626597
# Edge-cases.
@@ -720,8 +691,7 @@ module Issue537 end
720691
return (; x, y)
721692
end
722693
@model function demo_tracked_submodel()
723-
@submodel (x, y) = demo_tracked()
724-
return (; x, y)
694+
return vals ~ to_submodel(demo_tracked(), false)
725695
end
726696
for model in [demo_tracked(), demo_tracked_submodel()]
727697
# Make sure it's runnable and `y` is present in the return-value.

test/deprecated.jl

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
@testset "deprecated" begin
2+
@testset "@submodel" begin
3+
@testset "is deprecated" begin
4+
@model inner() = x ~ Normal()
5+
@model outer() = @submodel x = inner()
6+
@test_logs(
7+
(
8+
:warn,
9+
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
10+
),
11+
outer()()
12+
)
13+
14+
@model outer_with_prefix() = @submodel prefix = "sub" x = inner()
15+
@test_logs(
16+
(
17+
:warn,
18+
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
19+
),
20+
outer_with_prefix()()
21+
)
22+
end
23+
24+
@testset "prefixing still works correctly" begin
25+
@model inner() = x ~ Normal()
26+
@model function outer()
27+
a = @submodel inner()
28+
b = @submodel prefix = "sub" inner()
29+
return a, b
30+
end
31+
@test outer()() isa Tuple{Float64,Float64}
32+
vi = VarInfo(outer())
33+
@test @varname(x) in keys(vi)
34+
@test @varname(var"sub.x") in keys(vi)
35+
end
36+
37+
@testset "logp is still accumulated properly" begin
38+
@model inner_assume() = x ~ Normal()
39+
@model inner_observe(x, y) = y ~ Normal(x)
40+
@model function outer(b)
41+
a = @submodel inner_assume()
42+
@submodel inner_observe(a, b)
43+
end
44+
y_val = 1.0
45+
model = outer(y_val)
46+
@test model() == y_val
47+
48+
x_val = 1.5
49+
vi = VarInfo(outer(y_val))
50+
DynamicPPL.setindex!!(vi, x_val, @varname(x))
51+
@test logprior(model, vi) logpdf(Normal(), x_val)
52+
@test loglikelihood(model, vi) logpdf(Normal(x_val), y_val)
53+
@test logjoint(model, vi)
54+
logpdf(Normal(), x_val) + logpdf(Normal(x_val), y_val)
55+
end
56+
end
57+
end

test/pointwise_logdensities.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(
1414
model, example_values...
1515
)
16-
logp_true = logprior(model, vi)
16+
logprior_true = logprior(model, vi)
1717

1818
# Compute the pointwise loglikelihoods.
1919
lls = pointwise_loglikelihoods(model, vi)
@@ -30,18 +30,18 @@
3030
lps_prior = pointwise_prior_logdensities(model, vi)
3131
@test :x DynamicPPL.getsym.(keys(lps_prior))
3232
logp = sum(sum, values(lps_prior))
33-
@test logp logp_true
33+
@test logp logprior_true
3434

3535
# Compute both likelihood and logdensity of prior
36-
# using the default DefaultContex
36+
# using the default DefaultContext
3737
lps = pointwise_logdensities(model, vi)
3838
logp = sum(sum, values(lps))
39-
@test logp (logp_true + loglikelihood_true)
39+
@test logp (logprior_true + loglikelihood_true)
4040

4141
# Test that modifications of Setup are picked up
4242
lps = pointwise_logdensities(model, vi, mod_ctx2)
4343
logp = sum(sum, values(lps))
44-
@test logp (logp_true + loglikelihood_true) * 1.2 * 1.4
44+
@test logp (logprior_true + loglikelihood_true) * 1.2 * 1.4
4545
end
4646
end
4747

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ include("test_util.jl")
5959
include("serialization.jl")
6060
include("pointwise_logdensities.jl")
6161
include("lkj.jl")
62+
include("deprecated.jl")
6263
end
6364

6465
if GROUP == "All" || GROUP == "Group2"

0 commit comments

Comments
 (0)