Skip to content

Commit 1744ba7

Browse files
torfjeldedevmotiongithub-actions[bot]yebaiphipsgabler
authored
Reopening of previous PR (#309)
* fixed some signatures for Model * fixed a method call * fixed method signatures * sort of fixed the matchingvalue functionality for model * formatting * removed redundant _tilde method * removed left-over acclogp! that should not be here anymore * export SamplingContext * use context instead of ctx to refer to contexts * formatting * use context instead of ctx for variables * use context instead of ctx to refer to contexts * Update src/compiler.jl Co-authored-by: David Widmann <[email protected]> * Update src/context_implementations.jl Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * added some whitespace to some docstrings * deprecated tilde and dot_tilde plus exported new versions * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * minor version bump * added impl of matchingvalue for contexts * reverted the change that makes assume always resample * removed the inds arguments from assume and dot_assume to stay non-breaking * Update src/context_implementations.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added missing sampler arg to tilde_observe * added missing sampler argument in dot_tilde_observe * fixed order of arguments in some dot_assume calls * formatting * formatting * added missing sampler argument in tilde_observe for SamplingContext * added missing word in a docstring * updated submodel macro * removed unwrap_childcontext and related since its not needed for this PR * updated submodel macro * fixed evaluation implementations of dot_assume * updated pointwise_loglikelihoods and related * added proper tests for pointwise_loglikelihoods * updated DPPL tests to reflect recent changes * bump minor version since this will be breaking * formatting * formatting * renamed mean_of_mean_models used in tests * bumped dppl version in integration tests * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * fixed ambiguity error * Introduction of `SamplingContext`: keeping it simple (#259) This is #253 but the only motivation here is to get `SamplingContext` in, nothing relating to interactions with other contexts, etc. Co-authored-by: Hong Ge <[email protected]> * Update src/DynamicPPL.jl Co-authored-by: David Widmann <[email protected]> * added initial impl of SimpleVarInfo * remove unnecessary debug statements to be compat with Zygote * make reconstruct slightly more generic * added a couple of convenience constructors * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * small fix * return var_info from tilde-statements, allowing impl of immutable versions * allow usage of non-Ref types in SimpleVarInfo * update submodel-macro * formatting and docstring for submodel-macro * attempt at supporting implicit returns too * added a small comment * simplifed submodel macro a bit * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed typo * use bang-bang convention * updated PointwiseLikelihoodContext * fixed issue where we unnecessarily replace the return-statement * check subtype in the retval * formatting * fixed type-instability in retval check * introduced evaluate method for model * remove unnecessary type-requirement * make return-value check much nicer * removed redundant creation of anonymous function * dont use UnionAll in return_values * updated tests for submodel to reflect new syntax * moved to using BangBang-convention for most methods * remove SimpleVarInfo from this branch * added a comment * reverted submodel macro to use = rather than ~ * updated SimpleVarInfo impl * added a couple of missing deprecations * updated tests * updated implementations of logjoint and others * formatting * added eltype impl for SimpleVarInfo * formatting * fixed eltype for SimpleVarInfo * implement setindex!! in prep for allowing sampling with immutable vi * formatting * initial work on allowing sampling using SimpleVarInfo * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add constructor for SimpleVarInfo using model * improved leftover to_namedtuple_expr, fixing a bug when used with Zygote * bumped patch version * fixed set_flag!! * forgot the return in the replace_returns * bigboy update to benchmarks * fixed some issues and added support for usage of Dict in SimpleVarInfo * added docstring and improved indexing behvaior for SimpleVarInfo * formatting * dont allow sampling with indexing when using SimpleVarInfo with NamedTuple unless shapes are specified * _setval_kernel and others are only supported by VarInfo atm * fixed typo in comment * added more values_as impls * removed redundant values_from_metadata * fixed bug in push!! for SimpleVarInfo * forgot which branch Im on * added handling of short defs in replace_returns and more docstrings * fixed bug in generate_tilde introduced in a merge * fixed a bug in isfuncdef * fixed tests * formatting * uncomment mistakenly commented code * bumped version * updated doctests * dont carry over bang-bang versions that we dont need for general varinfos * Apply suggestions from @phipsgabler Co-authored-by: Philipp Gabler <[email protected]> * updated tests * removed unnecessary BangBang methods * fixed zygote rule for dot_observe * fixed Setfield.jl + returning VarInfo bug in model-macro * updated tests * fixed docs * formatting * fixed issues when using ThreadSafeVarInfo * fixed _pointwise_observe for ThreadSafeVarInfo * updated ThreadSafeVarInfo * made SimpleVarInfo compat with ThreadSafeVarInfo and added show * added some tests for return-values of models * formatting * fixed doctest for SimpleVarInfo * formatting * removed comparison of show from doctest for SimpleVarInfo * Update src/compiler.jl Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * removed OrderedCollections from docs * some additional fixes * fixed method ambiguity and some ill-defined map * renamed evaluate to evaluate!! * added implementations of haskey, getindex and setindex!! for SimpleVarInfo * formatting * dropped redundant definition * use getproperty instead of getindex * fixed method-ambiguity and added some comments * fixed docstring of SimpleVarInfo * fixed docstrings * fixed Project.toml for docs * fixed docstring of canview * fixed docstrings * another attempt at fixing docstrings * added a TODO comment * remove some output from docstring of SimpleVarInfo * fixed haskey and hasvalue for AbstractDict * updated some comments * updated some errors * added sampling dot_assume for SimpleVarInfo * added true versions of density computations to TestUtils * added tests specific for SimpleVarInfo * also document TestUtils * added TestUtils to docs * fixed setindex!! for SimpleVarInfo using AbstractDict * added more tests * formatting * dont use BangBang for setall! * revert unnecessary changes to settrans! * revert unnecessary changes to set_flag! * revert some changes to docstrings * fixed some comments and docstrings * added more convenient logjoint, logprior, and loglikelihood methods * removed unnecessary export * fixed export * use the Setfield impl of getindex, etc. as default and specialize on AbstractDict * fixed docstrings of logjoint, etc. * Apply suggestions from code review Co-authored-by: Philipp Gabler <[email protected]> * fixed docstring for model * replaced return_values by capturing return-value from tilde-statements instead * added some tests for return-value of model * added broadcast_foreach * Apply suggestions from @devmotion Co-authored-by: David Widmann <[email protected]> * remove broadcast_foreach for now * some fixes to ThreadSafeVarInfo * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * fixed docstrings * forgot qualification for set * formatting * added comment about why we cant use MacroTools.isdef * remove unnecessary deprecation * udpated some docstrings * fixed more docstrings * make overloads of BangBang methods qualified * remove overloading of values and instead use values_as without the type specified * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * renamed hasvalue for SimpleVarInfo to _haskey * revert changes from previous commit * minor version bump * fixed sampling with ThreadSafeVarInfo * fixed setindex!! for ThreadSafeVarInfo * fixed eltype for ThreadSafeVarInfo wrapping a SimpleVarInfo * fixed a test * relax atol in serialization tests a bit * temporarily disable Julia 1.3 * relax atol for a prior check * Improvements to `@submodel` in #309 (#348) * added prefix keyword argument to submodel-macro * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * converted example in docs into test * fixed docstring * Apply suggestions from code review Co-authored-by: Philipp Gabler <[email protected]> * removed redundant prefix_submodel_context def and added another example to docstring * fixed doctests * attempt at fixing doctests * another attempt at fixing doctests * had a typo in docstring Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Philipp Gabler <[email protected]> * fixed a test case using submodel * improved docstring according to comments by @devmotion Co-authored-by: David Widmann <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Philipp Gabler <[email protected]>
1 parent 12f3b36 commit 1744ba7

28 files changed

+1563
-264
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
matrix:
1818
version:
19-
- '1.3' # minimum supported version
19+
# - '1.3' # minimum supported version
2020
- '1' # current stable version
2121
os:
2222
- ubuntu-latest

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.16.2"
3+
version = "0.17.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
[deps]
22
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
45
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
56

67
[compat]
78
Distributions = "0.25"
89
Documenter = "0.27"
10+
Setfield = "0.7.1, 0.8"
911
StableRNGs = "1"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ makedocs(;
88
sitename="DynamicPPL",
99
format=Documenter.HTML(),
1010
modules=[DynamicPPL],
11-
pages=["Home" => "index.md"],
11+
pages=["Home" => "index.md", "TestUtils" => "test_utils.md"],
1212
strict=true,
1313
checkdocs=:exports,
1414
doctestfilters=[

docs/src/test_utils.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# DynamicPPL.TestUtils
2+
3+
```@autodocs
4+
Modules = [DynamicPPL.TestUtils]
5+
```

src/DynamicPPL.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using ChainRulesCore: ChainRulesCore
1010
using MacroTools: MacroTools
1111
using ZygoteRules: ZygoteRules
1212
using BangBang: BangBang
13+
using Setfield: Setfield
1314

1415
using Setfield: Setfield
1516
using BangBang: BangBang
@@ -31,15 +32,23 @@ import Base:
3132
keys,
3233
haskey
3334

35+
using BangBang: push!!, empty!!, setindex!!
36+
3437
# VarInfo
3538
export AbstractVarInfo,
3639
VarInfo,
3740
UntypedVarInfo,
3841
TypedVarInfo,
42+
SimpleVarInfo,
43+
push!!,
44+
empty!!,
3945
getlogp,
4046
setlogp!,
4147
acclogp!,
4248
resetlogp!,
49+
setlogp!!,
50+
acclogp!!,
51+
resetlogp!!,
4352
get_num_produce,
4453
set_num_produce!,
4554
reset_num_produce!,
@@ -139,13 +148,32 @@ include("distribution_wrappers.jl")
139148
include("contexts.jl")
140149
include("varinfo.jl")
141150
include("threadsafe.jl")
151+
include("simple_varinfo.jl")
142152
include("context_implementations.jl")
143153
include("compiler.jl")
144154
include("prob_macro.jl")
145155
include("compat/ad.jl")
146156
include("loglikelihoods.jl")
147157
include("submodel_macro.jl")
148-
149158
include("test_utils.jl")
150159

160+
# Deprecations
161+
@deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo)
162+
@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!(
163+
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution
164+
)
165+
@deprecate push!(
166+
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler
167+
) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler)
168+
@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!(
169+
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector
170+
)
171+
@deprecate push!(
172+
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}
173+
) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector})
174+
175+
@deprecate setlogp!(vi, logp) setlogp!!(vi, logp)
176+
@deprecate acclogp!(vi, logp) acclogp!!(vi, logp)
177+
@deprecate resetlogp!(vi) resetlogp!!(vi)
178+
151179
end # module

