Skip to content

Commit 5347acd

Browse files
torfjeldedevmotionyebai
authored
Fix for #596 (#597)
* initial work on `TypeWrap` * uncommented tests that are now working nicely * added `matchingvalue` impl for `TypeWrap` so we will correctly convert * simplified `transform_args` a tiny bit * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * bump patch version * updated Turing.jl tests * fixed typo in tests * also add `TypeWrap` to kwargs in model * added proper testing for TypeWrap in addition to fix to evaluatordef * fixed tests * make breaking release --------- Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent bf6a5b1 commit 5347acd

File tree

5 files changed

+73
-15
lines changed

5 files changed

+73
-15
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.25.4"
3+
version = "0.26.0"
44

55

66
[deps]

src/compiler.jl

+43-3
Original file line numberDiff line numberDiff line change
@@ -602,10 +602,42 @@ hasmissing(::Type{>:Missing}) = true
602602
hasmissing(::Type{<:AbstractArray{TA}}) where {TA} = hasmissing(TA)
603603
hasmissing(::Type{Union{}}) = false # issue #368
604604

605+
"""
606+
TypeWrap{T}
607+
608+
A wrapper type used internally to make expressions such as `::Type{TV}` in the model arguments
609+
not ending up as a `DataType`.
610+
"""
611+
struct TypeWrap{T} end
612+
613+
function arg_type_is_type(e)
614+
return Meta.isexpr(e, :curly) && length(e.args) > 1 && e.args[1] === :Type
615+
end
616+
605617
function splitarg_to_expr((arg_name, arg_type, is_splat, default))
606618
return is_splat ? :($arg_name...) : arg_name
607619
end
608620

621+
"""
622+
transform_args(args)
623+
624+
Return transformed `args` used in both the model constructor and evaluator.
625+
626+
Specifically, this replaces expressions of the form `::Type{TV}=Vector{Float64}`
627+
with `::TypeWrap{TV}=TypeWrap{Vector{Float64}}()` to avoid introducing `DataType`.
628+
"""
629+
function transform_args(args)
630+
splitargs = map(args) do arg
631+
arg_name, arg_type, is_splat, default = MacroTools.splitarg(arg)
632+
return if arg_type_is_type(arg_type)
633+
arg_name, :($TypeWrap{$(arg_type.args[2])}), is_splat, :($TypeWrap{$default}())
634+
else
635+
arg_name, arg_type, is_splat, default
636+
end
637+
end
638+
return map(Base.splat(MacroTools.combinearg), splitargs)
639+
end
640+
609641
function namedtuple_from_splitargs(splitargs)
610642
names = map(splitargs) do (arg_name, arg_type, is_splat, default)
611643
is_splat ? Symbol("#splat#$(arg_name)") : arg_name
@@ -623,8 +655,12 @@ is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#")
623655
Builds the output expression.
624656
"""
625657
function build_output(modeldef, linenumbernode)
626-
args = modeldef[:args]
627-
kwargs = modeldef[:kwargs]
658+
args = transform_args(modeldef[:args])
659+
kwargs = transform_args(modeldef[:kwargs])
660+
661+
# Need to update `args` and `kwargs` since we might have added `TypeWrap` to the types.
662+
modeldef[:args] = args
663+
modeldef[:kwargs] = kwargs
628664

629665
## Build the anonymous evaluator from the user-provided model definition.
630666
evaluatordef = copy(modeldef)
@@ -713,9 +749,13 @@ function matchingvalue(sampler, vi, value)
713749
return value
714750
end
715751
end
752+
# If we hit `Type` or `TypeWrap`, we immediately jump to `get_matching_type`.
716753
function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType)
717754
return get_matching_type(sampler, vi, value)
718755
end
756+
function matchingvalue(sampler::AbstractSampler, vi, value::TypeWrap{T}) where {T}
757+
return TypeWrap{get_matching_type(sampler, vi, T)}()
758+
end
719759

720760
function matchingvalue(context::AbstractContext, vi, value)
721761
return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value)
@@ -731,7 +771,7 @@ function matchingvalue(context::SamplingContext, vi, value)
731771
end
732772

733773
"""
734-
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}
774+
get_matching_type(spl::AbstractSampler, vi, ::TypeWrap{T}) where {T}
735775
736776
Get the specialized version of type `T` for sampler `spl`.
737777

