Skip to content

Commit 5c8fc59

Browse files
penelopeysmmhauru
andauthored
Remove Turing integration tests (#733)
* Remove Turing integration tests * Don't run removed tests in CI * Add demo_assume_literal_observe + rename demo_assume_observe_literal -> demo_assume_multivariate_observe_literal * Rename models * Implement make_chain_from_prior([rng,] model, n_iters) * Bump minor version * Re-add almost all integration tests into DPPL test suite proper * Remove deprecated `link!` / `invlink!` * Add comments to value_iterator_from_chain test * Remove outdated comment Co-authored-by: Markus Hauru <[email protected]> * Remove link!/invlink! from threadsafe.jl * Update API docs on transformations * Update API docs page --------- Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Markus Hauru <[email protected]>
1 parent 0f07520 commit 5c8fc59

21 files changed

+532
-1076
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ jobs:
7373

7474
- uses: julia-actions/julia-runtest@v1
7575
env:
76-
GROUP: All
7776
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}
7877

7978
- uses: julia-actions/julia-processcoverage@v1

.github/workflows/CompatHelper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ jobs:
1414
env:
1515
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
1616
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
17-
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test", "test/turing"])'
17+
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test"])'

.github/workflows/JuliaPre.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,3 @@ jobs:
2525
- uses: julia-actions/cache@v2
2626
- uses: julia-actions/julia-buildpkg@v1
2727
- uses: julia-actions/julia-runtest@v1
28-
env:
29-
GROUP: DynamicPPL

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.31.5"
3+
version = "0.32.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,9 @@ VarInfo
279279
TypedVarInfo
280280
```
281281

282-
One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.
283-
284-
```@docs
285-
link!
286-
invlink!
287-
```
282+
One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [transformation page](internals/transformations.md).
283+
The [Transformations section below](#Transformations) describes the methods used for this.
284+
In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.
288285

289286
```@docs
290287
set_flag!

src/test_utils/models.jl

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -323,28 +323,30 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)})
323323
return [@varname(s), @varname(m)]
324324
end
325325

326-
@model function demo_assume_observe_literal()
327-
# `assume` and literal `observe`
326+
@model function demo_assume_multivariate_observe_literal()
327+
# multivariate `assume` and literal `observe`
328328
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
329329
m ~ MvNormal(zeros(2), Diagonal(s))
330330
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))
331331

332332
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
333333
end
334-
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
334+
function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m)
335335
s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
336336
m_dist = MvNormal(zeros(2), Diagonal(s))
337337
return logpdf(s_dist, s) + logpdf(m_dist, m)
338338
end
339-
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
339+
function loglikelihood_true(
340+
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
341+
)
340342
return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0])
341343
end
342344
function logprior_true_with_logabsdet_jacobian(
343-
model::Model{typeof(demo_assume_observe_literal)}, s, m
345+
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
344346
)
345347
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
346348
end
347-
function varnames(model::Model{typeof(demo_assume_observe_literal)})
349+
function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)})
348350
return [@varname(s), @varname(m)]
349351
end
350352

@@ -377,26 +379,50 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)})
377379
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
378380
end
379381

380-
@model function demo_assume_literal_dot_observe()
382+
@model function demo_assume_observe_literal()
383+
# univariate `assume` and literal `observe`
384+
s ~ InverseGamma(2, 3)
385+
m ~ Normal(0, sqrt(s))
386+
1.5 ~ Normal(m, sqrt(s))
387+
2.0 ~ Normal(m, sqrt(s))
388+
389+
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
390+
end
391+
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
392+
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
393+
end
394+
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
395+
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
396+
end
397+
function logprior_true_with_logabsdet_jacobian(
398+
model::Model{typeof(demo_assume_observe_literal)}, s, m
399+
)
400+
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
401+
end
402+
function varnames(model::Model{typeof(demo_assume_observe_literal)})
403+
return [@varname(s), @varname(m)]
404+
end
405+
406+
@model function demo_assume_dot_observe_literal()
381407
# `assume` and literal `dot_observe`
382408
s ~ InverseGamma(2, 3)
383409
m ~ Normal(0, sqrt(s))
384410
[1.5, 2.0] .~ Normal(m, sqrt(s))
385411

386412
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
387413
end
388-
function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m)
414+
function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m)
389415
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
390416
end
391-
function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m)
417+
function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m)
392418
return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0])
393419
end
394420
function logprior_true_with_logabsdet_jacobian(
395-
model::Model{typeof(demo_assume_literal_dot_observe)}, s, m
421+
model::Model{typeof(demo_assume_dot_observe_literal)}, s, m
396422
)
397423
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
398424
end
399-
function varnames(model::Model{typeof(demo_assume_literal_dot_observe)})
425+
function varnames(model::Model{typeof(demo_assume_dot_observe_literal)})
400426
return [@varname(s), @varname(m)]
401427
end
402428

