Skip to content

Commit 5f61d88

Browse files
torfjeldegithub-actions[bot]yebai
authored
Allow usage of array literals on LHS of ~ (#261)
* treat vectors on LHS as literals * version bump * added test for array literals * simplified the isliteral check * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * dont allow empty literals * added some tests for isliteral * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed typo in tests * bump Bijectors for tests Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent 101bd0d commit 5f61d88

File tree

4 files changed

+39
-4
lines changed

4 files changed

+39
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.11.2"
3+
version = "0.11.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ end
3636
# failsafe: a literal is never an assumption
3737
isassumption(expr) = :(false)
3838

39+
"""
40+
isliteral(expr)
41+
42+
Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` otherwise.
43+
"""
44+
isliteral(e) = false
45+
isliteral(::Number) = true
46+
isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args)
47+
3948
"""
4049
check_tilde_rhs(x)
4150
@@ -240,7 +249,7 @@ variables.
240249
"""
241250
function generate_tilde(left, right)
242251
# If the LHS is a literal, it is always an observation
243-
if !(left isa Symbol || left isa Expr)
252+
if isliteral(left)
244253
return quote
245254
$(DynamicPPL.tilde_observe)(
246255
__context__,
@@ -290,7 +299,7 @@ Generate the expression that replaces `left .~ right` in the model body.
290299
"""
291300
function generate_dot_tilde(left, right)
292301
# If the LHS is a literal, it is always an observation
293-
if !(left isa Symbol || left isa Expr)
302+
if isliteral(left)
294303
return quote
295304
$(DynamicPPL.dot_tilde_observe)(
296305
__context__,

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2121
[compat]
2222
AbstractMCMC = "2.1, 3.0"
2323
AbstractPPL = "0.1.3"
24-
Bijectors = "0.8.2, 0.9"
24+
Bijectors = "0.9.5"
2525
Distributions = "0.24, 0.25"
2626
DistributionsAD = "0.6.3"
2727
Documenter = "0.26.1"

test/compiler.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,4 +423,30 @@ end
423423
x = [Laplace(), Normal(), MvNormal(3, 1.0)]
424424
@test DynamicPPL.check_tilde_rhs(x) === x
425425
end
426+
@testset "isliteral" begin
427+
@test DynamicPPL.isliteral(:([1.0]))
428+
@test DynamicPPL.isliteral(:([[1.0], 1.0]))
429+
@test DynamicPPL.isliteral(:((1.0, 1.0)))
430+
431+
@test !(DynamicPPL.isliteral(:([x])))
432+
@test !(DynamicPPL.isliteral(:([[x], 1.0])))
433+
@test !(DynamicPPL.isliteral(:((x, 1.0))))
434+
end
435+
436+
@testset "array literals" begin
437+
# Verify that we indeed can parse this.
438+
@test @model(function array_literal_model()
439+
# `assume` and literal `observe`
440+
m ~ MvNormal(2, 1.0)
441+
return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
442+
end) isa Function
443+
444+
@model function array_literal_model()
445+
# `assume` and literal `observe`
446+
m ~ MvNormal(2, 1.0)
447+
return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
448+
end
449+
450+
@test array_literal_model()() == [10.0, 10.0]
451+
end
426452
end

0 commit comments

Comments
 (0)