Skip to content

Commit cb4ea95

Browse files
authored
Fix isnan for NamedTuple distributions (#897)
* Fix isnan for NamedTuple distributions * Add a test
1 parent 60ee68e commit cb4ea95

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

HISTORY.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.35.9
4+
5+
Fixed the `isnan` check introduced in 0.35.7 for distributions which returned NamedTuple.
6+
37
## 0.35.8
48

59
Added the `DynamicPPL.TestUtils.AD.run_ad` function to test the correctness and/or benchmark the performance of an automatic differentiation backend on DynamicPPL models.

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.35.8"
3+
version = "0.35.9"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/debug_utils.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@ function _has_missings(x::AbstractArray)
242242
return false
243243
end
244244

245+
_has_nans(x::NamedTuple) = any(_has_nans, x)
246+
_has_nans(x::AbstractArray) = any(_has_nans, x)
247+
_has_nans(x) = isnan(x)
248+
245249
# assume
246250
function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo)
247251
record_varname!(context, vn, dist)
@@ -291,7 +295,7 @@ function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo)
291295
)
292296
end
293297
# Check for NaN's as well
294-
if any(isnan, left)
298+
if _has_nans(left)
295299
error(
296300
"Encountered a NaN value on the left-hand side of an" *
297301
" observe statement; this may indicate that your data" *

test/debug_utils.jl

+9-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,15 @@
130130
x[i] ~ Normal(a)
131131
end
132132
end
133-
model = demo_nan_in_data([1.0, NaN])
134-
@test_throws ErrorException check_model(model; error_on_failure=true)
133+
m = demo_nan_in_data([1.0, NaN])
134+
@test_throws ErrorException check_model(m; error_on_failure=true)
135+
# Test NamedTuples with nested arrays, see #898
136+
@model function demo_nan_complicated(nt)
137+
nt ~ product_distribution((x=Normal(), y=Dirichlet([2, 4])))
138+
return x ~ Normal()
139+
end
140+
m = demo_nan_complicated((x=1.0, y=[NaN, 0.5]))
141+
@test_throws ErrorException check_model(m; error_on_failure=true)
135142
end
136143

137144
@testset "incorrect use of condition" begin

0 commit comments

Comments
 (0)