src/compat/ad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# See https://github.com/TuringLang/Turing.jl/issues/1199
2-
ChainRulesCore.@non_differentiable push!(
2+
ChainRulesCore.@non_differentiable push!!(
33
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
44
)
55

@@ -16,7 +16,7 @@ ZygoteRules.@adjoint function dot_observe(
1616
)
1717
function dot_observe_fallback(spl, dists, value, vi)
1818
increment_num_produce!(vi)
19-
return sum(map(Distributions.loglikelihood, dists, value))
19+
return sum(map(Distributions.loglikelihood, dists, value)), vi
2020
end
2121
return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi)
2222
end

src/compiler.jl

Lines changed: 100 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,12 @@ end
355355

356356
function generate_tilde_literal(left, right)
357357
# If the LHS is a literal, it is always an observation
358+
@gensym value
358359
return quote
359-
$(DynamicPPL.tilde_observe!)(
360+
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
360361
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
361362
)
363+
$value
362364
end
363365
end
364366

@@ -373,7 +375,7 @@ function generate_tilde(left, right)
373375

374376
# Otherwise it is determined by the model or its value,
375377
# if the LHS represents an observation
376-
@gensym vn isassumption
378+
@gensym vn isassumption value
377379

378380
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
379381
# that in DynamicPPL we the entire function body. Instead we should be
@@ -389,32 +391,38 @@ function generate_tilde(left, right)
389391
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
390392
end
391393

