Skip to content

Commit 18af48a

Browse files
authored
Special-case ReshapeTransform for singleton inputs (#699)
* Special-case ReshapeTransform for singleton inputs * Use fill(x[], ()) instead
1 parent bd9f465 commit 18af48a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.30.1"
3+
version = "0.30.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/utils.jl

+9-2
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,15 @@ function (f::ReshapeTransform)(x)
286286
if size(x) != f.input_size
287287
throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))"))
288288
end
289-
# The call to `tovec` is only needed in case `x` is a scalar.
290-
return reshape(tovec(x), f.output_size)
289+
if f.output_size == ()
290+
# Specially handle the case where x is a singleton array, see
291+
# https://github.com/JuliaDiff/ReverseDiff.jl/issues/265 and
292+
# https://github.com/TuringLang/DynamicPPL.jl/issues/698
293+
return fill(x[], ())
294+
else
295+
# The call to `tovec` is only needed in case `x` is a scalar.
296+
return reshape(tovec(x), f.output_size)
297+
end
291298
end
292299

293300
function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x)

0 commit comments

Comments
 (0)