@@ -574,8 +600,9 @@ const DemoModels = Union{
574600
Model{typeof(demo_assume_multivariate_observe)},
575601
Model{typeof(demo_dot_assume_observe_index)},
576602
Model{typeof(demo_assume_dot_observe)},
577-
Model{typeof(demo_assume_literal_dot_observe)},
603+
Model{typeof(demo_assume_dot_observe_literal)},
578604
Model{typeof(demo_assume_observe_literal)},
605+
Model{typeof(demo_assume_multivariate_observe_literal)},
579606
Model{typeof(demo_dot_assume_observe_index_literal)},
580607
Model{typeof(demo_assume_submodel_observe_index_literal)},
581608
Model{typeof(demo_dot_assume_observe_submodel)},
@@ -585,7 +612,9 @@ const DemoModels = Union{
585612
}
586613

587614
const UnivariateAssumeDemoModels = Union{
588-
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
615+
Model{typeof(demo_assume_dot_observe)},
616+
Model{typeof(demo_assume_dot_observe_literal)},
617+
Model{typeof(demo_assume_observe_literal)},
589618
}
590619
function posterior_mean(model::UnivariateAssumeDemoModels)
591620
return (s=49 / 24, m=7 / 6)
@@ -609,7 +638,7 @@ const MultivariateAssumeDemoModels = Union{
609638
Model{typeof(demo_assume_index_observe)},
610639
Model{typeof(demo_assume_multivariate_observe)},
611640
Model{typeof(demo_dot_assume_observe_index)},
612-
Model{typeof(demo_assume_observe_literal)},
641+
Model{typeof(demo_assume_multivariate_observe_literal)},
613642
Model{typeof(demo_dot_assume_observe_index_literal)},
614643
Model{typeof(demo_assume_submodel_observe_index_literal)},
615644
Model{typeof(demo_dot_assume_observe_submodel)},
@@ -759,9 +788,10 @@ const DEMO_MODELS = (
759788
demo_assume_multivariate_observe(),
760789
demo_dot_assume_observe_index(),
761790
demo_assume_dot_observe(),
762-
demo_assume_observe_literal(),
791+
demo_assume_multivariate_observe_literal(),
763792
demo_dot_assume_observe_index_literal(),
764-
demo_assume_literal_dot_observe(),
793+
demo_assume_dot_observe_literal(),
794+
demo_assume_observe_literal(),
765795
demo_assume_submodel_observe_index_literal(),
766796
demo_dot_assume_observe_submodel(),
767797
demo_dot_assume_dot_observe_matrix(),

src/threadsafe.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
7979
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
8080
haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)
8181

82-
link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl)
83-
invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl)
8482
islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)
8583

8684
function link!!(

src/varinfo.jl

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,27 +1221,6 @@ function link!!(
12211221
return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model)
12221222
end
12231223

1224-
"""
1225-
link!(vi::VarInfo, spl::Sampler)
1226-
1227-
Transform the values of the random variables sampled by `spl` in `vi` from the support
1228-
of their distributions to the Euclidean space and set their corresponding `"trans"`
1229-
flag values to `true`.
1230-
"""
1231-
function link!(vi::VarInfo, spl::AbstractSampler)
1232-
Base.depwarn(
1233-
"`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.",
1234-
:link!,
1235-
)
1236-
return _link!(vi, spl)
1237-
end
1238-
function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val)
1239-
Base.depwarn(
1240-
"`link!(varinfo, sampler, spaceval)` is deprecated, use `link!!(varinfo, sampler, model)` instead.",
1241-
:link!,
1242-
)
1243-
return _link!(vi, spl, spaceval)
1244-
end
12451224
function _link!(vi::UntypedVarInfo, spl::AbstractSampler)
12461225
# TODO: Change to a lazy iterator over `vns`
12471226
vns = _getvns(vi, spl)
@@ -1319,29 +1298,6 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode
13191298
return maybe_invlink_before_eval!!(t, vi, context, model)
13201299
end
13211300

1322-
"""
1323-
invlink!(vi::VarInfo, spl::AbstractSampler)
1324-
1325-
Transform the values of the random variables sampled by `spl` in `vi` from the
1326-
Euclidean space back to the support of their distributions and sets their corresponding
1327-
`"trans"` flag values to `false`.
1328-
"""
1329-
function invlink!(vi::VarInfo, spl::AbstractSampler)
1330-
Base.depwarn(
1331-
"`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.",
1332-
:invlink!,
1333-
)
1334-
return _invlink!(vi, spl)
1335-
end
1336-
1337-
function invlink!(vi::VarInfo, spl::AbstractSampler, spaceval::Val)
1338-
Base.depwarn(
1339-
"`invlink!(varinfo, sampler, spaceval)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.",
1340-
:invlink!,
1341-
)
1342-
return _invlink!(vi, spl, spaceval)
1343-
end
1344-
13451301
function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler)
13461302
vns = _getvns(vi, spl)
13471303
if istrans(vi, vns[1])

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2020
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2121
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2222
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
23+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2324
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2425
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2526
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

test/ad.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,43 @@
3434
end
3535
end
3636
end
37+
38+
@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
39+
# Failing model
40+
t = 1:0.05:8
41+
σ = 0.3
42+
y = @. rand(sin(t) + Normal(0, σ))
43+
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
44+
# Priors
45+
α ~ Normal(y[1], 0.001)
46+
τ ~ Exponential(1)
47+
η ~ filldist(Normal(0, 1), TT - 1)
48+
σ ~ Exponential(1)
49+
# create latent variable
50+
x = Vector{T}(undef, TT)
51+
x[1] = α
52+
for t in 2:TT
53+
x[t] = x[t - 1] + η[t - 1] * τ
54+
end
55+
# measurement model
56+
y ~ MvNormal(x, σ^2 * I)
57+
return x
58+
end
59+
model = state_space(y, length(t))
60+
61+
# Dummy sampling algorithm for testing. The test case can only be replicated
62+
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
63+
# overload assume so that model evaluation doesn't fail due to a lack
64+
# of implementation
65+
struct MyEmptyAlg end
66+
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
67+
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
68+
DynamicPPL.assume(dist, vn, vi)
69+
70+
# Compiling the ReverseDiff tape used to fail here
71+
spl = Sampler(MyEmptyAlg())
72+
vi = VarInfo(model)
73+
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
74+
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
75+
end
3776
end

0 commit comments

Comments
 (0)