test/compiler.jl

+12
Original file line numberDiff line numberDiff line change
@@ -717,4 +717,16 @@ module Issue537 end
717717
@test haskey(values, @varname(y))
718718
end
719719
end
720+
721+
@testset "signature parsing + TypeWrap" begin
722+
@model function demo_typewrap(
723+
a, b=1, ::Type{T1}=Float64; c, d=2, t::Type{T2}=Int
724+
) where {T1,T2}
725+
return (; a, b, c, d, t)
726+
end
727+
728+
model = demo_typewrap(1; c=2)
729+
res = model()
730+
@test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}())
731+
end
720732
end

test/model.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
350350

351351
@testset "Type stability of models" begin
352352
models_to_test = [
353-
# FIXME: Fix issues with type-stability in `DEMO_MODELS`.
354-
# DynamicPPL.TestUtils.DEMO_MODELS...,
355-
DynamicPPL.TestUtils.demo_lkjchol(2),
353+
DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2)
356354
]
357355
@testset "$(model.f)" for model in models_to_test
358356
vns = DynamicPPL.TestUtils.varnames(model)

test/turing/compiler.jl

+16-8
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,12 @@
9595
@test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10)
9696
@test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10)
9797

98-
@test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10)
99-
@test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10)
98+
@test_throws ErrorException chain = sample(
99+
gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10
100+
)
101+
@test_throws ErrorException chain = sample(
102+
gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10
103+
)
100104
end
101105
@testset "new interface" begin
102106
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
@@ -310,31 +314,35 @@
310314
end
311315

312316
t_loop = @elapsed res = sample(vdemo1(), alg, 250)
313-
t_loop = @elapsed res = sample(vdemo1(Float64), alg, 250)
317+
t_loop = @elapsed res = sample(vdemo1(DynamicPPL.TypeWrap{Float64}()), alg, 250)
314318

315319
vdemo1kw(; T) = vdemo1(T)
316-
t_loop = @elapsed res = sample(vdemo1kw(; T=Float64), alg, 250)
320+
t_loop = @elapsed res = sample(
321+
vdemo1kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250
322+
)
317323

318324
@model function vdemo2(::Type{T}=Float64) where {T<:Real}
319325
x = Vector{T}(undef, N)
320326
@. x ~ Normal(0, 2)
321327
end
322328

323329
t_vec = @elapsed res = sample(vdemo2(), alg, 250)
324-
t_vec = @elapsed res = sample(vdemo2(Float64), alg, 250)
330+
t_vec = @elapsed res = sample(vdemo2(DynamicPPL.TypeWrap{Float64}()), alg, 250)
325331

326332
vdemo2kw(; T) = vdemo2(T)
327-
t_vec = @elapsed res = sample(vdemo2kw(; T=Float64), alg, 250)
333+
t_vec = @elapsed res = sample(
334+
vdemo2kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250
335+
)
328336

329337
@model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector}
330338
x = TV(undef, N)
331339
@. x ~ InverseGamma(2, 3)
332340
end
333341

334342
sample(vdemo3(), alg, 250)
335-
sample(vdemo3(Vector{Float64}), alg, 250)
343+
sample(vdemo3(DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250)
336344

337345
vdemo3kw(; T) = vdemo3(T)
338-
sample(vdemo3kw(; T=Vector{Float64}), alg, 250)
346+
sample(vdemo3kw(; T=DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250)
339347
end
340348
end

0 commit comments

Comments
 (0)