392-
$(DynamicPPL.tilde_observe!)(
394+
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
393395
__context__,
394396
$(DynamicPPL.check_tilde_rhs)($right),
395397
$(maybe_view(left)),
396398
$vn,
397399
__varinfo__,
398400
)
401+
$value
399402
end
400403
end
401404
end
402405

403406
function generate_tilde_assume(left, right, vn)
404-
expr = :(
405-
$left = $(DynamicPPL.tilde_assume!)(
407+
# HACK: Because the Setfield.jl macro does not support assignment
408+
# with multiple arguments on the LHS, we need to capture the return-values
409+
# and then update the LHS variables one by one.
410+
@gensym value
411+
expr = :($left = $value)
412+
if left isa Expr
413+
expr = AbstractPPL.drop_escape(
414+
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
415+
)
416+
end
417+
418+
return quote
419+
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
406420
__context__,
407421
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
408422
__varinfo__,
409423
)
410-
)
411-
412-
return if left isa Expr
413-
AbstractPPL.drop_escape(
414-
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
415-
)
416-
else
417-
return expr
424+
$expr
425+
$value
418426
end
419427
end
420428

@@ -428,7 +436,7 @@ function generate_dot_tilde(left, right)
428436

429437
# Otherwise it is determined by the model or its value,
430438
# if the LHS represents an observation
431-
@gensym vn isassumption
439+
@gensym vn isassumption value
432440
return quote
433441
$vn = $(AbstractPPL.drop_escape(varname(left)))
434442
$isassumption = $(DynamicPPL.isassumption(left))
@@ -440,13 +448,14 @@ function generate_dot_tilde(left, right)
440448
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
441449
end
442450

