Skip to content

Commit 2344689

Browse files
penelopeysmtorfjeldegithub-actions[bot]
authored
Fixing CI from #711 (#729)
* Fix wrong function being called * Don't test setchildcontext on leaf contexts * Update src/test_utils/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed method ambiguities for DebugContext * test the context interface for DebugContext on multiple models * Update src/debug_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d635a17 commit 2344689

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

src/debug_utils.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,9 @@ function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
332332
record_post_tilde_assume!(context, vn, right, value, logp, vi)
333333
return value, logp, vi
334334
end
335-
function DynamicPPL.tilde_assume(rng, context::DebugContext, sampler, right, vn, vi)
335+
function DynamicPPL.tilde_assume(
336+
rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
337+
)
336338
record_pre_tilde_assume!(context, vn, right, vi)
337339
value, logp, vi = DynamicPPL.tilde_assume(
338340
rng, childcontext(context), sampler, right, vn, vi
@@ -425,7 +427,7 @@ function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi)
425427
end
426428

427429
function DynamicPPL.dot_tilde_assume(
428-
rng, context::DebugContext, sampler, right, left, vn, vi
430+
rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi
429431
)
430432
record_pre_dot_tilde_assume!(context, vn, left, right, vi)
431433
value, logp, vi = DynamicPPL.dot_tilde_assume(

src/test_utils/contexts.jl

+15-15
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,6 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
7171
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
7272
throw(ValueError("Invalid NodeTrait: $node_trait"))
7373

74-
# The interface methods.
75-
if node_trait isa DynamicPPL.IsParent
76-
# `childcontext` and `setchildcontext`
77-
# With new child context
78-
childcontext_new = TestParentContext()
79-
@test DynamicPPL.childcontext(
80-
DynamicPPL.setchildcontext(context, childcontext_new)
81-
) == childcontext_new
82-
end
83-
8474
# To see change, let's make sure we're using a different leaf context than the current.
8575
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
8676
PriorContext()
@@ -90,11 +80,21 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
9080
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
9181
leafcontext_new
9282

93-
# Setting the child context to a leaf should now change the leafcontext accordingly.
94-
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
95-
@test DynamicPPL.setchildcontext(context_with_new_leaf) ===
96-
DynamicPPL.setleafcontext(context_with_new_leaf) ===
97-
leafcontext_new
83+
# The interface methods.
84+
if node_trait isa DynamicPPL.IsParent
85+
# `childcontext` and `setchildcontext`
86+
# With new child context
87+
childcontext_new = TestParentContext()
88+
@test DynamicPPL.childcontext(
89+
DynamicPPL.setchildcontext(context, childcontext_new)
90+
) == childcontext_new
91+
# Setting the child context to a leaf should now change the leafcontext
92+
# accordingly.
93+
context_with_new_leaf = DynamicPPL.setchildcontext(context, leafcontext_new)
94+
@test DynamicPPL.childcontext(context_with_new_leaf) ===
95+
DynamicPPL.leafcontext(context_with_new_leaf) ===
96+
leafcontext_new
97+
end
9898

9999
# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
100100
# The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it.

test/debug_utils.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
@testset "check_model" begin
22
@testset "context interface" begin
3-
# HACK: Require a model to instantiate it, so let's just grab one.
4-
model = first(DynamicPPL.TestUtils.DEMO_MODELS)
5-
context = DynamicPPL.DebugUtils.DebugContext(model)
6-
DynamicPPL.TestUtils.test_context(context, model)
3+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
4+
context = DynamicPPL.DebugUtils.DebugContext(model)
5+
DynamicPPL.TestUtils.test_context(context, model)
6+
end
77
end
88

99
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@@ -14,7 +14,7 @@
1414
# Check that the trace contains all the variables in the model.
1515
varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace)
1616
for vn in DynamicPPL.TestUtils.varnames(model)
17-
@test vn in varnames_in_trace
17+
@test vn in varnames_in_traces
1818
end
1919

2020
# Quick checks for `show` of trace.

0 commit comments

Comments
 (0)