Skip to content

Commit c6f7198

Browse files
torfjeldeyebai
andauthored
missing kwarg not handled correctly (#617)
* kwargs should also be added to the missings * added test * bump patch version * Fix turing tests (#618) * fix turing tests * fix missing imports --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 10874bf commit c6f7198

File tree

5 files changed

+20
-6
lines changed

5 files changed

+20
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.27.1"
3+
version = "0.27.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/model.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,16 @@ model with different arguments.
6767
@generated function Model(
6868
f::F,
6969
args::NamedTuple{argnames,Targs},
70-
defaults::NamedTuple,
70+
defaults::NamedTuple{kwargnames,Tkwargs},
7171
context::AbstractContext=DefaultContext(),
72-
) where {F,argnames,Targs}
73-
missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing)
74-
return :(Model{$missings}(f, args, defaults, context))
72+
) where {F,argnames,Targs,kwargnames,Tkwargs}
73+
missing_args = Tuple(
74+
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
75+
)
76+
missing_kwargs = Tuple(
77+
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
78+
)
79+
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
7580
end
7681

7782
function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)

test/model.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
243243
@test length(test_defaults(missing, 2)()) == 2
244244
end
245245

246+
@testset "missing kwarg" begin
247+
@model test_missing_kwarg(; x=missing) = x ~ Normal(0, 1)
248+
@test :x in keys(rand(test_missing_kwarg()))
249+
end
250+
246251
@testset "extract priors" begin
247252
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
248253
priors = extract_priors(model)

test/turing/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
2+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
23
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
34
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
46
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
57
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
68
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -9,5 +11,5 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
911
[compat]
1012
DynamicPPL = "0.24, 0.25, 0.26, 0.27"
1113
ReverseDiff = "1.15"
12-
Turing = "0.31, 0.32, 0.33"
14+
Turing = "0.33"
1315
julia = "1.7"

test/turing/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Random.seed!(100)
1414
include(joinpath(pathof(DynamicPPL), "..", "..", "test", "test_util.jl"))
1515
include(joinpath(pathof(Turing), "..", "..", "test", "test_utils", "numerical_tests.jl"))
1616

17+
using .NumericalTests: check_numerical
18+
1719
@testset "Turing" begin
1820
include("compiler.jl")
1921
include("loglikelihoods.jl")

0 commit comments

Comments
 (0)