443-
$(DynamicPPL.dot_tilde_observe!)(
451+
$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(
444452
__context__,
445453
$(DynamicPPL.check_tilde_rhs)($right),
446454
$(maybe_view(left)),
447455
$vn,
448456
__varinfo__,
449457
)
458+
$value
450459
end
451460
end
452461
end
@@ -455,15 +464,82 @@ function generate_dot_tilde_assume(left, right, vn)
455464
# We don't need to use `Setfield.@set` here since
456465
# `.=` is always going to be inplace + needs `left` to
457466
# be something that supports `.=`.
458-
return :(
459-
$left .= $(DynamicPPL.dot_tilde_assume!)(
467+
@gensym value
468+
return quote
469+
$value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)(
460470
__context__,
461471
$(DynamicPPL.unwrap_right_left_vns)(
462472
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
463473
)...,
464474
__varinfo__,
465475
)
466-
)
476+
$left .= $value
477+
$value
478+
end
479+
end
480+
481+
# Note that we cannot use `MacroTools.isdef` because
482+
# of https://github.com/FluxML/MacroTools.jl/issues/154.
483+
"""
484+
isfuncdef(expr)
485+
486+
Return `true` if `expr` is any form of function definition, and `false` otherwise.
487+
"""
488+
function isfuncdef(e::Expr)
489+
return if Meta.isexpr(e, :function)
490+
# Classic `function f(...)`
491+
true
492+
elseif Meta.isexpr(e, :->)
493+
# Anonymous functions/lambdas, e.g. `do` blocks or `->` defs.
494+
true
495+
elseif Meta.isexpr(e, :(=)) && Meta.isexpr(e.args[1], :call)
496+
# Short function defs, e.g. `f(args...) = ...`.
497+
true
498+
else
499+
false
500+
end
501+
end
502+
503+
"""
504+
replace_returns(expr)
505+
506+
Return `Expr` with all `return ...` statements replaced with
507+
`return ..., DynamicPPL.return_values(__varinfo__)`.
508+
509+
Note that this method will _not_ replace `return` statements within function
510+
definitions. This is checked using [`isfuncdef`](@ref).
511+
"""
512+
replace_returns(e) = e
513+
function replace_returns(e::Expr)
514+
if isfuncdef(e)
515+
return e
516+
end
517+
518+
if Meta.isexpr(e, :return)
519+
# NOTE: `return` always has an argument. In the case of
520+
# an empty `return`, the lowered expression will be `return nothing`.
521+
# Hence we don't need any special handling for empty returns.
522+
retval_expr = if length(e.args) > 1
523+
Expr(:tuple, e.args...)
524+
else
525+
e.args[1]
526+
end
527+
528+
return :(return ($retval_expr, __varinfo__))
529+
end
530+
531+
return Expr(e.head, map(replace_returns, e.args)...)
532+
end
533+
534+
# If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`.
535+
make_returns_explicit!(body) = Expr(:return, body)
536+
function make_returns_explicit!(body::Expr)
537+
# If the last statement is a return-statement, we don't do anything.
538+
# Otherwise we replace the last statement with a `return` statement.
539+
if !Meta.isexpr(body.args[end], :return)
540+
body.args[end] = Expr(:return, body.args[end])
541+
end
542+
return body
467543
end
468544

469545
const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}}
@@ -496,10 +572,14 @@ function build_output(modelinfo, linenumbernode)
496572
# Replace the user-provided function body with the version created by DynamicPPL.
497573
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
498574
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
499-
# to the call site
575+
# to the call site.
576+
# NOTE: We need to replace statements of the form `return ...` with
577+
# `return (..., __varinfo__)` to ensure that the second
578+
# element in the returned value is always the most up-to-date `__varinfo__`.
579+
# See the docstrings of `replace_returns` for more info.
500580
evaluatordef[:body] = MacroTools.@q begin
501581
$(linenumbernode)
502-
$(modelinfo[:body])
582+
$(replace_returns(make_returns_explicit!(modelinfo[:body])))
503583
end
504584

505585
## Build the model function.

0 commit comments

Comments
 (0)