From 1d9a1a324d706c369df384d3913199ef0b64adef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:46:27 +0100 Subject: [PATCH 01/45] renamed rv to result in forward --- src/bijectors/leaky_relu.jl | 6 +++--- src/bijectors/named_bijector.jl | 18 +++++++++--------- src/bijectors/normalise.jl | 10 +++++----- src/bijectors/planar_layer.jl | 2 +- src/bijectors/radial_layer.jl | 2 +- src/bijectors/rational_quadratic_spline.jl | 4 ++-- src/bijectors/stacked.jl | 6 +++--- src/transformed_distribution.jl | 22 +++++++++++----------- 8 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index 65060c14..c9128f01 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -50,7 +50,7 @@ logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> loga function forward(b::LeakyReLU{<:Any, 0}, x::Real) mask = x < zero(x) J = mask * b.α + !mask * one(x) - return (rv=J * x, logabsdetjac=log(abs(J))) + return (result=J * x, logabsdetjac=log(abs(J))) end # Batched version @@ -58,7 +58,7 @@ function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - return (rv=J .* x, logabsdetjac=log.(abs.(J))) + return (result=J .* x, logabsdetjac=log.(abs.(J))) end # (N=1) Multivariate case @@ -97,5 +97,5 @@ function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) end y = J .* x - return (rv=y, logabsdetjac=logjac) + return (result=y, logabsdetjac=logjac) end diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 36f8691c..500044b9 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,6 +1,6 @@ abstract type AbstractNamedBijector <: AbstractBijector end -forward(b::AbstractNamedBijector, x) = (rv = b(x), logabsdetjac = logabsdetjac(b, x)) +forward(b::AbstractNamedBijector, x) = (result = b(x), logabsdetjac = logabsdetjac(b, x)) ####################### ### `NamedBijector` ### @@ -125,7 +125,7 @@ function logabsdetjac(cb::NamedComposition, x) y, logjac = forward(cb.bs[1], x) for i = 2:length(cb.bs) res = forward(cb.bs[i], y) - y = res.rv + y = res.result logjac += res.logabsdetjac end @@ -141,7 +141,7 @@ end for i = 2:N - 1 temp = gensym(:res) push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.rv)) + push!(expr.args, :(y = $temp.result)) push!(expr.args, :(logjac += $temp.logabsdetjac)) end # don't need to evaluate the last bijector, only it's `logabsdetjac` @@ -154,14 +154,14 @@ end function forward(cb::NamedComposition, x) - rv, logjac = forward(cb.bs[1], x) + result, logjac = forward(cb.bs[1], x) for t in cb.bs[2:end] - res = forward(t, rv) - rv = res.rv + res = forward(t, result) + result = res.result logjac = res.logabsdetjac + logjac end - return (rv=rv, logabsdetjac=logjac) + return (result=result, logabsdetjac=logjac) end @@ -171,10 +171,10 @@ end for i = 2:length(T.parameters) temp = gensym(:temp) push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.rv)) + push!(expr.args, :(y = $temp.result)) push!(expr.args, :(logjac += $temp.logabsdetjac)) end - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (result = y, logabsdetjac = logjac))) return expr end diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 81496468..16b48a29 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -72,16 +72,16 @@ function forward(bn::InvertibleBatchNorm, x) v = reshape(bn.v, as...) end - rv = s .* (x .- m) ./ sqrt.(v .+ bn.eps) .+ b + result = s .* (x .- m) ./ sqrt.(v .+ bn.eps) .+ b logabsdetjac = ( fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims)) ) - return (rv=rv, logabsdetjac=logabsdetjac) + return (result=result, logabsdetjac=logabsdetjac) end logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac -(bn::InvertibleBatchNorm)(x) = forward(bn, x).rv +(bn::InvertibleBatchNorm)(x) = forward(bn, x).result function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -94,10 +94,10 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) v = reshape(bn.v, as...) x = (y .- b) ./ s .* sqrt.(v .+ bn.eps) .+ m - return (rv=x, logabsdetjac=-logabsdetjac(bn, x)) + return (result=x, logabsdetjac=-logabsdetjac(bn, x)) end -(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).rv +(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).result function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index b82a7274..9d1a73d0 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -105,7 +105,7 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) b = first(flow.b) log_det_jacobian = log1p.(wT_û .* abs2.(sech.(_vec(wT_z) .+ b))) - return (rv = transformed, logabsdetjac = log_det_jacobian) + return (result = transformed, logabsdetjac = log_det_jacobian) end function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real}) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index b678ed96..a5cb5f34 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -67,7 +67,7 @@ function forward(flow::RadialLayer, z::AbstractVecOrMat) (d - 1) * log(1 + β_hat * h_) + log(1 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) ) # from eq(14) - return (rv = transformed, logabsdetjac = log_det_jacobian) + return (result = transformed, logabsdetjac = log_det_jacobian) end function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 826a4017..a0907b82 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -346,7 +346,7 @@ function rqs_forward( T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x)) if (x ≤ -widths[end]) || (x ≥ widths[end]) - return (rv = one(T) * x, logabsdetjac = zero(T) * x) + return (result = one(T) * x, logabsdetjac = zero(T) * x) end # Find which bin `x` is in @@ -379,7 +379,7 @@ function rqs_forward( numerator_y = Δy * (s * ξ^2 + d_k * ξ * (1 - ξ)) y = h_k + numerator_y / denominator - return (rv = y, logabsdetjac = logjac) + return (result = y, logabsdetjac = logjac) end function forward(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 0131a76c..144645fd 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -131,7 +131,7 @@ end # logjac = sum(_logjac) # (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) # logjac += sum(_logjac) -# return (rv = vcat(y_1, y_2), logabsdetjac = logjac) +# return (result = vcat(y_1, y_2), logabsdetjac = logjac) # end @generated function forward(b::Stacked{T, N}, x::AbstractVector) where {N, T<:Tuple} expr = Expr(:block) @@ -151,7 +151,7 @@ end push!(y_names, y_name) end - push!(expr.args, :(return (rv = vcat($(y_names...)), logabsdetjac = logjac))) + push!(expr.args, :(return (result = vcat($(y_names...)), logabsdetjac = logjac))) return expr end @@ -163,5 +163,5 @@ function forward(sb::Stacked{<:AbstractArray, N}, x::AbstractVector) where {N} logjac += sum(l) y end - return (rv = vcat(yinit, ys), logabsdetjac = logjac) + return (result = vcat(yinit, ys), logabsdetjac = logjac) end diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 6a30f7eb..d6177c72 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -85,14 +85,14 @@ Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res.result) + res.logabsdetjac end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf(td::MvTransformed, y::AbstractMatrix{<:Real}) # batch-implementation for multivariate res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res.result) + res.logabsdetjac end function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) @@ -100,12 +100,12 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res.result)) + res.logabsdetjac end function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res.result) + res.logabsdetjac end function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -113,12 +113,12 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res.result)) + res.logabsdetjac end # TODO: should eventually drop using `logpdf_with_trans` and replace with # res = forward(inv(td.transform), y) -# logpdf(td.dist, res.rv) .- res.logabsdetjac +# logpdf(td.dist, res.result) .- res.logabsdetjac function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) return logpdf_with_trans(td.dist, inv(td.transform)(y), true) end @@ -163,18 +163,18 @@ and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.result) + res.logabsdetjac, res.logabsdetjac) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.result) + res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.result) + res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -182,14 +182,14 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea ϵ = _eps(T) res = forward(inv(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.result)) + res.logabsdetjac return (lp, res.logabsdetjac) end # TODO: should eventually drop using `logpdf_with_trans` function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) + return (logpdf_with_trans(td.dist, res.result, true), res.logabsdetjac) end """ From 0717e3e7af25a0e63a791cf4f54674b36fe3cd4d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:48:34 +0100 Subject: [PATCH 02/45] added abstrac type Transform and removed dimensionality from Bijector --- src/interface.jl | 212 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 155 insertions(+), 57 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 9454956b..9ace3b38 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -31,16 +31,120 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t ###################### # Bijector interface # ###################### -"Abstract type for a bijector." -abstract type AbstractBijector end +""" + +Abstract type for a transformation. + +## Implementing + +A subtype of `Transform` of should at least implement `transform(b, x)`. + +If the `Transform` is also invertible: +- Required: + - [`invertible`](@ref): should return [`Invertible`](@ref). + - _Either_ of the following: + - `transform(::Inverse{<:MyTransform}, x)`: the `transform` for its inverse. + - `Base.inv(b::MyTransform)`: returns an existing `Transform`. + - [`logabsdetjac`](@ref): computes the log-abs-det jacobian factor. +- Optional: + - [`forward`](@ref): `transform` and `logabsdetjac` combined. Useful in cases where we + can exploit shared computation in the two. + +For the above methods, there are mutating versions which can _optionally_ be implemented: +- [`transform!`](@ref) +- [`logabsdetjac!`](@ref) +- [`forward!`](@ref) + +Finally, there are _batched_ versions of the above methods which can _optionally_ be implemented: +- [`transform_batch`](@ref) +- [`logabsdetjac_batch`](@ref) +- [`forward_batch`](@ref) + +and similarly for the mutating versions. Default implementations depends on the type of `xs`. +Note that these methods are usually used through broadcasting, i.e. `b.(x)` with `x` a `AbstractBatch` +falls back to `transform_batch(b, x)`. +""" +abstract type Transform end + +Broadcast.broadcastable(b::Transform) = Ref(b) + +# Invertibility "trait". +struct NotInvertible end +struct Invertible end + +# Useful for checking if compositions, etc. are invertible or not. +Base.:+(::NotInvertible, ::Invertible) = NotInvertible() +Base.:+(::Invertible, ::NotInvertible) = NotInvertible() +Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() +Base.:+(::Invertible, ::Invertible) = Invertible() + +invertible(::Transform) = NotInvertible() + +""" + inv(t::Transform[, ::Invertible]) + +Returns the inverse of transform `t`. +""" +Base.inv(t::Transform) = Base.inv(t, invertible(t)) +Base.inv(t::Transform, ::NotInvertible) = error("$(t) is not invertible") + +""" + transform(b, x) + +Transform `x` using `b`. + +Alternatively, one can just call `b`, i.e. `b(x)`. +""" +transform +(t::Transform)(x) = transform(t, x) + +""" + transform!(b, x, y) + +Transforms `x` using `b`, storing the result in `y`. +""" +transform!(b, x, y) = (y .= transform(b, x)) + +""" + logabsdetjac(b, x) + +Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform. +""" +logabsdetjac + +""" + logabsdetjac!(b, x, logjac) + +Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform, +_accumulating_ the result in `logjac`. +""" +logabsdetjac!(b, x, logjac) = (logjac += logabsdetjac(b, x)) + +""" + forward(b, x) + +Computes both `transform` and `logabsdetjac` in one forward pass, and +returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. + +This defaults to the call above, but often one can re-use computation +in the computation of the forward pass and the computation of the +`logabsdetjac`. `forward` allows the user to take advantange of such +efficiencies, if they exist. +""" +forward(b, x) = (result = transform(b, x), logabsdetjac = logabsdetjac(b, x)) -"Abstract type of bijectors with fixed dimensionality." -abstract type Bijector{N} <: AbstractBijector end +function forward!(b, x, out) + y, logjac = forward(b, x) + out.result .= y + out.logabsdetjac .+= logjac -dimension(b::Bijector{N}) where {N} = N -dimension(b::Type{<:Bijector{N}}) where {N} = N + return out +end -Broadcast.broadcastable(b::Bijector) = Ref(b) +"Abstract type of a bijector, i.e. differentiable bijection with differentiable inverse." +abstract type Bijector <: Transform end + +invertible(::Bijector) = Invertible() """ isclosedform(b::Bijector)::bool @@ -61,81 +165,75 @@ isclosedform(b::Bijector) = true A `Bijector` representing the inverse transform of `b`. """ -struct Inverse{B <: Bijector, N} <: Bijector{N} +struct Inverse{B<:Bijector} <: Bijector orig::B - - Inverse(b::B) where {N, B<:Bijector{N}} = new{B, N}(b) end # field contains nested numerical parameters Functors.@functor Inverse -up1(b::Inverse) = Inverse(up1(b.orig)) - inv(b::Bijector) = Inverse(b) inv(ib::Inverse{<:Bijector}) = ib.orig Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig -""" - logabsdetjac(b::Bijector, x) - logabsdetjac(ib::Inverse{<:Bijector}, y) +logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) -Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform. -Similarily for the inverse-transform. +""" + logabsdetjacinv(b::Bijector, y) -Default implementation for `Inverse{<:Bijector}` is implemented as -`- logabsdetjac` of original `Bijector`. +Just an alias for `logabsdetjac(inv(b), y)`. """ -logabsdetjac(ib::Inverse{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) +logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) + +############################## +# Example bijector: Identity # +############################## + +struct Identity <: Bijector end +inv(b::Identity) = b + +transform(::Identity, x) = copy(x) +transform!(::Identity, x, y) = (y .= x; return y) +logabsdetjac(::Identity, x) = zero(eltype(x)) +logabsdetjac!(::Identity, x, logjac) = logjac + +#################### +# Batched versions # +#################### +# NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. +# so we can actually reference them. Since we just did this for `Identity`, we're good. +Broadcast.broadcasted(b::Transform, xs::Batch) = transform_batch(b, xs) +Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_batch(b, xs) +Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_batch(b, xs) +Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_batch(b, xs) + """ - forward(b::Bijector, x) + transform_batch(b, xs) -Computes both `transform` and `logabsdetjac` in one forward pass, and -returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. +Transform `xs` by `b`, treating `xs` as a "batch", i.e. a collection of independent inputs. -This defaults to the call above, but often one can re-use computation -in the computation of the forward pass and the computation of the -`logabsdetjac`. `forward` allows the user to take advantange of such -efficiencies, if they exist. +See also: [`transform`](@ref) """ -forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) +transform_batch """ - logabsdetjacinv(b::Bijector, y) + logabsdetjac_batch(b, xs) -Just an alias for `logabsdetjac(inv(b), y)`. +Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs. + +See also: [`logabsdetjac`](@ref) """ -logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) +logabsdetjac_batch -############################## -# Example bijector: Identity # -############################## +""" + forward_batch(b, xs) -struct Identity{N} <: Bijector{N} end -(::Identity)(x) = copy(x) -inv(b::Identity) = b -up1(::Identity{N}) where {N} = Identity{N + 1}() - -logabsdetjac(::Identity{0}, x::Real) = zero(eltype(x)) -@generated function logabsdetjac( - b::Identity{N1}, - x::AbstractArray{T2, N2} -) where {N1, T2, N2} - if N1 == N2 - return :(zero(eltype(x))) - elseif N1 + 1 == N2 - return :(zeros(eltype(x), size(x, $N2))) - else - return :(throw(MethodError(logabsdetjac, (b, x)))) - end -end -logabsdetjac(::Identity{2}, x::AbstractArray{<:AbstractMatrix}) = zeros(eltype(x[1]), size(x)) +Computes `forward(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs. -######################## -# Convenient constants # -######################## -const ZeroOrOneDimBijector = Union{Bijector{0}, Bijector{1}} +See also: [`transform`](@ref) +""" +forward_batch(b, xs) = (result = transform_batch(b, xs), logabsdetjac = logabsdetjac_batch(b, xs)) ###################### # Bijectors includes # From 81a2ed600078d2ab20e22ec3dcdae20ee79eb2f1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:48:51 +0100 Subject: [PATCH 03/45] updated Composed to new interface --- src/bijectors/composed.jl | 155 +++++++++++++++++++++++++++++--------- 1 file changed, 119 insertions(+), 36 deletions(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index f05b819c..33f9551e 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -85,57 +85,50 @@ true ``` """ -struct Composed{A, N} <: Bijector{N} +struct Composed{A} <: Transform ts::A end -Composed(bs::Tuple{Vararg{<:Bijector{N}}}) where N = Composed{typeof(bs),N}(bs) -Composed(bs::AbstractArray{<:Bijector{N}}) where N = Composed{typeof(bs),N}(bs) - # field contains nested numerical parameters Functors.@functor Composed +invertible(cb::Composed) = sum(map(invertible, cb.ts)) + isclosedform(b::Composed) = all(isclosedform, b.ts) -up1(b::Composed) = Composed(up1.(b.ts)) -function Base.:(==)(b1::Composed{<:Any, N}, b2::Composed{<:Any, N}) where {N} + +function Base.:(==)(b1::Composed, b2::Composed) ts1, ts2 = b1.ts, b2.ts return length(ts1) == length(ts2) && all(x == y for (x, y) in zip(ts1, ts2)) end """ - composel(ts::Bijector...)::Composed{<:Tuple} + composel(ts::Transform...)::Composed{<:Tuple} Constructs `Composed` such that `ts` are applied left-to-right. """ -composel(ts::Bijector{N}...) where {N} = Composed(ts) +composel(ts::Transform...) = Composed(ts) """ - composer(ts::Bijector...)::Composed{<:Tuple} + composer(ts::Transform...)::Composed{<:Tuple} Constructs `Composed` such that `ts` are applied right-to-left. """ -composer(ts::Bijector{N}...) where {N} = Composed(reverse(ts)) +composer(ts::Transform...) = Composed(reverse(ts)) # The transformation of `Composed` applies functions left-to-right # but in mathematics we usually go from right-to-left; this reversal ensures that # when we use the mathematical composition ∘ we get the expected behavior. # TODO: change behavior of `transform` of `Composed`? -@generated function ∘(b1::Bijector{N1}, b2::Bijector{N2}) where {N1, N2} - if N1 == N2 - return :(composel(b2, b1)) - else - return :(throw(DimensionMismatch("$(typeof(b1)) expects $(N1)-dim but $(typeof(b2)) expects $(N2)-dim"))) - end -end +∘(b1::Transform, b2::Transform) = composel(b2, b1) # type-stable composition rules -∘(b1::Composed{<:Tuple}, b2::Bijector) = composel(b2, b1.ts...) -∘(b1::Bijector, b2::Composed{<:Tuple}) = composel(b2.ts..., b1) +∘(b1::Composed{<:Tuple}, b2::Transform) = composel(b2, b1.ts...) +∘(b1::Transform, b2::Composed{<:Tuple}) = composel(b2.ts..., b1) ∘(b1::Composed{<:Tuple}, b2::Composed{<:Tuple}) = composel(b2.ts..., b1.ts...) # type-unstable composition rules -∘(b1::Composed{<:AbstractArray}, b2::Bijector) = Composed(pushfirst!(copy(b1.ts), b2)) -∘(b1::Bijector, b2::Composed{<:AbstractArray}) = Composed(push!(copy(b2.ts), b1)) +∘(b1::Composed{<:AbstractArray}, b2::Transform) = Composed(pushfirst!(copy(b1.ts), b2)) +∘(b1::Transform, b2::Composed{<:AbstractArray}) = Composed(push!(copy(b2.ts), b1)) function ∘(b1::Composed{<:AbstractArray}, b2::Composed{<:AbstractArray}) return Composed(append!(copy(b2.ts), copy(b1.ts))) end @@ -149,14 +142,14 @@ function ∘(b1::T1, b2::T2) where {T1<:Composed{<:AbstractArray}, T2<:Composed{ end -∘(::Identity{N}, ::Identity{N}) where {N} = Identity{N}() -∘(::Identity{N}, b::Bijector{N}) where {N} = b -∘(b::Bijector{N}, ::Identity{N}) where {N} = b +∘(::Identity, ::Identity) = Identity() +∘(::Identity, b::Transform) = b +∘(b::Transform, ::Identity) = b -inv(ct::Composed) = Composed(reverse(map(inv, ct.ts))) +inv(ct::Composed, ::Invertible) = Composed(reverse(map(inv, ct.ts))) -# # TODO: should arrays also be using recursive implementation instead? -function (cb::Composed{<:AbstractArray{<:Bijector}})(x) +# TODO: should arrays also be using recursive implementation instead? +function transform(cb::Composed, x) @assert length(cb.ts) > 0 res = cb.ts[1](x) for b ∈ Base.Iterators.drop(cb.ts, 1) @@ -166,7 +159,17 @@ function (cb::Composed{<:AbstractArray{<:Bijector}})(x) return res end -@generated function (cb::Composed{T})(x) where {T<:Tuple} +function transform_batch(cb::Composed, x) + @assert length(cb.ts) > 0 + res = cb.ts[1].(x) + for b ∈ Base.Iterators.drop(cb.ts, 1) + res = b.(res) + end + + return res +end + +@generated function transform(cb::Composed{T}, x) where {T<:Tuple} @assert length(T.parameters) > 0 expr = :(x) for i in 1:length(T.parameters) @@ -175,17 +178,36 @@ end return expr end +@generated function transform_batch(cb::Composed{T}, x) where {T<:Tuple} + @assert length(T.parameters) > 0 + expr = :(x) + for i in 1:length(T.parameters) + expr = :(transform_batch(cb.ts[$i], $expr)) + end + return expr +end + function logabsdetjac(cb::Composed, x) y, logjac = forward(cb.ts[1], x) for i = 2:length(cb.ts) res = forward(cb.ts[i], y) - y = res.rv + y = res.result logjac += res.logabsdetjac end return logjac end +function logabsdetjac_batch(cb::Composed, x) + init = forward(cb.ts[1], x) + result = reduce(cb.ts[2:end]; init = init) do (y, logjac), b + return forward(b, y) + end + + return result.logabsdetjac +end + + @generated function logabsdetjac(cb::Composed{T}, x) where {T<:Tuple} N = length(T.parameters) @@ -195,7 +217,7 @@ end for i = 2:N - 1 temp = gensym(:res) push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.rv)) + push!(expr.args, :(y = $temp.result)) push!(expr.args, :(logjac += $temp.logabsdetjac)) end # don't need to evaluate the last bijector, only it's `logabsdetjac` @@ -206,16 +228,60 @@ end return expr end +""" + logabsdetjac_batch(cb::Composed{<:Tuple}, x) + +Generates something of the form +```julia +quote + (y, logjac_1) = forward_batch(cb.ts[1], x) + logjac_2 = logabsdetjac_batch(cb.ts[2], y) + return logjac_1 + logjac_2 +end +``` +""" +@generated function logabsdetjac_batch(cb::Composed{T}, x) where {T<:Tuple} + N = length(T.parameters) + + expr = Expr(:block) + push!(expr.args, :((y, logjac_1) = forward_batch(cb.ts[1], x))) + + for i = 2:N - 1 + temp = gensym(:res) + push!(expr.args, :($temp = forward_batch(cb.ts[$i], y))) + push!(expr.args, :(y = $temp.result)) + push!(expr.args, :($(Symbol("logjac_$i")) = $temp.logabsdetjac)) + end + # don't need to evaluate the last bijector, only it's `logabsdetjac` + push!(expr.args, :($(Symbol("logjac_$N")) = logabsdetjac_batch(cb.ts[$N], y))) + + sum_expr = Expr(:call, :+, [Symbol("logjac_$i") for i = 1:N]...) + push!(expr.args, :(return $(sum_expr))) + + return expr +end + function forward(cb::Composed, x) - rv, logjac = forward(cb.ts[1], x) + result, logjac = forward(cb.ts[1], x) for t in cb.ts[2:end] - res = forward(t, rv) - rv = res.rv + res = forward(t, result) + result = res.result logjac = res.logabsdetjac + logjac end - return (rv=rv, logabsdetjac=logjac) + return (result=result, logabsdetjac=logjac) +end + +function forward_batch(cb::Composed, x) + result, logjac = forward_batch(cb.ts[1], x) + + for t in cb.ts[2:end] + res = forward_batch(t, result) + result = res.result + logjac = res.logabsdetjac + logjac + end + return (result=result, logabsdetjac=logjac) end @@ -225,10 +291,27 @@ end for i = 2:length(T.parameters) temp = gensym(:temp) push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.rv)) + push!(expr.args, :(y = $temp.result)) push!(expr.args, :(logjac += $temp.logabsdetjac)) end - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (result = y, logabsdetjac = logjac))) + + return expr +end + +@generated function forward_batch(cb::Composed{T}, x) where {T<:Tuple} + N = length(T.parameters) + expr = Expr(:block) + push!(expr.args, :((y, logjac_1) = forward_batch(cb.ts[1], x))) + for i = 2:N + temp = gensym(:temp) + push!(expr.args, :($temp = forward_batch(cb.ts[$i], y))) + push!(expr.args, :(y = $temp.result)) + push!(expr.args, :($(Symbol("logjac_$i")) = $temp.logabsdetjac)) + end + + sum_expr = Expr(:call, :+, [Symbol("logjac_$i") for i = 1:N]...) + push!(expr.args, :(return (result = y, logabsdetjac = $(sum_expr)))) return expr end From 251ab9c47bceafd9453b6c0c378a60f70301e23b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:51:11 +0100 Subject: [PATCH 04/45] updated Exp and Log to new interface --- src/bijectors/exp_log.jl | 51 +++++++++------------------------------- 1 file changed, 11 insertions(+), 40 deletions(-) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 0e2da142..d3794a59 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -2,48 +2,19 @@ # Exp & Log # ############# -struct Exp{N} <: Bijector{N} end -struct Log{N} <: Bijector{N} end -up1(::Exp{N}) where {N} = Exp{N + 1}() -up1(::Log{N}) where {N} = Log{N + 1}() +struct Exp <: Bijector end +struct Log <: Bijector end -Exp() = Exp{0}() -Log() = Log{0}() +inv(b::Exp) = Log() +inv(b::Log) = Exp() -(b::Exp{0})(y::Real) = exp(y) -(b::Log{0})(x::Real) = log(x) +transform(b::Exp, y) = exp.(y) +transform(b::Log, x) = log.(x) -(b::Exp{0})(y::AbstractArray{<:Real}) = exp.(y) -(b::Log{0})(x::AbstractArray{<:Real}) = log.(x) +logabsdetjac(b::Exp, x) = sum(x) +logabsdetjac(b::Log, x) = -sum(log, x) -(b::Exp{1})(y::AbstractVector{<:Real}) = exp.(y) -(b::Exp{1})(y::AbstractMatrix{<:Real}) = exp.(y) -(b::Log{1})(x::AbstractVector{<:Real}) = log.(x) -(b::Log{1})(x::AbstractMatrix{<:Real}) = log.(x) - -(b::Exp{2})(y::AbstractMatrix{<:Real}) = exp.(y) -(b::Log{2})(x::AbstractMatrix{<:Real}) = log.(x) - -(b::Exp{2})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, y) -(b::Log{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) - -inv(b::Exp{N}) where {N} = Log{N}() -inv(b::Log{N}) where {N} = Exp{N}() - -logabsdetjac(b::Exp{0}, x::Real) = x -logabsdetjac(b::Exp{0}, x::AbstractVector) = x -logabsdetjac(b::Exp{1}, x::AbstractVector) = sum(x) -logabsdetjac(b::Exp{1}, x::AbstractMatrix) = vec(sum(x; dims = 1)) -logabsdetjac(b::Exp{2}, x::AbstractMatrix) = sum(x) -logabsdetjac(b::Exp{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) -end - -logabsdetjac(b::Log{0}, x::Real) = -log(x) -logabsdetjac(b::Log{0}, x::AbstractVector) = .-log.(x) -logabsdetjac(b::Log{1}, x::AbstractVector) = - sum(log, x) -logabsdetjac(b::Log{1}, x::AbstractMatrix) = - vec(sum(log, x; dims = 1)) -logabsdetjac(b::Log{2}, x::AbstractMatrix) = - sum(log, x) -logabsdetjac(b::Log{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) +function forward(b::Log, x) + y = transform(b, x) + return (result = y, logabsdetjac = -sum(y)) end From f1ef968d809a5a029bd8cacf9ed080d4711ddf72 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:51:18 +0100 Subject: [PATCH 05/45] updated Logit to new interface --- src/bijectors/logit.jl | 32 +++++--------------------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index ff1cf554..9d26de5d 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -3,43 +3,21 @@ ###################### using StatsFuns: logit, logistic -struct Logit{N, T<:Real} <: Bijector{N} +struct Logit{T} <: Bijector a::T b::T end -Logit(a::Real, b::Real) = Logit{0}(a, b) -Logit(a::AbstractArray{<:Real, N}, b::AbstractArray{<:Real, N}) where {N} = Logit{N}(a, b) -function Logit{N}(a, b) where {N} - T = promote_type(typeof(a), typeof(b)) - Logit{N, T}(a, b) -end -# fields are numerical parameters -function Functors.functor(::Type{<:Logit{N}}, x) where N - function reconstruct_logit(xs) - T = promote_type(typeof(xs.a), typeof(xs.b)) - return Logit{N,T}(xs.a, xs.b) - end - return (a = x.a, b = x.b,), reconstruct_logit -end +Functors.@functor Logit -up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b) # For equality of Logit with Float64 fields to one with Duals Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b -(b::Logit)(x) = _logit.(x, b.a, b.b) -(b::Logit)(x::AbstractArray{<:AbstractArray}) = map(b, x) +transform(b::Logit, x) = _logit.(x, b.a, b.b) _logit(x, a, b) = logit((x - a) / (b - a)) -(ib::Inverse{<:Logit})(y) = _ilogit.(y, ib.orig.a, ib.orig.b) -(ib::Inverse{<:Logit})(x::AbstractArray{<:AbstractArray}) = map(ib, x) +transform(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) _ilogit(y, a, b) = (b - a) * logistic(y) + a -logabsdetjac(b::Logit{0}, x) = logit_logabsdetjac.(x, b.a, b.b) -logabsdetjac(b::Logit{1}, x::AbstractVector) = sum(logit_logabsdetjac.(x, b.a, b.b)) -logabsdetjac(b::Logit{1}, x::AbstractMatrix) = vec(sum(logit_logabsdetjac.(x, b.a, b.b), dims = 1)) -logabsdetjac(b::Logit{2}, x::AbstractMatrix) = sum(logit_logabsdetjac.(x, b.a, b.b)) -logabsdetjac(b::Logit{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) -end +logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) From 45ff3646d0aec0d0dc49a2fc9ba5c437bd0cb1d2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:53:29 +0100 Subject: [PATCH 06/45] removed something that shouldnt be there --- src/interface.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 9ace3b38..bd37555a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -200,14 +200,6 @@ logabsdetjac!(::Identity, x, logjac) = logjac #################### # Batched versions # #################### -# NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. -# so we can actually reference them. Since we just did this for `Identity`, we're good. -Broadcast.broadcasted(b::Transform, xs::Batch) = transform_batch(b, xs) -Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_batch(b, xs) -Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_batch(b, xs) -Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_batch(b, xs) - - """ transform_batch(b, xs) From eb94e001517cd15d6fed35bb4c7a19765249e3f4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 07:04:54 +0100 Subject: [PATCH 07/45] removed false statement in docstring of Transform --- src/interface.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index bd37555a..8c9b2b56 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -61,8 +61,6 @@ Finally, there are _batched_ versions of the above methods which can _optionally - [`forward_batch`](@ref) and similarly for the mutating versions. Default implementations depends on the type of `xs`. -Note that these methods are usually used through broadcasting, i.e. `b.(x)` with `x` a `AbstractBatch` -falls back to `transform_batch(b, x)`. """ abstract type Transform end From 0d8783f37108a9ad56a81cb243319ac53f2a13db Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 08:22:58 +0100 Subject: [PATCH 08/45] fixed a typo in implementation of logabsdetjac_batch --- src/bijectors/composed.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 33f9551e..5cce6d20 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -199,9 +199,9 @@ function logabsdetjac(cb::Composed, x) end function logabsdetjac_batch(cb::Composed, x) - init = forward(cb.ts[1], x) + init = forward_batch(cb.ts[1], x) result = reduce(cb.ts[2:end]; init = init) do (y, logjac), b - return forward(b, y) + return forward_batch(b, y) end return result.logabsdetjac From 8f9988e1de65649c1be8e03cb6caab6d16e58739 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 08:31:48 +0100 Subject: [PATCH 09/45] added types for representing batches --- src/Bijectors.jl | 2 ++ src/batch.jl | 88 ++++++++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 3 ++ 3 files changed, 93 insertions(+) create mode 100644 src/batch.jl diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 955d0be7..94b10e43 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -33,6 +33,7 @@ using Reexport, Requires using StatsFuns using LinearAlgebra using MappedArrays +using ConstructionBase using Base.Iterators: drop using LinearAlgebra: AbstractTriangular import Functors @@ -244,6 +245,7 @@ function getlogp(d::InverseWishart, Xcf, X) end include("utils.jl") +include("batch.jl") include("interface.jl") include("chainrules.jl") diff --git a/src/batch.jl b/src/batch.jl new file mode 100644 index 00000000..4c5a512b --- /dev/null +++ b/src/batch.jl @@ -0,0 +1,88 @@ +abstract type AbstractBatch{T} <: AbstractVector{T} end + +""" + value(x) + +Returns the underlying storage used for the entire batch. + +If `x` is not `AbstractBatch`, then this is the identity function. +""" +value(x) = x + +struct Batch{V, T} <: AbstractBatch{T} + value::V +end + +value(x::Batch) = x.value + +# Convenient aliases +const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}} +const VectorBatch = Batch{<:AbstractVector{<:AbstractArray{<:Real}}} + +# Constructor for `ArrayBatch`. +Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x) +function Batch(x::AbstractArray{<:Real}) + V = typeof(x) + # HACK: This assumes the batch is non-empty. + T = typeof(getindex_for_last(x, 1)) + return Batch{V, T}(x) +end + +# Constructor for `VectorBatch`. +Batch(x::AbstractVector{<:AbstractArray}) = Batch{typeof(x), eltype(x)}(x) + +# `AbstractVector` interface. +Base.size(batch::Batch{<:AbstractArray{<:Any, N}}) where {N} = (size(value(batch), N), ) + +# Since impl inherited from `AbstractVector` doesn't do exactly what we want. +Base.similar(b::AbstractBatch) = reconstruct(b, similar(value(b))) + +# For `VectorBatch` +Base.getindex(batch::VectorBatch, i::Int) = value(batch)[i] +Base.getindex(batch::VectorBatch, i::CartesianIndex{1}) = value(batch)[i] +Base.getindex(batch::VectorBatch, i) = Batch(value(batch)[i]) +function Base.setindex!(batch::VectorBatch, v, i) + # `v` can also be a `Batch`. + Base.setindex!(value(batch), value(v), i) + return batch +end + +# For `ArrayBatch` +@generated function getindex_for_last(x::AbstractArray{<:Any, N}, inds) where {N} + e = Expr(:call) + push!(e.args, :(Base.view)) + push!(e.args, :x) + + for i = 1:N - 1 + push!(e.args, :(:)) + end + + push!(e.args, :(inds)) + + return e +end + +@generated function setindex_for_last!(out::AbstractArray{<:Any, N}, x, inds) where {N} + e = Expr(:call) + push!(e.args, :(Base.setindex!)) + push!(e.args, :out) + push!(e.args, :x) + + for i = 1:N - 1 + push!(e.args, :(:)) + end + + push!(e.args, :(inds)) + + return e +end + +# General arrays. +Base.getindex(batch::ArrayBatch, i::Int) = getindex_for_last(value(batch), i) +Base.getindex(batch::ArrayBatch, i::CartesianIndex{1}) = getindex_for_last(value(batch), i) +Base.getindex(batch::ArrayBatch, i) = Batch(getindex_for_last(value(batch), i)) +function Base.setindex!(batch::ArrayBatch, v, i) + # `v` can also be a `Batch`. + setindex_for_last!(value(batch), value(v), i) + return batch +end diff --git a/src/utils.jl b/src/utils.jl index 8203e1b4..d1034dbc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,3 +6,6 @@ aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b) # flatten arrays with fallback for scalars _vec(x::AbstractArray{<:Real}) = vec(x) _vec(x::Real) = x + +# Useful for reconstructing objects. +reconstruct(b, args...) = constructorof(typeof(b))(args...) From 9fa37d1f66858460e85a12bf06aa21a07475372a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 08:33:05 +0100 Subject: [PATCH 10/45] make it possible to use broadcasting for working with batches --- src/interface.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 8c9b2b56..9ace3b38 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -61,6 +61,8 @@ Finally, there are _batched_ versions of the above methods which can _optionally - [`forward_batch`](@ref) and similarly for the mutating versions. Default implementations depends on the type of `xs`. +Note that these methods are usually used through broadcasting, i.e. `b.(x)` with `x` a `AbstractBatch` +falls back to `transform_batch(b, x)`. """ abstract type Transform end @@ -198,6 +200,14 @@ logabsdetjac!(::Identity, x, logjac) = logjac #################### # Batched versions # #################### +# NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. +# so we can actually reference them. Since we just did this for `Identity`, we're good. +Broadcast.broadcasted(b::Transform, xs::Batch) = transform_batch(b, xs) +Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_batch(b, xs) +Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_batch(b, xs) +Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_batch(b, xs) + + """ transform_batch(b, xs) From 168dd43ed5964a2fb0af27d276ad4c2b2a4d04ad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 09:22:59 +0100 Subject: [PATCH 11/45] updated SimplexBijector to new interface, I think --- src/bijectors/simplex.jl | 119 ++++++++++++++++----------------------- 1 file changed, 50 insertions(+), 69 deletions(-) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index 0cf0ff38..f9f83d85 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -1,20 +1,16 @@ #################### # Simplex bijector # #################### -struct SimplexBijector{N, T} <: Bijector{N} end -SimplexBijector() = SimplexBijector{1}() -SimplexBijector{N}() where {N} = SimplexBijector{N,true}() +struct SimplexBijector{T} <: Bijector end +SimplexBijector() = SimplexBijector{true}() -# Special case `N = 1` -SimplexBijector{true}() = SimplexBijector{1,true}() -SimplexBijector{false}() = SimplexBijector{1,false}() +transform(b::SimplexBijector, x) = _simplex_bijector(x, b) +transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b) -(b::SimplexBijector{1})(x::AbstractVector) = _simplex_bijector(x, b) -(b::SimplexBijector{1})(y::AbstractVector, x::AbstractVector) = _simplex_bijector!(y, x, b) -function _simplex_bijector(x::AbstractVector, b::SimplexBijector{1}) - return _simplex_bijector!(similar(x), x, b) -end -function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{1, proj}) where {proj} +_simplex_bijector(x::AbstractArray, b::SimplexBijector) = _simplex_bijector!(similar(x), x, b) + +# Vector implementation. +function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj} K = length(x) @assert K > 1 "x needs to be of length greater than 1" T = eltype(x) @@ -39,24 +35,19 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{1, proj}) wh return y end -# Vectorised implementation of the above. -function (b::SimplexBijector{1})(X::AbstractMatrix) - _simplex_bijector(X, b) -end -function (b::SimplexBijector{1})( - Y::AbstractMatrix, - X::AbstractMatrix, -) - _simplex_bijector!(Y, X, b) +function transform_batch(b::SimplexBijector, X::ArrayBatch{2}) + Batch(_simplex_bijector(value(X), b)) end -function (b::SimplexBijector{2, proj})(X::AbstractMatrix) where {proj} - SimplexBijector{1, proj}()(X) -end -(b::SimplexBijector{2})(X::AbstractArray{<:AbstractMatrix}) = map(b, X) -function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) - _simplex_bijector!(similar(X), X, b) +function transform_batch!( + b::SimplexBijector, + Y::Batch{<:AbstractMatrix{T}}, + X::Batch{<:AbstractMatrix{T}}, +) where {T} + Batch(_simplex_bijector!(value(Y), value(X), b)) end -function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{1, proj}) where {proj} + +# Matrix implementation. +function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj} K, N = size(X, 1), size(X, 2) @assert K > 1 "x needs to be of length greater than 1" T = eltype(X) @@ -81,19 +72,19 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{1, proj}) wh return Y end -function (ib::Inverse{<:SimplexBijector{1, proj}})(y::AbstractVector{T}) where {T, proj} - _simplex_inv_bijector(y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{1}})( - x::AbstractVector{T}, - y::AbstractVector{T}, +# Inverse. +transform(ib::Inverse{<:SimplexBijector}, y::AbstractArray) = _simplex_inv_bijector(y, ib.orig) +function transform!( + ib::Inverse{<:SimplexBijector}, + x::AbstractArray{T}, + y::AbstractArray{T}, ) where {T} - _simplex_inv_bijector!(x, y, ib.orig) + return _simplex_inv_bijector!(x, y, ib.orig) end -function _simplex_inv_bijector(y::AbstractVector, b::SimplexBijector{1}) - return _simplex_inv_bijector!(similar(y), y, b) -end -function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{1, proj}) where {proj} + +_simplex_inv_bijector(y, b) = _simplex_inv_bijector!(similar(y), y, b) + +function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) where {proj} K = length(y) @assert K > 1 "x needs to be of length greater than 1" T = eltype(y) @@ -116,27 +107,17 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{1, proj return x end -# Vectorised implementation of the above. -function (ib::Inverse{<:SimplexBijector{1}})(Y::AbstractMatrix) - _simplex_inv_bijector(Y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{1}})( - X::AbstractMatrix{T}, - Y::AbstractMatrix{T}, -) where {T <: Real} - _simplex_inv_bijector!(X, Y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{2, proj}})(Y::AbstractMatrix) where {proj} - inv(SimplexBijector{1, proj}())(Y) -end -function (ib::Inverse{<:SimplexBijector{2, proj}})(X::AbstractMatrix, Y::AbstractMatrix) where {proj} - inv(SimplexBijector{1, proj}())(X, Y) +# Batched versions. +transform_batch(ib::Inverse{<:SimplexBijector}, Y::ArrayBatch{2}) = _simplex_inv_bijector(Y, ib.orig) +function transform_batch!( + ib::Inverse{<:SimplexBijector} + X::Batch{<:AbstractMatrix{T}}, + Y::Batch{<:AbstractMatrix{T}}, +) where {T<:Real} + return Batch(_simplex_inv_bijector!(value(X), value(Y), ib.orig)) end -(ib::Inverse{<:SimplexBijector{2}})(Y::AbstractArray{<:AbstractMatrix}) = map(ib, Y) -function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) - _simplex_inv_bijector!(similar(Y), Y, b) -end -function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{1, proj}) where {proj} + +function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj} K, N = size(Y, 1), size(Y, 2) @assert K > 1 "x needs to be of length greater than 1" T = eltype(Y) @@ -160,7 +141,7 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{1, proj return X end -function logabsdetjac(b::SimplexBijector{1}, x::AbstractVector{T}) where {T} +function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where {T} ϵ = _eps(T) lp = zero(T) @@ -211,7 +192,9 @@ function simplex_logabsdetjac_gradient(x::AbstractVector) end return g end -function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix{T}) where {T} +function logabsdetjac_batch(b::SimplexBijector, x::Batch{<:AbstractMatrix{T}}) where {T} + x = value(x) + ϵ = _eps(T) nlp = similar(x, T, size(x, 2)) nlp .= zero(T) @@ -227,14 +210,12 @@ function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix{T}) where {T} nlp[col] -= log(max(z, ϵ)) + log(max(one(T) - z, ϵ)) + log(max(one(T) - sum_tmp, ϵ)) end end - return nlp + return Batch(nlp) end -function logabsdetjac(b::SimplexBijector{2, proj}, x::AbstractMatrix) where {proj} - return sum(logabsdetjac(SimplexBijector{1, proj}(), x)) -end -function logabsdetjac(b::SimplexBijector{2}, x::AbstractArray{<:AbstractMatrix}) - return map(x -> logabsdetjac(b, x), x) +function logabsdetjac(b::SimplexBijector, x::AbstractMatrix) + return sum(value(logabsdetjac(b, Batch(x)))) end + function simplex_logabsdetjac_gradient(x::AbstractMatrix) T = eltype(x) ϵ = _eps(T) @@ -303,7 +284,7 @@ function simplex_link_jacobian( end return UpperTriangular(dydxt)' end -function jacobian(b::SimplexBijector{1, proj}, x::AbstractVector{T}) where {proj, T} +function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj, T} return simplex_link_jacobian(x, Val(proj)) end @@ -425,7 +406,7 @@ function simplex_invlink_jacobian( return LowerTriangular(dxdy) end # jacobian -function jacobian(ib::Inverse{<:SimplexBijector{1, proj}}, y::AbstractVector{T}) where {proj, T} +function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj, T} return simplex_invlink_jacobian(y, Val(proj)) end From d44cf420180d1387e48f498571a8e0d56ec5bb3c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 09:23:22 +0100 Subject: [PATCH 12/45] updated PDBijector to new interface --- src/bijectors/pd.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index 332a8b36..deac1bd6 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -1,4 +1,4 @@ -struct PDBijector <: Bijector{2} end +struct PDBijector <: Bijector end # This function has custom adjoints defined for Tracker, Zygote and ReverseDiff. # I couldn't find a mutation-free implementation that maintains TrackedArrays in Tracker @@ -7,19 +7,17 @@ function replace_diag(f, X) g(i, j) = ifelse(i == j, f(X[i, i]), X[i, j]) return g.(1:size(X, 1), (1:size(X, 2))') end -(b::PDBijector)(X::AbstractMatrix{<:Real}) = pd_link(X) +transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X) function pd_link(X) Y = lower(parent(cholesky(X; check = true).L)) return replace_diag(log, Y) end -(b::PDBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X) lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) -function (ib::Inverse{<:PDBijector})(Y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{<:PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) return getpd(X) end -(ib::Inverse{<:PDBijector})(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, X) getpd(X) = LowerTriangular(X) * LowerTriangular(X)' function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) @@ -37,7 +35,3 @@ function logabsdetjac(b::PDBijector, Xcf::Cholesky) d = size(U, 1) return - sum((d .- (1:d) .+ 2) .* log.(diag(U))) - d * log(T(2)) end - -logabsdetjac(b::PDBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) = mapvcat(X) do x - logabsdetjac(b, x) -end From c719b076dcb33c4aa1a61d07b32e8d5ee6342f40 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 09:24:42 +0100 Subject: [PATCH 13/45] use transform_batch rather than broadcasting --- src/bijectors/composed.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 5cce6d20..3a0e5a9e 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -163,7 +163,7 @@ function transform_batch(cb::Composed, x) @assert length(cb.ts) > 0 res = cb.ts[1].(x) for b ∈ Base.Iterators.drop(cb.ts, 1) - res = b.(res) + res = transform_batch(b, res) end return res From 0a62e96f2dff33d796436b40849f7078f635a4c9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:19 +0100 Subject: [PATCH 14/45] added default implementations for batches --- src/interface.jl | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 9ace3b38..cfcaf7c8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -215,7 +215,19 @@ Transform `xs` by `b`, treating `xs` as a "batch", i.e. a collection of independ See also: [`transform`](@ref) """ -transform_batch +transform_batch(b, xs) = _transform_batch(b, xs) +# Default implementations uses private methods to avoid method ambiguity. +_transform_batch(b, xs::VectorBatch) = reconstruct(xs, map(b, value(xs))) +function _transform_batch(b, xs::ArrayBatch{2}) + # TODO: Check if we can avoid using these custom methods. + return eachcolmaphcat(b, x) +end +function _transform_batch(b, xs::ArrayBatch{N}) where {N} + res = reduce(map(b, eachslice(value(xs), dims=N))) do acc, x + cat(acc, x; dims = N) + end + return reconstruct(xs, res) +end """ logabsdetjac_batch(b, xs) @@ -224,7 +236,12 @@ Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of See also: [`logabsdetjac`](@ref) """ -logabsdetjac_batch +logabsdetjac_batch(b, xs) = _logabsdetjac_batch(b, xs) +# Default implementations uses private methods to avoid method ambiguity. +_logabsdetjac_batch(b, xs::VectorBatch) = reconstruct(xs, map(x -> logabsdetjac(b, x), value(xs))) +function _logabsdetjac_batch(b, xs::ArrayBatch{N}) where {N} + return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), dims=N))) +end """ forward_batch(b, xs) From 21a66ab419bdea329a7a784169f77e44d7313260 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:35 +0100 Subject: [PATCH 15/45] updated ADBijector to new interface --- src/bijectors/adbijector.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/adbijector.jl b/src/bijectors/adbijector.jl index b7596b1d..135f2bba 100644 --- a/src/bijectors/adbijector.jl +++ b/src/bijectors/adbijector.jl @@ -2,7 +2,7 @@ Abstract type for a `Bijector{N}` making use of auto-differentation (AD) to implement `jacobian` and, by impliciation, `logabsdetjac`. """ -abstract type ADBijector{AD, N} <: Bijector{N} end +abstract type ADBijector{AD} <: Bijector end struct SingularJacobianException{B<:Bijector} <: Exception b::B @@ -24,4 +24,4 @@ end function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) fact = lu(jacobian(b, x), check=false) return issuccess(fact) ? logabsdet(fact)[1] : throw(SingularJacobianException(b)) -end \ No newline at end of file +end From 0b04d1497f3dac7db102dcdeeb8c86d0f65f697a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:42 +0100 Subject: [PATCH 16/45] updated CorrBijector to new interface --- src/bijectors/corr.jl | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f7eeaee2..a051ef17 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -61,9 +61,9 @@ Note: The implementation doesn't follow their "manageable expression" directly, because their equation seems wrong (7/30/2020). Insteadly it follows definition above the "manageable expression" directly, which is also described in above doc. """ -struct CorrBijector <: Bijector{2} end +struct CorrBijector <: Bijector end -function (b::CorrBijector)(x::AbstractMatrix{<:Real}) +function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) w = cholesky(x).U # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) return r + zero(x) @@ -71,16 +71,12 @@ function (b::CorrBijector)(x::AbstractMatrix{<:Real}) # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 end -(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X) - -function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{<:CorrBijector}, y::AbstractMatrix{<:Real}) w = _inv_link_chol_lkj(y) return w' * w end -(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y) - -function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) +function logabsdetjac(::Inverse{<:CorrBijector}, y::AbstractMatrix{<:Real}) K = LinearAlgebra.checksquare(y) result = float(zero(eltype(y))) @@ -100,12 +96,6 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) =# return -logabsdetjac(inv(b), (b(X))) end -function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) - return mapvcat(X) do x - logabsdetjac(b, x) - end -end - function _inv_link_chol_lkj(y) K = LinearAlgebra.checksquare(y) From 2cfd24b27b4b858023342d0979a1b32085dee70c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:49 +0100 Subject: [PATCH 17/45] updated Coupling to new interface --- src/bijectors/coupling.jl | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 088caf2a..8cd3490a 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -167,7 +167,7 @@ Shift{Array{Float64,1},1}([2.0]) # References [1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). """ -struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask} +struct Coupling{F, M} <: Bijector where {F, M <: PartitionMask} θ::F mask::M end @@ -195,7 +195,7 @@ function couple(cl::Coupling, x::AbstractVector) return b end -function (cl::Coupling)(x::AbstractVector) +function transform(cl::Coupling, x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) @@ -205,10 +205,8 @@ function (cl::Coupling)(x::AbstractVector) # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end -(cl::Coupling)(x::AbstractMatrix) = eachcolmaphcat(cl, x) - -function (icl::Inverse{<:Coupling})(y::AbstractVector) +function transform(icl::Inverse{<:Coupling}, y::AbstractVector) cl = icl.orig y_1, y_2, y_3 = partition(cl.mask, y) @@ -218,7 +216,6 @@ function (icl::Inverse{<:Coupling})(y::AbstractVector) return combine(cl.mask, ib(y_1), y_2, y_3) end -(icl::Inverse{<:Coupling})(y::AbstractMatrix) = eachcolmaphcat(icl, y) function logabsdetjac(cl::Coupling, x::AbstractVector) x_1, x_2, x_3 = partition(cl.mask, x) @@ -228,7 +225,3 @@ function logabsdetjac(cl::Coupling, x::AbstractVector) # therefore we sum to ensure such a thing does not happen return sum(logabsdetjac(b, x_1)) end - -function logabsdetjac(cl::Coupling, x::AbstractMatrix) - return [logabsdetjac(cl, view(x, :, i)) for i in axes(x, 2)] -end From c272cd45d84044b2d026bcf0b5f9ecd35c3f584b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:59 +0100 Subject: [PATCH 18/45] updated LeakyReLU to new interface --- src/bijectors/leaky_relu.jl | 82 ++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index c9128f01..e7913038 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -1,5 +1,5 @@ """ - LeakyReLU{T, N}(α::T) <: Bijector{N} + LeakyReLU{T}(α::T) <: Bijector Defines the invertible mapping @@ -7,95 +7,95 @@ Defines the invertible mapping where α > 0. """ -struct LeakyReLU{T, N} <: Bijector{N} +struct LeakyReLU{T} <: Bijector α::T end -LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α) -LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α) +Functors.@functor LeakyReLU -# field is a numerical parameter -function Functors.functor(::Type{LeakyReLU{<:Any,N}}, x) where N - function reconstruct_leakyrelu(xs) - return LeakyReLU{typeof(xs.α),N}(xs.α) - end - return (α = x.α,), reconstruct_leakyrelu -end - -up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α) +Base.inv(b::LeakyReLU) = LeakyReLU(inv.(b.α)) # (N=0) Univariate case -function (b::LeakyReLU{<:Any, 0})(x::Real) +function transform(b::LeakyReLU, x::Real) mask = x < zero(x) return mask * b.α * x + !mask * x end -(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) -function Base.inv(b::LeakyReLU{<:Any,N}) where N - invα = inv.(b.α) - return LeakyReLU{typeof(invα),N}(invα) -end - -function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real) +function logabsdetjac(b::LeakyReLU, x::Real) mask = x < zero(x) J = mask * b.α + (1 - mask) * one(x) return log(abs(J)) end -logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x) - # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU{<:Any, 0}, x::Real) +function forward(b::LeakyReLU, x::Real) mask = x < zero(x) J = mask * b.α + !mask * one(x) return (result=J * x, logabsdetjac=log(abs(J))) end -# Batched version -function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) +function forward_batch(b::LeakyReLU, xs::Batch{<:AbstractVector}) + x = value(xs) + J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - return (result=J .* x, logabsdetjac=log.(abs.(J))) + return (result=Batch(J .* x), logabsdetjac=Batch(log.(abs.(J)))) end -# (N=1) Multivariate case -function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat) +# Array inputs. +function transform(b::LeakyReLU, x::AbstractArray) return let z = zero(eltype(x)) @. (x < z) * b.α * x + (x > z) * x end end -function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) +function logabsdetjac(b::LeakyReLU, x::AbstractArray) + return sum(value(logabsdetjac_batch(b, Batch(x)))) +end + +function logabsdetjac_batch(b::LeakyReLU, xs::ArrayBatch{N}) where {N} + x = value(xs) + # Is really diagonal of jacobian J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - - if x isa AbstractVector - return sum(log.(abs.(J))) - elseif x isa AbstractMatrix - return vec(sum(log.(abs.(J)); dims = 1)) # sum along column + + logjac = if N ≤ 1 + sum(log ∘ abs, J) + else + vec(sum(map(log ∘ abs, J); dims = 1:N - 1)) end + + return Batch(logjac) end # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) +function forward(b::LeakyReLU, x::AbstractArray) + y, logjac = forward_batch(b, Batch(x)) + + return (result = value(y), logabsdetjac = sum(value(logjac))) +end + +function forward_batch(b::LeakyReLU, xs::ArrayBatch{N}) where {N} + x = value(xs) + # Is really diagonal of jacobian J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - if x isa AbstractVector - logjac = sum(log.(abs.(J))) - elseif x isa AbstractMatrix - logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column + logjac = if N ≤ 1 + sum(log ∘ abs, J) + else + vec(sum(map(log ∘ abs, J); dims = 1:N - 1)) end y = J .* x - return (result=y, logabsdetjac=logjac) + return (result=Batch(y), logabsdetjac=Batch(logjac)) end From 5e2a585be79a3b28033b889975c2c254dcaa70ae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:05 +0100 Subject: [PATCH 19/45] updated NamedBijector to new interface --- src/bijectors/named_bijector.jl | 120 +------------------------------- 1 file changed, 2 insertions(+), 118 deletions(-) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 500044b9..776e6e96 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,4 +1,4 @@ -abstract type AbstractNamedBijector <: AbstractBijector end +abstract type AbstractNamedBijector <: Bijector end forward(b::AbstractNamedBijector, x) = (result = b(x), logabsdetjac = logabsdetjac(b, x)) @@ -64,122 +64,6 @@ end return :(+($(exprs...))) end - -###################### -### `NamedInverse` ### -###################### -""" - NamedInverse <: AbstractNamedBijector - -Represents the inverse of a `AbstractNamedBijector`, similarily to `Inverse` for `Bijector`. - -See also: [`Inverse`](@ref) -""" -struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector - orig::B -end -Base.inv(nb::AbstractNamedBijector) = NamedInverse(nb) -Base.inv(ni::NamedInverse) = ni.orig - -logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inv(ni), ni(y)) - -########################## -### `NamedComposition` ### -########################## -""" - NamedComposition <: AbstractNamedBijector - -Wraps a tuple of array of `AbstractNamedBijector` and implements their composition. - -This is very similar to `Composed` for `Bijector`, with the exception that we do not require -the inputs to have the same "dimension", which in this case refers to the *symbols* for the -`NamedTuple` that this takes as input. - -See also: [`Composed`](@ref) -""" -struct NamedComposition{Bs} <: AbstractNamedBijector - bs::Bs -end - -# Essentially just copy-paste from impl of composition for 'standard' bijectors, -# with minor changes here and there. -composel(bs::AbstractNamedBijector...) = NamedComposition(bs) -composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs)) -∘(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1) - -inv(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs))) - -function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x) - @assert length(cb.bs) > 0 - res = cb.bs[1](x) - for b ∈ Base.Iterators.drop(cb.bs, 1) - res = b(res) - end - - return res -end - -(cb::NamedComposition{<:Tuple})(x) = foldl(|>, cb.bs; init=x) - -function logabsdetjac(cb::NamedComposition, x) - y, logjac = forward(cb.bs[1], x) - for i = 2:length(cb.bs) - res = forward(cb.bs[i], y) - y = res.result - logjac += res.logabsdetjac - end - - return logjac -end - -@generated function logabsdetjac(cb::NamedComposition{T}, x) where {T<:Tuple} - N = length(T.parameters) - - expr = Expr(:block) - push!(expr.args, :((y, logjac) = forward(cb.bs[1], x))) - - for i = 2:N - 1 - temp = gensym(:res) - push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.result)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) - end - # don't need to evaluate the last bijector, only it's `logabsdetjac` - push!(expr.args, :(logjac += logabsdetjac(cb.bs[$N], y))) - - push!(expr.args, :(return logjac)) - - return expr -end - - -function forward(cb::NamedComposition, x) - result, logjac = forward(cb.bs[1], x) - - for t in cb.bs[2:end] - res = forward(t, result) - result = res.result - logjac = res.logabsdetjac + logjac - end - return (result=result, logabsdetjac=logjac) -end - - -@generated function forward(cb::NamedComposition{T}, x) where {T<:Tuple} - expr = Expr(:block) - push!(expr.args, :((y, logjac) = forward(cb.bs[1], x))) - for i = 2:length(T.parameters) - temp = gensym(:temp) - push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.result)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) - end - push!(expr.args, :(return (result = y, logabsdetjac = logjac))) - - return expr -end - - ############################ ### `NamedCouplingLayer` ### ############################ @@ -227,7 +111,7 @@ deps(b::NamedCoupling{<:Any, Deps}) where {Deps} = Deps end end -@generated function (ni::NamedInverse{<:NamedCoupling{target, deps, F}})( +@generated function (ni::Inverse{<:NamedCoupling{target, deps, F}})( x::NamedTuple ) where {target, deps, F} return quote From 8e41a50303730a0fa97bd73c850e38448cd1b0fe Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:12 +0100 Subject: [PATCH 20/45] updated BatchNormalisation to new interface --- src/bijectors/normalise.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 16b48a29..7f9a0a67 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -6,7 +6,7 @@ using Statistics: mean istraining() = false -mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1} +mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector b :: T1 # bias logs :: T1 # log-scale m :: T2 # moving mean @@ -80,8 +80,7 @@ function forward(bn::InvertibleBatchNorm, x) end logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac - -(bn::InvertibleBatchNorm)(x) = forward(bn, x).result +transform(bn::InvertibleBatchNorm, x) = forward(bn, x).result function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -97,7 +96,7 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) return (result=x, logabsdetjac=-logabsdetjac(bn, x)) end -(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).result +transform(bn::Inverse{<:InvertibleBatchNorm}, y) = forward(bn, y).result function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") From 9f45b16ebcffeb0cfc154d780e9304961a3a994c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:29 +0100 Subject: [PATCH 21/45] updated Permute to new interface --- src/bijectors/permute.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index d4fdef7b..7ac1d454 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -81,7 +81,7 @@ julia> inv(b1)(b1([1., 2., 3.])) 3.0 ``` """ -struct Permute{A} <: Bijector{1} +struct Permute{A} <: Bijector A::A end @@ -150,8 +150,8 @@ function Permute(n::Int, indices::Pair{Vector{Int}, Vector{Int}}...) end -@inline (b::Permute)(x::AbstractVecOrMat) = b.A * x +@inline transform(b::Permute, x::AbstractVecOrMat) = b.A * x @inline inv(b::Permute) = Permute(transpose(b.A)) logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x)) -logabsdetjac(b::Permute, x::AbstractMatrix) = zero(eltype(x), size(x, 2)) +logabsdetjac_batch(b::Permute, x::Batch) = zero(eltype(x), length(x)) From f793dadbebef7928d8476ea4ad71f0adb1f35a55 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:35 +0100 Subject: [PATCH 22/45] updated PlanarLayer to new interface --- src/bijectors/planar_layer.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 9d1a73d0..38bce1a7 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -14,7 +14,7 @@ using NNlib: softplus # TODO: add docstring -struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector{1} +struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector w::T1 u::T1 b::T2 @@ -78,7 +78,7 @@ function _transform(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) return (transformed = transformed, wT_û = wT_û, wT_z = wT_z) end -(b::PlanarLayer)(z) = _transform(b, z).transformed +transform(b::PlanarLayer, z) = _transform(b, z).transformed #= Log-determinant of the Jacobian of the planar layer @@ -108,7 +108,7 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) return (result = transformed, logabsdetjac = log_det_jacobian) end -function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real}) +function transform(ib::Inverse{<:PlanarLayer}, y::AbstractVecOrMat{<:Real}) flow = ib.orig w = flow.w b = first(flow.b) From 19d1ef1e94d1e03c81262267ed56a09d7a391520 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:42 +0100 Subject: [PATCH 23/45] updated RadialLayer to new interface --- src/bijectors/radial_layer.jl | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index a5cb5f34..1d95c4ef 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -12,7 +12,7 @@ using NNlib: softplus # RadialLayer # ############### -mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector{1} +mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector α_::T1 β::T1 z_0::T2 @@ -50,8 +50,10 @@ function _radial_transform(α_, β, z_0, z) return (transformed = transformed, α = α, β_hat = β_hat, r = r) end -(b::RadialLayer)(z::AbstractMatrix{<:Real}) = _transform(b, z).transformed -(b::RadialLayer)(z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) +transform(b::RadialLayer, z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) +transform(b::RadialLayer, z::AbstractMatrix{<:Real}) = _transform(b, z).transformed + +transform_batch(b::RadialLayer, z::ArrayBatch{2}) = Batch(transform(b, value(z))) function forward(flow::RadialLayer, z::AbstractVecOrMat) transformed, α, β_hat, r = _transform(flow, z) @@ -70,7 +72,12 @@ function forward(flow::RadialLayer, z::AbstractVecOrMat) return (result = transformed, logabsdetjac = log_det_jacobian) end -function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) +function forward_batch(b::RadialLayer, z::ArrayBatch{2}) + result, logjac = forward(b, value(z)) + return (result = Batch(result), logabsdetjac = Batch(logjac)) +end + +function transform(ib::Inverse{<:RadialLayer}, y::AbstractVector{<:Real}) flow = ib.orig z0 = flow.z_0 α = softplus(first(flow.α_)) # from A.2 @@ -84,7 +91,7 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) return z0 .+ γ .* y_minus_z0 end -function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{<:RadialLayer}, y::AbstractMatrix{<:Real}) flow = ib.orig z0 = flow.z_0 α = softplus(first(flow.α_)) # from A.2 @@ -100,6 +107,10 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) return z0 .+ γ .* y_minus_z0 end +function transform_batch(ib::Inverse{<:RadialLayer}, y::ArrayBatch{2}) + return Batch(transform(ib, value(y))) +end + """ compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) @@ -120,7 +131,7 @@ where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows. arXiv:1505.05770 """ -function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) +function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_test/β_hat) γ = norm(y_minus_z0) a = α_plus_β_hat - γ r = (sqrt(a^2 + 4 * α * γ) - a) / 2 @@ -128,3 +139,7 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) end logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac + +function logabsdetjac_batch(flow::RadialLayer, x::ArrayBatch{2}) + return Batch(logabsdetjac(flow, value(x))) +end From 50155514e4a660807b2658ac63756176a9af6f35 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:50 +0100 Subject: [PATCH 24/45] updated RationalQuadraticSpline to new interface --- src/bijectors/rational_quadratic_spline.jl | 43 ++++++++++------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index a0907b82..0dea3a7b 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -1,8 +1,7 @@ using NNlib """ - RationalQuadraticSpline{T, 0} <: Bijector{0} - RationalQuadraticSpline{T, 1} <: Bijector{1} + RationalQuadraticSpline{T} <: Bijector Implementation of the Rational Quadratic Spline flow [1]. @@ -77,7 +76,7 @@ julia> b([-1., 5.]) # References [1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019). """ -struct RationalQuadraticSpline{T, N} <: Bijector{N} +struct RationalQuadraticSpline{T} <: Bijector widths::T # K widths heights::T # K heights derivatives::T # K derivatives, with endpoints being ones @@ -179,18 +178,18 @@ end # univariate -function (b::RationalQuadraticSpline{<:AbstractVector, 0})(x::Real) +function transform(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_univariate(b.widths, b.heights, b.derivatives, x) end -(b::RationalQuadraticSpline{<:AbstractVector, 0})(x::AbstractVector) = b.(x) +function transform_batch(b::RationalQuadraticSpline{<:AbstractVector}, x::ArrayBatch{1}) + return Batch(transform.(b, value(x))) +end # multivariate -function (b::RationalQuadraticSpline{<:AbstractMatrix, 1})(x::AbstractVector) +# TODO: Improve. +function transform(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) return [rqs_univariate(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for i = 1:length(x)] end -function (b::RationalQuadraticSpline{<:AbstractMatrix, 1})(x::AbstractMatrix) - return eachcolmaphcat(b, x) -end ########################## ### Inverse evaluation ### @@ -234,18 +233,18 @@ function rqs_univariate_inverse(widths, heights, derivatives, y::Real) return ξ * w + w_k end -function (ib::Inverse{<:RationalQuadraticSpline, 0})(y::Real) +function transform(ib::Inverse{<:RationalQuadraticSpline}, y::Real) return rqs_univariate_inverse(ib.orig.widths, ib.orig.heights, ib.orig.derivatives, y) end -(ib::Inverse{<:RationalQuadraticSpline, 0})(y::AbstractVector) = ib.(y) +function transform_batch(ib::Inverse{<:RationalQuadraticSpline}, y::AbstractVector) + return Batch(transform.(ib, value(y))) +end -function (ib::Inverse{<:RationalQuadraticSpline, 1})(y::AbstractVector) +# TODO: Improve. +function transform(ib::Inverse{<:RationalQuadraticSpline}, y::AbstractVector) b = ib.orig return [rqs_univariate_inverse(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], y[i]) for i = 1:length(y)] end -function (ib::Inverse{<:RationalQuadraticSpline, 1})(y::AbstractMatrix) - return eachcolmaphcat(ib, y) -end ###################### ### `logabsdetjac` ### @@ -315,21 +314,19 @@ function rqs_logabsdetjac( return log(numerator) - 2 * log(denominator) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_logabsdetjac(b.widths, b.heights, b.derivatives, x) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::AbstractVector) - return logabsdetjac.(b, x) +function logabsdetjac_batch(b::RationalQuadraticSpline{<:AbstractVector}, x::ArrayBatch{1}) + return Batch(logabsdetjac.(b, value(x))) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix, 1}, x::AbstractVector) +# TODO: Improve. +function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) return sum([ rqs_logabsdetjac(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for i = 1:length(x) ]) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix, 1}, x::AbstractMatrix) - return mapvcat(x -> logabsdetjac(b, x), eachcol(x)) -end ################# ### `forward` ### @@ -382,6 +379,6 @@ function rqs_forward( return (result = y, logabsdetjac = logjac) end -function forward(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function forward(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end From 3b755267aa3354a40d013fbf3b15a66e20e01afc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:59 +0100 Subject: [PATCH 25/45] updated Scale to new interface --- src/bijectors/scale.jl | 77 ++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 44 deletions(-) diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index e19b33f5..f3f15a15 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -1,57 +1,46 @@ -struct Scale{T, N} <: Bijector{N} +struct Scale{T} <: Bijector a::T end -Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a +Base.:(==)(b1::Scale, b2::Scale) = b1.a == b2.a -function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D - return Scale{typeof(a), D}(a) -end - -# field is a numerical parameter -function Functors.functor(::Type{<:Scale{<:Any,N}}, x) where N - function reconstruct_scale(xs) - return Scale{typeof(xs.a),N}(xs.a) - end - return (a = x.a,), reconstruct_scale -end +Functors.@functor Scale -up1(b::Scale{T, N}) where {N, T} = Scale{T, N + 1}(a) - -(b::Scale)(x) = b.a .* x -(b::Scale{<:AbstractMatrix, 1})(x::AbstractVecOrMat) = b.a * x -(b::Scale{<:AbstractMatrix, 2})(x::AbstractMatrix) = b.a * x -(ib::Inverse{<:Scale})(y) = Scale(inv(ib.orig.a))(y) -(ib::Inverse{<:Scale{<:AbstractVector}})(y) = Scale(inv.(ib.orig.a))(y) -(ib::Inverse{<:Scale{<:AbstractMatrix, 1}})(y::AbstractVecOrMat) = ib.orig.a \ y -(ib::Inverse{<:Scale{<:AbstractMatrix, 2}})(y::AbstractMatrix) = ib.orig.a \ y +transform(b::Scale, x) = b.a .* x +transform(b::Scale{<:AbstractMatrix}, x::AbstractVecOrMat) = b.a * x +transform(ib::Inverse{<:Scale}, y) = transform(Scale(inv(ib.orig.a)), y) +transform(ib::Inverse{<:Scale{<:AbstractVector}}, y) = transform(Scale(inv.(ib.orig.a)), y) +transform(ib::Inverse{<:Scale{<:AbstractMatrix}}, y::AbstractVecOrMat) = ib.orig.a \ y # We're going to implement custom adjoint for this -logabsdetjac(b::Scale{<:Any, N}, x) where {N} = _logabsdetjac_scale(b.a, x, Val(N)) +logabsdetjac(b::Scale, x::Real) = _logabsdetjac_scale(b.a, x, Val(0)) +function logabsdetjac(b::Scale, x::AbstractArray{<:Real, N}) where {N} + return _logabsdetjac_scale(b.a, x, Val(N)) +end +function logabsdetjac_batch(b::Scale, x::ArrayBatch{N}) where {N} + return Batch(_logabsdetjac_scale(b, value(x), Val(N - 1))) +end + +# Scalar: single input. _logabsdetjac_scale(a::Real, x::Real, ::Val{0}) = log(abs(a)) -_logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) = fill(log(abs(a)), length(x)) _logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{1}) = log(abs(a)) * length(x) -_logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{1}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{2}) = log(abs(a)) * length(x) -_logabsdetjac_scale(a::Real, x::AbstractArray{<:AbstractMatrix}, ::Val{2}) = map(x) do x - _logabsdetjac_scale(a, x, Val(2)) -end -_logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) = sum(x -> log(abs(x)), a) -_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{1}) = fill(sum(x -> log(abs(x)), a), size(x, 2)) -_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(x -> log(abs(x)), a) -_logabsdetjac_scale(a::AbstractVector, x::AbstractArray{<:AbstractMatrix}, ::Val{2}) = map(x) do x - _logabsdetjac_scale(a, x, Val(2)) -end + +# Scalar: batch. +_logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) = fill(log(abs(a)), length(x)) +_logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{1}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) + +# Vector: single input. +_logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) = sum(log ∘ abs, a) +_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(log ∘ abs, a) + +# Vector: batch. +_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{1}) = fill(sum(log ∘ abs, a), size(x, 2)) + +# Matrix: single input. _logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Val{1}) = logabsdet(a)[1] -_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Val{1}) where {T} = logabsdet(a)[1] * ones(T, size(x, 2)) _logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix, ::Val{2}) = logabsdet(a)[1] -function _logabsdetjac_scale( - a::AbstractMatrix, - x::AbstractArray{<:AbstractMatrix}, - ::Val{2}, -) - map(x) do x - _logabsdetjac_scale(a, x, Val(2)) - end -end \ No newline at end of file + +# Matrix: batch. +_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Val{1}) where {T} = logabsdet(a)[1] * ones(T, size(x, 2)) From 195a107f5fa64598c09641fa20d2751fd5c11a18 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:48:04 +0100 Subject: [PATCH 26/45] updated Shift to new interface --- src/bijectors/shift.jl | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index f6c39f82..9aaade56 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -1,33 +1,26 @@ ################# # Shift & Scale # ################# -struct Shift{T, N} <: Bijector{N} +struct Shift{T} <: Bijector a::T end Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a -function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D - return Shift{typeof(a), D}(a) -end - -# field is a numerical parameter -function Functors.functor(::Type{<:Shift{<:Any,N}}, x) where N - function reconstruct_shift(xs) - return Shift{typeof(xs.a),N}(xs.a) - end - return (a = x.a,), reconstruct_shift -end - -up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a) +Functors.@functor Shift -(b::Shift)(x) = b.a .+ x -(b::Shift{<:Any, 2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x) +transform(b::Shift, x) = b.a .+ x -inv(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) +inv(b::Shift) = Shift(-b.a) # FIXME: implement custom adjoint to ensure we don't get tracking -logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val(N)) +function logabsdetjac(b::Shift, x::AbstractArray{<:Real, N}) where {N} + return _logabsdetjac_shift(b.a, x, Val(N)) +end + +function logabsdetjac_batch(b::Shift, x::AbstractArray{<:Real, N}) where {N} + return _logabsdetjac_shift(b.a, x, Val(N - 1)) +end _logabsdetjac_shift(a::Real, x::Real, ::Val{0}) = zero(eltype(x)) _logabsdetjac_shift(a::Real, x::AbstractVector{T}, ::Val{0}) where {T<:Real} = zeros(T, length(x)) From f56ed6ab365a693230e84902b0557dc5bf0250df Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:48:13 +0100 Subject: [PATCH 27/45] updated Stacked to new interface --- src/bijectors/stacked.jl | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 34116616..20104246 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -21,7 +21,7 @@ b = stack(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` """ -struct Stacked{Bs, Rs} <: Bijector{1} +struct Stacked{Bs, Rs} <: Transform bs::Bs ranges::Rs end @@ -48,13 +48,15 @@ end isclosedform(b::Stacked) = all(isclosedform, b.bs) -stack(bs::Bijector{0}...) = Stacked(bs) +invertible(b::Stacked) = sum(map(invertible, b.bs)) + +stack(bs::Bijector...) = Stacked(bs) # For some reason `inv.(sb.bs)` was unstable... This works though. -inv(sb::Stacked) = Stacked(map(inv, sb.bs), sb.ranges) +inv(sb::Stacked, ::Invertible) = Stacked(map(inv, sb.bs), sb.ranges) # map is not type stable for many stacked bijectors as a large tuple # hence the generated function -@generated function inv(sb::Stacked{A}) where {A <: Tuple} +@generated function inv(sb::Stacked{A}, ::Invertible) where {A <: Tuple} exprs = [] for i = 1:length(A.parameters) push!(exprs, :(inv(sb.bs[$i]))) @@ -79,7 +81,7 @@ function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real}) return y end # The Stacked{<:AbstractArray} version is not TrackedArray friendly -function (sb::Stacked{<:AbstractArray})(x::AbstractVector{<:Real}) +function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) N = length(sb.bs) N == 1 && return sb.bs[1](x[sb.ranges[1]]) @@ -90,7 +92,6 @@ function (sb::Stacked{<:AbstractArray})(x::AbstractVector{<:Real}) return y end -(sb::Stacked)(x::AbstractMatrix{<:Real}) = eachcolmaphcat(sb, x) function logabsdetjac( b::Stacked, x::AbstractVector{<:Real} @@ -122,12 +123,6 @@ function logabsdetjac(b::Stacked{<:Tuple{<:Bijector}}, x::AbstractVector{<:Real} return sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) end -function logabsdetjac(b::Stacked, x::AbstractMatrix{<:Real}) - return map(eachcol(x)) do c - logabsdetjac(b, c) - end -end - # Generates something similar to: # # quote From 72d68b8bfa28c72acf372b9323012de3da8e01f6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:48:19 +0100 Subject: [PATCH 28/45] updated TruncatedBijector to new interface --- src/bijectors/truncated.jl | 106 ++++--------------------------------- 1 file changed, 11 insertions(+), 95 deletions(-) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index de5997b2..be8f4f10 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -1,60 +1,22 @@ ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### -struct TruncatedBijector{N, T1, T2} <: Bijector{N} +struct TruncatedBijector{T1, T2} <: Bijector lb::T1 ub::T2 end -TruncatedBijector(lb, ub) = TruncatedBijector{0}(lb, ub) -function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2} - return TruncatedBijector{N, T1, T2}(lb, ub) -end -# field are numerical parameters -function Functors.functor(::Type{<:TruncatedBijector{N}}, x) where N - function reconstruct_truncatedbijector(xs) - return TruncatedBijector{N}(xs.lb, xs.ub) - end - return (lb = x.lb, ub = x.ub,), reconstruct_truncatedbijector -end - -up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub) +Functors.@functor TruncatedBijector function Base.:(==)(b1::TruncatedBijector, b2::TruncatedBijector) return b1.lb == b2.lb && b1.ub == b2.ub end -function (b::TruncatedBijector{0})(x::Real) +function transform(b::TruncatedBijector, x) a, b = b.lb, b.ub - truncated_link(_clamp(x, a, b), a, b) + return truncated_link.(_clamp.(x, a, b), a, b) end -function (b::TruncatedBijector{0})(x::AbstractArray{<:Real}) - a, b = b.lb, b.ub - truncated_link.(_clamp.(x, a, b), a, b) -end -function (b::TruncatedBijector{1})(x::AbstractVecOrMat{<:Real}) - a, b = b.lb, b.ub - if a isa AbstractVector - @assert b isa AbstractVector - maporbroadcast(x, a, b) do x, a, b - truncated_link(_clamp(x, a, b), a, b) - end - else - truncated_link.(_clamp.(x, a, b), a, b) - end -end -function (b::TruncatedBijector{2})(x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - if a isa AbstractMatrix - @assert b isa AbstractMatrix - maporbroadcast(x, a, b) do x, a, b - truncated_link(_clamp(x, a, b), a, b) - end - else - truncated_link.(_clamp.(x, a, b), a, b) - end -end -(b::TruncatedBijector{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) + function truncated_link(x::Real, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -68,37 +30,11 @@ function truncated_link(x::Real, a, b) end end -function (ib::Inverse{<:TruncatedBijector{0}})(y::Real) +function transform(ib::Inverse{<:TruncatedBijector}, y) a, b = ib.orig.lb, ib.orig.ub - _clamp(truncated_invlink(y, a, b), a, b) + return _clamp.(truncated_invlink.(y, a, b), a, b) end -function (ib::Inverse{<:TruncatedBijector{0}})(y::AbstractArray{<:Real}) - a, b = ib.orig.lb, ib.orig.ub - _clamp.(truncated_invlink.(y, a, b), a, b) -end -function (ib::Inverse{<:TruncatedBijector{1}})(y::AbstractVecOrMat{<:Real}) - a, b = ib.orig.lb, ib.orig.ub - if a isa AbstractVector - @assert b isa AbstractVector - maporbroadcast(y, a, b) do y, a, b - _clamp(truncated_invlink(y, a, b), a, b) - end - else - _clamp.(truncated_invlink.(y, a, b), a, b) - end -end -function (ib::Inverse{<:TruncatedBijector{2}})(y::AbstractMatrix{<:Real}) - a, b = ib.orig.lb, ib.orig.ub - if a isa AbstractMatrix - @assert b isa AbstractMatrix - return maporbroadcast(y, a, b) do y, a, b - _clamp(truncated_invlink(y, a, b), a, b) - end - else - return _clamp.(truncated_invlink.(y, a, b), a, b) - end -end -(ib::Inverse{<:TruncatedBijector{2}})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, y) + function truncated_invlink(y, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -112,31 +48,11 @@ function truncated_invlink(y, a, b) end end -function logabsdetjac(b::TruncatedBijector{0}, x::Real) - a, b = b.lb, b.ub - truncated_logabsdetjac(_clamp(x, a, b), a, b) -end -function logabsdetjac(b::TruncatedBijector{0}, x::AbstractArray{<:Real}) +function logabsdetjac(b::TruncatedBijector, x) a, b = b.lb, b.ub - truncated_logabsdetjac.(_clamp.(x, a, b), a, b) -end -function logabsdetjac(b::TruncatedBijector{1}, x::AbstractVector{<:Real}) - a, b = b.lb, b.ub - sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b)) -end -function logabsdetjac(b::TruncatedBijector{1}, x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - vec(sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b), dims = 1)) -end -function logabsdetjac(b::TruncatedBijector{2}, x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b)) -end -function logabsdetjac(b::TruncatedBijector{2}, x::AbstractArray{<:AbstractMatrix{<:Real}}) - map(x) do x - logabsdetjac(b, x) - end + return truncated_logabsdetjac.(_clamp.(x, a, b), a, b) end + function truncated_logabsdetjac(x, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded From 214aa92943cfd6c0375a089c450368e0d6018844 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 12:02:20 +0100 Subject: [PATCH 29/45] added ConstructionBase as dependency --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index a86495c3..5abd4e67 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.9.4" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 836e152b74fff2f019af3f3b63eaab500ffe5833 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 12:02:31 +0100 Subject: [PATCH 30/45] fixed a bunch of small typos and errors from previous commits --- src/bijectors/radial_layer.jl | 2 +- src/bijectors/rational_quadratic_spline.jl | 4 ++-- src/bijectors/shift.jl | 2 +- src/bijectors/simplex.jl | 2 +- src/interface.jl | 2 +- src/transformed_distribution.jl | 22 +++++++++++----------- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 1d95c4ef..35b98b34 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -131,7 +131,7 @@ where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows. arXiv:1505.05770 """ -function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_test/β_hat) +function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_test) γ = norm(y_minus_z0) a = α_plus_β_hat - γ r = (sqrt(a^2 + 4 * α * γ) - a) / 2 diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 0dea3a7b..faa91c5d 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -90,7 +90,7 @@ struct RationalQuadraticSpline{T} <: Bijector @assert length(widths) == length(heights) == length(derivatives) @assert all(derivatives .> 0) "derivatives need to be positive" - return new{T, 0}(widths, heights, derivatives) + return new{T}(widths, heights, derivatives) end function RationalQuadraticSpline( @@ -100,7 +100,7 @@ struct RationalQuadraticSpline{T} <: Bijector ) where {T<:AbstractMatrix} @assert size(widths, 2) == size(heights, 2) == size(derivatives, 2) @assert all(derivatives .> 0) "derivatives need to be positive" - return new{T, 1}(widths, heights, derivatives) + return new{T}(widths, heights, derivatives) end end diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 9aaade56..2773de0d 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -5,7 +5,7 @@ struct Shift{T} <: Bijector a::T end -Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a +Base.:(==)(b1::Shift, b2::Shift) = b1.a == b2.a Functors.@functor Shift diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index f9f83d85..3e990592 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -110,7 +110,7 @@ end # Batched versions. transform_batch(ib::Inverse{<:SimplexBijector}, Y::ArrayBatch{2}) = _simplex_inv_bijector(Y, ib.orig) function transform_batch!( - ib::Inverse{<:SimplexBijector} + ib::Inverse{<:SimplexBijector}, X::Batch{<:AbstractMatrix{T}}, Y::Batch{<:AbstractMatrix{T}}, ) where {T<:Real} diff --git a/src/interface.jl b/src/interface.jl index cfcaf7c8..d9ed8ba9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -220,7 +220,7 @@ transform_batch(b, xs) = _transform_batch(b, xs) _transform_batch(b, xs::VectorBatch) = reconstruct(xs, map(b, value(xs))) function _transform_batch(b, xs::ArrayBatch{2}) # TODO: Check if we can avoid using these custom methods. - return eachcolmaphcat(b, x) + return Batch(eachcolmaphcat(b, value(xs))) end function _transform_batch(b, xs::ArrayBatch{N}) where {N} res = reduce(map(b, eachslice(value(xs), dims=N))) do acc, x diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index d6177c72..ae5ef03f 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -2,12 +2,12 @@ struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector} dist::D transform::B - - TransformedDistribution(d::UnivariateDistribution, b::Bijector{0}) = new{typeof(d), typeof(b), Univariate}(d, b) - TransformedDistribution(d::MultivariateDistribution, b::Bijector{1}) = new{typeof(d), typeof(b), Multivariate}(d, b) - TransformedDistribution(d::MatrixDistribution, b::Bijector{2}) = new{typeof(d), typeof(b), Matrixvariate}(d, b) end +TransformedDistribution(d::UnivariateDistribution, b::Bijector) = new{typeof(d), typeof(b), Univariate}(d, b) +TransformedDistribution(d::MultivariateDistribution, b::Bijector) = new{typeof(d), typeof(b), Multivariate}(d, b) +TransformedDistribution(d::MatrixDistribution, b::Bijector) = new{typeof(d), typeof(b), Matrixvariate}(d, b) + # fields may contain nested numerical parameters Functors.@functor TransformedDistribution @@ -38,9 +38,9 @@ Returns the constrained-to-unconstrained bijector for distribution `d`. bijector(d::DiscreteUnivariateDistribution) = Identity{0}() bijector(d::DiscreteMultivariateDistribution) = Identity{1}() bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) -bijector(d::Product{Discrete}) = Identity{1}() +bijector(d::Product{Discrete}) = Identity() function bijector(d::Product{Continuous}) - return TruncatedBijector{1}(_minmax(d.v)...) + return TruncatedBijector(_minmax(d.v)...) end @generated function _minmax(d::AbstractArray{T}) where {T} try @@ -51,11 +51,11 @@ end end end -bijector(d::Normal) = Identity{0}() -bijector(d::Distributions.AbstractMvNormal) = Identity{1}() -bijector(d::Distributions.AbstractMvLogNormal) = Log{1}() -bijector(d::PositiveDistribution) = Log{0}() -bijector(d::SimplexDistribution) = SimplexBijector{1}() +bijector(d::Normal) = Identity() +bijector(d::Distributions.AbstractMvNormal) = Identity() +bijector(d::Distributions.AbstractMvLogNormal) = Log() +bijector(d::PositiveDistribution) = Log() +bijector(d::SimplexDistribution) = SimplexBijector() bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d))) bijector_bounded(d, a=minimum(d), b=maximum(d)) = Logit(a, b) From ff6b756a1c3215c9a8dbf08ccc9462c4f733d4c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 12:07:33 +0100 Subject: [PATCH 31/45] forgot to wrap some in Batch --- src/bijectors/scale.jl | 2 +- src/bijectors/shift.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index f3f15a15..ad6df463 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -19,7 +19,7 @@ function logabsdetjac(b::Scale, x::AbstractArray{<:Real, N}) where {N} end function logabsdetjac_batch(b::Scale, x::ArrayBatch{N}) where {N} - return Batch(_logabsdetjac_scale(b, value(x), Val(N - 1))) + return Batch(_logabsdetjac_scale(b.a, value(x), Val(N - 1))) end # Scalar: single input. diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 2773de0d..6bd0fcd0 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -19,7 +19,7 @@ function logabsdetjac(b::Shift, x::AbstractArray{<:Real, N}) where {N} end function logabsdetjac_batch(b::Shift, x::AbstractArray{<:Real, N}) where {N} - return _logabsdetjac_shift(b.a, x, Val(N - 1)) + return Batch(_logabsdetjac_shift(b.a, x, Val(N - 1))) end _logabsdetjac_shift(a::Real, x::Real, ::Val{0}) = zero(eltype(x)) From 989aaa87b0144ce588e15a1d12ac0141ee5f3526 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 05:32:07 +0100 Subject: [PATCH 32/45] allow inverses of non-bijectors --- src/interface.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d9ed8ba9..c6617a59 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -79,6 +79,7 @@ Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() Base.:+(::Invertible, ::Invertible) = Invertible() invertible(::Transform) = NotInvertible() +isinvertible(t::Transform) = invertible(t) isa Invertible """ inv(t::Transform[, ::Invertible]) @@ -160,21 +161,31 @@ requires an iterative procedure to evaluate. isclosedform(b::Bijector) = true """ - inv(b::Bijector) - Inverse(b::Bijector) + inv(b::Transform) + Inverse(b::Transform) -A `Bijector` representing the inverse transform of `b`. +A `Transform` representing the inverse transform of `b`. """ -struct Inverse{B<:Bijector} <: Bijector - orig::B +struct Inverse{T<:Transform} <: Transform + orig::T + + function Inverse(orig::Transform) + if !isinvertible(orig) + error("$(orig) is not invertible") + end + + return new{typeof(orig)}(orig) + end end -# field contains nested numerical parameters Functors.@functor Inverse -inv(b::Bijector) = Inverse(b) -inv(ib::Inverse{<:Bijector}) = ib.orig -Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig +inv(b, ::Invertible) = Inverse(b) +inv(ib::Inverse) = ib.orig + +invertible(ib::Inverse) = Invertible() + +Base.:(==)(b1::Inverse, b2::Inverse) = b1.orig == b2.orig logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) From 4d2388263d65a82f011fe411ce6e7884f82a8a9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 05:32:19 +0100 Subject: [PATCH 33/45] relax definition of VectorBatch so Vector{<:Real} is covered --- src/batch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/batch.jl b/src/batch.jl index 4c5a512b..1e1f7700 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -17,7 +17,7 @@ value(x::Batch) = x.value # Convenient aliases const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}} -const VectorBatch = Batch{<:AbstractVector{<:AbstractArray{<:Real}}} +const VectorBatch = Batch{<:AbstractVector} # Constructor for `ArrayBatch`. Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x) From 0f9d334f352a0e77b7a31c2b8ba0f2a8f2f9644d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 05:40:57 +0100 Subject: [PATCH 34/45] just perform invertibility check in Inverse rather than inv --- src/bijectors/composed.jl | 2 +- src/interface.jl | 74 +++++++++++++++++++-------------------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 3a0e5a9e..1fbaaed0 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -146,7 +146,7 @@ end ∘(::Identity, b::Transform) = b ∘(b::Transform, ::Identity) = b -inv(ct::Composed, ::Invertible) = Composed(reverse(map(inv, ct.ts))) +Base.inv(ct::Composed) = Composed(reverse(map(inv, ct.ts))) # TODO: should arrays also be using recursive implementation instead? function transform(cb::Composed, x) diff --git a/src/interface.jl b/src/interface.jl index c6617a59..330a28f6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -68,27 +68,6 @@ abstract type Transform end Broadcast.broadcastable(b::Transform) = Ref(b) -# Invertibility "trait". -struct NotInvertible end -struct Invertible end - -# Useful for checking if compositions, etc. are invertible or not. -Base.:+(::NotInvertible, ::Invertible) = NotInvertible() -Base.:+(::Invertible, ::NotInvertible) = NotInvertible() -Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() -Base.:+(::Invertible, ::Invertible) = Invertible() - -invertible(::Transform) = NotInvertible() -isinvertible(t::Transform) = invertible(t) isa Invertible - -""" - inv(t::Transform[, ::Invertible]) - -Returns the inverse of transform `t`. -""" -Base.inv(t::Transform) = Base.inv(t, invertible(t)) -Base.inv(t::Transform, ::NotInvertible) = error("$(t) is not invertible") - """ transform(b, x) @@ -142,23 +121,18 @@ function forward!(b, x, out) return out end -"Abstract type of a bijector, i.e. differentiable bijection with differentiable inverse." -abstract type Bijector <: Transform end - -invertible(::Bijector) = Invertible() - -""" - isclosedform(b::Bijector)::bool - isclosedform(b⁻¹::Inverse{<:Bijector})::bool +# Invertibility "trait". +struct NotInvertible end +struct Invertible end -Returns `true` or `false` depending on whether or not evaluation of `b` -has a closed-form implementation. +# Useful for checking if compositions, etc. are invertible or not. +Base.:+(::NotInvertible, ::Invertible) = NotInvertible() +Base.:+(::Invertible, ::NotInvertible) = NotInvertible() +Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() +Base.:+(::Invertible, ::Invertible) = Invertible() -Most bijectors have closed-form evaluations, but there are cases where -this is not the case. For example the *inverse* evaluation of `PlanarLayer` -requires an iterative procedure to evaluate. -""" -isclosedform(b::Bijector) = true +invertible(::Transform) = NotInvertible() +isinvertible(t::Transform) = invertible(t) isa Invertible """ inv(b::Transform) @@ -180,13 +154,37 @@ end Functors.@functor Inverse -inv(b, ::Invertible) = Inverse(b) -inv(ib::Inverse) = ib.orig +""" + inv(t::Transform[, ::Invertible]) + +Returns the inverse of transform `t`. +""" +Base.inv(t::Transform) = Inverse(t) +Base.inv(ib::Inverse) = ib.orig invertible(ib::Inverse) = Invertible() Base.:(==)(b1::Inverse, b2::Inverse) = b1.orig == b2.orig +"Abstract type of a bijector, i.e. differentiable bijection with differentiable inverse." +abstract type Bijector <: Transform end + +invertible(::Bijector) = Invertible() + +""" + isclosedform(b::Bijector)::bool + isclosedform(b⁻¹::Inverse{<:Bijector})::bool + +Returns `true` or `false` depending on whether or not evaluation of `b` +has a closed-form implementation. + +Most bijectors have closed-form evaluations, but there are cases where +this is not the case. For example the *inverse* evaluation of `PlanarLayer` +requires an iterative procedure to evaluate. +""" +isclosedform(b::Bijector) = true + +# Default implementation for inverse of a `Bijector`. logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) """ From af0b24b18f1f0969cff939a604059bdae67fa79a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 05:42:48 +0100 Subject: [PATCH 35/45] moved some code arround --- src/interface.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 330a28f6..3a985868 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -121,6 +121,19 @@ function forward!(b, x, out) return out end +""" + isclosedform(b::Transform)::bool + isclosedform(b⁻¹::Inverse{<:Transform})::bool + +Returns `true` or `false` depending on whether or not evaluation of `b` +has a closed-form implementation. + +Most transformations have closed-form evaluations, but there are cases where +this is not the case. For example the *inverse* evaluation of `PlanarLayer` +requires an iterative procedure to evaluate. +""" +isclosedform(b::Transform) = true + # Invertibility "trait". struct NotInvertible end struct Invertible end @@ -171,19 +184,6 @@ abstract type Bijector <: Transform end invertible(::Bijector) = Invertible() -""" - isclosedform(b::Bijector)::bool - isclosedform(b⁻¹::Inverse{<:Bijector})::bool - -Returns `true` or `false` depending on whether or not evaluation of `b` -has a closed-form implementation. - -Most bijectors have closed-form evaluations, but there are cases where -this is not the case. For example the *inverse* evaluation of `PlanarLayer` -requires an iterative procedure to evaluate. -""" -isclosedform(b::Bijector) = true - # Default implementation for inverse of a `Bijector`. logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) From 0777fab7afba473181a35f4ede367ac17f310b22 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 06:00:09 +0100 Subject: [PATCH 36/45] added docstrings and default impls for mutating batched methods --- src/interface.jl | 56 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 3a985868..011ab2c6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -238,6 +238,27 @@ function _transform_batch(b, xs::ArrayBatch{N}) where {N} return reconstruct(xs, res) end +""" + transform_batch!(b, xs, ys) + +Transform `xs` by `b` treating `xs` as a "batch", i.e. a collection of independent inputs, +and storing the result in `ys`. + +See also: [`transform!`](@ref) +""" +transform_batch!(b, xs, ys) = _transform_batch!(b, xs, ys) +function _transform_batch!(b, xs, ys) + for i = 1:length(xs) + if eltype(ys) <: Real + ys[i] = transform(b, xs[i]) + else + transform!(b, xs[i], ys[i]) + end + end + + return ys +end + """ logabsdetjac_batch(b, xs) @@ -252,6 +273,27 @@ function _logabsdetjac_batch(b, xs::ArrayBatch{N}) where {N} return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), dims=N))) end +""" + logabsdetjac_batch!(b, xs, logjacs) + +Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs, +accumulating the result in `logjacs`. + +See also: [`logabsdetjac!`](@ref) +""" +logabsdetjac_batch!(b, xs, logjacs) = _logabsdetjac_batch!(b, xs, logjacs) +function _logabsdetjac_batch!(b, xs, logjacs) + for i = 1:length(xs) + if eltype(logjacs) <: Real + logjacs[i] += logabsdetjac(b, xs[i]) + else + logabsdetjac!(b, xs[i], logjacs[i]) + end + end + + return logjacs +end + """ forward_batch(b, xs) @@ -261,6 +303,20 @@ See also: [`transform`](@ref) """ forward_batch(b, xs) = (result = transform_batch(b, xs), logabsdetjac = logabsdetjac_batch(b, xs)) +""" + forward_batch!(b, xs, out) + +Computes `forward(b, xs)` in place, treating `xs` as a "batch", i.e. a collection of independent inputs. + +See also: [`forward!`](@ref) +""" +function forward_batch!(b, xs, out) + transform_batch!(b, xs, out.result) + logabsdetjac_batch!(b, xs, out.logabsdetjac) + + return out +end + ###################### # Bijectors includes # ###################### From 42839c373a9f0f57b2de23762cf2e3a5376c0c3d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:29:02 +0100 Subject: [PATCH 37/45] add elementype to VectorBatch --- src/batch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/batch.jl b/src/batch.jl index 1e1f7700..0089618a 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -17,7 +17,7 @@ value(x::Batch) = x.value # Convenient aliases const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}} -const VectorBatch = Batch{<:AbstractVector} +const VectorBatch{T} = Batch{<:AbstractVector{T}} # Constructor for `ArrayBatch`. Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x) From 2ce74d4e45570b534dae8422e9d73e4f540336ac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:29:28 +0100 Subject: [PATCH 38/45] simplify Shift bijector --- src/bijectors/shift.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 6bd0fcd0..85050e5c 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -14,17 +14,13 @@ transform(b::Shift, x) = b.a .+ x inv(b::Shift) = Shift(-b.a) # FIXME: implement custom adjoint to ensure we don't get tracking -function logabsdetjac(b::Shift, x::AbstractArray{<:Real, N}) where {N} - return _logabsdetjac_shift(b.a, x, Val(N)) +function logabsdetjac(b::Shift, x::Union{Real, AbstractArray{<:Real}}) + return _logabsdetjac_shift(b.a, x) end -function logabsdetjac_batch(b::Shift, x::AbstractArray{<:Real, N}) where {N} - return Batch(_logabsdetjac_shift(b.a, x, Val(N - 1))) +function logabsdetjac_batch(b::Shift, x::ArrayBatch) + return Batch(_logabsdetjac_shift_array_batch(b.a, value(x))) end -_logabsdetjac_shift(a::Real, x::Real, ::Val{0}) = zero(eltype(x)) -_logabsdetjac_shift(a::Real, x::AbstractVector{T}, ::Val{0}) where {T<:Real} = zeros(T, length(x)) -_logabsdetjac_shift(a::T1, x::AbstractVector{T2}, ::Val{1}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2) -_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{1}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x, 2)) -_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{2}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2) -_logabsdetjac_shift(a::T1, x::AbstractArray{<:AbstractMatrix{T2}}, ::Val{2}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x)) +_logabsdetjac_shift(a, x) = zero(eltype(x)) +_logabsdetjac_shift_array_batch(a, x) = zeros(eltype(x), size(x, ndims(x))) From 926ef2716dce342f5a73c20d44d14f51727e6815 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:30:06 +0100 Subject: [PATCH 39/45] added rrules for logabsdetjac_shift --- src/chainrules.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index b7f2b51a..7ed33599 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -6,3 +6,11 @@ ChainRulesCore.@scalar_rule( ), (x, - tanh(Ω + b) * x, x - 1), ) + +function ChainRulesCore.rrule(::typeof(_logabsdetjac_shift), a, x) + return _logabsdetjac_shift(a, x), Δ -> (ChainRulesCore.NO_FIELDS, ChainRulesCore.ZeroTangent(), ChainRulesCore.ZeroTangent()) +end + +function ChainRulesCore.rrule(::typeof(_logabsdetjac_shift_array_batch), a, x) + return _logabsdetjac_shift_array_batch(a, x), Δ -> (ChainRulesCore.NO_FIELDS, ChainRulesCore.ZeroTangent(), ChainRulesCore.ZeroTangent()) +end From c3745d76229e81f70cfa029ce850fdd641c24eec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:30:36 +0100 Subject: [PATCH 40/45] use type-stable implementation of eachslice --- src/interface.jl | 7 +++++-- src/utils.jl | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 011ab2c6..7e1cb171 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -232,7 +232,7 @@ function _transform_batch(b, xs::ArrayBatch{2}) return Batch(eachcolmaphcat(b, value(xs))) end function _transform_batch(b, xs::ArrayBatch{N}) where {N} - res = reduce(map(b, eachslice(value(xs), dims=N))) do acc, x + res = reduce(map(b, eachslice(value(xs), Val{N}()))) do acc, x cat(acc, x; dims = N) end return reconstruct(xs, res) @@ -269,8 +269,11 @@ See also: [`logabsdetjac`](@ref) logabsdetjac_batch(b, xs) = _logabsdetjac_batch(b, xs) # Default implementations uses private methods to avoid method ambiguity. _logabsdetjac_batch(b, xs::VectorBatch) = reconstruct(xs, map(x -> logabsdetjac(b, x), value(xs))) +function _logabsdetjac_batch(b, xs::ArrayBatch{2}) + return reconstruct(xs, map(x -> logabsdetjac(b, x), eachcol(value(xs)))) +end function _logabsdetjac_batch(b, xs::ArrayBatch{N}) where {N} - return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), dims=N))) + return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), Val{N}()))) end """ diff --git a/src/utils.jl b/src/utils.jl index d1034dbc..dcb99eb7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,3 +9,17 @@ _vec(x::Real) = x # Useful for reconstructing objects. reconstruct(b, args...) = constructorof(typeof(b))(args...) + +# Despite kwargs using `NamedTuple` in Julia 1.6, I'm still running +# into type-instability issues when using `eachslice`. So we define our own. +# https://github.com/JuliaLang/julia/issues/39639 +# TODO: Check if this is the case in Julia 1.7. +# Adapted from https://github.com/JuliaLang/julia/blob/ef673537f8622fcaea92ac85e07962adcc17745b/base/abstractarraymath.jl#L506-L513. +# FIXME: It seems to break Zygote though. Which is weird because normal `eachslice` does not. +@inline function Base.eachslice(A::AbstractArray, ::Val{N}) where {N} + dim = N + dim <= ndims(A) || throw(DimensionMismatch("A doesn't have $dim dimensions")) + inds_before = ntuple(_ -> Colon(), dim-1) + inds_after = ntuple(_ -> Colon(), ndims(A)-dim) + return (view(A, inds_before..., i, inds_after...) for i in axes(A, dim)) +end From 16009860be18e7dfddea1c47a74fd2fa8f67f475 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:30:55 +0100 Subject: [PATCH 41/45] initial work on adding proper testing --- test/ad/utils.jl | 12 ++- test/bijectors/utils.jl | 210 ++++++++++++++++++++++++---------------- 2 files changed, 137 insertions(+), 85 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 6c9bf4a4..e87ba9b4 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -136,14 +136,19 @@ function test_ad(dist::DistSpec; kwargs...) end end -function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) +function test_ad(_f, _x, broken = (); rtol = 1e-6, atol = 1e-6) + f = _x isa Real ? _f ∘ first : _f + x = [_x;] + finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] if AD == "All" || AD == "Tracker" if :Tracker in broken @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol else - @test Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol + ∇tracker = Tracker.gradient(f, x)[1] + @test Tracker.data(∇tracker) ≈ finitediff rtol=rtol atol=atol + @test Tracker.istracked(∇tracker) end end @@ -159,7 +164,8 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) if :Zygote in broken @test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol else - @test Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol + ∇zygote = Zygote.gradient(f, x)[1] + @test (all(finitediff .== 0) && ∇zygote === nothing) || isapprox(∇zygote, finitediff, rtol=rtol, atol=atol) end end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index a0fdb6f2..539408d4 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -1,8 +1,11 @@ -function test_bijector_reals( - b::Bijector{0}, - x_true::Real, - y_true::Real, - logjac_true::Real; +import Bijectors: AbstractBatch + + +function test_bijector_single( + b::Bijector, + x_true, + y_true, + logjac_true; isequal = true, tol = 1e-6 ) @@ -17,67 +20,66 @@ function test_bijector_reals( ires = isequal ? @inferred(forward(inv(b), y_true)) : @inferred(forward(inv(b), y)) # Always want the following to hold - @test ires.rv ≈ x_true atol=tol + @test ires.result ≈ x_true atol=tol @test ires.logabsdetjac ≈ -logjac atol=tol if isequal @test y ≈ y_true atol=tol # forward @test (@inferred ib(y_true)) ≈ x_true atol=tol # inverse @test logjac ≈ logjac_true # logjac forward - @test res.rv ≈ y_true atol=tol # forward using `forward` + @test res.result ≈ y_true atol=tol # forward using `forward` @test res.logabsdetjac ≈ logjac_true atol=tol # logjac using `forward` else @test y ≠ y_true # forward @test (@inferred ib(y)) ≈ x_true atol=tol # inverse @test logjac ≠ logjac_true # logjac forward - @test res.rv ≠ y_true # forward using `forward` + @test res.result ≠ y_true # forward using `forward` @test res.logabsdetjac ≠ logjac_true # logjac using `forward` end end -function test_bijector_arrays( +function test_bijector_batch( b::Bijector, - xs_true::AbstractArray{<:Real}, - ys_true::AbstractArray{<:Real}, - logjacs_true::Union{Real, AbstractArray{<:Real}}; + xs_true::AbstractBatch, + ys_true::AbstractBatch, + logjacs_true; isequal = true, tol = 1e-6 -) + ib = @inferred inv(b) - ys = @inferred b(xs_true) - logjacs = @inferred logabsdetjac(b, xs_true) - res = @inferred forward(b, xs_true) + ys = @inferred broadcast(b, xs_true) + logjacs = @inferred broadcast(logabsdetjac, b, xs_true) + res = @inferred broadcast(forward, b, xs_true) # If `isequal` is false, then we use the computed `y`, # but if it's true, we use the true `y`. - ires = isequal ? @inferred(forward(inv(b), ys_true)) : @inferred(forward(inv(b), ys)) + ires = isequal ? @inferred(broadcast(forward, inv(b), ys_true)) : @inferred(broadcast(forward, inv(b), ys)) # always want the following to hold - @test ys isa typeof(ys_true) - @test logjacs isa typeof(logjacs_true) - @test mean(abs, ires.rv - xs_true) ≤ tol - @test mean(abs, ires.logabsdetjac + logjacs) ≤ tol + @test ys isa AbstractBatch + @test logjacs isa AbstractBatch + @test mean(norm, ires.result - xs_true) ≤ tol + @test mean(norm, ires.logabsdetjac + logjacs) ≤ tol if isequal - @test mean(abs, ys - ys_true) ≤ tol # forward - @test mean(abs, (ib(ys_true)) - xs_true) ≤ tol # inverse - @test mean(abs, logjacs - logjacs_true) ≤ tol # logjac forward - @test mean(abs, res.rv - ys_true) ≤ tol # forward using `forward` + @test mean(norm, ys - ys_true) ≤ tol # forward + @test mean(norm, broadcast(ib, ys_true) - xs_true) ≤ tol # inverse + @test mean(abs, logjacs .- logjacs_true) ≤ tol # logjac forward + @test mean(norm, res.result - ys_true) ≤ tol # forward using `forward` @test mean(abs, res.logabsdetjac - logjacs_true) ≤ tol # logjac `forward` @test mean(abs, ires.logabsdetjac + logjacs_true) ≤ tol # inverse logjac `forward` else # Don't want the following to be equal to their "true" values - @test mean(abs, ys - ys_true) > tol # forward + @test mean(norm, ys - ys_true) > tol # forward @test mean(abs, logjacs - logjacs_true) > tol # logjac forward - @test mean(abs, res.rv - ys_true) > tol # forward using `forward` + @test mean(norm, res.result - ys_true) > tol # forward using `forward` # Still want the following to be equal to the COMPUTED values - @test mean(abs, ib(ys) - xs_true) ≤ tol # inverse + @test mean(norm, broadcast(ib, ys) - xs_true) ≤ tol # inverse @test mean(abs, res.logabsdetjac - logjacs) ≤ tol # logjac forward using `forward` end end """ - test_bijector(b::Bijector, xs::Array; kwargs...) test_bijector(b::Bijector, xs::Array, ys::Array, logjacs::Array; kwargs...) Tests the bijector `b` on the inputs `xs` against the, optionally, provided `ys` @@ -103,82 +105,110 @@ treated as "counter-examples", i.e. values NOT to match. - `tol = 1e-6`: the absolute tolerance used for the checks. This is also used to check arrays where we check that the L1-norm is sufficiently small. """ -function test_bijector(b::Bijector{0}, xs::AbstractVector{<:Real}) - return test_bijector(b, xs, zeros(length(xs)), zeros(length(xs)); isequal = false) -end - -function test_bijector(b::Bijector{1}, xs::AbstractMatrix{<:Real}) - return test_bijector(b, xs, zeros(size(xs)), zeros(size(xs, 2)); isequal = false) -end - function test_bijector( - b::Bijector{0}, - xs_true::AbstractVector{<:Real}, - ys_true::AbstractVector{<:Real}, - logjacs_true::AbstractVector{<:Real}; + b::Bijector, + xs_true::AbstractBatch, + ys_true::AbstractBatch, + logjacs_true::AbstractBatch; kwargs... ) - ib = inv(b) - - # Batch - test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) - # Test `logabsdetjac` against jacobians test_logabsdetjac(b, xs_true) - test_logabsdetjac(b, ys_true) + + if Bijectors.isinvertible(b) + ib = inv(b) + test_logabsdetjac(ib, ys_true) + end for (x_true, y_true, logjac_true) in zip(xs_true, ys_true, logjacs_true) - test_bijector_reals(b, x_true, y_true, logjac_true; kwargs...) + # Test validity of single input. + test_bijector_single(b, x_true, y_true, logjac_true; kwargs...) - # Test AD - test_ad(x -> b(first(x)), [x_true, ]) + # Test AD wrt. inputs. + test_ad(x -> sum(b(x)), x_true) + test_ad(x -> logabsdetjac(b, x), x_true) + + if Bijectors.isinvertible(b) + y = b(x_true) + test_ad(x -> sum(ib(x)), y) + end + end - y = b(x_true) - test_ad(x -> ib(first(x)), [y, ]) + # Test AD wrt. parameters. + test_bijector_parameter_gradient(b, xs_true[1], ys_true[1]) - test_ad(x -> logabsdetjac(b, first(x)), [x_true, ]) + # Test validity of collection of inputs. + test_bijector_batch(b, xs_true, ys_true, logjacs_true; kwargs...) + + # AD testing for batch. + f, arg = make_gradient_function(x -> sum(sum(b.(x))), xs_true) + test_ad(f, arg) + f, arg = make_gradient_function(x -> sum(logabsdetjac.(b, x)), xs_true) + test_ad(f, arg) + + if Bijectors.isinvertible(b) + ys = b.(xs_true) + f, arg = make_gradient_function(y -> sum(sum(ib.(y))), ys) + test_ad(f, arg) end end +function make_gradient_function(f, xs::ArrayBatch) + s = size(Bijectors.value(xs)) -function test_bijector( - b::Bijector{1}, - xs_true::AbstractMatrix{<:Real}, - ys_true::AbstractMatrix{<:Real}, - logjacs_true::AbstractVector{<:Real}; - kwargs... -) - ib = inv(b) + function g(x) + x_batch = Bijectors.reconstruct(xs, reshape(x, s)) + return f(x_batch) + end - # Batch - test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) + return g, vec(Bijectors.value(xs)) +end - # Test `logabsdetjac` against jacobians - test_logabsdetjac(b, xs_true) - test_logabsdetjac(b, ys_true) - - for (x_true, y_true, logjac_true) in zip(eachcol(xs_true), eachcol(ys_true), logjacs_true) - # HACK: collect to avoid dealing with sub-arrays and thus allowing us to compare the - # type of the computed output to the "true" output. - test_bijector_arrays(b, collect(x_true), collect(y_true), logjac_true; kwargs...) +function make_gradient_function(f, xs::VectorBatch{<:AbstractArray{<:Real}}) + xs_new = vcat(map(vec, Bijectors.value(xs))) + n = length(xs_new) - # Test AD - test_ad(x -> sum(b(x)), collect(x_true)) - y = b(x_true) - test_ad(x -> sum(ib(x)), y) + s = size(Bijectors.value(xs[1])) + stride = n ÷ length(xs) - test_ad(x -> logabsdetjac(b, x), x_true) + function g(x) + x_vec = map(1:stride:n) do i + reshape(x[i:i + stride - 1], s) + end + + x_batch = Bijectors.reconstruct(xs, x_vec) + return f(x_batch) end + + return g, xs_new end -function test_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix; tol=1e-6) - logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)] - @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol +make_jacobian_function(f, xs::AbstractVector) = f, xs +function make_jacobian_function(f, xs::AbstractArray) + xs_new = vec(xs) + s = size(xs) + + function g(x) + return vec(f(reshape(x, s))) + end + + return g, xs_new end -function test_logabsdetjac(b::Bijector{0}, xs::AbstractVector; tol=1e-6) - logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs] - @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol +function test_logabsdetjac(b::Transform, xs::Batch{<:Any, <:AbstractArray}; tol=1e-6) + f, _ = make_jacobian_function(b, xs[1]) + logjac_ad = map(xs) do x + first(logabsdet(ForwardDiff.jacobian(f, x))) + end + + @test mean(collect(logabsdetjac.(b, xs)) - logjac_ad) ≤ tol +end + +function test_logabsdetjac(b::Transform, xs::Batch{<:Any, <:Real}; tol=1e-6) + logjac_ad = map(xs) do x + log(abs(ForwardDiff.derivative(b, x))) + end + @test mean(collect(logabsdetjac.(b, xs)) - logjac_ad) ≤ tol end # Check if `Functors.functor` works properly @@ -187,3 +217,19 @@ function test_functor(x, xs) @test x == re(_xs) @test _xs == xs end + +function test_bijector_parameter_gradient(b::Transform, x, y = b(x)) + args, re = Functors.functor(b) + recon(k, param) = re(merge(args, NamedTuple{(k, )}((param, )))) + + # Compute the gradient wrt. one argument at the time. + for (k, v) in pairs(args) + test_ad(p -> sum(transform(recon(k, p), x)), v) + test_ad(p -> logabsdetjac(recon(k, p), x), v) + + if Bijectors.isinvertible(b) + test_ad(p -> sum(transform(inv(recon(k, p)), y)), v) + test_ad(p -> logabsdetjac(inv(recon(k, p)), y), v) + end + end +end From 2f4d32822a9b0203c7e8e851b389d336e9344b31 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:31:18 +0100 Subject: [PATCH 42/45] make Batch compatible with Zygote --- src/compat/zygote.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index b16ef58b..c971597b 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -291,3 +291,20 @@ end end +# Otherwise Zygote complains. +Zygote.@adjoint function Batch(x) + return Batch(x), function(_Δ) + # Sometimes `Batch` has been extraced using `value`, in which case + # we get back a `NamedTuple`. + # Other times the value was extracted by iterating over `Batch`, + # in which case we don't get a `NamedTuple`. + Δ = _Δ isa NamedTuple{(:value, )} ? _Δ.value : _Δ + return if Δ isa AbstractArray{<:Real} + (Δ, ) + else + Δ_new = similar(x, eltype(Δ[1])) + Δ_new[:] .= vcat(Δ...) + (Δ_new, ) + end + end +end From e5439f586becbd6973f44e94ff5d703166771f4e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 06:43:27 +0100 Subject: [PATCH 43/45] updated OrderedBijector --- src/bijectors/ordered.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl index d1bfd8f0..ec2b3ed8 100644 --- a/src/bijectors/ordered.jl +++ b/src/bijectors/ordered.jl @@ -7,7 +7,7 @@ A bijector mapping ordered vectors in ℝᵈ to unordered vectors in ℝᵈ. - [Stan's documentation](https://mc-stan.org/docs/2_27/reference-manual/ordered-vector.html) - Note that this transformation and its inverse are the _opposite_ of in this reference. """ -struct OrderedBijector <: Bijector{1} end +struct OrderedBijector <: Bijector end """ ordered(d::Distribution) @@ -16,7 +16,7 @@ Return a `Distribution` whose support are ordered vectors, i.e., vectors with in """ ordered(d::ContinuousMultivariateDistribution) = Bijectors.transformed(d, OrderedBijector()) -(b::OrderedBijector)(y::AbstractVecOrMat) = _transform_ordered(y) +transform(b::OrderedBijector, y::AbstractVecOrMat) = _transform_ordered(y) function _transform_ordered(y::AbstractVector) x = similar(y) @@ -45,8 +45,7 @@ function _transform_ordered(y::AbstractMatrix) return x end -(ib::Inverse{<:OrderedBijector})(x::AbstractVecOrMat) = _transform_inverse_ordered(x) - +transform(ib::Inverse{<:OrderedBijector}, x::AbstractVecOrMat) = _transform_inverse_ordered(x) function _transform_inverse_ordered(x::AbstractVector) y = similar(x) @assert !isempty(y) From 5681358caf4250534507e6ed3e94528b7dd5e41b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 24 Jan 2022 18:12:34 +0000 Subject: [PATCH 44/45] temporary stuff --- Project.toml | 1 + src/Bijectors.jl | 56 +++-- src/bijectors/composed.jl | 307 +++-------------------- src/bijectors/exp_log.jl | 91 +++++-- src/bijectors/logit.jl | 24 +- src/interface.jl | 497 +++++++++++++++++++++++++------------- 6 files changed, 491 insertions(+), 485 deletions(-) diff --git a/Project.toml b/Project.toml index 4d25f7cb..90b3abac 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.10.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +Batching = "a4738bf4-2ff1-4f03-87a2-28fa6f9d5d14" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 45151107..5ba75b9c 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -36,6 +36,9 @@ using ConstructionBase using Base.Iterators: drop using LinearAlgebra: AbstractTriangular +using Batching +export batch + import ChangesOfVariables: with_logabsdet_jacobian import InverseFunctions: inverse @@ -55,10 +58,13 @@ export TransformDistribution, logpdf_with_trans, isclosedform, transform, + transform!, with_logabsdet_jacobian, inverse, forward, + forward!, logabsdetjac, + logabsdetjac!, logabsdetjacinv, Bijector, ADBijector, @@ -77,7 +83,9 @@ export TransformDistribution, PlanarLayer, RadialLayer, CouplingLayer, - InvertibleBatchNorm + InvertibleBatchNorm, + Elementwise, + elementwise if VERSION < v"1.1" using Compat: eachcol @@ -252,13 +260,13 @@ function getlogp(d::InverseWishart, Xcf, X) end include("utils.jl") -include("batch.jl") +# include("batch.jl") include("interface.jl") -include("chainrules.jl") +# include("chainrules.jl") -Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) +Base.@deprecate forward(b::Transform, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) -@noinline function Base.inv(b::AbstractBijector) +@noinline function Base.inv(b::Transform) Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv) inverse(b) end @@ -267,24 +275,24 @@ end maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) -# optional dependencies -function __init__() - @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin - function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x...)) - end - function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x2, x...)) - end - function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x2, x3, x...)) - end - end - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl") - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl") - @require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl") -end +# # optional dependencies +# function __init__() +# @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin +# function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) +# return copy(f.(x1, x...)) +# end +# function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) +# return copy(f.(x1, x2, x...)) +# end +# function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) +# return copy(f.(x1, x2, x3, x...)) +# end +# end +# @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl") +# @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") +# @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") +# @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl") +# @require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl") +# end end # module diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index dab9e519..a3c07158 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -1,290 +1,59 @@ -############### -# Composition # -############### +invertible(cb::ComposedFunction) = invertible(cb.inner) + invertible(cb.outer) +isclosedform(cb::ComposedFunction) = isclosedform(cb.inner) && isclosedform(cb.outer) -""" - Composed(ts::A) +transform_single(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) +transform_multiple(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) - ∘(b1::Bijector{N}, b2::Bijector{N})::Composed{<:Tuple} - composel(ts::Bijector{N}...)::Composed{<:Tuple} - composer(ts::Bijector{N}...)::Composed{<:Tuple} - -where `A` refers to either -- `Tuple{Vararg{<:Bijector{N}}}`: a tuple of bijectors of dimensionality `N` -- `AbstractArray{<:Bijector{N}}`: an array of bijectors of dimensionality `N` - -A `Bijector` representing composition of bijectors. `composel` and `composer` results in a -`Composed` for which application occurs from left-to-right and right-to-left, respectively. - -Note that all the alternative ways of constructing a `Composed` returns a `Tuple` of bijectors. -This ensures type-stability of implementations of all relating methdos, e.g. `inverse`. - -If you want to use an `Array` as the container instead you can do - - Composed([b1, b2, ...]) - -In general this is not advised since you lose type-stability, but there might be cases -where this is desired, e.g. if you have a insanely large number of bijectors to compose. - -# Examples -## Simple example -Let's consider a simple example of `Exp`: -```julia-repl -julia> using Bijectors: Exp - -julia> b = Exp() -Exp{0}() - -julia> b ∘ b -Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}())) - -julia> (b ∘ b)(1.0) == exp(exp(1.0)) # evaluation -true - -julia> inverse(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion -true - -julia> logabsdetjac(b ∘ b, 1.0) # determinant of jacobian -3.718281828459045 -``` - -# Notes -## Order -It's important to note that `∘` does what is expected mathematically, which means that the -bijectors are applied to the input right-to-left, e.g. first applying `b2` and then `b1`: -```julia -(b1 ∘ b2)(x) == b1(b2(x)) # => true -``` -But in the `Composed` struct itself, we store the bijectors left-to-right, so that -```julia -cb1 = b1 ∘ b2 # => Composed.ts == (b2, b1) -cb2 = composel(b2, b1) # => Composed.ts == (b2, b1) -cb1(x) == cb2(x) == b1(b2(x)) # => true -``` - -## Structure -`∘` will result in "flatten" the composition structure while `composel` and -`composer` preserve the compositional structure. This is most easily seen by an example: -```julia-repl -julia> b = Exp() -Exp{0}() - -julia> cb1 = b ∘ b; cb2 = b ∘ b; - -julia> (cb1 ∘ cb2).ts # <= different -(Exp{0}(), Exp{0}(), Exp{0}(), Exp{0}()) - -julia> (cb1 ∘ cb2).ts isa NTuple{4, Exp{0}} -true - -julia> Bijectors.composer(cb1, cb2).ts -(Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}())), Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}()))) - -julia> Bijectors.composer(cb1, cb2).ts isa Tuple{Composed, Composed} -true -``` - -""" -struct Composed{A} <: Transform - ts::A -end - -# field contains nested numerical parameters -Functors.@functor Composed - -invertible(cb::Composed) = sum(map(invertible, cb.ts)) - -isclosedform(b::Composed) = all(isclosedform, b.ts) - -function Base.:(==)(b1::Composed, b2::Composed) - ts1, ts2 = b1.ts, b2.ts - return length(ts1) == length(ts2) && all(x == y for (x, y) in zip(ts1, ts2)) -end - -""" - composel(ts::Transform...)::Composed{<:Tuple} - -Constructs `Composed` such that `ts` are applied left-to-right. -""" -composel(ts::Transform...) = Composed(ts) - -""" - composer(ts::Transform...)::Composed{<:Tuple} - -Constructs `Composed` such that `ts` are applied right-to-left. -""" -composer(ts::Transform...) = Composed(reverse(ts)) - -# The transformation of `Composed` applies functions left-to-right -# but in mathematics we usually go from right-to-left; this reversal ensures that -# when we use the mathematical composition ∘ we get the expected behavior. -# TODO: change behavior of `transform` of `Composed`? -∘(b1::Transform, b2::Transform) = composel(b2, b1) - -# type-stable composition rules -∘(b1::Composed{<:Tuple}, b2::Transform) = composel(b2, b1.ts...) -∘(b1::Transform, b2::Composed{<:Tuple}) = composel(b2.ts..., b1) -∘(b1::Composed{<:Tuple}, b2::Composed{<:Tuple}) = composel(b2.ts..., b1.ts...) - -# type-unstable composition rules -∘(b1::Composed{<:AbstractArray}, b2::Transform) = Composed(pushfirst!(copy(b1.ts), b2)) -∘(b1::Transform, b2::Composed{<:AbstractArray}) = Composed(push!(copy(b2.ts), b1)) -function ∘(b1::Composed{<:AbstractArray}, b2::Composed{<:AbstractArray}) - return Composed(append!(copy(b2.ts), copy(b1.ts))) -end - -# if combining type-unstable and type-stable, return type-unstable -function ∘(b1::T1, b2::T2) where {T1<:Composed{<:Tuple}, T2<:Composed{<:AbstractArray}} - error("Cannot compose compositions of different container-types; ($T1, $T2)") -end -function ∘(b1::T1, b2::T2) where {T1<:Composed{<:AbstractArray}, T2<:Composed{<:Tuple}} - error("Cannot compose compositions of different container-types; ($T1, $T2)") -end - - -∘(::Identity, ::Identity) = Identity() -∘(::Identity, b::Transform) = b -∘(b::Transform, ::Identity) = b - -inverse(ct::Composed) = Composed(reverse(map(inverse, ct.ts))) - -# TODO: should arrays also be using recursive implementation instead? -function transform(cb::Composed, x) - @assert length(cb.ts) > 0 - res = cb.ts[1](x) - for b ∈ Base.Iterators.drop(cb.ts, 1) - res = b(res) - end - - return res +function transform_single!(cb::ComposedFunction, x, y) + transform!(cb.inner, x, y) + return transform!(cb.outer, y, y) end -function transform_batch(cb::Composed, x) - @assert length(cb.ts) > 0 - res = cb.ts[1].(x) - for b ∈ Base.Iterators.drop(cb.ts, 1) - res = transform_batch(b, res) - end - - return res +function transform_multiple!(cb::ComposedFunction, x, y) + transform!(cb.inner, x, y) + return transform!(cb.outer, y, y) end -@generated function transform(cb::Composed{T}, x) where {T<:Tuple} - @assert length(T.parameters) > 0 - expr = :(x) - for i in 1:length(T.parameters) - expr = :(cb.ts[$i]($expr)) - end - return expr +function logabsdetjac_single(cb::ComposedFunction, x) + y, logjac = forward(cb.inner, x) + return logabsdetjac(cb.outer, y) + logjac end -@generated function transform_batch(cb::Composed{T}, x) where {T<:Tuple} - @assert length(T.parameters) > 0 - expr = :(x) - for i in 1:length(T.parameters) - expr = :(transform_batch(cb.ts[$i], $expr)) - end - return expr +function logabsdetjac_multiple(cb::ComposedFunction, x) + y, logjac = forward(cb.inner, x) + return logabsdetjac(cb.outer, y) + logjac end -function logabsdetjac(cb::Composed, x) - y, logjac = with_logabsdet_jacobian(cb.ts[1], x) - for i = 2:length(cb.ts) - y, res_logjac = with_logabsdet_jacobian(cb.ts[i], y) - logjac += res_logjac - end - - return logjac +function logabsdetjac_single!(cb::ComposedFunction, x, logjac) + y = similar(x) + forward!(cb.inner, x, y, logjac) + return logabdetjac!(cb.outer, y, y, logjac) end -function logabsdetjac_batch(cb::Composed, x) - init = forward_batch(cb.ts[1], x) - result = reduce(cb.ts[2:end]; init = init) do (y, logjac), b - return forward_batch(b, y) - end - - return result.logabsdetjac -end - - -@generated function logabsdetjac(cb::Composed{T}, x) where {T<:Tuple} - N = length(T.parameters) - - expr = Expr(:block) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - for i = 2:N - 1 - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.ts[$i], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - end - # don't need to evaluate the last bijector, only it's `logabsdetjac` - sym_ladj, sym_tmp_ladj = gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :($sym_tmp_ladj = logabsdetjac(cb.ts[$N], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - push!(expr.args, :(return $sym_ladj)) - - return expr +function logabsdetjac_multiple!(cb::ComposedFunction, x, logjac) + y = similar(x) + forward!(cb.inner, x, y, logjac) + return logabdetjac!(cb.outer, y, y, logjac) end -""" - logabsdetjac_batch(cb::Composed{<:Tuple}, x) - -Generates something of the form -```julia -quote - (y, logjac_1) = forward_batch(cb.ts[1], x) - logjac_2 = logabsdetjac_batch(cb.ts[2], y) - return logjac_1 + logjac_2 +function forward_single(cb::ComposedFunction, x) + y1, logjac1 = forward(cb.inner, x) + y2, logjac2 = forward(cb.outer, y1) + return y2, logjac1 + logjac2 end -``` -""" -@generated function logabsdetjac_batch(cb::Composed{T}, x) where {T<:Tuple} - N = length(T.parameters) - - expr = Expr(:block) - push!(expr.args, :((y, logjac_1) = forward_batch(cb.ts[1], x))) - - for i = 2:N - 1 - temp = gensym(:res) - push!(expr.args, :($temp = forward_batch(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.result)) - push!(expr.args, :($(Symbol("logjac_$i")) = $temp.logabsdetjac)) - end - # don't need to evaluate the last bijector, only it's `logabsdetjac` - push!(expr.args, :($(Symbol("logjac_$N")) = logabsdetjac_batch(cb.ts[$N], y))) - sum_expr = Expr(:call, :+, [Symbol("logjac_$i") for i = 1:N]...) - push!(expr.args, :(return $(sum_expr))) - - return expr +function forward_multiple(cb::ComposedFunction, x) + y1, logjac1 = forward(cb.inner, x) + y2, logjac2 = forward(cb.outer, y1) + return y2, logjac1 + logjac2 end - -function with_logabsdet_jacobian(cb::Composed, x) - rv, logjac = with_logabsdet_jacobian(cb.ts[1], x) - - for t in cb.ts[2:end] - rv, res_logjac = with_logabsdet_jacobian(t, rv) - logjac += res_logjac - end - return (rv, logjac) +function forward_single!(cb::ComposedFunction, x, y, logjac) + forward!(cb.inner, x, y, logjac) + return forward!(cb.outer, y, y, logjac) end -@generated function with_logabsdet_jacobian(cb::Composed{T}, x) where {T<:Tuple} - expr = Expr(:block) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - for i = 2:length(T.parameters) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.ts[$i], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - end - push!(expr.args, :(return ($sym_y, $sym_ladj))) - - return expr +function forward_multiple!(cb::ComposedFunction, x, y, logjac) + forward!(cb.inner, x, y, logjac) + return forward!(cb.outer, y, y, logjac) end diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index d3794a59..92197f80 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -1,20 +1,83 @@ -############# -# Exp & Log # -############# +transform_single!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) -struct Exp <: Bijector end -struct Log <: Bijector end +transform_multiple(b::Union{typeof(log), typeof(exp)}, x) = b.(x) +transform_multiple!(b::Union{typeof(log), typeof(exp)}, x, y) = broadcast!(b, y, x) -inv(b::Exp) = Log() -inv(b::Log) = Exp() +logabsdetjac_single(b::typeof(exp), x::Real) = x +logabsdetjac_single(b::Elementwise{typeof(exp)}, x) = sum(x) -transform(b::Exp, y) = exp.(y) -transform(b::Log, x) = log.(x) +logabsdetjac_single(b::typeof(log), x::Real) = -log(x) +logabsdetjac_single(b::Elementwise{typeof(log)}, x) = -sum(log, x) -logabsdetjac(b::Exp, x) = sum(x) -logabsdetjac(b::Log, x) = -sum(log, x) +logabsdetjac_multiple(b::typeof(exp), xs) = xs +logabsdetjac_multiple(b::Elementwise{typeof(exp)}, xs) = map(sum, xs) -function forward(b::Log, x) - y = transform(b, x) - return (result = y, logabsdetjac = -sum(y)) +logabsdetjac_multiple(b::typeof(log), xs) = -map(log, xs) +logabsdetjac_multiple(b::Elementwise{typeof(log)}, xs) = -map(sum ∘ log, xs) + +function forward_single(b::typeof(exp), x::Real) + y = b(x) + return y, x +end +function forward_single(b::Elementwise{typeof(exp)}, x) + y = b(x) + return y, sum(x) +end + +function forward_multiple(b::typeof(exp), xs::AbstractBatch{<:Real}) + ys = transform(b, xs) + return ys, xs +end +function forward_multiple!( + b::typeof(exp), + xs::AbstractBatch{<:Real}, + ys::AbstractBatch{<:Real}, + logjacs::AbstractBatch{<:Real} +) + transform!(b, xs, ys) + logjacs += xs + return ys, logjacs +end +function forward_multiple(b::Elementwise{typeof(exp)}, xs) + ys = transform(b, xs) + return ys, map(sum, xs) +end +function forward_multiple!(b::Elementwise{typeof(exp)}, xs, ys, logjacs) + # Do this before `transform!` in case `xs === ys`. + logjacs += map(sum, xs) + transform!(b, xs, ys) + return ys, logjacs +end + +function forward_single(b::typeof(log), y::Real) + x = transform(b, y) + return x, -x +end +function forward_single(b::Elementwise{typeof(log)}, y) + x = transform(b, y) + return x, -sum(x) +end + +function forward_multiple(b::typeof(log), ys::AbstractBatch{<:Real}) + xs = transform(b, ys) + return xs, -xs +end +function forward_multiple!( + b::typeof(log), + ys::AbstractBatch{<:Real}, + xs::AbstractBatch{<:Real}, + logjacs::AbstractBatch{<:Real} +) + transform!(b, ys, xs) + logjacs -= xs + return xs, logjacs +end +function forward_multiple(b::Elementwise{typeof(log)}, ys) + xs = transform(b, ys) + return xs, -map(sum, xs) +end +function forward_multiple!(b::Elementwise{typeof(log)}, ys, xs, logjacs) + transform!(b, ys, xs) + logjacs -= map(sum, xs) + return xs, logjacs end diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index 9d26de5d..17afd125 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -1,8 +1,6 @@ ###################### # Logit and Logistic # ###################### -using StatsFuns: logit, logistic - struct Logit{T} <: Bijector a::T b::T @@ -13,11 +11,23 @@ Functors.@functor Logit # For equality of Logit with Float64 fields to one with Duals Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b -transform(b::Logit, x) = _logit.(x, b.a, b.b) -_logit(x, a, b) = logit((x - a) / (b - a)) +# TODO: Implement `forward` and batched versions. + +# Evaluation +_logit(x, a, b) = LogExpFunctions.logit((x - a) / (b - a)) +transform_single(b::Logit, x) = _logit.(x, b.a, b.b) +function transform_multiple(b::Logit, xs::Batching.ArrayBatch{<:Real}) + return batch_like(xs, _logit.(Batching.value(xs))) +end -transform(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) -_ilogit(y, a, b) = (b - a) * logistic(y) + a +# Inverse +_ilogit(y, a, b) = (b - a) * LogExpFunctions.logistic(y) + a + +transform_single(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) +function transform_multiple(ib::Inverse{<:Logit}, ys::Batching.ArrayBatch) + return batch_like(ys, _ilogit.(Batching.value(ys))) +end -logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) +# `logabsdetjac` logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) +logabsdetjac_single(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) diff --git a/src/interface.jl b/src/interface.jl index 809022df..8e96a27f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -3,6 +3,21 @@ import Base: ∘ import Random: AbstractRNG import Distributions: logpdf, rand, rand!, _rand!, _logpdf +const Elementwise{F} = Base.Fix1{<:Union{typeof(map),typeof(broadcast)}, F} +""" + elementwise(f) + +Alias for `Base.Fix1(broadcast, f)`. + +In the case where `f::ComposedFunction`, the result is +`Base.Fix1(broadcast, f.outer) ∘ Base.Fix1(broadcast, f.inner)` rather than +`Base.Fix1(broadcast, f)`. +""" +elementwise(f) = Base.Fix1(broadcast, f) +# TODO: This is makes dispatching quite a bit easier, but uncertain if this is really +# the way to go. +elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner)) + ####################################### # AD stuff "extracted" from Turing.jl # ####################################### @@ -35,90 +50,321 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t Abstract type for a transformation. -## Implementing +# Implementing -A subtype of `Transform` of should at least implement `transform(b, x)`. +A subtype of `Transform` of should at least implement [`transform(b, x)`](@ref). If the `Transform` is also invertible: - Required: - [`invertible`](@ref): should return [`Invertible`](@ref). - _Either_ of the following: - `transform(::Inverse{<:MyTransform}, x)`: the `transform` for its inverse. - - `Base.inv(b::MyTransform)`: returns an existing `Transform`. + - `InverseFunctions.inverse(b::MyTransform)`: returns an existing `Transform`. - [`logabsdetjac`](@ref): computes the log-abs-det jacobian factor. - Optional: - [`forward`](@ref): `transform` and `logabsdetjac` combined. Useful in cases where we can exploit shared computation in the two. For the above methods, there are mutating versions which can _optionally_ be implemented: -- [`transform!`](@ref) -- [`logabsdetjac!`](@ref) -- [`forward!`](@ref) +- [`transform_single!`](@ref) +- [`logabsdetjac_single!`](@ref) +- [`forward_single!`](@ref) Finally, there are _batched_ versions of the above methods which can _optionally_ be implemented: -- [`transform_batch`](@ref) -- [`logabsdetjac_batch`](@ref) -- [`forward_batch`](@ref) +- [`transform_multiple`](@ref) +- [`logabsdetjac_multiple`](@ref) +- [`forward_multiple`](@ref) and similarly for the mutating versions. Default implementations depends on the type of `xs`. -Note that these methods are usually used through broadcasting, i.e. `b.(x)` with `x` a `AbstractBatch` -falls back to `transform_batch(b, x)`. """ abstract type Transform end +(t::Transform)(x) = transform(t, x) + Broadcast.broadcastable(b::Transform) = Ref(b) """ transform(b, x) +Transform `x` using `b`, treating `x` as a single input. + +Alias for [`transform_single(b, x)`](@ref). +""" +transform(b, x) = transform_single(b, x) + +""" + transform_single(b, x) + Transform `x` using `b`. -Alternatively, one can just call `b`, i.e. `b(x)`. +Defaults to `b(x)` if `b isa Function` and `first(forward(b, x))` otherwise. """ -transform -(t::Transform)(x) = transform(t, x) +transform_single(b, x) = first(forward(b, x)) +transform_single(b::Function, x) = b(x) + +""" + transform(b, xs::AbstractBatch) + +Transform `xs` using `b`, treating `xs` as a collection of inputs. +Alias for [`transform_multiple(b, x)`](@ref). """ - transform!(b, x, y) +transform(b, xs::AbstractBatch) = transform_multiple(b, xs) -Transforms `x` using `b`, storing the result in `y`. """ -transform!(b, x, y) = (y .= transform(b, x)) + transform_multiple(b, xs) + +Transform `xs` using `b`, treating `xs` as a collection of inputs. + +Defaults to `map(Base.Fix1(transform, b), xs)`. +""" +transform_multiple(b, xs) = map(Base.Fix1(transform, b), xs) +function transform_multiple(b::Elementwise, x::Batching.ArrayBatch) + return batch_like(x, transform(b, Batching.value(x))) +end + + +""" + transform!(b, x[, y]) + +Transform `x` using `b`, storing the result in `y`. + +If `y` is not provided, `x` is used as the output. + +Alias for [`transform_single!(b, x, y)`](@ref). +""" +transform!(b, x, y=x) = transform_single!(b, x, y) + +""" + transform_single!(b, x, y) + +Transform `x` using `b`, storing the result in `y`. +""" +transform_single!(b, x, y) = (y .= transform(b, x)) + +""" + transform!(b, xs::AbstractBatch[, ys::AbstractBatch]) + +Transform `x` for `x` in `xs` using `b`, storing the result in `ys`. + +If `ys` is not provided, `xs` is used as the output. + +Alias for [`transform_multiple!(b, xs, ys)`](@ref). +""" +transform!(b, xs::AbstractBatch, ys::AbstractBatch=xs) = transform_multiple!(b, xs, ys) + +""" + transform_multiple!(b, xs::AbstractBatch[, ys::AbstractBatch]) + +Transform `x` for `x` in `xs` using `b`, storing the result in `ys`. +""" +transform_multiple!(b, xs, ys) = broadcast!(Base.Fix1(transform, b), xs, ys) +function transform_multiple!(b::Elementwise, x::Batching.ArrayBatch, y::Batching.ArrayBatch) + broadcast!(b, Batching.value(y), Batching.value(x)) + return y +end """ logabsdetjac(b, x) -Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform. +Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. + +Alias for [`logabsdetjac_single`](@ref). + +See also: [`logabsdetjac(b, xs::AbstractBatch)`](@ref). +""" +logabsdetjac(b, x) = logabsdetjac_single(b, x) + +""" + logabsdetjac_single(b, x) + +Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. + +Defaults to `last(forward(b, x))`. +""" +logabsdetjac_single(b, x) = last(forward(b, x)) + +""" + logabsdetjac(b, xs::AbstractBatch) + +Return a collection representing `log(abs(det(J(b, x))))` for every `x` in `xs`, +where `J(b, x)` is the jacobian of `b` at `x`. + +Alias for [`Bijectors.logabsdetjac_multiple`](@ref). + +See also: [`logabsdetjac(b, x)`](@ref). +""" +logabsdetjac(b, xs::AbstractBatch) = logabsdetjac_multiple(b, xs) + +""" + logabsdetjac_multiple(b, xs) + +Return a collection representing `log(abs(det(J(b, x))))` for every `x` in `xs`, +where `J(b, x)` is the jacobian of `b` at `x`. + +Defaults to `map(Base.Fix1(logabsdetjac, b), xs)`. +""" +logabsdetjac_multiple(b, xs) = map(Base.Fix1(logabsdetjac, b), xs) + +""" + logabsdetjac!(b, x[, logjac]) + +Compute `log(abs(det(J(b, x))))` and store the result in `logjac`, where `J(b, x)` is the jacobian of `b` at `x`. + +Alias for [`logabsdetjac_single!(b, x, logjac)`](@ref). """ -logabsdetjac +logabsdetjac!(b, x, logjac=zero(eltype(x))) = logabsdetjac_single!(b, x, logjac) """ - logabsdetjac!(b, x, logjac) + logabsdetjac_single!(b, x[, logjac]) -Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform, -_accumulating_ the result in `logjac`. +Compute `log(abs(det(J(b, x))))` and accumulate the result in `logjac`, +where `J(b, x)` is the jacobian of `b` at `x`. """ -logabsdetjac!(b, x, logjac) = (logjac += logabsdetjac(b, x)) +logabsdetjac_single!(b, x, logjac) = (logjac += logabsdetjac(b, x)) + +""" + logabsdetjac!(b, xs::AbstractBatch[, logjacs::AbstractBatch]) + +Compute `log(abs(det(J(b, x))))` and store the result in `logjacs` for +every `x` in `xs`, where `J(b, x)` is the jacobian of `b` at `x`. + +Alias for [`logabsdetjac_single!(b, x, logjac)`](@ref). +""" +function logabsdetjac!( + b, + xs::AbstractBatch, + logjacs::AbstractBatch=batch_like(xs, zeros(eltype(eltype(xs)), length(xs))) +) + return logabsdetjac_multiple!(b, xs, logjacs) +end + +""" + logabsdetjac_multiple!(b, xs::AbstractBatch, logjacs::AbstractBatch) + +Compute `log(abs(det(J(b, x))))` and store the result in `logjacs` for +every `x` in `xs`, where `J(b, x)` is the jacobian of `b` at `x`. +""" +logabsdetjac_multiple!(b, xs, logjacs) = (logjacs .+= logabsdetjac(b, xs)) """ forward(b, x) -Computes both `transform` and `logabsdetjac` in one forward pass, and -returns a named tuple `(b(x), logabsdetjac(b, x))`. +Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as single input. + +Alias for [`forward_single(b, x)`](@ref). + +See also: [`forward(b, xs::AbstractBatch)`](@ref). +""" +forward(b, x) = forward_single(b, x) + +""" + forward_single(b, x) + +Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as a single input. + +Defaults to `ChangeOfVariables.with_logabsdet_jacobian(b, x)`. +""" +forward_single(b, x) = with_logabsdet_jacobian(b, x) + +""" + forward(b, xs::AbstractBatch) + +Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as +collection of inputs. -This defaults to the call above, but often one can re-use computation -in the computation of the forward pass and the computation of the -`logabsdetjac`. `forward` allows the user to take advantange of such -efficiencies, if they exist. +Alias for [`forward_multiple(b, xs)`](@ref). + +See also: [`forward(b, x)`](@ref). +""" +forward(b, xs::AbstractBatch) = forward_multiple(b, xs) + +""" + forward_multiple(b, xs) + +Return `(transform(b, xs), logabsdetjac(b, xs))` treating +`xs` as a batch, i.e. a collection inputs. + +See also: [`forward_single(b, x)`](@ref). +""" +function forward_multiple(b, xs) + # If `b` doesn't have its own definition of `forward_multiple` + # we just broadcast `forward_single`, resulting in a batch of `(y, logjac)` + # pairs which we then unwrap. + results = forward.(Ref(b), xs) + ys = map(first, results) + logjacs = map(last, results) + return batch_like(xs, ys, logjacs) +end + +# `logjac` as an argument doesn't make too much sense for `forward_single!` +# when the inputs have `eltype` `Real`. +""" + forward!(b, x[, y, logjac]) + +Compute `transform(b, x)` and `logabsdetjac(b, x)`, storing the result +in `y` and `logjac`, respetively. + +If `y` is not provided, then `x` will be used in its place. + +Alias for [`forward_single!(b, x, y, logjac)`](@ref). + +See also: [`forward!(b, xs::AbstractBatch, ys::AbstractBatch, logjacs::AbstractBatch)`](@ref). """ -forward(b, x) = (result = transform(b, x), logabsdetjac = logabsdetjac(b, x)) +forward!(b, x, y=x, logjac=zero(eltype(x))) = forward_single!(b, x, y, logjac) -function forward!(b, x, out) - y, logjac = forward(b, x) - out.result .= y - out.logabsdetjac .+= logjac +""" + forward_single!(b, x, y, logjac) + +Compute `transform(b, x)` and `logabsdetjac(b, x)`, storing the result +in `y` and `logjac`, respetively. - return out +Defaults to calling `forward(b, x)` and updating `y` and `logjac` with the result. +""" +function forward_single!(b, x, y, logjac) + y_, logjac_ = forward(b, x) + y .= y_ + return (y, logjac + logjac_) +end + +""" + forward!(b, xs[, ys, logjacs]) + +Compute `transform(b, x)` and `logabsdetjac(b, x)` for every `x` in the collection `xs`, +storing the results in `ys` and `logjacs`, respetively. + +If `ys` is not provided, then `xs` will be used in its place. + +Alias for [`forward_multiple!(b, xs, ys, logjacs)`](@ref). + +See also: [`forward!(b, x, y, logjac)`](@ref). +""" +function forward!( + b, + xs::AbstractBatch, + ys::AbstractBatch=xs, + logjacs::AbstractBatch=batch_like(xs, zeros(eltype(eltype(xs)), length(xs))) +) + return forward_multiple!(b, xs, ys, logjacs) +end + +""" + forward!(b, xs, ys, logjacs) + +Compute `transform(b, x)` and `logabsdetjac(b, x)` for every `x` in the collection `xs`, +storing the results in `ys` and `logjacs`, respetively. + +Defaults to iterating through the `xs` and calling [`forward_single(b, x)`](@ref) +for every `x` in `xs`. +""" +function forward_multiple!(b, xs, ys, logjacs) + for i in eachindex(ys) + res = forward_single(b, xs[i]) + ys[i] .= first(res) + logjacs[i] += last(res) + end + + return ys, logjacs end """ @@ -132,7 +378,7 @@ Most transformations have closed-form evaluations, but there are cases where this is not the case. For example the *inverse* evaluation of `PlanarLayer` requires an iterative procedure to evaluate. """ -isclosedform(b::Transform) = true +isclosedform(t::Transform) = true # Invertibility "trait". struct NotInvertible end @@ -144,11 +390,22 @@ Base.:+(::Invertible, ::NotInvertible) = NotInvertible() Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() Base.:+(::Invertible, ::Invertible) = Invertible() +""" + invertible(t) + +Return `Invertible()` if `t` is invertible, and `NotInvertible()` otherwise. +""" invertible(::Transform) = NotInvertible() + +""" + isinvertible(t) + +Return `true` if `t` is invertible, and `false` otherwise. +""" isinvertible(t::Transform) = invertible(t) isa Invertible """ - inv(b::Transform) + inverse(b::Transform) Inverse(b::Transform) A `Transform` representing the inverse transform of `b`. @@ -168,12 +425,12 @@ end Functors.@functor Inverse """ - inv(t::Transform[, ::Invertible]) + inverse(t::Transform[, ::Invertible]) Returns the inverse of transform `t`. """ -Base.inv(t::Transform) = Inverse(t) -Base.inv(ib::Inverse) = ib.orig +inverse(t::Transform) = Inverse(t) +inverse(ib::Inverse) = ib.orig invertible(ib::Inverse) = Invertible() @@ -197,157 +454,55 @@ logabsdetjacinv(b::Bijector, y) = logabsdetjac(inverse(b), y) ############################## # Example bijector: Identity # ############################## +# Here we don't need to separate between batched version and non-batched, and so +# we can just overload `transform`, etc. directly. +transform(::typeof(identity), x) = copy(x) +transform!(::typeof(identity), x, y) = copy!(y, x) -struct Identity <: Bijector end -inv(b::Identity) = b +logabsdetjac_single(::typeof(identity), x) = zero(eltype(x)) +logabsdetjac_multiple(::typeof(identity), x) = batch_like(x, zeros(eltype(eltype(x)), length(x))) -transform(::Identity, x) = copy(x) -transform!(::Identity, x, y) = (y .= x; return y) -logabsdetjac(::Identity, x) = zero(eltype(x)) -logabsdetjac!(::Identity, x, logjac) = logjac +logabsdetjac_single!(::typeof(identity), x, logjac) = logjac +logabsdetjac_multiple!(::typeof(identity), x, logjac) = logjac #################### # Batched versions # #################### # NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. # so we can actually reference them. Since we just did this for `Identity`, we're good. -Broadcast.broadcasted(b::Transform, xs::Batch) = transform_batch(b, xs) -Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_batch(b, xs) -Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_batch(b, xs) -Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_batch(b, xs) - - -""" - transform_batch(b, xs) - -Transform `xs` by `b`, treating `xs` as a "batch", i.e. a collection of independent inputs. - -See also: [`transform`](@ref) -""" -transform_batch(b, xs) = _transform_batch(b, xs) -# Default implementations uses private methods to avoid method ambiguity. -_transform_batch(b, xs::VectorBatch) = reconstruct(xs, map(b, value(xs))) -function _transform_batch(b, xs::ArrayBatch{2}) - # TODO: Check if we can avoid using these custom methods. - return Batch(eachcolmaphcat(b, value(xs))) -end -function _transform_batch(b, xs::ArrayBatch{N}) where {N} - res = reduce(map(b, eachslice(value(xs), Val{N}()))) do acc, x - cat(acc, x; dims = N) - end - return reconstruct(xs, res) -end - -""" - transform_batch!(b, xs, ys) - -Transform `xs` by `b` treating `xs` as a "batch", i.e. a collection of independent inputs, -and storing the result in `ys`. - -See also: [`transform!`](@ref) -""" -transform_batch!(b, xs, ys) = _transform_batch!(b, xs, ys) -function _transform_batch!(b, xs, ys) - for i = 1:length(xs) - if eltype(ys) <: Real - ys[i] = transform(b, xs[i]) - else - transform!(b, xs[i], ys[i]) - end - end - - return ys -end - -""" - logabsdetjac_batch(b, xs) - -Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs. - -See also: [`logabsdetjac`](@ref) -""" -logabsdetjac_batch(b, xs) = _logabsdetjac_batch(b, xs) -# Default implementations uses private methods to avoid method ambiguity. -_logabsdetjac_batch(b, xs::VectorBatch) = reconstruct(xs, map(x -> logabsdetjac(b, x), value(xs))) -function _logabsdetjac_batch(b, xs::ArrayBatch{2}) - return reconstruct(xs, map(x -> logabsdetjac(b, x), eachcol(value(xs)))) -end -function _logabsdetjac_batch(b, xs::ArrayBatch{N}) where {N} - return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), Val{N}()))) -end - -""" - logabsdetjac_batch!(b, xs, logjacs) - -Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs, -accumulating the result in `logjacs`. - -See also: [`logabsdetjac!`](@ref) -""" -logabsdetjac_batch!(b, xs, logjacs) = _logabsdetjac_batch!(b, xs, logjacs) -function _logabsdetjac_batch!(b, xs, logjacs) - for i = 1:length(xs) - if eltype(logjacs) <: Real - logjacs[i] += logabsdetjac(b, xs[i]) - else - logabsdetjac!(b, xs[i], logjacs[i]) - end - end - - return logjacs -end - -""" - forward_batch(b, xs) - -Computes `forward(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs. - -See also: [`transform`](@ref) -""" -forward_batch(b, xs) = (result = transform_batch(b, xs), logabsdetjac = logabsdetjac_batch(b, xs)) - -""" - forward_batch!(b, xs, out) - -Computes `forward(b, xs)` in place, treating `xs` as a "batch", i.e. a collection of independent inputs. - -See also: [`forward!`](@ref) -""" -function forward_batch!(b, xs, out) - transform_batch!(b, xs, out.result) - logabsdetjac_batch!(b, xs, out.logabsdetjac) - - return out -end +# Broadcast.broadcasted(b::Transform, xs::Batch) = transform_multiple(b, xs) +# Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_multiple(b, xs) +# Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_multiple(b, xs) +# Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_multiple(b, xs) ###################### # Bijectors includes # ###################### # General -include("bijectors/adbijector.jl") +# include("bijectors/adbijector.jl") include("bijectors/composed.jl") -include("bijectors/stacked.jl") +# include("bijectors/stacked.jl") # Specific include("bijectors/exp_log.jl") include("bijectors/logit.jl") -include("bijectors/scale.jl") -include("bijectors/shift.jl") -include("bijectors/permute.jl") -include("bijectors/simplex.jl") -include("bijectors/pd.jl") -include("bijectors/corr.jl") -include("bijectors/truncated.jl") -include("bijectors/named_bijector.jl") -include("bijectors/ordered.jl") +# include("bijectors/scale.jl") +# include("bijectors/shift.jl") +# include("bijectors/permute.jl") +# include("bijectors/simplex.jl") +# include("bijectors/pd.jl") +# include("bijectors/corr.jl") +# include("bijectors/truncated.jl") +# include("bijectors/named_bijector.jl") +# include("bijectors/ordered.jl") # Normalizing flow related -include("bijectors/planar_layer.jl") -include("bijectors/radial_layer.jl") -include("bijectors/leaky_relu.jl") -include("bijectors/coupling.jl") -include("bijectors/normalise.jl") -include("bijectors/rational_quadratic_spline.jl") +# include("bijectors/planar_layer.jl") +# include("bijectors/radial_layer.jl") +# include("bijectors/leaky_relu.jl") +# include("bijectors/coupling.jl") +# include("bijectors/normalise.jl") +# include("bijectors/rational_quadratic_spline.jl") ################## # Other includes # From 306aa662f3fe54c7d73d93e2837f5580bb49ae3c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 24 Jan 2022 18:20:55 +0000 Subject: [PATCH 45/45] added docs --- docs/Project.toml | 3 + docs/make.jl | 15 +++++ docs/src/index.md | 144 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 docs/Project.toml create mode 100644 docs/make.jl create mode 100644 docs/src/index.md diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..7e7c8131 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,3 @@ +[deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 00000000..54fb95c5 --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,15 @@ +using Documenter +using Bijectors + +makedocs( + sitename = "Bijectors", + format = Documenter.HTML(), + modules = [Bijectors] +) + +# Documenter can also automatically deploy documentation to gh-pages. +# See "Hosting Documentation" and deploydocs() in the Documenter manual +# for more information. +#=deploydocs( + repo = "" +)=# diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 00000000..ce0f1b0a --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,144 @@ +# Bijectors.jl + +Documentation for Bijectors.jl + +## Usage + +A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function: +- The inverse of `exp` is `log`. +- The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`. + +```@repl usage +using Bijectors +transform(exp, 1.0) +logabsdetjac(exp, 1.0) +forward(exp, 1.0) +``` + +If you want to instead transform a collection of inputs, you can use the `batch` method from Batching.jl to inform Bijectors.jl that the input now represents a collection of inputs rather than a single input: + +```@repl usage +xs = batch(ones(2)); +transform(exp, xs) +logabsdetjac(exp, xs) +forward(exp, xs) +``` + +Some transformations is well-defined for different types of inputs, e.g. `exp` can also act elementwise on a `N`-dimensional `Array{<:Real,N}`. To specify that a transformation should be acting elementwise, we use the [`elementwise`](@ref) method: + +```@repl usage +x = ones(2, 2) +transform(elementwise(exp), x) +logabsdetjac(elementwise(exp), x) +forward(elementwise(exp), x) +``` + +And batched versions: + +```@repl usage +xs = batch(ones(2, 2, 3)); +transform(elementwise(exp), xs) +logabsdetjac(elementwise(exp), xs) +forward(elementwise(exp), xs) +``` + +These methods also work nicely for compositions of transformations: + +```@repl usage +transform(elementwise(log ∘ exp), xs) +``` + +Unlike `exp`, some transformations have parameters affecting the resulting transformation they represent, e.g. `Logit` has two parameters `a` and `b` representing the lower- and upper-bound, respectively, of its domain: + +```@repl usage +using Bijectors: Logit + +f = Logit(0.0, 1.0) +f(rand()) # takes us from `(0, 1)` to `(-∞, ∞)` +``` + +## User-facing methods + +Without mutation: + +```@docs +transform +logabsdetjac +forward(b, x) +``` + +With mutation: + +```@docs +transform! +logabsdetjac! +forward! +``` + +## Implementing a transformation + +Any callable can be made into a bijector by providing an implementation of [`forward(b, x)`](@ref), which is done by overloading + +```@docs +Bijectors.forward_single +``` + +where + +```@docs +Bijectors.transform_single +Bijectors.logabsdetjac_single +``` + +You can then optionally implement `transform` and `logabsdetjac` to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. + +Note that a _user_ of the bijector should generally be using [`forward(b, x)`](@ref) rather than calling [`forward_single`](@ref) directly. + +To implement "batched" versions of the above functionalities, i.e. methods which act on a _collection_ of inputs rather than a single input, you can overload the following method: + +```@docs +Bijectors.forward_multiple +``` + +And similarly, if you want to specialize [`transform`](@ref) and [`logabsdetjac`](@ref), you can implement + +```@docs +Bijectors.transform_multiple +Bijectors.logabsdetjac_multiple +``` + +### Mutability + +There are also _mutable_ versions of all of the above: + +```@docs +Bijectors.forward_single! +Bijectors.forward_multiple! +Bijectors.transform_single! +Bijectors.transform_multiple! +Bijectors.logabsdetjac_single! +Bijectors.logabsdetjac_multiple! +``` + +## Working with Distributions.jl + +```@docs +Bijectors.bijector +Bijectors.transformed(d::Distribution, b::Bijector) +``` + +## Utilities + +```@docs +Bijectors.elementwise +Bijectors.isinvertible +Bijectors.isclosedform(t::Bijectors.Transform) +``` + +## API + +```@docs +Bijectors.Transform +Bijectors.Bijector +Bijectors.Inverse +```