Skip to content

Commit c1d3636

Browse files
devmotiontorfjeldegithub-actions[bot]
authored
Fix type inference of eltype(vi, spl) (Turing#2151) (#568)
* Fix type inference of `eltype(vi, spl)` (Turing#2151) * Update Project.toml * Add test * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix test --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 39751b1 commit c1d3636

File tree

6 files changed

+44
-4
lines changed

6 files changed

+44
-4
lines changed

Project.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
4-
version = "0.24.4"
3+
version = "0.24.5"
54

65
[deps]
76
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/abstract_varinfo.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,14 @@ Determine the default `eltype` of the values returned by `vi[spl]`.
363363
This method is considered legacy, and is likely to be deprecated in the future.
364364
"""
365365
function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior})
366-
return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)}))
366+
T = Base.promote_op(getindex, typeof(vi), typeof(spl))
367+
if T === Union{}
368+
# In this case `getindex(vi, spl)` errors
369+
# Let us throw a more descriptive error message
370+
# Ref https://github.com/TuringLang/Turing.jl/issues/2151
371+
return eltype(vi[spl])
372+
end
373+
return eltype(T)
367374
end
368375

369376
# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert

src/varinfo.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,7 @@ function getindex(vi::TypedVarInfo, spl::Sampler)
14441444
# Gets the ranges as a NamedTuple
14451445
ranges = _getranges(vi, spl)
14461446
# Calling getfield(ranges, f) gives all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi`
1447-
return vcat(_getindex(vi.metadata, ranges)...)
1447+
return reduce(vcat, _getindex(vi.metadata, ranges))
14481448
end
14491449
# Recursively builds a tuple of the `vals` of all the symbols
14501450
@generated function _getindex(metadata, ranges::NamedTuple{names}) where {names}

test/turing/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
33
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
44
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
56
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
67
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
78

89
[compat]
910
DynamicPPL = "0.24"
11+
ReverseDiff = "1.15"
1012
Turing = "0.30"
1113
julia = "1.7"

test/turing/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using DynamicPPL
22
using Turing
33
using LinearAlgebra
4+
using ReverseDiff
45

56
using Random
67
using Test

test/turing/varinfo.jl

+31
Original file line numberDiff line numberDiff line change
@@ -311,4 +311,35 @@
311311
@test vi.metadata.w.gids[1] == Set([hmc.selector])
312312
@test vi.metadata.u.gids[1] == Set([hmc.selector]) =#
313313
end
314+
315+
@testset "Turing#2151: eltype(vi, spl)" begin
316+
# build data
317+
t = 1:0.05:8
318+
σ = 0.3
319+
y = @. rand(sin(t) + Normal(0, σ))
320+
321+
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
322+
# Priors
323+
α ~ Normal(y[1], 0.001)
324+
τ ~ Exponential(1)
325+
η ~ filldist(Normal(0, 1), TT - 1)
326+
σ ~ Exponential(1)
327+
328+
# create latent variable
329+
x = Vector{T}(undef, TT)
330+
x[1] = α
331+
for t in 2:TT
332+
x[t] = x[t - 1] + η[t - 1] * τ
333+
end
334+
335+
# measurement model
336+
y ~ MvNormal(x, σ^2 * I)
337+
338+
return x
339+
end
340+
341+
n = 10
342+
model = state_space(y, length(t))
343+
@test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n
344+
end
314345
end

0 commit comments

Comments
 (0)