From 1d9a1a324d706c369df384d3913199ef0b64adef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:46:27 +0100 Subject: [PATCH 01/98] renamed rv to result in forward --- src/bijectors/leaky_relu.jl | 6 +++--- src/bijectors/named_bijector.jl | 18 +++++++++--------- src/bijectors/normalise.jl | 10 +++++----- src/bijectors/planar_layer.jl | 2 +- src/bijectors/radial_layer.jl | 2 +- src/bijectors/rational_quadratic_spline.jl | 4 ++-- src/bijectors/stacked.jl | 6 +++--- src/transformed_distribution.jl | 22 +++++++++++----------- 8 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index 65060c14..c9128f01 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -50,7 +50,7 @@ logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> loga function forward(b::LeakyReLU{<:Any, 0}, x::Real) mask = x < zero(x) J = mask * b.α + !mask * one(x) - return (rv=J * x, logabsdetjac=log(abs(J))) + return (result=J * x, logabsdetjac=log(abs(J))) end # Batched version @@ -58,7 +58,7 @@ function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - return (rv=J .* x, logabsdetjac=log.(abs.(J))) + return (result=J .* x, logabsdetjac=log.(abs.(J))) end # (N=1) Multivariate case @@ -97,5 +97,5 @@ function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) end y = J .* x - return (rv=y, logabsdetjac=logjac) + return (result=y, logabsdetjac=logjac) end diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 36f8691c..500044b9 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,6 +1,6 @@ abstract type AbstractNamedBijector <: AbstractBijector end -forward(b::AbstractNamedBijector, x) = (rv = b(x), logabsdetjac = logabsdetjac(b, x)) +forward(b::AbstractNamedBijector, x) = (result = b(x), logabsdetjac = logabsdetjac(b, x)) ####################### ### `NamedBijector` ### @@ -125,7 +125,7 @@ function logabsdetjac(cb::NamedComposition, x) y, logjac = forward(cb.bs[1], x) for i = 2:length(cb.bs) res = forward(cb.bs[i], y) - y = res.rv + y = res.result logjac += res.logabsdetjac end @@ -141,7 +141,7 @@ end for i = 2:N - 1 temp = gensym(:res) push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.rv)) + push!(expr.args, :(y = $temp.result)) push!(expr.args, :(logjac += $temp.logabsdetjac)) end # don't need to evaluate the last bijector, only it's `logabsdetjac` @@ -154,14 +154,14 @@ end function forward(cb::NamedComposition, x) - rv, logjac = forward(cb.bs[1], x) + result, logjac = forward(cb.bs[1], x) for t in cb.bs[2:end] - res = forward(t, rv) - rv = res.rv + res = forward(t, result) + result = res.result logjac = res.logabsdetjac + logjac end - return (rv=rv, logabsdetjac=logjac) + return (result=result, logabsdetjac=logjac) end @@ -171,10 +171,10 @@ end for i = 2:length(T.parameters) temp = gensym(:temp) push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.rv)) + push!(expr.args, :(y = $temp.result)) push!(expr.args, :(logjac += $temp.logabsdetjac)) end - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (result = y, logabsdetjac = logjac))) return expr end diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 81496468..16b48a29 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -72,16 +72,16 @@ function forward(bn::InvertibleBatchNorm, x) v = reshape(bn.v, as...) end - rv = s .* (x .- m) ./ sqrt.(v .+ bn.eps) .+ b + result = s .* (x .- m) ./ sqrt.(v .+ bn.eps) .+ b logabsdetjac = ( fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims)) ) - return (rv=rv, logabsdetjac=logabsdetjac) + return (result=result, logabsdetjac=logabsdetjac) end logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac -(bn::InvertibleBatchNorm)(x) = forward(bn, x).rv +(bn::InvertibleBatchNorm)(x) = forward(bn, x).result function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -94,10 +94,10 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) v = reshape(bn.v, as...) x = (y .- b) ./ s .* sqrt.(v .+ bn.eps) .+ m - return (rv=x, logabsdetjac=-logabsdetjac(bn, x)) + return (result=x, logabsdetjac=-logabsdetjac(bn, x)) end -(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).rv +(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).result function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index b82a7274..9d1a73d0 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -105,7 +105,7 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) b = first(flow.b) log_det_jacobian = log1p.(wT_û .* abs2.(sech.(_vec(wT_z) .+ b))) - return (rv = transformed, logabsdetjac = log_det_jacobian) + return (result = transformed, logabsdetjac = log_det_jacobian) end function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real}) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index b678ed96..a5cb5f34 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -67,7 +67,7 @@ function forward(flow::RadialLayer, z::AbstractVecOrMat) (d - 1) * log(1 + β_hat * h_) + log(1 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) ) # from eq(14) - return (rv = transformed, logabsdetjac = log_det_jacobian) + return (result = transformed, logabsdetjac = log_det_jacobian) end function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 826a4017..a0907b82 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -346,7 +346,7 @@ function rqs_forward( T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x)) if (x ≤ -widths[end]) || (x ≥ widths[end]) - return (rv = one(T) * x, logabsdetjac = zero(T) * x) + return (result = one(T) * x, logabsdetjac = zero(T) * x) end # Find which bin `x` is in @@ -379,7 +379,7 @@ function rqs_forward( numerator_y = Δy * (s * ξ^2 + d_k * ξ * (1 - ξ)) y = h_k + numerator_y / denominator - return (rv = y, logabsdetjac = logjac) + return (result = y, logabsdetjac = logjac) end function forward(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 0131a76c..144645fd 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -131,7 +131,7 @@ end # logjac = sum(_logjac) # (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) # logjac += sum(_logjac) -# return (rv = vcat(y_1, y_2), logabsdetjac = logjac) +# return (result = vcat(y_1, y_2), logabsdetjac = logjac) # end @generated function forward(b::Stacked{T, N}, x::AbstractVector) where {N, T<:Tuple} expr = Expr(:block) @@ -151,7 +151,7 @@ end push!(y_names, y_name) end - push!(expr.args, :(return (rv = vcat($(y_names...)), logabsdetjac = logjac))) + push!(expr.args, :(return (result = vcat($(y_names...)), logabsdetjac = logjac))) return expr end @@ -163,5 +163,5 @@ function forward(sb::Stacked{<:AbstractArray, N}, x::AbstractVector) where {N} logjac += sum(l) y end - return (rv = vcat(yinit, ys), logabsdetjac = logjac) + return (result = vcat(yinit, ys), logabsdetjac = logjac) end diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 6a30f7eb..d6177c72 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -85,14 +85,14 @@ Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res.result) + res.logabsdetjac end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf(td::MvTransformed, y::AbstractMatrix{<:Real}) # batch-implementation for multivariate res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res.result) + res.logabsdetjac end function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) @@ -100,12 +100,12 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res.result)) + res.logabsdetjac end function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res.result) + res.logabsdetjac end function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -113,12 +113,12 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res.result)) + res.logabsdetjac end # TODO: should eventually drop using `logpdf_with_trans` and replace with # res = forward(inv(td.transform), y) -# logpdf(td.dist, res.rv) .- res.logabsdetjac +# logpdf(td.dist, res.result) .- res.logabsdetjac function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) return logpdf_with_trans(td.dist, inv(td.transform)(y), true) end @@ -163,18 +163,18 @@ and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.result) + res.logabsdetjac, res.logabsdetjac) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.result) + res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.result) + res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -182,14 +182,14 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea ϵ = _eps(T) res = forward(inv(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.result)) + res.logabsdetjac return (lp, res.logabsdetjac) end # TODO: should eventually drop using `logpdf_with_trans` function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) + return (logpdf_with_trans(td.dist, res.result, true), res.logabsdetjac) end """ From 0717e3e7af25a0e63a791cf4f54674b36fe3cd4d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:48:34 +0100 Subject: [PATCH 02/98] added abstrac type Transform and removed dimensionality from Bijector --- src/interface.jl | 212 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 155 insertions(+), 57 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 9454956b..9ace3b38 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -31,16 +31,120 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t ###################### # Bijector interface # ###################### -"Abstract type for a bijector." -abstract type AbstractBijector end +""" + +Abstract type for a transformation. + +## Implementing + +A subtype of `Transform` of should at least implement `transform(b, x)`. + +If the `Transform` is also invertible: +- Required: + - [`invertible`](@ref): should return [`Invertible`](@ref). + - _Either_ of the following: + - `transform(::Inverse{<:MyTransform}, x)`: the `transform` for its inverse. + - `Base.inv(b::MyTransform)`: returns an existing `Transform`. + - [`logabsdetjac`](@ref): computes the log-abs-det jacobian factor. +- Optional: + - [`forward`](@ref): `transform` and `logabsdetjac` combined. Useful in cases where we + can exploit shared computation in the two. + +For the above methods, there are mutating versions which can _optionally_ be implemented: +- [`transform!`](@ref) +- [`logabsdetjac!`](@ref) +- [`forward!`](@ref) + +Finally, there are _batched_ versions of the above methods which can _optionally_ be implemented: +- [`transform_batch`](@ref) +- [`logabsdetjac_batch`](@ref) +- [`forward_batch`](@ref) + +and similarly for the mutating versions. Default implementations depends on the type of `xs`. +Note that these methods are usually used through broadcasting, i.e. `b.(x)` with `x` a `AbstractBatch` +falls back to `transform_batch(b, x)`. +""" +abstract type Transform end + +Broadcast.broadcastable(b::Transform) = Ref(b) + +# Invertibility "trait". +struct NotInvertible end +struct Invertible end + +# Useful for checking if compositions, etc. are invertible or not. +Base.:+(::NotInvertible, ::Invertible) = NotInvertible() +Base.:+(::Invertible, ::NotInvertible) = NotInvertible() +Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() +Base.:+(::Invertible, ::Invertible) = Invertible() + +invertible(::Transform) = NotInvertible() + +""" + inv(t::Transform[, ::Invertible]) + +Returns the inverse of transform `t`. +""" +Base.inv(t::Transform) = Base.inv(t, invertible(t)) +Base.inv(t::Transform, ::NotInvertible) = error("$(t) is not invertible") + +""" + transform(b, x) + +Transform `x` using `b`. + +Alternatively, one can just call `b`, i.e. `b(x)`. +""" +transform +(t::Transform)(x) = transform(t, x) + +""" + transform!(b, x, y) + +Transforms `x` using `b`, storing the result in `y`. +""" +transform!(b, x, y) = (y .= transform(b, x)) + +""" + logabsdetjac(b, x) + +Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform. +""" +logabsdetjac + +""" + logabsdetjac!(b, x, logjac) + +Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform, +_accumulating_ the result in `logjac`. +""" +logabsdetjac!(b, x, logjac) = (logjac += logabsdetjac(b, x)) + +""" + forward(b, x) + +Computes both `transform` and `logabsdetjac` in one forward pass, and +returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. + +This defaults to the call above, but often one can re-use computation +in the computation of the forward pass and the computation of the +`logabsdetjac`. `forward` allows the user to take advantange of such +efficiencies, if they exist. +""" +forward(b, x) = (result = transform(b, x), logabsdetjac = logabsdetjac(b, x)) -"Abstract type of bijectors with fixed dimensionality." -abstract type Bijector{N} <: AbstractBijector end +function forward!(b, x, out) + y, logjac = forward(b, x) + out.result .= y + out.logabsdetjac .+= logjac -dimension(b::Bijector{N}) where {N} = N -dimension(b::Type{<:Bijector{N}}) where {N} = N + return out +end -Broadcast.broadcastable(b::Bijector) = Ref(b) +"Abstract type of a bijector, i.e. differentiable bijection with differentiable inverse." +abstract type Bijector <: Transform end + +invertible(::Bijector) = Invertible() """ isclosedform(b::Bijector)::bool @@ -61,81 +165,75 @@ isclosedform(b::Bijector) = true A `Bijector` representing the inverse transform of `b`. """ -struct Inverse{B <: Bijector, N} <: Bijector{N} +struct Inverse{B<:Bijector} <: Bijector orig::B - - Inverse(b::B) where {N, B<:Bijector{N}} = new{B, N}(b) end # field contains nested numerical parameters Functors.@functor Inverse -up1(b::Inverse) = Inverse(up1(b.orig)) - inv(b::Bijector) = Inverse(b) inv(ib::Inverse{<:Bijector}) = ib.orig Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig -""" - logabsdetjac(b::Bijector, x) - logabsdetjac(ib::Inverse{<:Bijector}, y) +logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) -Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform. -Similarily for the inverse-transform. +""" + logabsdetjacinv(b::Bijector, y) -Default implementation for `Inverse{<:Bijector}` is implemented as -`- logabsdetjac` of original `Bijector`. +Just an alias for `logabsdetjac(inv(b), y)`. """ -logabsdetjac(ib::Inverse{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) +logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) + +############################## +# Example bijector: Identity # +############################## + +struct Identity <: Bijector end +inv(b::Identity) = b + +transform(::Identity, x) = copy(x) +transform!(::Identity, x, y) = (y .= x; return y) +logabsdetjac(::Identity, x) = zero(eltype(x)) +logabsdetjac!(::Identity, x, logjac) = logjac + +#################### +# Batched versions # +#################### +# NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. +# so we can actually reference them. Since we just did this for `Identity`, we're good. +Broadcast.broadcasted(b::Transform, xs::Batch) = transform_batch(b, xs) +Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_batch(b, xs) +Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_batch(b, xs) +Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_batch(b, xs) + """ - forward(b::Bijector, x) + transform_batch(b, xs) -Computes both `transform` and `logabsdetjac` in one forward pass, and -returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. +Transform `xs` by `b`, treating `xs` as a "batch", i.e. a collection of independent inputs. -This defaults to the call above, but often one can re-use computation -in the computation of the forward pass and the computation of the -`logabsdetjac`. `forward` allows the user to take advantange of such -efficiencies, if they exist. +See also: [`transform`](@ref) """ -forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) +transform_batch """ - logabsdetjacinv(b::Bijector, y) + logabsdetjac_batch(b, xs) -Just an alias for `logabsdetjac(inv(b), y)`. +Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs. + +See also: [`logabsdetjac`](@ref) """ -logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) +logabsdetjac_batch -############################## -# Example bijector: Identity # -############################## +""" + forward_batch(b, xs) -struct Identity{N} <: Bijector{N} end -(::Identity)(x) = copy(x) -inv(b::Identity) = b -up1(::Identity{N}) where {N} = Identity{N + 1}() - -logabsdetjac(::Identity{0}, x::Real) = zero(eltype(x)) -@generated function logabsdetjac( - b::Identity{N1}, - x::AbstractArray{T2, N2} -) where {N1, T2, N2} - if N1 == N2 - return :(zero(eltype(x))) - elseif N1 + 1 == N2 - return :(zeros(eltype(x), size(x, $N2))) - else - return :(throw(MethodError(logabsdetjac, (b, x)))) - end -end -logabsdetjac(::Identity{2}, x::AbstractArray{<:AbstractMatrix}) = zeros(eltype(x[1]), size(x)) +Computes `forward(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs. -######################## -# Convenient constants # -######################## -const ZeroOrOneDimBijector = Union{Bijector{0}, Bijector{1}} +See also: [`transform`](@ref) +""" +forward_batch(b, xs) = (result = transform_batch(b, xs), logabsdetjac = logabsdetjac_batch(b, xs)) ###################### # Bijectors includes # From 81a2ed600078d2ab20e22ec3dcdae20ee79eb2f1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:48:51 +0100 Subject: [PATCH 03/98] updated Composed to new interface --- src/bijectors/composed.jl | 155 +++++++++++++++++++++++++++++--------- 1 file changed, 119 insertions(+), 36 deletions(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index f05b819c..33f9551e 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -85,57 +85,50 @@ true ``` """ -struct Composed{A, N} <: Bijector{N} +struct Composed{A} <: Transform ts::A end -Composed(bs::Tuple{Vararg{<:Bijector{N}}}) where N = Composed{typeof(bs),N}(bs) -Composed(bs::AbstractArray{<:Bijector{N}}) where N = Composed{typeof(bs),N}(bs) - # field contains nested numerical parameters Functors.@functor Composed +invertible(cb::Composed) = sum(map(invertible, cb.ts)) + isclosedform(b::Composed) = all(isclosedform, b.ts) -up1(b::Composed) = Composed(up1.(b.ts)) -function Base.:(==)(b1::Composed{<:Any, N}, b2::Composed{<:Any, N}) where {N} + +function Base.:(==)(b1::Composed, b2::Composed) ts1, ts2 = b1.ts, b2.ts return length(ts1) == length(ts2) && all(x == y for (x, y) in zip(ts1, ts2)) end """ - composel(ts::Bijector...)::Composed{<:Tuple} + composel(ts::Transform...)::Composed{<:Tuple} Constructs `Composed` such that `ts` are applied left-to-right. """ -composel(ts::Bijector{N}...) where {N} = Composed(ts) +composel(ts::Transform...) = Composed(ts) """ - composer(ts::Bijector...)::Composed{<:Tuple} + composer(ts::Transform...)::Composed{<:Tuple} Constructs `Composed` such that `ts` are applied right-to-left. """ -composer(ts::Bijector{N}...) where {N} = Composed(reverse(ts)) +composer(ts::Transform...) = Composed(reverse(ts)) # The transformation of `Composed` applies functions left-to-right # but in mathematics we usually go from right-to-left; this reversal ensures that # when we use the mathematical composition ∘ we get the expected behavior. # TODO: change behavior of `transform` of `Composed`? -@generated function ∘(b1::Bijector{N1}, b2::Bijector{N2}) where {N1, N2} - if N1 == N2 - return :(composel(b2, b1)) - else - return :(throw(DimensionMismatch("$(typeof(b1)) expects $(N1)-dim but $(typeof(b2)) expects $(N2)-dim"))) - end -end +∘(b1::Transform, b2::Transform) = composel(b2, b1) # type-stable composition rules -∘(b1::Composed{<:Tuple}, b2::Bijector) = composel(b2, b1.ts...) -∘(b1::Bijector, b2::Composed{<:Tuple}) = composel(b2.ts..., b1) +∘(b1::Composed{<:Tuple}, b2::Transform) = composel(b2, b1.ts...) +∘(b1::Transform, b2::Composed{<:Tuple}) = composel(b2.ts..., b1) ∘(b1::Composed{<:Tuple}, b2::Composed{<:Tuple}) = composel(b2.ts..., b1.ts...) # type-unstable composition rules -∘(b1::Composed{<:AbstractArray}, b2::Bijector) = Composed(pushfirst!(copy(b1.ts), b2)) -∘(b1::Bijector, b2::Composed{<:AbstractArray}) = Composed(push!(copy(b2.ts), b1)) +∘(b1::Composed{<:AbstractArray}, b2::Transform) = Composed(pushfirst!(copy(b1.ts), b2)) +∘(b1::Transform, b2::Composed{<:AbstractArray}) = Composed(push!(copy(b2.ts), b1)) function ∘(b1::Composed{<:AbstractArray}, b2::Composed{<:AbstractArray}) return Composed(append!(copy(b2.ts), copy(b1.ts))) end @@ -149,14 +142,14 @@ function ∘(b1::T1, b2::T2) where {T1<:Composed{<:AbstractArray}, T2<:Composed{ end -∘(::Identity{N}, ::Identity{N}) where {N} = Identity{N}() -∘(::Identity{N}, b::Bijector{N}) where {N} = b -∘(b::Bijector{N}, ::Identity{N}) where {N} = b +∘(::Identity, ::Identity) = Identity() +∘(::Identity, b::Transform) = b +∘(b::Transform, ::Identity) = b -inv(ct::Composed) = Composed(reverse(map(inv, ct.ts))) +inv(ct::Composed, ::Invertible) = Composed(reverse(map(inv, ct.ts))) -# # TODO: should arrays also be using recursive implementation instead? -function (cb::Composed{<:AbstractArray{<:Bijector}})(x) +# TODO: should arrays also be using recursive implementation instead? +function transform(cb::Composed, x) @assert length(cb.ts) > 0 res = cb.ts[1](x) for b ∈ Base.Iterators.drop(cb.ts, 1) @@ -166,7 +159,17 @@ function (cb::Composed{<:AbstractArray{<:Bijector}})(x) return res end -@generated function (cb::Composed{T})(x) where {T<:Tuple} +function transform_batch(cb::Composed, x) + @assert length(cb.ts) > 0 + res = cb.ts[1].(x) + for b ∈ Base.Iterators.drop(cb.ts, 1) + res = b.(res) + end + + return res +end + +@generated function transform(cb::Composed{T}, x) where {T<:Tuple} @assert length(T.parameters) > 0 expr = :(x) for i in 1:length(T.parameters) @@ -175,17 +178,36 @@ end return expr end +@generated function transform_batch(cb::Composed{T}, x) where {T<:Tuple} + @assert length(T.parameters) > 0 + expr = :(x) + for i in 1:length(T.parameters) + expr = :(transform_batch(cb.ts[$i], $expr)) + end + return expr +end + function logabsdetjac(cb::Composed, x) y, logjac = forward(cb.ts[1], x) for i = 2:length(cb.ts) res = forward(cb.ts[i], y) - y = res.rv + y = res.result logjac += res.logabsdetjac end return logjac end +function logabsdetjac_batch(cb::Composed, x) + init = forward(cb.ts[1], x) + result = reduce(cb.ts[2:end]; init = init) do (y, logjac), b + return forward(b, y) + end + + return result.logabsdetjac +end + + @generated function logabsdetjac(cb::Composed{T}, x) where {T<:Tuple} N = length(T.parameters) @@ -195,7 +217,7 @@ end for i = 2:N - 1 temp = gensym(:res) push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.rv)) + push!(expr.args, :(y = $temp.result)) push!(expr.args, :(logjac += $temp.logabsdetjac)) end # don't need to evaluate the last bijector, only it's `logabsdetjac` @@ -206,16 +228,60 @@ end return expr end +""" + logabsdetjac_batch(cb::Composed{<:Tuple}, x) + +Generates something of the form +```julia +quote + (y, logjac_1) = forward_batch(cb.ts[1], x) + logjac_2 = logabsdetjac_batch(cb.ts[2], y) + return logjac_1 + logjac_2 +end +``` +""" +@generated function logabsdetjac_batch(cb::Composed{T}, x) where {T<:Tuple} + N = length(T.parameters) + + expr = Expr(:block) + push!(expr.args, :((y, logjac_1) = forward_batch(cb.ts[1], x))) + + for i = 2:N - 1 + temp = gensym(:res) + push!(expr.args, :($temp = forward_batch(cb.ts[$i], y))) + push!(expr.args, :(y = $temp.result)) + push!(expr.args, :($(Symbol("logjac_$i")) = $temp.logabsdetjac)) + end + # don't need to evaluate the last bijector, only it's `logabsdetjac` + push!(expr.args, :($(Symbol("logjac_$N")) = logabsdetjac_batch(cb.ts[$N], y))) + + sum_expr = Expr(:call, :+, [Symbol("logjac_$i") for i = 1:N]...) + push!(expr.args, :(return $(sum_expr))) + + return expr +end + function forward(cb::Composed, x) - rv, logjac = forward(cb.ts[1], x) + result, logjac = forward(cb.ts[1], x) for t in cb.ts[2:end] - res = forward(t, rv) - rv = res.rv + res = forward(t, result) + result = res.result logjac = res.logabsdetjac + logjac end - return (rv=rv, logabsdetjac=logjac) + return (result=result, logabsdetjac=logjac) +end + +function forward_batch(cb::Composed, x) + result, logjac = forward_batch(cb.ts[1], x) + + for t in cb.ts[2:end] + res = forward_batch(t, result) + result = res.result + logjac = res.logabsdetjac + logjac + end + return (result=result, logabsdetjac=logjac) end @@ -225,10 +291,27 @@ end for i = 2:length(T.parameters) temp = gensym(:temp) push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.rv)) + push!(expr.args, :(y = $temp.result)) push!(expr.args, :(logjac += $temp.logabsdetjac)) end - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (result = y, logabsdetjac = logjac))) + + return expr +end + +@generated function forward_batch(cb::Composed{T}, x) where {T<:Tuple} + N = length(T.parameters) + expr = Expr(:block) + push!(expr.args, :((y, logjac_1) = forward_batch(cb.ts[1], x))) + for i = 2:N + temp = gensym(:temp) + push!(expr.args, :($temp = forward_batch(cb.ts[$i], y))) + push!(expr.args, :(y = $temp.result)) + push!(expr.args, :($(Symbol("logjac_$i")) = $temp.logabsdetjac)) + end + + sum_expr = Expr(:call, :+, [Symbol("logjac_$i") for i = 1:N]...) + push!(expr.args, :(return (result = y, logabsdetjac = $(sum_expr)))) return expr end From 251ab9c47bceafd9453b6c0c378a60f70301e23b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:51:11 +0100 Subject: [PATCH 04/98] updated Exp and Log to new interface --- src/bijectors/exp_log.jl | 51 +++++++++------------------------------- 1 file changed, 11 insertions(+), 40 deletions(-) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 0e2da142..d3794a59 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -2,48 +2,19 @@ # Exp & Log # ############# -struct Exp{N} <: Bijector{N} end -struct Log{N} <: Bijector{N} end -up1(::Exp{N}) where {N} = Exp{N + 1}() -up1(::Log{N}) where {N} = Log{N + 1}() +struct Exp <: Bijector end +struct Log <: Bijector end -Exp() = Exp{0}() -Log() = Log{0}() +inv(b::Exp) = Log() +inv(b::Log) = Exp() -(b::Exp{0})(y::Real) = exp(y) -(b::Log{0})(x::Real) = log(x) +transform(b::Exp, y) = exp.(y) +transform(b::Log, x) = log.(x) -(b::Exp{0})(y::AbstractArray{<:Real}) = exp.(y) -(b::Log{0})(x::AbstractArray{<:Real}) = log.(x) +logabsdetjac(b::Exp, x) = sum(x) +logabsdetjac(b::Log, x) = -sum(log, x) -(b::Exp{1})(y::AbstractVector{<:Real}) = exp.(y) -(b::Exp{1})(y::AbstractMatrix{<:Real}) = exp.(y) -(b::Log{1})(x::AbstractVector{<:Real}) = log.(x) -(b::Log{1})(x::AbstractMatrix{<:Real}) = log.(x) - -(b::Exp{2})(y::AbstractMatrix{<:Real}) = exp.(y) -(b::Log{2})(x::AbstractMatrix{<:Real}) = log.(x) - -(b::Exp{2})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, y) -(b::Log{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) - -inv(b::Exp{N}) where {N} = Log{N}() -inv(b::Log{N}) where {N} = Exp{N}() - -logabsdetjac(b::Exp{0}, x::Real) = x -logabsdetjac(b::Exp{0}, x::AbstractVector) = x -logabsdetjac(b::Exp{1}, x::AbstractVector) = sum(x) -logabsdetjac(b::Exp{1}, x::AbstractMatrix) = vec(sum(x; dims = 1)) -logabsdetjac(b::Exp{2}, x::AbstractMatrix) = sum(x) -logabsdetjac(b::Exp{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) -end - -logabsdetjac(b::Log{0}, x::Real) = -log(x) -logabsdetjac(b::Log{0}, x::AbstractVector) = .-log.(x) -logabsdetjac(b::Log{1}, x::AbstractVector) = - sum(log, x) -logabsdetjac(b::Log{1}, x::AbstractMatrix) = - vec(sum(log, x; dims = 1)) -logabsdetjac(b::Log{2}, x::AbstractMatrix) = - sum(log, x) -logabsdetjac(b::Log{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) +function forward(b::Log, x) + y = transform(b, x) + return (result = y, logabsdetjac = -sum(y)) end From f1ef968d809a5a029bd8cacf9ed080d4711ddf72 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:51:18 +0100 Subject: [PATCH 05/98] updated Logit to new interface --- src/bijectors/logit.jl | 32 +++++--------------------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index ff1cf554..9d26de5d 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -3,43 +3,21 @@ ###################### using StatsFuns: logit, logistic -struct Logit{N, T<:Real} <: Bijector{N} +struct Logit{T} <: Bijector a::T b::T end -Logit(a::Real, b::Real) = Logit{0}(a, b) -Logit(a::AbstractArray{<:Real, N}, b::AbstractArray{<:Real, N}) where {N} = Logit{N}(a, b) -function Logit{N}(a, b) where {N} - T = promote_type(typeof(a), typeof(b)) - Logit{N, T}(a, b) -end -# fields are numerical parameters -function Functors.functor(::Type{<:Logit{N}}, x) where N - function reconstruct_logit(xs) - T = promote_type(typeof(xs.a), typeof(xs.b)) - return Logit{N,T}(xs.a, xs.b) - end - return (a = x.a, b = x.b,), reconstruct_logit -end +Functors.@functor Logit -up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b) # For equality of Logit with Float64 fields to one with Duals Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b -(b::Logit)(x) = _logit.(x, b.a, b.b) -(b::Logit)(x::AbstractArray{<:AbstractArray}) = map(b, x) +transform(b::Logit, x) = _logit.(x, b.a, b.b) _logit(x, a, b) = logit((x - a) / (b - a)) -(ib::Inverse{<:Logit})(y) = _ilogit.(y, ib.orig.a, ib.orig.b) -(ib::Inverse{<:Logit})(x::AbstractArray{<:AbstractArray}) = map(ib, x) +transform(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) _ilogit(y, a, b) = (b - a) * logistic(y) + a -logabsdetjac(b::Logit{0}, x) = logit_logabsdetjac.(x, b.a, b.b) -logabsdetjac(b::Logit{1}, x::AbstractVector) = sum(logit_logabsdetjac.(x, b.a, b.b)) -logabsdetjac(b::Logit{1}, x::AbstractMatrix) = vec(sum(logit_logabsdetjac.(x, b.a, b.b), dims = 1)) -logabsdetjac(b::Logit{2}, x::AbstractMatrix) = sum(logit_logabsdetjac.(x, b.a, b.b)) -logabsdetjac(b::Logit{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x - logabsdetjac(b, x) -end +logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) From 45ff3646d0aec0d0dc49a2fc9ba5c437bd0cb1d2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 06:53:29 +0100 Subject: [PATCH 06/98] removed something that shouldnt be there --- src/interface.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 9ace3b38..bd37555a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -200,14 +200,6 @@ logabsdetjac!(::Identity, x, logjac) = logjac #################### # Batched versions # #################### -# NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. -# so we can actually reference them. Since we just did this for `Identity`, we're good. -Broadcast.broadcasted(b::Transform, xs::Batch) = transform_batch(b, xs) -Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_batch(b, xs) -Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_batch(b, xs) -Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_batch(b, xs) - - """ transform_batch(b, xs) From eb94e001517cd15d6fed35bb4c7a19765249e3f4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 07:04:54 +0100 Subject: [PATCH 07/98] removed false statement in docstring of Transform --- src/interface.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index bd37555a..8c9b2b56 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -61,8 +61,6 @@ Finally, there are _batched_ versions of the above methods which can _optionally - [`forward_batch`](@ref) and similarly for the mutating versions. Default implementations depends on the type of `xs`. -Note that these methods are usually used through broadcasting, i.e. `b.(x)` with `x` a `AbstractBatch` -falls back to `transform_batch(b, x)`. """ abstract type Transform end From 0d8783f37108a9ad56a81cb243319ac53f2a13db Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 08:22:58 +0100 Subject: [PATCH 08/98] fixed a typo in implementation of logabsdetjac_batch --- src/bijectors/composed.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 33f9551e..5cce6d20 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -199,9 +199,9 @@ function logabsdetjac(cb::Composed, x) end function logabsdetjac_batch(cb::Composed, x) - init = forward(cb.ts[1], x) + init = forward_batch(cb.ts[1], x) result = reduce(cb.ts[2:end]; init = init) do (y, logjac), b - return forward(b, y) + return forward_batch(b, y) end return result.logabsdetjac From 8f9988e1de65649c1be8e03cb6caab6d16e58739 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 08:31:48 +0100 Subject: [PATCH 09/98] added types for representing batches --- src/Bijectors.jl | 2 ++ src/batch.jl | 88 ++++++++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 3 ++ 3 files changed, 93 insertions(+) create mode 100644 src/batch.jl diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 955d0be7..94b10e43 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -33,6 +33,7 @@ using Reexport, Requires using StatsFuns using LinearAlgebra using MappedArrays +using ConstructionBase using Base.Iterators: drop using LinearAlgebra: AbstractTriangular import Functors @@ -244,6 +245,7 @@ function getlogp(d::InverseWishart, Xcf, X) end include("utils.jl") +include("batch.jl") include("interface.jl") include("chainrules.jl") diff --git a/src/batch.jl b/src/batch.jl new file mode 100644 index 00000000..4c5a512b --- /dev/null +++ b/src/batch.jl @@ -0,0 +1,88 @@ +abstract type AbstractBatch{T} <: AbstractVector{T} end + +""" + value(x) + +Returns the underlying storage used for the entire batch. + +If `x` is not `AbstractBatch`, then this is the identity function. +""" +value(x) = x + +struct Batch{V, T} <: AbstractBatch{T} + value::V +end + +value(x::Batch) = x.value + +# Convenient aliases +const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}} +const VectorBatch = Batch{<:AbstractVector{<:AbstractArray{<:Real}}} + +# Constructor for `ArrayBatch`. +Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x) +function Batch(x::AbstractArray{<:Real}) + V = typeof(x) + # HACK: This assumes the batch is non-empty. + T = typeof(getindex_for_last(x, 1)) + return Batch{V, T}(x) +end + +# Constructor for `VectorBatch`. +Batch(x::AbstractVector{<:AbstractArray}) = Batch{typeof(x), eltype(x)}(x) + +# `AbstractVector` interface. +Base.size(batch::Batch{<:AbstractArray{<:Any, N}}) where {N} = (size(value(batch), N), ) + +# Since impl inherited from `AbstractVector` doesn't do exactly what we want. +Base.similar(b::AbstractBatch) = reconstruct(b, similar(value(b))) + +# For `VectorBatch` +Base.getindex(batch::VectorBatch, i::Int) = value(batch)[i] +Base.getindex(batch::VectorBatch, i::CartesianIndex{1}) = value(batch)[i] +Base.getindex(batch::VectorBatch, i) = Batch(value(batch)[i]) +function Base.setindex!(batch::VectorBatch, v, i) + # `v` can also be a `Batch`. + Base.setindex!(value(batch), value(v), i) + return batch +end + +# For `ArrayBatch` +@generated function getindex_for_last(x::AbstractArray{<:Any, N}, inds) where {N} + e = Expr(:call) + push!(e.args, :(Base.view)) + push!(e.args, :x) + + for i = 1:N - 1 + push!(e.args, :(:)) + end + + push!(e.args, :(inds)) + + return e +end + +@generated function setindex_for_last!(out::AbstractArray{<:Any, N}, x, inds) where {N} + e = Expr(:call) + push!(e.args, :(Base.setindex!)) + push!(e.args, :out) + push!(e.args, :x) + + for i = 1:N - 1 + push!(e.args, :(:)) + end + + push!(e.args, :(inds)) + + return e +end + +# General arrays. +Base.getindex(batch::ArrayBatch, i::Int) = getindex_for_last(value(batch), i) +Base.getindex(batch::ArrayBatch, i::CartesianIndex{1}) = getindex_for_last(value(batch), i) +Base.getindex(batch::ArrayBatch, i) = Batch(getindex_for_last(value(batch), i)) +function Base.setindex!(batch::ArrayBatch, v, i) + # `v` can also be a `Batch`. + setindex_for_last!(value(batch), value(v), i) + return batch +end diff --git a/src/utils.jl b/src/utils.jl index 8203e1b4..d1034dbc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,3 +6,6 @@ aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b) # flatten arrays with fallback for scalars _vec(x::AbstractArray{<:Real}) = vec(x) _vec(x::Real) = x + +# Useful for reconstructing objects. +reconstruct(b, args...) = constructorof(typeof(b))(args...) From 9fa37d1f66858460e85a12bf06aa21a07475372a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 08:33:05 +0100 Subject: [PATCH 10/98] make it possible to use broadcasting for working with batches --- src/interface.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 8c9b2b56..9ace3b38 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -61,6 +61,8 @@ Finally, there are _batched_ versions of the above methods which can _optionally - [`forward_batch`](@ref) and similarly for the mutating versions. Default implementations depends on the type of `xs`. +Note that these methods are usually used through broadcasting, i.e. `b.(x)` with `x` a `AbstractBatch` +falls back to `transform_batch(b, x)`. """ abstract type Transform end @@ -198,6 +200,14 @@ logabsdetjac!(::Identity, x, logjac) = logjac #################### # Batched versions # #################### +# NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. +# so we can actually reference them. Since we just did this for `Identity`, we're good. +Broadcast.broadcasted(b::Transform, xs::Batch) = transform_batch(b, xs) +Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_batch(b, xs) +Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_batch(b, xs) +Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_batch(b, xs) + + """ transform_batch(b, xs) From 168dd43ed5964a2fb0af27d276ad4c2b2a4d04ad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 09:22:59 +0100 Subject: [PATCH 11/98] updated SimplexBijector to new interface, I think --- src/bijectors/simplex.jl | 119 ++++++++++++++++----------------------- 1 file changed, 50 insertions(+), 69 deletions(-) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index 0cf0ff38..f9f83d85 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -1,20 +1,16 @@ #################### # Simplex bijector # #################### -struct SimplexBijector{N, T} <: Bijector{N} end -SimplexBijector() = SimplexBijector{1}() -SimplexBijector{N}() where {N} = SimplexBijector{N,true}() +struct SimplexBijector{T} <: Bijector end +SimplexBijector() = SimplexBijector{true}() -# Special case `N = 1` -SimplexBijector{true}() = SimplexBijector{1,true}() -SimplexBijector{false}() = SimplexBijector{1,false}() +transform(b::SimplexBijector, x) = _simplex_bijector(x, b) +transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b) -(b::SimplexBijector{1})(x::AbstractVector) = _simplex_bijector(x, b) -(b::SimplexBijector{1})(y::AbstractVector, x::AbstractVector) = _simplex_bijector!(y, x, b) -function _simplex_bijector(x::AbstractVector, b::SimplexBijector{1}) - return _simplex_bijector!(similar(x), x, b) -end -function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{1, proj}) where {proj} +_simplex_bijector(x::AbstractArray, b::SimplexBijector) = _simplex_bijector!(similar(x), x, b) + +# Vector implementation. +function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj} K = length(x) @assert K > 1 "x needs to be of length greater than 1" T = eltype(x) @@ -39,24 +35,19 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{1, proj}) wh return y end -# Vectorised implementation of the above. -function (b::SimplexBijector{1})(X::AbstractMatrix) - _simplex_bijector(X, b) -end -function (b::SimplexBijector{1})( - Y::AbstractMatrix, - X::AbstractMatrix, -) - _simplex_bijector!(Y, X, b) +function transform_batch(b::SimplexBijector, X::ArrayBatch{2}) + Batch(_simplex_bijector(value(X), b)) end -function (b::SimplexBijector{2, proj})(X::AbstractMatrix) where {proj} - SimplexBijector{1, proj}()(X) -end -(b::SimplexBijector{2})(X::AbstractArray{<:AbstractMatrix}) = map(b, X) -function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) - _simplex_bijector!(similar(X), X, b) +function transform_batch!( + b::SimplexBijector, + Y::Batch{<:AbstractMatrix{T}}, + X::Batch{<:AbstractMatrix{T}}, +) where {T} + Batch(_simplex_bijector!(value(Y), value(X), b)) end -function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{1, proj}) where {proj} + +# Matrix implementation. +function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj} K, N = size(X, 1), size(X, 2) @assert K > 1 "x needs to be of length greater than 1" T = eltype(X) @@ -81,19 +72,19 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{1, proj}) wh return Y end -function (ib::Inverse{<:SimplexBijector{1, proj}})(y::AbstractVector{T}) where {T, proj} - _simplex_inv_bijector(y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{1}})( - x::AbstractVector{T}, - y::AbstractVector{T}, +# Inverse. +transform(ib::Inverse{<:SimplexBijector}, y::AbstractArray) = _simplex_inv_bijector(y, ib.orig) +function transform!( + ib::Inverse{<:SimplexBijector}, + x::AbstractArray{T}, + y::AbstractArray{T}, ) where {T} - _simplex_inv_bijector!(x, y, ib.orig) + return _simplex_inv_bijector!(x, y, ib.orig) end -function _simplex_inv_bijector(y::AbstractVector, b::SimplexBijector{1}) - return _simplex_inv_bijector!(similar(y), y, b) -end -function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{1, proj}) where {proj} + +_simplex_inv_bijector(y, b) = _simplex_inv_bijector!(similar(y), y, b) + +function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) where {proj} K = length(y) @assert K > 1 "x needs to be of length greater than 1" T = eltype(y) @@ -116,27 +107,17 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{1, proj return x end -# Vectorised implementation of the above. -function (ib::Inverse{<:SimplexBijector{1}})(Y::AbstractMatrix) - _simplex_inv_bijector(Y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{1}})( - X::AbstractMatrix{T}, - Y::AbstractMatrix{T}, -) where {T <: Real} - _simplex_inv_bijector!(X, Y, ib.orig) -end -function (ib::Inverse{<:SimplexBijector{2, proj}})(Y::AbstractMatrix) where {proj} - inv(SimplexBijector{1, proj}())(Y) -end -function (ib::Inverse{<:SimplexBijector{2, proj}})(X::AbstractMatrix, Y::AbstractMatrix) where {proj} - inv(SimplexBijector{1, proj}())(X, Y) +# Batched versions. +transform_batch(ib::Inverse{<:SimplexBijector}, Y::ArrayBatch{2}) = _simplex_inv_bijector(Y, ib.orig) +function transform_batch!( + ib::Inverse{<:SimplexBijector} + X::Batch{<:AbstractMatrix{T}}, + Y::Batch{<:AbstractMatrix{T}}, +) where {T<:Real} + return Batch(_simplex_inv_bijector!(value(X), value(Y), ib.orig)) end -(ib::Inverse{<:SimplexBijector{2}})(Y::AbstractArray{<:AbstractMatrix}) = map(ib, Y) -function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) - _simplex_inv_bijector!(similar(Y), Y, b) -end -function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{1, proj}) where {proj} + +function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj} K, N = size(Y, 1), size(Y, 2) @assert K > 1 "x needs to be of length greater than 1" T = eltype(Y) @@ -160,7 +141,7 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{1, proj return X end -function logabsdetjac(b::SimplexBijector{1}, x::AbstractVector{T}) where {T} +function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where {T} ϵ = _eps(T) lp = zero(T) @@ -211,7 +192,9 @@ function simplex_logabsdetjac_gradient(x::AbstractVector) end return g end -function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix{T}) where {T} +function logabsdetjac_batch(b::SimplexBijector, x::Batch{<:AbstractMatrix{T}}) where {T} + x = value(x) + ϵ = _eps(T) nlp = similar(x, T, size(x, 2)) nlp .= zero(T) @@ -227,14 +210,12 @@ function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix{T}) where {T} nlp[col] -= log(max(z, ϵ)) + log(max(one(T) - z, ϵ)) + log(max(one(T) - sum_tmp, ϵ)) end end - return nlp + return Batch(nlp) end -function logabsdetjac(b::SimplexBijector{2, proj}, x::AbstractMatrix) where {proj} - return sum(logabsdetjac(SimplexBijector{1, proj}(), x)) -end -function logabsdetjac(b::SimplexBijector{2}, x::AbstractArray{<:AbstractMatrix}) - return map(x -> logabsdetjac(b, x), x) +function logabsdetjac(b::SimplexBijector, x::AbstractMatrix) + return sum(value(logabsdetjac(b, Batch(x)))) end + function simplex_logabsdetjac_gradient(x::AbstractMatrix) T = eltype(x) ϵ = _eps(T) @@ -303,7 +284,7 @@ function simplex_link_jacobian( end return UpperTriangular(dydxt)' end -function jacobian(b::SimplexBijector{1, proj}, x::AbstractVector{T}) where {proj, T} +function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj, T} return simplex_link_jacobian(x, Val(proj)) end @@ -425,7 +406,7 @@ function simplex_invlink_jacobian( return LowerTriangular(dxdy) end # jacobian -function jacobian(ib::Inverse{<:SimplexBijector{1, proj}}, y::AbstractVector{T}) where {proj, T} +function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj, T} return simplex_invlink_jacobian(y, Val(proj)) end From d44cf420180d1387e48f498571a8e0d56ec5bb3c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 09:23:22 +0100 Subject: [PATCH 12/98] updated PDBijector to new interface --- src/bijectors/pd.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index 332a8b36..deac1bd6 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -1,4 +1,4 @@ -struct PDBijector <: Bijector{2} end +struct PDBijector <: Bijector end # This function has custom adjoints defined for Tracker, Zygote and ReverseDiff. # I couldn't find a mutation-free implementation that maintains TrackedArrays in Tracker @@ -7,19 +7,17 @@ function replace_diag(f, X) g(i, j) = ifelse(i == j, f(X[i, i]), X[i, j]) return g.(1:size(X, 1), (1:size(X, 2))') end -(b::PDBijector)(X::AbstractMatrix{<:Real}) = pd_link(X) +transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X) function pd_link(X) Y = lower(parent(cholesky(X; check = true).L)) return replace_diag(log, Y) end -(b::PDBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X) lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) -function (ib::Inverse{<:PDBijector})(Y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{<:PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) return getpd(X) end -(ib::Inverse{<:PDBijector})(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, X) getpd(X) = LowerTriangular(X) * LowerTriangular(X)' function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) @@ -37,7 +35,3 @@ function logabsdetjac(b::PDBijector, Xcf::Cholesky) d = size(U, 1) return - sum((d .- (1:d) .+ 2) .* log.(diag(U))) - d * log(T(2)) end - -logabsdetjac(b::PDBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) = mapvcat(X) do x - logabsdetjac(b, x) -end From c719b076dcb33c4aa1a61d07b32e8d5ee6342f40 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 09:24:42 +0100 Subject: [PATCH 13/98] use transform_batch rather than broadcasting --- src/bijectors/composed.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 5cce6d20..3a0e5a9e 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -163,7 +163,7 @@ function transform_batch(cb::Composed, x) @assert length(cb.ts) > 0 res = cb.ts[1].(x) for b ∈ Base.Iterators.drop(cb.ts, 1) - res = b.(res) + res = transform_batch(b, res) end return res From 0a62e96f2dff33d796436b40849f7078f635a4c9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:19 +0100 Subject: [PATCH 14/98] added default implementations for batches --- src/interface.jl | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 9ace3b38..cfcaf7c8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -215,7 +215,19 @@ Transform `xs` by `b`, treating `xs` as a "batch", i.e. a collection of independ See also: [`transform`](@ref) """ -transform_batch +transform_batch(b, xs) = _transform_batch(b, xs) +# Default implementations uses private methods to avoid method ambiguity. +_transform_batch(b, xs::VectorBatch) = reconstruct(xs, map(b, value(xs))) +function _transform_batch(b, xs::ArrayBatch{2}) + # TODO: Check if we can avoid using these custom methods. + return eachcolmaphcat(b, x) +end +function _transform_batch(b, xs::ArrayBatch{N}) where {N} + res = reduce(map(b, eachslice(value(xs), dims=N))) do acc, x + cat(acc, x; dims = N) + end + return reconstruct(xs, res) +end """ logabsdetjac_batch(b, xs) @@ -224,7 +236,12 @@ Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of See also: [`logabsdetjac`](@ref) """ -logabsdetjac_batch +logabsdetjac_batch(b, xs) = _logabsdetjac_batch(b, xs) +# Default implementations uses private methods to avoid method ambiguity. +_logabsdetjac_batch(b, xs::VectorBatch) = reconstruct(xs, map(x -> logabsdetjac(b, x), value(xs))) +function _logabsdetjac_batch(b, xs::ArrayBatch{N}) where {N} + return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), dims=N))) +end """ forward_batch(b, xs) From 21a66ab419bdea329a7a784169f77e44d7313260 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:35 +0100 Subject: [PATCH 15/98] updated ADBijector to new interface --- src/bijectors/adbijector.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/adbijector.jl b/src/bijectors/adbijector.jl index b7596b1d..135f2bba 100644 --- a/src/bijectors/adbijector.jl +++ b/src/bijectors/adbijector.jl @@ -2,7 +2,7 @@ Abstract type for a `Bijector{N}` making use of auto-differentation (AD) to implement `jacobian` and, by impliciation, `logabsdetjac`. """ -abstract type ADBijector{AD, N} <: Bijector{N} end +abstract type ADBijector{AD} <: Bijector end struct SingularJacobianException{B<:Bijector} <: Exception b::B @@ -24,4 +24,4 @@ end function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) fact = lu(jacobian(b, x), check=false) return issuccess(fact) ? logabsdet(fact)[1] : throw(SingularJacobianException(b)) -end \ No newline at end of file +end From 0b04d1497f3dac7db102dcdeeb8c86d0f65f697a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:42 +0100 Subject: [PATCH 16/98] updated CorrBijector to new interface --- src/bijectors/corr.jl | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f7eeaee2..a051ef17 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -61,9 +61,9 @@ Note: The implementation doesn't follow their "manageable expression" directly, because their equation seems wrong (7/30/2020). Insteadly it follows definition above the "manageable expression" directly, which is also described in above doc. """ -struct CorrBijector <: Bijector{2} end +struct CorrBijector <: Bijector end -function (b::CorrBijector)(x::AbstractMatrix{<:Real}) +function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) w = cholesky(x).U # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) return r + zero(x) @@ -71,16 +71,12 @@ function (b::CorrBijector)(x::AbstractMatrix{<:Real}) # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 end -(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X) - -function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{<:CorrBijector}, y::AbstractMatrix{<:Real}) w = _inv_link_chol_lkj(y) return w' * w end -(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y) - -function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) +function logabsdetjac(::Inverse{<:CorrBijector}, y::AbstractMatrix{<:Real}) K = LinearAlgebra.checksquare(y) result = float(zero(eltype(y))) @@ -100,12 +96,6 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) =# return -logabsdetjac(inv(b), (b(X))) end -function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) - return mapvcat(X) do x - logabsdetjac(b, x) - end -end - function _inv_link_chol_lkj(y) K = LinearAlgebra.checksquare(y) From 2cfd24b27b4b858023342d0979a1b32085dee70c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:49 +0100 Subject: [PATCH 17/98] updated Coupling to new interface --- src/bijectors/coupling.jl | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 088caf2a..8cd3490a 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -167,7 +167,7 @@ Shift{Array{Float64,1},1}([2.0]) # References [1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). """ -struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask} +struct Coupling{F, M} <: Bijector where {F, M <: PartitionMask} θ::F mask::M end @@ -195,7 +195,7 @@ function couple(cl::Coupling, x::AbstractVector) return b end -function (cl::Coupling)(x::AbstractVector) +function transform(cl::Coupling, x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) @@ -205,10 +205,8 @@ function (cl::Coupling)(x::AbstractVector) # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end -(cl::Coupling)(x::AbstractMatrix) = eachcolmaphcat(cl, x) - -function (icl::Inverse{<:Coupling})(y::AbstractVector) +function transform(icl::Inverse{<:Coupling}, y::AbstractVector) cl = icl.orig y_1, y_2, y_3 = partition(cl.mask, y) @@ -218,7 +216,6 @@ function (icl::Inverse{<:Coupling})(y::AbstractVector) return combine(cl.mask, ib(y_1), y_2, y_3) end -(icl::Inverse{<:Coupling})(y::AbstractMatrix) = eachcolmaphcat(icl, y) function logabsdetjac(cl::Coupling, x::AbstractVector) x_1, x_2, x_3 = partition(cl.mask, x) @@ -228,7 +225,3 @@ function logabsdetjac(cl::Coupling, x::AbstractVector) # therefore we sum to ensure such a thing does not happen return sum(logabsdetjac(b, x_1)) end - -function logabsdetjac(cl::Coupling, x::AbstractMatrix) - return [logabsdetjac(cl, view(x, :, i)) for i in axes(x, 2)] -end From c272cd45d84044b2d026bcf0b5f9ecd35c3f584b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:46:59 +0100 Subject: [PATCH 18/98] updated LeakyReLU to new interface --- src/bijectors/leaky_relu.jl | 82 ++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index c9128f01..e7913038 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -1,5 +1,5 @@ """ - LeakyReLU{T, N}(α::T) <: Bijector{N} + LeakyReLU{T}(α::T) <: Bijector Defines the invertible mapping @@ -7,95 +7,95 @@ Defines the invertible mapping where α > 0. """ -struct LeakyReLU{T, N} <: Bijector{N} +struct LeakyReLU{T} <: Bijector α::T end -LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α) -LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α) +Functors.@functor LeakyReLU -# field is a numerical parameter -function Functors.functor(::Type{LeakyReLU{<:Any,N}}, x) where N - function reconstruct_leakyrelu(xs) - return LeakyReLU{typeof(xs.α),N}(xs.α) - end - return (α = x.α,), reconstruct_leakyrelu -end - -up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α) +Base.inv(b::LeakyReLU) = LeakyReLU(inv.(b.α)) # (N=0) Univariate case -function (b::LeakyReLU{<:Any, 0})(x::Real) +function transform(b::LeakyReLU, x::Real) mask = x < zero(x) return mask * b.α * x + !mask * x end -(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) -function Base.inv(b::LeakyReLU{<:Any,N}) where N - invα = inv.(b.α) - return LeakyReLU{typeof(invα),N}(invα) -end - -function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real) +function logabsdetjac(b::LeakyReLU, x::Real) mask = x < zero(x) J = mask * b.α + (1 - mask) * one(x) return log(abs(J)) end -logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x) - # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU{<:Any, 0}, x::Real) +function forward(b::LeakyReLU, x::Real) mask = x < zero(x) J = mask * b.α + !mask * one(x) return (result=J * x, logabsdetjac=log(abs(J))) end -# Batched version -function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) +function forward_batch(b::LeakyReLU, xs::Batch{<:AbstractVector}) + x = value(xs) + J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - return (result=J .* x, logabsdetjac=log.(abs.(J))) + return (result=Batch(J .* x), logabsdetjac=Batch(log.(abs.(J)))) end -# (N=1) Multivariate case -function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat) +# Array inputs. +function transform(b::LeakyReLU, x::AbstractArray) return let z = zero(eltype(x)) @. (x < z) * b.α * x + (x > z) * x end end -function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) +function logabsdetjac(b::LeakyReLU, x::AbstractArray) + return sum(value(logabsdetjac_batch(b, Batch(x)))) +end + +function logabsdetjac_batch(b::LeakyReLU, xs::ArrayBatch{N}) where {N} + x = value(xs) + # Is really diagonal of jacobian J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - - if x isa AbstractVector - return sum(log.(abs.(J))) - elseif x isa AbstractMatrix - return vec(sum(log.(abs.(J)); dims = 1)) # sum along column + + logjac = if N ≤ 1 + sum(log ∘ abs, J) + else + vec(sum(map(log ∘ abs, J); dims = 1:N - 1)) end + + return Batch(logjac) end # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) +function forward(b::LeakyReLU, x::AbstractArray) + y, logjac = forward_batch(b, Batch(x)) + + return (result = value(y), logabsdetjac = sum(value(logjac))) +end + +function forward_batch(b::LeakyReLU, xs::ArrayBatch{N}) where {N} + x = value(xs) + # Is really diagonal of jacobian J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - if x isa AbstractVector - logjac = sum(log.(abs.(J))) - elseif x isa AbstractMatrix - logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column + logjac = if N ≤ 1 + sum(log ∘ abs, J) + else + vec(sum(map(log ∘ abs, J); dims = 1:N - 1)) end y = J .* x - return (result=y, logabsdetjac=logjac) + return (result=Batch(y), logabsdetjac=Batch(logjac)) end From 5e2a585be79a3b28033b889975c2c254dcaa70ae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:05 +0100 Subject: [PATCH 19/98] updated NamedBijector to new interface --- src/bijectors/named_bijector.jl | 120 +------------------------------- 1 file changed, 2 insertions(+), 118 deletions(-) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 500044b9..776e6e96 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,4 +1,4 @@ -abstract type AbstractNamedBijector <: AbstractBijector end +abstract type AbstractNamedBijector <: Bijector end forward(b::AbstractNamedBijector, x) = (result = b(x), logabsdetjac = logabsdetjac(b, x)) @@ -64,122 +64,6 @@ end return :(+($(exprs...))) end - -###################### -### `NamedInverse` ### -###################### -""" - NamedInverse <: AbstractNamedBijector - -Represents the inverse of a `AbstractNamedBijector`, similarily to `Inverse` for `Bijector`. - -See also: [`Inverse`](@ref) -""" -struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector - orig::B -end -Base.inv(nb::AbstractNamedBijector) = NamedInverse(nb) -Base.inv(ni::NamedInverse) = ni.orig - -logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inv(ni), ni(y)) - -########################## -### `NamedComposition` ### -########################## -""" - NamedComposition <: AbstractNamedBijector - -Wraps a tuple of array of `AbstractNamedBijector` and implements their composition. - -This is very similar to `Composed` for `Bijector`, with the exception that we do not require -the inputs to have the same "dimension", which in this case refers to the *symbols* for the -`NamedTuple` that this takes as input. - -See also: [`Composed`](@ref) -""" -struct NamedComposition{Bs} <: AbstractNamedBijector - bs::Bs -end - -# Essentially just copy-paste from impl of composition for 'standard' bijectors, -# with minor changes here and there. -composel(bs::AbstractNamedBijector...) = NamedComposition(bs) -composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs)) -∘(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1) - -inv(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs))) - -function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x) - @assert length(cb.bs) > 0 - res = cb.bs[1](x) - for b ∈ Base.Iterators.drop(cb.bs, 1) - res = b(res) - end - - return res -end - -(cb::NamedComposition{<:Tuple})(x) = foldl(|>, cb.bs; init=x) - -function logabsdetjac(cb::NamedComposition, x) - y, logjac = forward(cb.bs[1], x) - for i = 2:length(cb.bs) - res = forward(cb.bs[i], y) - y = res.result - logjac += res.logabsdetjac - end - - return logjac -end - -@generated function logabsdetjac(cb::NamedComposition{T}, x) where {T<:Tuple} - N = length(T.parameters) - - expr = Expr(:block) - push!(expr.args, :((y, logjac) = forward(cb.bs[1], x))) - - for i = 2:N - 1 - temp = gensym(:res) - push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.result)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) - end - # don't need to evaluate the last bijector, only it's `logabsdetjac` - push!(expr.args, :(logjac += logabsdetjac(cb.bs[$N], y))) - - push!(expr.args, :(return logjac)) - - return expr -end - - -function forward(cb::NamedComposition, x) - result, logjac = forward(cb.bs[1], x) - - for t in cb.bs[2:end] - res = forward(t, result) - result = res.result - logjac = res.logabsdetjac + logjac - end - return (result=result, logabsdetjac=logjac) -end - - -@generated function forward(cb::NamedComposition{T}, x) where {T<:Tuple} - expr = Expr(:block) - push!(expr.args, :((y, logjac) = forward(cb.bs[1], x))) - for i = 2:length(T.parameters) - temp = gensym(:temp) - push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.result)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) - end - push!(expr.args, :(return (result = y, logabsdetjac = logjac))) - - return expr -end - - ############################ ### `NamedCouplingLayer` ### ############################ @@ -227,7 +111,7 @@ deps(b::NamedCoupling{<:Any, Deps}) where {Deps} = Deps end end -@generated function (ni::NamedInverse{<:NamedCoupling{target, deps, F}})( +@generated function (ni::Inverse{<:NamedCoupling{target, deps, F}})( x::NamedTuple ) where {target, deps, F} return quote From 8e41a50303730a0fa97bd73c850e38448cd1b0fe Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:12 +0100 Subject: [PATCH 20/98] updated BatchNormalisation to new interface --- src/bijectors/normalise.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 16b48a29..7f9a0a67 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -6,7 +6,7 @@ using Statistics: mean istraining() = false -mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1} +mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector b :: T1 # bias logs :: T1 # log-scale m :: T2 # moving mean @@ -80,8 +80,7 @@ function forward(bn::InvertibleBatchNorm, x) end logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac - -(bn::InvertibleBatchNorm)(x) = forward(bn, x).result +transform(bn::InvertibleBatchNorm, x) = forward(bn, x).result function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -97,7 +96,7 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) return (result=x, logabsdetjac=-logabsdetjac(bn, x)) end -(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).result +transform(bn::Inverse{<:InvertibleBatchNorm}, y) = forward(bn, y).result function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") From 9f45b16ebcffeb0cfc154d780e9304961a3a994c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:29 +0100 Subject: [PATCH 21/98] updated Permute to new interface --- src/bijectors/permute.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index d4fdef7b..7ac1d454 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -81,7 +81,7 @@ julia> inv(b1)(b1([1., 2., 3.])) 3.0 ``` """ -struct Permute{A} <: Bijector{1} +struct Permute{A} <: Bijector A::A end @@ -150,8 +150,8 @@ function Permute(n::Int, indices::Pair{Vector{Int}, Vector{Int}}...) end -@inline (b::Permute)(x::AbstractVecOrMat) = b.A * x +@inline transform(b::Permute, x::AbstractVecOrMat) = b.A * x @inline inv(b::Permute) = Permute(transpose(b.A)) logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x)) -logabsdetjac(b::Permute, x::AbstractMatrix) = zero(eltype(x), size(x, 2)) +logabsdetjac_batch(b::Permute, x::Batch) = zero(eltype(x), length(x)) From f793dadbebef7928d8476ea4ad71f0adb1f35a55 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:35 +0100 Subject: [PATCH 22/98] updated PlanarLayer to new interface --- src/bijectors/planar_layer.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 9d1a73d0..38bce1a7 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -14,7 +14,7 @@ using NNlib: softplus # TODO: add docstring -struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector{1} +struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector w::T1 u::T1 b::T2 @@ -78,7 +78,7 @@ function _transform(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) return (transformed = transformed, wT_û = wT_û, wT_z = wT_z) end -(b::PlanarLayer)(z) = _transform(b, z).transformed +transform(b::PlanarLayer, z) = _transform(b, z).transformed #= Log-determinant of the Jacobian of the planar layer @@ -108,7 +108,7 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) return (result = transformed, logabsdetjac = log_det_jacobian) end -function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real}) +function transform(ib::Inverse{<:PlanarLayer}, y::AbstractVecOrMat{<:Real}) flow = ib.orig w = flow.w b = first(flow.b) From 19d1ef1e94d1e03c81262267ed56a09d7a391520 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:42 +0100 Subject: [PATCH 23/98] updated RadialLayer to new interface --- src/bijectors/radial_layer.jl | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index a5cb5f34..1d95c4ef 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -12,7 +12,7 @@ using NNlib: softplus # RadialLayer # ############### -mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector{1} +mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector α_::T1 β::T1 z_0::T2 @@ -50,8 +50,10 @@ function _radial_transform(α_, β, z_0, z) return (transformed = transformed, α = α, β_hat = β_hat, r = r) end -(b::RadialLayer)(z::AbstractMatrix{<:Real}) = _transform(b, z).transformed -(b::RadialLayer)(z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) +transform(b::RadialLayer, z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) +transform(b::RadialLayer, z::AbstractMatrix{<:Real}) = _transform(b, z).transformed + +transform_batch(b::RadialLayer, z::ArrayBatch{2}) = Batch(transform(b, value(z))) function forward(flow::RadialLayer, z::AbstractVecOrMat) transformed, α, β_hat, r = _transform(flow, z) @@ -70,7 +72,12 @@ function forward(flow::RadialLayer, z::AbstractVecOrMat) return (result = transformed, logabsdetjac = log_det_jacobian) end -function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) +function forward_batch(b::RadialLayer, z::ArrayBatch{2}) + result, logjac = forward(b, value(z)) + return (result = Batch(result), logabsdetjac = Batch(logjac)) +end + +function transform(ib::Inverse{<:RadialLayer}, y::AbstractVector{<:Real}) flow = ib.orig z0 = flow.z_0 α = softplus(first(flow.α_)) # from A.2 @@ -84,7 +91,7 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) return z0 .+ γ .* y_minus_z0 end -function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{<:RadialLayer}, y::AbstractMatrix{<:Real}) flow = ib.orig z0 = flow.z_0 α = softplus(first(flow.α_)) # from A.2 @@ -100,6 +107,10 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) return z0 .+ γ .* y_minus_z0 end +function transform_batch(ib::Inverse{<:RadialLayer}, y::ArrayBatch{2}) + return Batch(transform(ib, value(y))) +end + """ compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) @@ -120,7 +131,7 @@ where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows. arXiv:1505.05770 """ -function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) +function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_test/β_hat) γ = norm(y_minus_z0) a = α_plus_β_hat - γ r = (sqrt(a^2 + 4 * α * γ) - a) / 2 @@ -128,3 +139,7 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) end logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac + +function logabsdetjac_batch(flow::RadialLayer, x::ArrayBatch{2}) + return Batch(logabsdetjac(flow, value(x))) +end From 50155514e4a660807b2658ac63756176a9af6f35 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:50 +0100 Subject: [PATCH 24/98] updated RationalQuadraticSpline to new interface --- src/bijectors/rational_quadratic_spline.jl | 43 ++++++++++------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index a0907b82..0dea3a7b 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -1,8 +1,7 @@ using NNlib """ - RationalQuadraticSpline{T, 0} <: Bijector{0} - RationalQuadraticSpline{T, 1} <: Bijector{1} + RationalQuadraticSpline{T} <: Bijector Implementation of the Rational Quadratic Spline flow [1]. @@ -77,7 +76,7 @@ julia> b([-1., 5.]) # References [1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019). """ -struct RationalQuadraticSpline{T, N} <: Bijector{N} +struct RationalQuadraticSpline{T} <: Bijector widths::T # K widths heights::T # K heights derivatives::T # K derivatives, with endpoints being ones @@ -179,18 +178,18 @@ end # univariate -function (b::RationalQuadraticSpline{<:AbstractVector, 0})(x::Real) +function transform(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_univariate(b.widths, b.heights, b.derivatives, x) end -(b::RationalQuadraticSpline{<:AbstractVector, 0})(x::AbstractVector) = b.(x) +function transform_batch(b::RationalQuadraticSpline{<:AbstractVector}, x::ArrayBatch{1}) + return Batch(transform.(b, value(x))) +end # multivariate -function (b::RationalQuadraticSpline{<:AbstractMatrix, 1})(x::AbstractVector) +# TODO: Improve. +function transform(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) return [rqs_univariate(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for i = 1:length(x)] end -function (b::RationalQuadraticSpline{<:AbstractMatrix, 1})(x::AbstractMatrix) - return eachcolmaphcat(b, x) -end ########################## ### Inverse evaluation ### @@ -234,18 +233,18 @@ function rqs_univariate_inverse(widths, heights, derivatives, y::Real) return ξ * w + w_k end -function (ib::Inverse{<:RationalQuadraticSpline, 0})(y::Real) +function transform(ib::Inverse{<:RationalQuadraticSpline}, y::Real) return rqs_univariate_inverse(ib.orig.widths, ib.orig.heights, ib.orig.derivatives, y) end -(ib::Inverse{<:RationalQuadraticSpline, 0})(y::AbstractVector) = ib.(y) +function transform_batch(ib::Inverse{<:RationalQuadraticSpline}, y::AbstractVector) + return Batch(transform.(ib, value(y))) +end -function (ib::Inverse{<:RationalQuadraticSpline, 1})(y::AbstractVector) +# TODO: Improve. +function transform(ib::Inverse{<:RationalQuadraticSpline}, y::AbstractVector) b = ib.orig return [rqs_univariate_inverse(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], y[i]) for i = 1:length(y)] end -function (ib::Inverse{<:RationalQuadraticSpline, 1})(y::AbstractMatrix) - return eachcolmaphcat(ib, y) -end ###################### ### `logabsdetjac` ### @@ -315,21 +314,19 @@ function rqs_logabsdetjac( return log(numerator) - 2 * log(denominator) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_logabsdetjac(b.widths, b.heights, b.derivatives, x) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::AbstractVector) - return logabsdetjac.(b, x) +function logabsdetjac_batch(b::RationalQuadraticSpline{<:AbstractVector}, x::ArrayBatch{1}) + return Batch(logabsdetjac.(b, value(x))) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix, 1}, x::AbstractVector) +# TODO: Improve. +function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) return sum([ rqs_logabsdetjac(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for i = 1:length(x) ]) end -function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix, 1}, x::AbstractMatrix) - return mapvcat(x -> logabsdetjac(b, x), eachcol(x)) -end ################# ### `forward` ### @@ -382,6 +379,6 @@ function rqs_forward( return (result = y, logabsdetjac = logjac) end -function forward(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function forward(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end From 3b755267aa3354a40d013fbf3b15a66e20e01afc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:47:59 +0100 Subject: [PATCH 25/98] updated Scale to new interface --- src/bijectors/scale.jl | 77 ++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 44 deletions(-) diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index e19b33f5..f3f15a15 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -1,57 +1,46 @@ -struct Scale{T, N} <: Bijector{N} +struct Scale{T} <: Bijector a::T end -Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a +Base.:(==)(b1::Scale, b2::Scale) = b1.a == b2.a -function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D - return Scale{typeof(a), D}(a) -end - -# field is a numerical parameter -function Functors.functor(::Type{<:Scale{<:Any,N}}, x) where N - function reconstruct_scale(xs) - return Scale{typeof(xs.a),N}(xs.a) - end - return (a = x.a,), reconstruct_scale -end +Functors.@functor Scale -up1(b::Scale{T, N}) where {N, T} = Scale{T, N + 1}(a) - -(b::Scale)(x) = b.a .* x -(b::Scale{<:AbstractMatrix, 1})(x::AbstractVecOrMat) = b.a * x -(b::Scale{<:AbstractMatrix, 2})(x::AbstractMatrix) = b.a * x -(ib::Inverse{<:Scale})(y) = Scale(inv(ib.orig.a))(y) -(ib::Inverse{<:Scale{<:AbstractVector}})(y) = Scale(inv.(ib.orig.a))(y) -(ib::Inverse{<:Scale{<:AbstractMatrix, 1}})(y::AbstractVecOrMat) = ib.orig.a \ y -(ib::Inverse{<:Scale{<:AbstractMatrix, 2}})(y::AbstractMatrix) = ib.orig.a \ y +transform(b::Scale, x) = b.a .* x +transform(b::Scale{<:AbstractMatrix}, x::AbstractVecOrMat) = b.a * x +transform(ib::Inverse{<:Scale}, y) = transform(Scale(inv(ib.orig.a)), y) +transform(ib::Inverse{<:Scale{<:AbstractVector}}, y) = transform(Scale(inv.(ib.orig.a)), y) +transform(ib::Inverse{<:Scale{<:AbstractMatrix}}, y::AbstractVecOrMat) = ib.orig.a \ y # We're going to implement custom adjoint for this -logabsdetjac(b::Scale{<:Any, N}, x) where {N} = _logabsdetjac_scale(b.a, x, Val(N)) +logabsdetjac(b::Scale, x::Real) = _logabsdetjac_scale(b.a, x, Val(0)) +function logabsdetjac(b::Scale, x::AbstractArray{<:Real, N}) where {N} + return _logabsdetjac_scale(b.a, x, Val(N)) +end +function logabsdetjac_batch(b::Scale, x::ArrayBatch{N}) where {N} + return Batch(_logabsdetjac_scale(b, value(x), Val(N - 1))) +end + +# Scalar: single input. _logabsdetjac_scale(a::Real, x::Real, ::Val{0}) = log(abs(a)) -_logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) = fill(log(abs(a)), length(x)) _logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{1}) = log(abs(a)) * length(x) -_logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{1}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{2}) = log(abs(a)) * length(x) -_logabsdetjac_scale(a::Real, x::AbstractArray{<:AbstractMatrix}, ::Val{2}) = map(x) do x - _logabsdetjac_scale(a, x, Val(2)) -end -_logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) = sum(x -> log(abs(x)), a) -_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{1}) = fill(sum(x -> log(abs(x)), a), size(x, 2)) -_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(x -> log(abs(x)), a) -_logabsdetjac_scale(a::AbstractVector, x::AbstractArray{<:AbstractMatrix}, ::Val{2}) = map(x) do x - _logabsdetjac_scale(a, x, Val(2)) -end + +# Scalar: batch. +_logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) = fill(log(abs(a)), length(x)) +_logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{1}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) + +# Vector: single input. +_logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) = sum(log ∘ abs, a) +_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(log ∘ abs, a) + +# Vector: batch. +_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{1}) = fill(sum(log ∘ abs, a), size(x, 2)) + +# Matrix: single input. _logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Val{1}) = logabsdet(a)[1] -_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Val{1}) where {T} = logabsdet(a)[1] * ones(T, size(x, 2)) _logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix, ::Val{2}) = logabsdet(a)[1] -function _logabsdetjac_scale( - a::AbstractMatrix, - x::AbstractArray{<:AbstractMatrix}, - ::Val{2}, -) - map(x) do x - _logabsdetjac_scale(a, x, Val(2)) - end -end \ No newline at end of file + +# Matrix: batch. +_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Val{1}) where {T} = logabsdet(a)[1] * ones(T, size(x, 2)) From 195a107f5fa64598c09641fa20d2751fd5c11a18 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:48:04 +0100 Subject: [PATCH 26/98] updated Shift to new interface --- src/bijectors/shift.jl | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index f6c39f82..9aaade56 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -1,33 +1,26 @@ ################# # Shift & Scale # ################# -struct Shift{T, N} <: Bijector{N} +struct Shift{T} <: Bijector a::T end Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a -function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D - return Shift{typeof(a), D}(a) -end - -# field is a numerical parameter -function Functors.functor(::Type{<:Shift{<:Any,N}}, x) where N - function reconstruct_shift(xs) - return Shift{typeof(xs.a),N}(xs.a) - end - return (a = x.a,), reconstruct_shift -end - -up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a) +Functors.@functor Shift -(b::Shift)(x) = b.a .+ x -(b::Shift{<:Any, 2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x) +transform(b::Shift, x) = b.a .+ x -inv(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) +inv(b::Shift) = Shift(-b.a) # FIXME: implement custom adjoint to ensure we don't get tracking -logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val(N)) +function logabsdetjac(b::Shift, x::AbstractArray{<:Real, N}) where {N} + return _logabsdetjac_shift(b.a, x, Val(N)) +end + +function logabsdetjac_batch(b::Shift, x::AbstractArray{<:Real, N}) where {N} + return _logabsdetjac_shift(b.a, x, Val(N - 1)) +end _logabsdetjac_shift(a::Real, x::Real, ::Val{0}) = zero(eltype(x)) _logabsdetjac_shift(a::Real, x::AbstractVector{T}, ::Val{0}) where {T<:Real} = zeros(T, length(x)) From f56ed6ab365a693230e84902b0557dc5bf0250df Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:48:13 +0100 Subject: [PATCH 27/98] updated Stacked to new interface --- src/bijectors/stacked.jl | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 34116616..20104246 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -21,7 +21,7 @@ b = stack(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` """ -struct Stacked{Bs, Rs} <: Bijector{1} +struct Stacked{Bs, Rs} <: Transform bs::Bs ranges::Rs end @@ -48,13 +48,15 @@ end isclosedform(b::Stacked) = all(isclosedform, b.bs) -stack(bs::Bijector{0}...) = Stacked(bs) +invertible(b::Stacked) = sum(map(invertible, b.bs)) + +stack(bs::Bijector...) = Stacked(bs) # For some reason `inv.(sb.bs)` was unstable... This works though. -inv(sb::Stacked) = Stacked(map(inv, sb.bs), sb.ranges) +inv(sb::Stacked, ::Invertible) = Stacked(map(inv, sb.bs), sb.ranges) # map is not type stable for many stacked bijectors as a large tuple # hence the generated function -@generated function inv(sb::Stacked{A}) where {A <: Tuple} +@generated function inv(sb::Stacked{A}, ::Invertible) where {A <: Tuple} exprs = [] for i = 1:length(A.parameters) push!(exprs, :(inv(sb.bs[$i]))) @@ -79,7 +81,7 @@ function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real}) return y end # The Stacked{<:AbstractArray} version is not TrackedArray friendly -function (sb::Stacked{<:AbstractArray})(x::AbstractVector{<:Real}) +function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) N = length(sb.bs) N == 1 && return sb.bs[1](x[sb.ranges[1]]) @@ -90,7 +92,6 @@ function (sb::Stacked{<:AbstractArray})(x::AbstractVector{<:Real}) return y end -(sb::Stacked)(x::AbstractMatrix{<:Real}) = eachcolmaphcat(sb, x) function logabsdetjac( b::Stacked, x::AbstractVector{<:Real} @@ -122,12 +123,6 @@ function logabsdetjac(b::Stacked{<:Tuple{<:Bijector}}, x::AbstractVector{<:Real} return sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) end -function logabsdetjac(b::Stacked, x::AbstractMatrix{<:Real}) - return map(eachcol(x)) do c - logabsdetjac(b, c) - end -end - # Generates something similar to: # # quote From 72d68b8bfa28c72acf372b9323012de3da8e01f6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 11:48:19 +0100 Subject: [PATCH 28/98] updated TruncatedBijector to new interface --- src/bijectors/truncated.jl | 106 ++++--------------------------------- 1 file changed, 11 insertions(+), 95 deletions(-) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index de5997b2..be8f4f10 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -1,60 +1,22 @@ ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### -struct TruncatedBijector{N, T1, T2} <: Bijector{N} +struct TruncatedBijector{T1, T2} <: Bijector lb::T1 ub::T2 end -TruncatedBijector(lb, ub) = TruncatedBijector{0}(lb, ub) -function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2} - return TruncatedBijector{N, T1, T2}(lb, ub) -end -# field are numerical parameters -function Functors.functor(::Type{<:TruncatedBijector{N}}, x) where N - function reconstruct_truncatedbijector(xs) - return TruncatedBijector{N}(xs.lb, xs.ub) - end - return (lb = x.lb, ub = x.ub,), reconstruct_truncatedbijector -end - -up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub) +Functors.@functor TruncatedBijector function Base.:(==)(b1::TruncatedBijector, b2::TruncatedBijector) return b1.lb == b2.lb && b1.ub == b2.ub end -function (b::TruncatedBijector{0})(x::Real) +function transform(b::TruncatedBijector, x) a, b = b.lb, b.ub - truncated_link(_clamp(x, a, b), a, b) + return truncated_link.(_clamp.(x, a, b), a, b) end -function (b::TruncatedBijector{0})(x::AbstractArray{<:Real}) - a, b = b.lb, b.ub - truncated_link.(_clamp.(x, a, b), a, b) -end -function (b::TruncatedBijector{1})(x::AbstractVecOrMat{<:Real}) - a, b = b.lb, b.ub - if a isa AbstractVector - @assert b isa AbstractVector - maporbroadcast(x, a, b) do x, a, b - truncated_link(_clamp(x, a, b), a, b) - end - else - truncated_link.(_clamp.(x, a, b), a, b) - end -end -function (b::TruncatedBijector{2})(x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - if a isa AbstractMatrix - @assert b isa AbstractMatrix - maporbroadcast(x, a, b) do x, a, b - truncated_link(_clamp(x, a, b), a, b) - end - else - truncated_link.(_clamp.(x, a, b), a, b) - end -end -(b::TruncatedBijector{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) + function truncated_link(x::Real, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -68,37 +30,11 @@ function truncated_link(x::Real, a, b) end end -function (ib::Inverse{<:TruncatedBijector{0}})(y::Real) +function transform(ib::Inverse{<:TruncatedBijector}, y) a, b = ib.orig.lb, ib.orig.ub - _clamp(truncated_invlink(y, a, b), a, b) + return _clamp.(truncated_invlink.(y, a, b), a, b) end -function (ib::Inverse{<:TruncatedBijector{0}})(y::AbstractArray{<:Real}) - a, b = ib.orig.lb, ib.orig.ub - _clamp.(truncated_invlink.(y, a, b), a, b) -end -function (ib::Inverse{<:TruncatedBijector{1}})(y::AbstractVecOrMat{<:Real}) - a, b = ib.orig.lb, ib.orig.ub - if a isa AbstractVector - @assert b isa AbstractVector - maporbroadcast(y, a, b) do y, a, b - _clamp(truncated_invlink(y, a, b), a, b) - end - else - _clamp.(truncated_invlink.(y, a, b), a, b) - end -end -function (ib::Inverse{<:TruncatedBijector{2}})(y::AbstractMatrix{<:Real}) - a, b = ib.orig.lb, ib.orig.ub - if a isa AbstractMatrix - @assert b isa AbstractMatrix - return maporbroadcast(y, a, b) do y, a, b - _clamp(truncated_invlink(y, a, b), a, b) - end - else - return _clamp.(truncated_invlink.(y, a, b), a, b) - end -end -(ib::Inverse{<:TruncatedBijector{2}})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, y) + function truncated_invlink(y, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -112,31 +48,11 @@ function truncated_invlink(y, a, b) end end -function logabsdetjac(b::TruncatedBijector{0}, x::Real) - a, b = b.lb, b.ub - truncated_logabsdetjac(_clamp(x, a, b), a, b) -end -function logabsdetjac(b::TruncatedBijector{0}, x::AbstractArray{<:Real}) +function logabsdetjac(b::TruncatedBijector, x) a, b = b.lb, b.ub - truncated_logabsdetjac.(_clamp.(x, a, b), a, b) -end -function logabsdetjac(b::TruncatedBijector{1}, x::AbstractVector{<:Real}) - a, b = b.lb, b.ub - sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b)) -end -function logabsdetjac(b::TruncatedBijector{1}, x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - vec(sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b), dims = 1)) -end -function logabsdetjac(b::TruncatedBijector{2}, x::AbstractMatrix{<:Real}) - a, b = b.lb, b.ub - sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b)) -end -function logabsdetjac(b::TruncatedBijector{2}, x::AbstractArray{<:AbstractMatrix{<:Real}}) - map(x) do x - logabsdetjac(b, x) - end + return truncated_logabsdetjac.(_clamp.(x, a, b), a, b) end + function truncated_logabsdetjac(x, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded From 214aa92943cfd6c0375a089c450368e0d6018844 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 12:02:20 +0100 Subject: [PATCH 29/98] added ConstructionBase as dependency --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index a86495c3..5abd4e67 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.9.4" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 836e152b74fff2f019af3f3b63eaab500ffe5833 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 12:02:31 +0100 Subject: [PATCH 30/98] fixed a bunch of small typos and errors from previous commits --- src/bijectors/radial_layer.jl | 2 +- src/bijectors/rational_quadratic_spline.jl | 4 ++-- src/bijectors/shift.jl | 2 +- src/bijectors/simplex.jl | 2 +- src/interface.jl | 2 +- src/transformed_distribution.jl | 22 +++++++++++----------- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 1d95c4ef..35b98b34 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -131,7 +131,7 @@ where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows. arXiv:1505.05770 """ -function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_test/β_hat) +function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_test) γ = norm(y_minus_z0) a = α_plus_β_hat - γ r = (sqrt(a^2 + 4 * α * γ) - a) / 2 diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 0dea3a7b..faa91c5d 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -90,7 +90,7 @@ struct RationalQuadraticSpline{T} <: Bijector @assert length(widths) == length(heights) == length(derivatives) @assert all(derivatives .> 0) "derivatives need to be positive" - return new{T, 0}(widths, heights, derivatives) + return new{T}(widths, heights, derivatives) end function RationalQuadraticSpline( @@ -100,7 +100,7 @@ struct RationalQuadraticSpline{T} <: Bijector ) where {T<:AbstractMatrix} @assert size(widths, 2) == size(heights, 2) == size(derivatives, 2) @assert all(derivatives .> 0) "derivatives need to be positive" - return new{T, 1}(widths, heights, derivatives) + return new{T}(widths, heights, derivatives) end end diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 9aaade56..2773de0d 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -5,7 +5,7 @@ struct Shift{T} <: Bijector a::T end -Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a +Base.:(==)(b1::Shift, b2::Shift) = b1.a == b2.a Functors.@functor Shift diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index f9f83d85..3e990592 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -110,7 +110,7 @@ end # Batched versions. transform_batch(ib::Inverse{<:SimplexBijector}, Y::ArrayBatch{2}) = _simplex_inv_bijector(Y, ib.orig) function transform_batch!( - ib::Inverse{<:SimplexBijector} + ib::Inverse{<:SimplexBijector}, X::Batch{<:AbstractMatrix{T}}, Y::Batch{<:AbstractMatrix{T}}, ) where {T<:Real} diff --git a/src/interface.jl b/src/interface.jl index cfcaf7c8..d9ed8ba9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -220,7 +220,7 @@ transform_batch(b, xs) = _transform_batch(b, xs) _transform_batch(b, xs::VectorBatch) = reconstruct(xs, map(b, value(xs))) function _transform_batch(b, xs::ArrayBatch{2}) # TODO: Check if we can avoid using these custom methods. - return eachcolmaphcat(b, x) + return Batch(eachcolmaphcat(b, value(xs))) end function _transform_batch(b, xs::ArrayBatch{N}) where {N} res = reduce(map(b, eachslice(value(xs), dims=N))) do acc, x diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index d6177c72..ae5ef03f 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -2,12 +2,12 @@ struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector} dist::D transform::B - - TransformedDistribution(d::UnivariateDistribution, b::Bijector{0}) = new{typeof(d), typeof(b), Univariate}(d, b) - TransformedDistribution(d::MultivariateDistribution, b::Bijector{1}) = new{typeof(d), typeof(b), Multivariate}(d, b) - TransformedDistribution(d::MatrixDistribution, b::Bijector{2}) = new{typeof(d), typeof(b), Matrixvariate}(d, b) end +TransformedDistribution(d::UnivariateDistribution, b::Bijector) = new{typeof(d), typeof(b), Univariate}(d, b) +TransformedDistribution(d::MultivariateDistribution, b::Bijector) = new{typeof(d), typeof(b), Multivariate}(d, b) +TransformedDistribution(d::MatrixDistribution, b::Bijector) = new{typeof(d), typeof(b), Matrixvariate}(d, b) + # fields may contain nested numerical parameters Functors.@functor TransformedDistribution @@ -38,9 +38,9 @@ Returns the constrained-to-unconstrained bijector for distribution `d`. bijector(d::DiscreteUnivariateDistribution) = Identity{0}() bijector(d::DiscreteMultivariateDistribution) = Identity{1}() bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) -bijector(d::Product{Discrete}) = Identity{1}() +bijector(d::Product{Discrete}) = Identity() function bijector(d::Product{Continuous}) - return TruncatedBijector{1}(_minmax(d.v)...) + return TruncatedBijector(_minmax(d.v)...) end @generated function _minmax(d::AbstractArray{T}) where {T} try @@ -51,11 +51,11 @@ end end end -bijector(d::Normal) = Identity{0}() -bijector(d::Distributions.AbstractMvNormal) = Identity{1}() -bijector(d::Distributions.AbstractMvLogNormal) = Log{1}() -bijector(d::PositiveDistribution) = Log{0}() -bijector(d::SimplexDistribution) = SimplexBijector{1}() +bijector(d::Normal) = Identity() +bijector(d::Distributions.AbstractMvNormal) = Identity() +bijector(d::Distributions.AbstractMvLogNormal) = Log() +bijector(d::PositiveDistribution) = Log() +bijector(d::SimplexDistribution) = SimplexBijector() bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d))) bijector_bounded(d, a=minimum(d), b=maximum(d)) = Logit(a, b) From ff6b756a1c3215c9a8dbf08ccc9462c4f733d4c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 12:07:33 +0100 Subject: [PATCH 31/98] forgot to wrap some in Batch --- src/bijectors/scale.jl | 2 +- src/bijectors/shift.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index f3f15a15..ad6df463 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -19,7 +19,7 @@ function logabsdetjac(b::Scale, x::AbstractArray{<:Real, N}) where {N} end function logabsdetjac_batch(b::Scale, x::ArrayBatch{N}) where {N} - return Batch(_logabsdetjac_scale(b, value(x), Val(N - 1))) + return Batch(_logabsdetjac_scale(b.a, value(x), Val(N - 1))) end # Scalar: single input. diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 2773de0d..6bd0fcd0 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -19,7 +19,7 @@ function logabsdetjac(b::Shift, x::AbstractArray{<:Real, N}) where {N} end function logabsdetjac_batch(b::Shift, x::AbstractArray{<:Real, N}) where {N} - return _logabsdetjac_shift(b.a, x, Val(N - 1)) + return Batch(_logabsdetjac_shift(b.a, x, Val(N - 1))) end _logabsdetjac_shift(a::Real, x::Real, ::Val{0}) = zero(eltype(x)) From 989aaa87b0144ce588e15a1d12ac0141ee5f3526 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 05:32:07 +0100 Subject: [PATCH 32/98] allow inverses of non-bijectors --- src/interface.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d9ed8ba9..c6617a59 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -79,6 +79,7 @@ Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() Base.:+(::Invertible, ::Invertible) = Invertible() invertible(::Transform) = NotInvertible() +isinvertible(t::Transform) = invertible(t) isa Invertible """ inv(t::Transform[, ::Invertible]) @@ -160,21 +161,31 @@ requires an iterative procedure to evaluate. isclosedform(b::Bijector) = true """ - inv(b::Bijector) - Inverse(b::Bijector) + inv(b::Transform) + Inverse(b::Transform) -A `Bijector` representing the inverse transform of `b`. +A `Transform` representing the inverse transform of `b`. """ -struct Inverse{B<:Bijector} <: Bijector - orig::B +struct Inverse{T<:Transform} <: Transform + orig::T + + function Inverse(orig::Transform) + if !isinvertible(orig) + error("$(orig) is not invertible") + end + + return new{typeof(orig)}(orig) + end end -# field contains nested numerical parameters Functors.@functor Inverse -inv(b::Bijector) = Inverse(b) -inv(ib::Inverse{<:Bijector}) = ib.orig -Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig +inv(b, ::Invertible) = Inverse(b) +inv(ib::Inverse) = ib.orig + +invertible(ib::Inverse) = Invertible() + +Base.:(==)(b1::Inverse, b2::Inverse) = b1.orig == b2.orig logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) From 4d2388263d65a82f011fe411ce6e7884f82a8a9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 05:32:19 +0100 Subject: [PATCH 33/98] relax definition of VectorBatch so Vector{<:Real} is covered --- src/batch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/batch.jl b/src/batch.jl index 4c5a512b..1e1f7700 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -17,7 +17,7 @@ value(x::Batch) = x.value # Convenient aliases const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}} -const VectorBatch = Batch{<:AbstractVector{<:AbstractArray{<:Real}}} +const VectorBatch = Batch{<:AbstractVector} # Constructor for `ArrayBatch`. Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x) From 0f9d334f352a0e77b7a31c2b8ba0f2a8f2f9644d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 05:40:57 +0100 Subject: [PATCH 34/98] just perform invertibility check in Inverse rather than inv --- src/bijectors/composed.jl | 2 +- src/interface.jl | 74 +++++++++++++++++++-------------------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 3a0e5a9e..1fbaaed0 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -146,7 +146,7 @@ end ∘(::Identity, b::Transform) = b ∘(b::Transform, ::Identity) = b -inv(ct::Composed, ::Invertible) = Composed(reverse(map(inv, ct.ts))) +Base.inv(ct::Composed) = Composed(reverse(map(inv, ct.ts))) # TODO: should arrays also be using recursive implementation instead? function transform(cb::Composed, x) diff --git a/src/interface.jl b/src/interface.jl index c6617a59..330a28f6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -68,27 +68,6 @@ abstract type Transform end Broadcast.broadcastable(b::Transform) = Ref(b) -# Invertibility "trait". -struct NotInvertible end -struct Invertible end - -# Useful for checking if compositions, etc. are invertible or not. -Base.:+(::NotInvertible, ::Invertible) = NotInvertible() -Base.:+(::Invertible, ::NotInvertible) = NotInvertible() -Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() -Base.:+(::Invertible, ::Invertible) = Invertible() - -invertible(::Transform) = NotInvertible() -isinvertible(t::Transform) = invertible(t) isa Invertible - -""" - inv(t::Transform[, ::Invertible]) - -Returns the inverse of transform `t`. -""" -Base.inv(t::Transform) = Base.inv(t, invertible(t)) -Base.inv(t::Transform, ::NotInvertible) = error("$(t) is not invertible") - """ transform(b, x) @@ -142,23 +121,18 @@ function forward!(b, x, out) return out end -"Abstract type of a bijector, i.e. differentiable bijection with differentiable inverse." -abstract type Bijector <: Transform end - -invertible(::Bijector) = Invertible() - -""" - isclosedform(b::Bijector)::bool - isclosedform(b⁻¹::Inverse{<:Bijector})::bool +# Invertibility "trait". +struct NotInvertible end +struct Invertible end -Returns `true` or `false` depending on whether or not evaluation of `b` -has a closed-form implementation. +# Useful for checking if compositions, etc. are invertible or not. +Base.:+(::NotInvertible, ::Invertible) = NotInvertible() +Base.:+(::Invertible, ::NotInvertible) = NotInvertible() +Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() +Base.:+(::Invertible, ::Invertible) = Invertible() -Most bijectors have closed-form evaluations, but there are cases where -this is not the case. For example the *inverse* evaluation of `PlanarLayer` -requires an iterative procedure to evaluate. -""" -isclosedform(b::Bijector) = true +invertible(::Transform) = NotInvertible() +isinvertible(t::Transform) = invertible(t) isa Invertible """ inv(b::Transform) @@ -180,13 +154,37 @@ end Functors.@functor Inverse -inv(b, ::Invertible) = Inverse(b) -inv(ib::Inverse) = ib.orig +""" + inv(t::Transform[, ::Invertible]) + +Returns the inverse of transform `t`. +""" +Base.inv(t::Transform) = Inverse(t) +Base.inv(ib::Inverse) = ib.orig invertible(ib::Inverse) = Invertible() Base.:(==)(b1::Inverse, b2::Inverse) = b1.orig == b2.orig +"Abstract type of a bijector, i.e. differentiable bijection with differentiable inverse." +abstract type Bijector <: Transform end + +invertible(::Bijector) = Invertible() + +""" + isclosedform(b::Bijector)::bool + isclosedform(b⁻¹::Inverse{<:Bijector})::bool + +Returns `true` or `false` depending on whether or not evaluation of `b` +has a closed-form implementation. + +Most bijectors have closed-form evaluations, but there are cases where +this is not the case. For example the *inverse* evaluation of `PlanarLayer` +requires an iterative procedure to evaluate. +""" +isclosedform(b::Bijector) = true + +# Default implementation for inverse of a `Bijector`. logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) """ From af0b24b18f1f0969cff939a604059bdae67fa79a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 05:42:48 +0100 Subject: [PATCH 35/98] moved some code arround --- src/interface.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 330a28f6..3a985868 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -121,6 +121,19 @@ function forward!(b, x, out) return out end +""" + isclosedform(b::Transform)::bool + isclosedform(b⁻¹::Inverse{<:Transform})::bool + +Returns `true` or `false` depending on whether or not evaluation of `b` +has a closed-form implementation. + +Most transformations have closed-form evaluations, but there are cases where +this is not the case. For example the *inverse* evaluation of `PlanarLayer` +requires an iterative procedure to evaluate. +""" +isclosedform(b::Transform) = true + # Invertibility "trait". struct NotInvertible end struct Invertible end @@ -171,19 +184,6 @@ abstract type Bijector <: Transform end invertible(::Bijector) = Invertible() -""" - isclosedform(b::Bijector)::bool - isclosedform(b⁻¹::Inverse{<:Bijector})::bool - -Returns `true` or `false` depending on whether or not evaluation of `b` -has a closed-form implementation. - -Most bijectors have closed-form evaluations, but there are cases where -this is not the case. For example the *inverse* evaluation of `PlanarLayer` -requires an iterative procedure to evaluate. -""" -isclosedform(b::Bijector) = true - # Default implementation for inverse of a `Bijector`. logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) From 0777fab7afba473181a35f4ede367ac17f310b22 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 06:00:09 +0100 Subject: [PATCH 36/98] added docstrings and default impls for mutating batched methods --- src/interface.jl | 56 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 3a985868..011ab2c6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -238,6 +238,27 @@ function _transform_batch(b, xs::ArrayBatch{N}) where {N} return reconstruct(xs, res) end +""" + transform_batch!(b, xs, ys) + +Transform `xs` by `b` treating `xs` as a "batch", i.e. a collection of independent inputs, +and storing the result in `ys`. + +See also: [`transform!`](@ref) +""" +transform_batch!(b, xs, ys) = _transform_batch!(b, xs, ys) +function _transform_batch!(b, xs, ys) + for i = 1:length(xs) + if eltype(ys) <: Real + ys[i] = transform(b, xs[i]) + else + transform!(b, xs[i], ys[i]) + end + end + + return ys +end + """ logabsdetjac_batch(b, xs) @@ -252,6 +273,27 @@ function _logabsdetjac_batch(b, xs::ArrayBatch{N}) where {N} return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), dims=N))) end +""" + logabsdetjac_batch!(b, xs, logjacs) + +Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs, +accumulating the result in `logjacs`. + +See also: [`logabsdetjac!`](@ref) +""" +logabsdetjac_batch!(b, xs, logjacs) = _logabsdetjac_batch!(b, xs, logjacs) +function _logabsdetjac_batch!(b, xs, logjacs) + for i = 1:length(xs) + if eltype(logjacs) <: Real + logjacs[i] += logabsdetjac(b, xs[i]) + else + logabsdetjac!(b, xs[i], logjacs[i]) + end + end + + return logjacs +end + """ forward_batch(b, xs) @@ -261,6 +303,20 @@ See also: [`transform`](@ref) """ forward_batch(b, xs) = (result = transform_batch(b, xs), logabsdetjac = logabsdetjac_batch(b, xs)) +""" + forward_batch!(b, xs, out) + +Computes `forward(b, xs)` in place, treating `xs` as a "batch", i.e. a collection of independent inputs. + +See also: [`forward!`](@ref) +""" +function forward_batch!(b, xs, out) + transform_batch!(b, xs, out.result) + logabsdetjac_batch!(b, xs, out.logabsdetjac) + + return out +end + ###################### # Bijectors includes # ###################### From 42839c373a9f0f57b2de23762cf2e3a5376c0c3d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:29:02 +0100 Subject: [PATCH 37/98] add elementype to VectorBatch --- src/batch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/batch.jl b/src/batch.jl index 1e1f7700..0089618a 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -17,7 +17,7 @@ value(x::Batch) = x.value # Convenient aliases const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}} -const VectorBatch = Batch{<:AbstractVector} +const VectorBatch{T} = Batch{<:AbstractVector{T}} # Constructor for `ArrayBatch`. Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x) From 2ce74d4e45570b534dae8422e9d73e4f540336ac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:29:28 +0100 Subject: [PATCH 38/98] simplify Shift bijector --- src/bijectors/shift.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 6bd0fcd0..85050e5c 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -14,17 +14,13 @@ transform(b::Shift, x) = b.a .+ x inv(b::Shift) = Shift(-b.a) # FIXME: implement custom adjoint to ensure we don't get tracking -function logabsdetjac(b::Shift, x::AbstractArray{<:Real, N}) where {N} - return _logabsdetjac_shift(b.a, x, Val(N)) +function logabsdetjac(b::Shift, x::Union{Real, AbstractArray{<:Real}}) + return _logabsdetjac_shift(b.a, x) end -function logabsdetjac_batch(b::Shift, x::AbstractArray{<:Real, N}) where {N} - return Batch(_logabsdetjac_shift(b.a, x, Val(N - 1))) +function logabsdetjac_batch(b::Shift, x::ArrayBatch) + return Batch(_logabsdetjac_shift_array_batch(b.a, value(x))) end -_logabsdetjac_shift(a::Real, x::Real, ::Val{0}) = zero(eltype(x)) -_logabsdetjac_shift(a::Real, x::AbstractVector{T}, ::Val{0}) where {T<:Real} = zeros(T, length(x)) -_logabsdetjac_shift(a::T1, x::AbstractVector{T2}, ::Val{1}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2) -_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{1}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x, 2)) -_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{2}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2) -_logabsdetjac_shift(a::T1, x::AbstractArray{<:AbstractMatrix{T2}}, ::Val{2}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x)) +_logabsdetjac_shift(a, x) = zero(eltype(x)) +_logabsdetjac_shift_array_batch(a, x) = zeros(eltype(x), size(x, ndims(x))) From 926ef2716dce342f5a73c20d44d14f51727e6815 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:30:06 +0100 Subject: [PATCH 39/98] added rrules for logabsdetjac_shift --- src/chainrules.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index b7f2b51a..7ed33599 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -6,3 +6,11 @@ ChainRulesCore.@scalar_rule( ), (x, - tanh(Ω + b) * x, x - 1), ) + +function ChainRulesCore.rrule(::typeof(_logabsdetjac_shift), a, x) + return _logabsdetjac_shift(a, x), Δ -> (ChainRulesCore.NO_FIELDS, ChainRulesCore.ZeroTangent(), ChainRulesCore.ZeroTangent()) +end + +function ChainRulesCore.rrule(::typeof(_logabsdetjac_shift_array_batch), a, x) + return _logabsdetjac_shift_array_batch(a, x), Δ -> (ChainRulesCore.NO_FIELDS, ChainRulesCore.ZeroTangent(), ChainRulesCore.ZeroTangent()) +end From c3745d76229e81f70cfa029ce850fdd641c24eec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:30:36 +0100 Subject: [PATCH 40/98] use type-stable implementation of eachslice --- src/interface.jl | 7 +++++-- src/utils.jl | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 011ab2c6..7e1cb171 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -232,7 +232,7 @@ function _transform_batch(b, xs::ArrayBatch{2}) return Batch(eachcolmaphcat(b, value(xs))) end function _transform_batch(b, xs::ArrayBatch{N}) where {N} - res = reduce(map(b, eachslice(value(xs), dims=N))) do acc, x + res = reduce(map(b, eachslice(value(xs), Val{N}()))) do acc, x cat(acc, x; dims = N) end return reconstruct(xs, res) @@ -269,8 +269,11 @@ See also: [`logabsdetjac`](@ref) logabsdetjac_batch(b, xs) = _logabsdetjac_batch(b, xs) # Default implementations uses private methods to avoid method ambiguity. _logabsdetjac_batch(b, xs::VectorBatch) = reconstruct(xs, map(x -> logabsdetjac(b, x), value(xs))) +function _logabsdetjac_batch(b, xs::ArrayBatch{2}) + return reconstruct(xs, map(x -> logabsdetjac(b, x), eachcol(value(xs)))) +end function _logabsdetjac_batch(b, xs::ArrayBatch{N}) where {N} - return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), dims=N))) + return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), Val{N}()))) end """ diff --git a/src/utils.jl b/src/utils.jl index d1034dbc..dcb99eb7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,3 +9,17 @@ _vec(x::Real) = x # Useful for reconstructing objects. reconstruct(b, args...) = constructorof(typeof(b))(args...) + +# Despite kwargs using `NamedTuple` in Julia 1.6, I'm still running +# into type-instability issues when using `eachslice`. So we define our own. +# https://github.com/JuliaLang/julia/issues/39639 +# TODO: Check if this is the case in Julia 1.7. +# Adapted from https://github.com/JuliaLang/julia/blob/ef673537f8622fcaea92ac85e07962adcc17745b/base/abstractarraymath.jl#L506-L513. +# FIXME: It seems to break Zygote though. Which is weird because normal `eachslice` does not. +@inline function Base.eachslice(A::AbstractArray, ::Val{N}) where {N} + dim = N + dim <= ndims(A) || throw(DimensionMismatch("A doesn't have $dim dimensions")) + inds_before = ntuple(_ -> Colon(), dim-1) + inds_after = ntuple(_ -> Colon(), ndims(A)-dim) + return (view(A, inds_before..., i, inds_after...) for i in axes(A, dim)) +end From 16009860be18e7dfddea1c47a74fd2fa8f67f475 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:30:55 +0100 Subject: [PATCH 41/98] initial work on adding proper testing --- test/ad/utils.jl | 12 ++- test/bijectors/utils.jl | 210 ++++++++++++++++++++++++---------------- 2 files changed, 137 insertions(+), 85 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 6c9bf4a4..e87ba9b4 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -136,14 +136,19 @@ function test_ad(dist::DistSpec; kwargs...) end end -function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) +function test_ad(_f, _x, broken = (); rtol = 1e-6, atol = 1e-6) + f = _x isa Real ? _f ∘ first : _f + x = [_x;] + finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] if AD == "All" || AD == "Tracker" if :Tracker in broken @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol else - @test Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol + ∇tracker = Tracker.gradient(f, x)[1] + @test Tracker.data(∇tracker) ≈ finitediff rtol=rtol atol=atol + @test Tracker.istracked(∇tracker) end end @@ -159,7 +164,8 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) if :Zygote in broken @test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol else - @test Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol + ∇zygote = Zygote.gradient(f, x)[1] + @test (all(finitediff .== 0) && ∇zygote === nothing) || isapprox(∇zygote, finitediff, rtol=rtol, atol=atol) end end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index a0fdb6f2..539408d4 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -1,8 +1,11 @@ -function test_bijector_reals( - b::Bijector{0}, - x_true::Real, - y_true::Real, - logjac_true::Real; +import Bijectors: AbstractBatch + + +function test_bijector_single( + b::Bijector, + x_true, + y_true, + logjac_true; isequal = true, tol = 1e-6 ) @@ -17,67 +20,66 @@ function test_bijector_reals( ires = isequal ? @inferred(forward(inv(b), y_true)) : @inferred(forward(inv(b), y)) # Always want the following to hold - @test ires.rv ≈ x_true atol=tol + @test ires.result ≈ x_true atol=tol @test ires.logabsdetjac ≈ -logjac atol=tol if isequal @test y ≈ y_true atol=tol # forward @test (@inferred ib(y_true)) ≈ x_true atol=tol # inverse @test logjac ≈ logjac_true # logjac forward - @test res.rv ≈ y_true atol=tol # forward using `forward` + @test res.result ≈ y_true atol=tol # forward using `forward` @test res.logabsdetjac ≈ logjac_true atol=tol # logjac using `forward` else @test y ≠ y_true # forward @test (@inferred ib(y)) ≈ x_true atol=tol # inverse @test logjac ≠ logjac_true # logjac forward - @test res.rv ≠ y_true # forward using `forward` + @test res.result ≠ y_true # forward using `forward` @test res.logabsdetjac ≠ logjac_true # logjac using `forward` end end -function test_bijector_arrays( +function test_bijector_batch( b::Bijector, - xs_true::AbstractArray{<:Real}, - ys_true::AbstractArray{<:Real}, - logjacs_true::Union{Real, AbstractArray{<:Real}}; + xs_true::AbstractBatch, + ys_true::AbstractBatch, + logjacs_true; isequal = true, tol = 1e-6 -) + ib = @inferred inv(b) - ys = @inferred b(xs_true) - logjacs = @inferred logabsdetjac(b, xs_true) - res = @inferred forward(b, xs_true) + ys = @inferred broadcast(b, xs_true) + logjacs = @inferred broadcast(logabsdetjac, b, xs_true) + res = @inferred broadcast(forward, b, xs_true) # If `isequal` is false, then we use the computed `y`, # but if it's true, we use the true `y`. - ires = isequal ? @inferred(forward(inv(b), ys_true)) : @inferred(forward(inv(b), ys)) + ires = isequal ? @inferred(broadcast(forward, inv(b), ys_true)) : @inferred(broadcast(forward, inv(b), ys)) # always want the following to hold - @test ys isa typeof(ys_true) - @test logjacs isa typeof(logjacs_true) - @test mean(abs, ires.rv - xs_true) ≤ tol - @test mean(abs, ires.logabsdetjac + logjacs) ≤ tol + @test ys isa AbstractBatch + @test logjacs isa AbstractBatch + @test mean(norm, ires.result - xs_true) ≤ tol + @test mean(norm, ires.logabsdetjac + logjacs) ≤ tol if isequal - @test mean(abs, ys - ys_true) ≤ tol # forward - @test mean(abs, (ib(ys_true)) - xs_true) ≤ tol # inverse - @test mean(abs, logjacs - logjacs_true) ≤ tol # logjac forward - @test mean(abs, res.rv - ys_true) ≤ tol # forward using `forward` + @test mean(norm, ys - ys_true) ≤ tol # forward + @test mean(norm, broadcast(ib, ys_true) - xs_true) ≤ tol # inverse + @test mean(abs, logjacs .- logjacs_true) ≤ tol # logjac forward + @test mean(norm, res.result - ys_true) ≤ tol # forward using `forward` @test mean(abs, res.logabsdetjac - logjacs_true) ≤ tol # logjac `forward` @test mean(abs, ires.logabsdetjac + logjacs_true) ≤ tol # inverse logjac `forward` else # Don't want the following to be equal to their "true" values - @test mean(abs, ys - ys_true) > tol # forward + @test mean(norm, ys - ys_true) > tol # forward @test mean(abs, logjacs - logjacs_true) > tol # logjac forward - @test mean(abs, res.rv - ys_true) > tol # forward using `forward` + @test mean(norm, res.result - ys_true) > tol # forward using `forward` # Still want the following to be equal to the COMPUTED values - @test mean(abs, ib(ys) - xs_true) ≤ tol # inverse + @test mean(norm, broadcast(ib, ys) - xs_true) ≤ tol # inverse @test mean(abs, res.logabsdetjac - logjacs) ≤ tol # logjac forward using `forward` end end """ - test_bijector(b::Bijector, xs::Array; kwargs...) test_bijector(b::Bijector, xs::Array, ys::Array, logjacs::Array; kwargs...) Tests the bijector `b` on the inputs `xs` against the, optionally, provided `ys` @@ -103,82 +105,110 @@ treated as "counter-examples", i.e. values NOT to match. - `tol = 1e-6`: the absolute tolerance used for the checks. This is also used to check arrays where we check that the L1-norm is sufficiently small. """ -function test_bijector(b::Bijector{0}, xs::AbstractVector{<:Real}) - return test_bijector(b, xs, zeros(length(xs)), zeros(length(xs)); isequal = false) -end - -function test_bijector(b::Bijector{1}, xs::AbstractMatrix{<:Real}) - return test_bijector(b, xs, zeros(size(xs)), zeros(size(xs, 2)); isequal = false) -end - function test_bijector( - b::Bijector{0}, - xs_true::AbstractVector{<:Real}, - ys_true::AbstractVector{<:Real}, - logjacs_true::AbstractVector{<:Real}; + b::Bijector, + xs_true::AbstractBatch, + ys_true::AbstractBatch, + logjacs_true::AbstractBatch; kwargs... ) - ib = inv(b) - - # Batch - test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) - # Test `logabsdetjac` against jacobians test_logabsdetjac(b, xs_true) - test_logabsdetjac(b, ys_true) + + if Bijectors.isinvertible(b) + ib = inv(b) + test_logabsdetjac(ib, ys_true) + end for (x_true, y_true, logjac_true) in zip(xs_true, ys_true, logjacs_true) - test_bijector_reals(b, x_true, y_true, logjac_true; kwargs...) + # Test validity of single input. + test_bijector_single(b, x_true, y_true, logjac_true; kwargs...) - # Test AD - test_ad(x -> b(first(x)), [x_true, ]) + # Test AD wrt. inputs. + test_ad(x -> sum(b(x)), x_true) + test_ad(x -> logabsdetjac(b, x), x_true) + + if Bijectors.isinvertible(b) + y = b(x_true) + test_ad(x -> sum(ib(x)), y) + end + end - y = b(x_true) - test_ad(x -> ib(first(x)), [y, ]) + # Test AD wrt. parameters. + test_bijector_parameter_gradient(b, xs_true[1], ys_true[1]) - test_ad(x -> logabsdetjac(b, first(x)), [x_true, ]) + # Test validity of collection of inputs. + test_bijector_batch(b, xs_true, ys_true, logjacs_true; kwargs...) + + # AD testing for batch. + f, arg = make_gradient_function(x -> sum(sum(b.(x))), xs_true) + test_ad(f, arg) + f, arg = make_gradient_function(x -> sum(logabsdetjac.(b, x)), xs_true) + test_ad(f, arg) + + if Bijectors.isinvertible(b) + ys = b.(xs_true) + f, arg = make_gradient_function(y -> sum(sum(ib.(y))), ys) + test_ad(f, arg) end end +function make_gradient_function(f, xs::ArrayBatch) + s = size(Bijectors.value(xs)) -function test_bijector( - b::Bijector{1}, - xs_true::AbstractMatrix{<:Real}, - ys_true::AbstractMatrix{<:Real}, - logjacs_true::AbstractVector{<:Real}; - kwargs... -) - ib = inv(b) + function g(x) + x_batch = Bijectors.reconstruct(xs, reshape(x, s)) + return f(x_batch) + end - # Batch - test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) + return g, vec(Bijectors.value(xs)) +end - # Test `logabsdetjac` against jacobians - test_logabsdetjac(b, xs_true) - test_logabsdetjac(b, ys_true) - - for (x_true, y_true, logjac_true) in zip(eachcol(xs_true), eachcol(ys_true), logjacs_true) - # HACK: collect to avoid dealing with sub-arrays and thus allowing us to compare the - # type of the computed output to the "true" output. - test_bijector_arrays(b, collect(x_true), collect(y_true), logjac_true; kwargs...) +function make_gradient_function(f, xs::VectorBatch{<:AbstractArray{<:Real}}) + xs_new = vcat(map(vec, Bijectors.value(xs))) + n = length(xs_new) - # Test AD - test_ad(x -> sum(b(x)), collect(x_true)) - y = b(x_true) - test_ad(x -> sum(ib(x)), y) + s = size(Bijectors.value(xs[1])) + stride = n ÷ length(xs) - test_ad(x -> logabsdetjac(b, x), x_true) + function g(x) + x_vec = map(1:stride:n) do i + reshape(x[i:i + stride - 1], s) + end + + x_batch = Bijectors.reconstruct(xs, x_vec) + return f(x_batch) end + + return g, xs_new end -function test_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix; tol=1e-6) - logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)] - @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol +make_jacobian_function(f, xs::AbstractVector) = f, xs +function make_jacobian_function(f, xs::AbstractArray) + xs_new = vec(xs) + s = size(xs) + + function g(x) + return vec(f(reshape(x, s))) + end + + return g, xs_new end -function test_logabsdetjac(b::Bijector{0}, xs::AbstractVector; tol=1e-6) - logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs] - @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol +function test_logabsdetjac(b::Transform, xs::Batch{<:Any, <:AbstractArray}; tol=1e-6) + f, _ = make_jacobian_function(b, xs[1]) + logjac_ad = map(xs) do x + first(logabsdet(ForwardDiff.jacobian(f, x))) + end + + @test mean(collect(logabsdetjac.(b, xs)) - logjac_ad) ≤ tol +end + +function test_logabsdetjac(b::Transform, xs::Batch{<:Any, <:Real}; tol=1e-6) + logjac_ad = map(xs) do x + log(abs(ForwardDiff.derivative(b, x))) + end + @test mean(collect(logabsdetjac.(b, xs)) - logjac_ad) ≤ tol end # Check if `Functors.functor` works properly @@ -187,3 +217,19 @@ function test_functor(x, xs) @test x == re(_xs) @test _xs == xs end + +function test_bijector_parameter_gradient(b::Transform, x, y = b(x)) + args, re = Functors.functor(b) + recon(k, param) = re(merge(args, NamedTuple{(k, )}((param, )))) + + # Compute the gradient wrt. one argument at the time. + for (k, v) in pairs(args) + test_ad(p -> sum(transform(recon(k, p), x)), v) + test_ad(p -> logabsdetjac(recon(k, p), x), v) + + if Bijectors.isinvertible(b) + test_ad(p -> sum(transform(inv(recon(k, p)), y)), v) + test_ad(p -> logabsdetjac(inv(recon(k, p)), y), v) + end + end +end From 2f4d32822a9b0203c7e8e851b389d336e9344b31 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 6 Jun 2021 16:31:18 +0100 Subject: [PATCH 42/98] make Batch compatible with Zygote --- src/compat/zygote.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index b16ef58b..c971597b 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -291,3 +291,20 @@ end end +# Otherwise Zygote complains. +Zygote.@adjoint function Batch(x) + return Batch(x), function(_Δ) + # Sometimes `Batch` has been extraced using `value`, in which case + # we get back a `NamedTuple`. + # Other times the value was extracted by iterating over `Batch`, + # in which case we don't get a `NamedTuple`. + Δ = _Δ isa NamedTuple{(:value, )} ? _Δ.value : _Δ + return if Δ isa AbstractArray{<:Real} + (Δ, ) + else + Δ_new = similar(x, eltype(Δ[1])) + Δ_new[:] .= vcat(Δ...) + (Δ_new, ) + end + end +end From e5439f586becbd6973f44e94ff5d703166771f4e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 06:43:27 +0100 Subject: [PATCH 43/98] updated OrderedBijector --- src/bijectors/ordered.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl index d1bfd8f0..ec2b3ed8 100644 --- a/src/bijectors/ordered.jl +++ b/src/bijectors/ordered.jl @@ -7,7 +7,7 @@ A bijector mapping ordered vectors in ℝᵈ to unordered vectors in ℝᵈ. - [Stan's documentation](https://mc-stan.org/docs/2_27/reference-manual/ordered-vector.html) - Note that this transformation and its inverse are the _opposite_ of in this reference. """ -struct OrderedBijector <: Bijector{1} end +struct OrderedBijector <: Bijector end """ ordered(d::Distribution) @@ -16,7 +16,7 @@ Return a `Distribution` whose support are ordered vectors, i.e., vectors with in """ ordered(d::ContinuousMultivariateDistribution) = Bijectors.transformed(d, OrderedBijector()) -(b::OrderedBijector)(y::AbstractVecOrMat) = _transform_ordered(y) +transform(b::OrderedBijector, y::AbstractVecOrMat) = _transform_ordered(y) function _transform_ordered(y::AbstractVector) x = similar(y) @@ -45,8 +45,7 @@ function _transform_ordered(y::AbstractMatrix) return x end -(ib::Inverse{<:OrderedBijector})(x::AbstractVecOrMat) = _transform_inverse_ordered(x) - +transform(ib::Inverse{<:OrderedBijector}, x::AbstractVecOrMat) = _transform_inverse_ordered(x) function _transform_inverse_ordered(x::AbstractVector) y = similar(x) @assert !isempty(y) From 5681358caf4250534507e6ed3e94528b7dd5e41b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 24 Jan 2022 18:12:34 +0000 Subject: [PATCH 44/98] temporary stuff --- Project.toml | 1 + src/Bijectors.jl | 56 +++-- src/bijectors/composed.jl | 307 +++-------------------- src/bijectors/exp_log.jl | 91 +++++-- src/bijectors/logit.jl | 24 +- src/interface.jl | 497 +++++++++++++++++++++++++------------- 6 files changed, 491 insertions(+), 485 deletions(-) diff --git a/Project.toml b/Project.toml index 4d25f7cb..90b3abac 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.10.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +Batching = "a4738bf4-2ff1-4f03-87a2-28fa6f9d5d14" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 45151107..5ba75b9c 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -36,6 +36,9 @@ using ConstructionBase using Base.Iterators: drop using LinearAlgebra: AbstractTriangular +using Batching +export batch + import ChangesOfVariables: with_logabsdet_jacobian import InverseFunctions: inverse @@ -55,10 +58,13 @@ export TransformDistribution, logpdf_with_trans, isclosedform, transform, + transform!, with_logabsdet_jacobian, inverse, forward, + forward!, logabsdetjac, + logabsdetjac!, logabsdetjacinv, Bijector, ADBijector, @@ -77,7 +83,9 @@ export TransformDistribution, PlanarLayer, RadialLayer, CouplingLayer, - InvertibleBatchNorm + InvertibleBatchNorm, + Elementwise, + elementwise if VERSION < v"1.1" using Compat: eachcol @@ -252,13 +260,13 @@ function getlogp(d::InverseWishart, Xcf, X) end include("utils.jl") -include("batch.jl") +# include("batch.jl") include("interface.jl") -include("chainrules.jl") +# include("chainrules.jl") -Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) +Base.@deprecate forward(b::Transform, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) -@noinline function Base.inv(b::AbstractBijector) +@noinline function Base.inv(b::Transform) Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv) inverse(b) end @@ -267,24 +275,24 @@ end maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) -# optional dependencies -function __init__() - @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin - function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x...)) - end - function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x2, x...)) - end - function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) - return copy(f.(x1, x2, x3, x...)) - end - end - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl") - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl") - @require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl") -end +# # optional dependencies +# function __init__() +# @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin +# function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) +# return copy(f.(x1, x...)) +# end +# function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) +# return copy(f.(x1, x2, x...)) +# end +# function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) +# return copy(f.(x1, x2, x3, x...)) +# end +# end +# @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl") +# @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") +# @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") +# @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl") +# @require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl") +# end end # module diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index dab9e519..a3c07158 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -1,290 +1,59 @@ -############### -# Composition # -############### +invertible(cb::ComposedFunction) = invertible(cb.inner) + invertible(cb.outer) +isclosedform(cb::ComposedFunction) = isclosedform(cb.inner) && isclosedform(cb.outer) -""" - Composed(ts::A) +transform_single(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) +transform_multiple(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) - ∘(b1::Bijector{N}, b2::Bijector{N})::Composed{<:Tuple} - composel(ts::Bijector{N}...)::Composed{<:Tuple} - composer(ts::Bijector{N}...)::Composed{<:Tuple} - -where `A` refers to either -- `Tuple{Vararg{<:Bijector{N}}}`: a tuple of bijectors of dimensionality `N` -- `AbstractArray{<:Bijector{N}}`: an array of bijectors of dimensionality `N` - -A `Bijector` representing composition of bijectors. `composel` and `composer` results in a -`Composed` for which application occurs from left-to-right and right-to-left, respectively. - -Note that all the alternative ways of constructing a `Composed` returns a `Tuple` of bijectors. -This ensures type-stability of implementations of all relating methdos, e.g. `inverse`. - -If you want to use an `Array` as the container instead you can do - - Composed([b1, b2, ...]) - -In general this is not advised since you lose type-stability, but there might be cases -where this is desired, e.g. if you have a insanely large number of bijectors to compose. - -# Examples -## Simple example -Let's consider a simple example of `Exp`: -```julia-repl -julia> using Bijectors: Exp - -julia> b = Exp() -Exp{0}() - -julia> b ∘ b -Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}())) - -julia> (b ∘ b)(1.0) == exp(exp(1.0)) # evaluation -true - -julia> inverse(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion -true - -julia> logabsdetjac(b ∘ b, 1.0) # determinant of jacobian -3.718281828459045 -``` - -# Notes -## Order -It's important to note that `∘` does what is expected mathematically, which means that the -bijectors are applied to the input right-to-left, e.g. first applying `b2` and then `b1`: -```julia -(b1 ∘ b2)(x) == b1(b2(x)) # => true -``` -But in the `Composed` struct itself, we store the bijectors left-to-right, so that -```julia -cb1 = b1 ∘ b2 # => Composed.ts == (b2, b1) -cb2 = composel(b2, b1) # => Composed.ts == (b2, b1) -cb1(x) == cb2(x) == b1(b2(x)) # => true -``` - -## Structure -`∘` will result in "flatten" the composition structure while `composel` and -`composer` preserve the compositional structure. This is most easily seen by an example: -```julia-repl -julia> b = Exp() -Exp{0}() - -julia> cb1 = b ∘ b; cb2 = b ∘ b; - -julia> (cb1 ∘ cb2).ts # <= different -(Exp{0}(), Exp{0}(), Exp{0}(), Exp{0}()) - -julia> (cb1 ∘ cb2).ts isa NTuple{4, Exp{0}} -true - -julia> Bijectors.composer(cb1, cb2).ts -(Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}())), Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}()))) - -julia> Bijectors.composer(cb1, cb2).ts isa Tuple{Composed, Composed} -true -``` - -""" -struct Composed{A} <: Transform - ts::A -end - -# field contains nested numerical parameters -Functors.@functor Composed - -invertible(cb::Composed) = sum(map(invertible, cb.ts)) - -isclosedform(b::Composed) = all(isclosedform, b.ts) - -function Base.:(==)(b1::Composed, b2::Composed) - ts1, ts2 = b1.ts, b2.ts - return length(ts1) == length(ts2) && all(x == y for (x, y) in zip(ts1, ts2)) -end - -""" - composel(ts::Transform...)::Composed{<:Tuple} - -Constructs `Composed` such that `ts` are applied left-to-right. -""" -composel(ts::Transform...) = Composed(ts) - -""" - composer(ts::Transform...)::Composed{<:Tuple} - -Constructs `Composed` such that `ts` are applied right-to-left. -""" -composer(ts::Transform...) = Composed(reverse(ts)) - -# The transformation of `Composed` applies functions left-to-right -# but in mathematics we usually go from right-to-left; this reversal ensures that -# when we use the mathematical composition ∘ we get the expected behavior. -# TODO: change behavior of `transform` of `Composed`? -∘(b1::Transform, b2::Transform) = composel(b2, b1) - -# type-stable composition rules -∘(b1::Composed{<:Tuple}, b2::Transform) = composel(b2, b1.ts...) -∘(b1::Transform, b2::Composed{<:Tuple}) = composel(b2.ts..., b1) -∘(b1::Composed{<:Tuple}, b2::Composed{<:Tuple}) = composel(b2.ts..., b1.ts...) - -# type-unstable composition rules -∘(b1::Composed{<:AbstractArray}, b2::Transform) = Composed(pushfirst!(copy(b1.ts), b2)) -∘(b1::Transform, b2::Composed{<:AbstractArray}) = Composed(push!(copy(b2.ts), b1)) -function ∘(b1::Composed{<:AbstractArray}, b2::Composed{<:AbstractArray}) - return Composed(append!(copy(b2.ts), copy(b1.ts))) -end - -# if combining type-unstable and type-stable, return type-unstable -function ∘(b1::T1, b2::T2) where {T1<:Composed{<:Tuple}, T2<:Composed{<:AbstractArray}} - error("Cannot compose compositions of different container-types; ($T1, $T2)") -end -function ∘(b1::T1, b2::T2) where {T1<:Composed{<:AbstractArray}, T2<:Composed{<:Tuple}} - error("Cannot compose compositions of different container-types; ($T1, $T2)") -end - - -∘(::Identity, ::Identity) = Identity() -∘(::Identity, b::Transform) = b -∘(b::Transform, ::Identity) = b - -inverse(ct::Composed) = Composed(reverse(map(inverse, ct.ts))) - -# TODO: should arrays also be using recursive implementation instead? -function transform(cb::Composed, x) - @assert length(cb.ts) > 0 - res = cb.ts[1](x) - for b ∈ Base.Iterators.drop(cb.ts, 1) - res = b(res) - end - - return res +function transform_single!(cb::ComposedFunction, x, y) + transform!(cb.inner, x, y) + return transform!(cb.outer, y, y) end -function transform_batch(cb::Composed, x) - @assert length(cb.ts) > 0 - res = cb.ts[1].(x) - for b ∈ Base.Iterators.drop(cb.ts, 1) - res = transform_batch(b, res) - end - - return res +function transform_multiple!(cb::ComposedFunction, x, y) + transform!(cb.inner, x, y) + return transform!(cb.outer, y, y) end -@generated function transform(cb::Composed{T}, x) where {T<:Tuple} - @assert length(T.parameters) > 0 - expr = :(x) - for i in 1:length(T.parameters) - expr = :(cb.ts[$i]($expr)) - end - return expr +function logabsdetjac_single(cb::ComposedFunction, x) + y, logjac = forward(cb.inner, x) + return logabsdetjac(cb.outer, y) + logjac end -@generated function transform_batch(cb::Composed{T}, x) where {T<:Tuple} - @assert length(T.parameters) > 0 - expr = :(x) - for i in 1:length(T.parameters) - expr = :(transform_batch(cb.ts[$i], $expr)) - end - return expr +function logabsdetjac_multiple(cb::ComposedFunction, x) + y, logjac = forward(cb.inner, x) + return logabsdetjac(cb.outer, y) + logjac end -function logabsdetjac(cb::Composed, x) - y, logjac = with_logabsdet_jacobian(cb.ts[1], x) - for i = 2:length(cb.ts) - y, res_logjac = with_logabsdet_jacobian(cb.ts[i], y) - logjac += res_logjac - end - - return logjac +function logabsdetjac_single!(cb::ComposedFunction, x, logjac) + y = similar(x) + forward!(cb.inner, x, y, logjac) + return logabdetjac!(cb.outer, y, y, logjac) end -function logabsdetjac_batch(cb::Composed, x) - init = forward_batch(cb.ts[1], x) - result = reduce(cb.ts[2:end]; init = init) do (y, logjac), b - return forward_batch(b, y) - end - - return result.logabsdetjac -end - - -@generated function logabsdetjac(cb::Composed{T}, x) where {T<:Tuple} - N = length(T.parameters) - - expr = Expr(:block) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - for i = 2:N - 1 - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.ts[$i], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - end - # don't need to evaluate the last bijector, only it's `logabsdetjac` - sym_ladj, sym_tmp_ladj = gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :($sym_tmp_ladj = logabsdetjac(cb.ts[$N], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - push!(expr.args, :(return $sym_ladj)) - - return expr +function logabsdetjac_multiple!(cb::ComposedFunction, x, logjac) + y = similar(x) + forward!(cb.inner, x, y, logjac) + return logabdetjac!(cb.outer, y, y, logjac) end -""" - logabsdetjac_batch(cb::Composed{<:Tuple}, x) - -Generates something of the form -```julia -quote - (y, logjac_1) = forward_batch(cb.ts[1], x) - logjac_2 = logabsdetjac_batch(cb.ts[2], y) - return logjac_1 + logjac_2 +function forward_single(cb::ComposedFunction, x) + y1, logjac1 = forward(cb.inner, x) + y2, logjac2 = forward(cb.outer, y1) + return y2, logjac1 + logjac2 end -``` -""" -@generated function logabsdetjac_batch(cb::Composed{T}, x) where {T<:Tuple} - N = length(T.parameters) - - expr = Expr(:block) - push!(expr.args, :((y, logjac_1) = forward_batch(cb.ts[1], x))) - - for i = 2:N - 1 - temp = gensym(:res) - push!(expr.args, :($temp = forward_batch(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.result)) - push!(expr.args, :($(Symbol("logjac_$i")) = $temp.logabsdetjac)) - end - # don't need to evaluate the last bijector, only it's `logabsdetjac` - push!(expr.args, :($(Symbol("logjac_$N")) = logabsdetjac_batch(cb.ts[$N], y))) - sum_expr = Expr(:call, :+, [Symbol("logjac_$i") for i = 1:N]...) - push!(expr.args, :(return $(sum_expr))) - - return expr +function forward_multiple(cb::ComposedFunction, x) + y1, logjac1 = forward(cb.inner, x) + y2, logjac2 = forward(cb.outer, y1) + return y2, logjac1 + logjac2 end - -function with_logabsdet_jacobian(cb::Composed, x) - rv, logjac = with_logabsdet_jacobian(cb.ts[1], x) - - for t in cb.ts[2:end] - rv, res_logjac = with_logabsdet_jacobian(t, rv) - logjac += res_logjac - end - return (rv, logjac) +function forward_single!(cb::ComposedFunction, x, y, logjac) + forward!(cb.inner, x, y, logjac) + return forward!(cb.outer, y, y, logjac) end -@generated function with_logabsdet_jacobian(cb::Composed{T}, x) where {T<:Tuple} - expr = Expr(:block) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - for i = 2:length(T.parameters) - sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) - push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.ts[$i], $sym_last_y))) - push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) - sym_last_y, sym_last_ladj = sym_y, sym_ladj - end - push!(expr.args, :(return ($sym_y, $sym_ladj))) - - return expr +function forward_multiple!(cb::ComposedFunction, x, y, logjac) + forward!(cb.inner, x, y, logjac) + return forward!(cb.outer, y, y, logjac) end diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index d3794a59..92197f80 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -1,20 +1,83 @@ -############# -# Exp & Log # -############# +transform_single!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) -struct Exp <: Bijector end -struct Log <: Bijector end +transform_multiple(b::Union{typeof(log), typeof(exp)}, x) = b.(x) +transform_multiple!(b::Union{typeof(log), typeof(exp)}, x, y) = broadcast!(b, y, x) -inv(b::Exp) = Log() -inv(b::Log) = Exp() +logabsdetjac_single(b::typeof(exp), x::Real) = x +logabsdetjac_single(b::Elementwise{typeof(exp)}, x) = sum(x) -transform(b::Exp, y) = exp.(y) -transform(b::Log, x) = log.(x) +logabsdetjac_single(b::typeof(log), x::Real) = -log(x) +logabsdetjac_single(b::Elementwise{typeof(log)}, x) = -sum(log, x) -logabsdetjac(b::Exp, x) = sum(x) -logabsdetjac(b::Log, x) = -sum(log, x) +logabsdetjac_multiple(b::typeof(exp), xs) = xs +logabsdetjac_multiple(b::Elementwise{typeof(exp)}, xs) = map(sum, xs) -function forward(b::Log, x) - y = transform(b, x) - return (result = y, logabsdetjac = -sum(y)) +logabsdetjac_multiple(b::typeof(log), xs) = -map(log, xs) +logabsdetjac_multiple(b::Elementwise{typeof(log)}, xs) = -map(sum ∘ log, xs) + +function forward_single(b::typeof(exp), x::Real) + y = b(x) + return y, x +end +function forward_single(b::Elementwise{typeof(exp)}, x) + y = b(x) + return y, sum(x) +end + +function forward_multiple(b::typeof(exp), xs::AbstractBatch{<:Real}) + ys = transform(b, xs) + return ys, xs +end +function forward_multiple!( + b::typeof(exp), + xs::AbstractBatch{<:Real}, + ys::AbstractBatch{<:Real}, + logjacs::AbstractBatch{<:Real} +) + transform!(b, xs, ys) + logjacs += xs + return ys, logjacs +end +function forward_multiple(b::Elementwise{typeof(exp)}, xs) + ys = transform(b, xs) + return ys, map(sum, xs) +end +function forward_multiple!(b::Elementwise{typeof(exp)}, xs, ys, logjacs) + # Do this before `transform!` in case `xs === ys`. + logjacs += map(sum, xs) + transform!(b, xs, ys) + return ys, logjacs +end + +function forward_single(b::typeof(log), y::Real) + x = transform(b, y) + return x, -x +end +function forward_single(b::Elementwise{typeof(log)}, y) + x = transform(b, y) + return x, -sum(x) +end + +function forward_multiple(b::typeof(log), ys::AbstractBatch{<:Real}) + xs = transform(b, ys) + return xs, -xs +end +function forward_multiple!( + b::typeof(log), + ys::AbstractBatch{<:Real}, + xs::AbstractBatch{<:Real}, + logjacs::AbstractBatch{<:Real} +) + transform!(b, ys, xs) + logjacs -= xs + return xs, logjacs +end +function forward_multiple(b::Elementwise{typeof(log)}, ys) + xs = transform(b, ys) + return xs, -map(sum, xs) +end +function forward_multiple!(b::Elementwise{typeof(log)}, ys, xs, logjacs) + transform!(b, ys, xs) + logjacs -= map(sum, xs) + return xs, logjacs end diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index 9d26de5d..17afd125 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -1,8 +1,6 @@ ###################### # Logit and Logistic # ###################### -using StatsFuns: logit, logistic - struct Logit{T} <: Bijector a::T b::T @@ -13,11 +11,23 @@ Functors.@functor Logit # For equality of Logit with Float64 fields to one with Duals Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b -transform(b::Logit, x) = _logit.(x, b.a, b.b) -_logit(x, a, b) = logit((x - a) / (b - a)) +# TODO: Implement `forward` and batched versions. + +# Evaluation +_logit(x, a, b) = LogExpFunctions.logit((x - a) / (b - a)) +transform_single(b::Logit, x) = _logit.(x, b.a, b.b) +function transform_multiple(b::Logit, xs::Batching.ArrayBatch{<:Real}) + return batch_like(xs, _logit.(Batching.value(xs))) +end -transform(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) -_ilogit(y, a, b) = (b - a) * logistic(y) + a +# Inverse +_ilogit(y, a, b) = (b - a) * LogExpFunctions.logistic(y) + a + +transform_single(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) +function transform_multiple(ib::Inverse{<:Logit}, ys::Batching.ArrayBatch) + return batch_like(ys, _ilogit.(Batching.value(ys))) +end -logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) +# `logabsdetjac` logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) +logabsdetjac_single(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) diff --git a/src/interface.jl b/src/interface.jl index 809022df..8e96a27f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -3,6 +3,21 @@ import Base: ∘ import Random: AbstractRNG import Distributions: logpdf, rand, rand!, _rand!, _logpdf +const Elementwise{F} = Base.Fix1{<:Union{typeof(map),typeof(broadcast)}, F} +""" + elementwise(f) + +Alias for `Base.Fix1(broadcast, f)`. + +In the case where `f::ComposedFunction`, the result is +`Base.Fix1(broadcast, f.outer) ∘ Base.Fix1(broadcast, f.inner)` rather than +`Base.Fix1(broadcast, f)`. +""" +elementwise(f) = Base.Fix1(broadcast, f) +# TODO: This is makes dispatching quite a bit easier, but uncertain if this is really +# the way to go. +elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner)) + ####################################### # AD stuff "extracted" from Turing.jl # ####################################### @@ -35,90 +50,321 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t Abstract type for a transformation. -## Implementing +# Implementing -A subtype of `Transform` of should at least implement `transform(b, x)`. +A subtype of `Transform` of should at least implement [`transform(b, x)`](@ref). If the `Transform` is also invertible: - Required: - [`invertible`](@ref): should return [`Invertible`](@ref). - _Either_ of the following: - `transform(::Inverse{<:MyTransform}, x)`: the `transform` for its inverse. - - `Base.inv(b::MyTransform)`: returns an existing `Transform`. + - `InverseFunctions.inverse(b::MyTransform)`: returns an existing `Transform`. - [`logabsdetjac`](@ref): computes the log-abs-det jacobian factor. - Optional: - [`forward`](@ref): `transform` and `logabsdetjac` combined. Useful in cases where we can exploit shared computation in the two. For the above methods, there are mutating versions which can _optionally_ be implemented: -- [`transform!`](@ref) -- [`logabsdetjac!`](@ref) -- [`forward!`](@ref) +- [`transform_single!`](@ref) +- [`logabsdetjac_single!`](@ref) +- [`forward_single!`](@ref) Finally, there are _batched_ versions of the above methods which can _optionally_ be implemented: -- [`transform_batch`](@ref) -- [`logabsdetjac_batch`](@ref) -- [`forward_batch`](@ref) +- [`transform_multiple`](@ref) +- [`logabsdetjac_multiple`](@ref) +- [`forward_multiple`](@ref) and similarly for the mutating versions. Default implementations depends on the type of `xs`. -Note that these methods are usually used through broadcasting, i.e. `b.(x)` with `x` a `AbstractBatch` -falls back to `transform_batch(b, x)`. """ abstract type Transform end +(t::Transform)(x) = transform(t, x) + Broadcast.broadcastable(b::Transform) = Ref(b) """ transform(b, x) +Transform `x` using `b`, treating `x` as a single input. + +Alias for [`transform_single(b, x)`](@ref). +""" +transform(b, x) = transform_single(b, x) + +""" + transform_single(b, x) + Transform `x` using `b`. -Alternatively, one can just call `b`, i.e. `b(x)`. +Defaults to `b(x)` if `b isa Function` and `first(forward(b, x))` otherwise. """ -transform -(t::Transform)(x) = transform(t, x) +transform_single(b, x) = first(forward(b, x)) +transform_single(b::Function, x) = b(x) + +""" + transform(b, xs::AbstractBatch) + +Transform `xs` using `b`, treating `xs` as a collection of inputs. +Alias for [`transform_multiple(b, x)`](@ref). """ - transform!(b, x, y) +transform(b, xs::AbstractBatch) = transform_multiple(b, xs) -Transforms `x` using `b`, storing the result in `y`. """ -transform!(b, x, y) = (y .= transform(b, x)) + transform_multiple(b, xs) + +Transform `xs` using `b`, treating `xs` as a collection of inputs. + +Defaults to `map(Base.Fix1(transform, b), xs)`. +""" +transform_multiple(b, xs) = map(Base.Fix1(transform, b), xs) +function transform_multiple(b::Elementwise, x::Batching.ArrayBatch) + return batch_like(x, transform(b, Batching.value(x))) +end + + +""" + transform!(b, x[, y]) + +Transform `x` using `b`, storing the result in `y`. + +If `y` is not provided, `x` is used as the output. + +Alias for [`transform_single!(b, x, y)`](@ref). +""" +transform!(b, x, y=x) = transform_single!(b, x, y) + +""" + transform_single!(b, x, y) + +Transform `x` using `b`, storing the result in `y`. +""" +transform_single!(b, x, y) = (y .= transform(b, x)) + +""" + transform!(b, xs::AbstractBatch[, ys::AbstractBatch]) + +Transform `x` for `x` in `xs` using `b`, storing the result in `ys`. + +If `ys` is not provided, `xs` is used as the output. + +Alias for [`transform_multiple!(b, xs, ys)`](@ref). +""" +transform!(b, xs::AbstractBatch, ys::AbstractBatch=xs) = transform_multiple!(b, xs, ys) + +""" + transform_multiple!(b, xs::AbstractBatch[, ys::AbstractBatch]) + +Transform `x` for `x` in `xs` using `b`, storing the result in `ys`. +""" +transform_multiple!(b, xs, ys) = broadcast!(Base.Fix1(transform, b), xs, ys) +function transform_multiple!(b::Elementwise, x::Batching.ArrayBatch, y::Batching.ArrayBatch) + broadcast!(b, Batching.value(y), Batching.value(x)) + return y +end """ logabsdetjac(b, x) -Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform. +Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. + +Alias for [`logabsdetjac_single`](@ref). + +See also: [`logabsdetjac(b, xs::AbstractBatch)`](@ref). +""" +logabsdetjac(b, x) = logabsdetjac_single(b, x) + +""" + logabsdetjac_single(b, x) + +Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. + +Defaults to `last(forward(b, x))`. +""" +logabsdetjac_single(b, x) = last(forward(b, x)) + +""" + logabsdetjac(b, xs::AbstractBatch) + +Return a collection representing `log(abs(det(J(b, x))))` for every `x` in `xs`, +where `J(b, x)` is the jacobian of `b` at `x`. + +Alias for [`Bijectors.logabsdetjac_multiple`](@ref). + +See also: [`logabsdetjac(b, x)`](@ref). +""" +logabsdetjac(b, xs::AbstractBatch) = logabsdetjac_multiple(b, xs) + +""" + logabsdetjac_multiple(b, xs) + +Return a collection representing `log(abs(det(J(b, x))))` for every `x` in `xs`, +where `J(b, x)` is the jacobian of `b` at `x`. + +Defaults to `map(Base.Fix1(logabsdetjac, b), xs)`. +""" +logabsdetjac_multiple(b, xs) = map(Base.Fix1(logabsdetjac, b), xs) + +""" + logabsdetjac!(b, x[, logjac]) + +Compute `log(abs(det(J(b, x))))` and store the result in `logjac`, where `J(b, x)` is the jacobian of `b` at `x`. + +Alias for [`logabsdetjac_single!(b, x, logjac)`](@ref). """ -logabsdetjac +logabsdetjac!(b, x, logjac=zero(eltype(x))) = logabsdetjac_single!(b, x, logjac) """ - logabsdetjac!(b, x, logjac) + logabsdetjac_single!(b, x[, logjac]) -Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform, -_accumulating_ the result in `logjac`. +Compute `log(abs(det(J(b, x))))` and accumulate the result in `logjac`, +where `J(b, x)` is the jacobian of `b` at `x`. """ -logabsdetjac!(b, x, logjac) = (logjac += logabsdetjac(b, x)) +logabsdetjac_single!(b, x, logjac) = (logjac += logabsdetjac(b, x)) + +""" + logabsdetjac!(b, xs::AbstractBatch[, logjacs::AbstractBatch]) + +Compute `log(abs(det(J(b, x))))` and store the result in `logjacs` for +every `x` in `xs`, where `J(b, x)` is the jacobian of `b` at `x`. + +Alias for [`logabsdetjac_single!(b, x, logjac)`](@ref). +""" +function logabsdetjac!( + b, + xs::AbstractBatch, + logjacs::AbstractBatch=batch_like(xs, zeros(eltype(eltype(xs)), length(xs))) +) + return logabsdetjac_multiple!(b, xs, logjacs) +end + +""" + logabsdetjac_multiple!(b, xs::AbstractBatch, logjacs::AbstractBatch) + +Compute `log(abs(det(J(b, x))))` and store the result in `logjacs` for +every `x` in `xs`, where `J(b, x)` is the jacobian of `b` at `x`. +""" +logabsdetjac_multiple!(b, xs, logjacs) = (logjacs .+= logabsdetjac(b, xs)) """ forward(b, x) -Computes both `transform` and `logabsdetjac` in one forward pass, and -returns a named tuple `(b(x), logabsdetjac(b, x))`. +Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as single input. + +Alias for [`forward_single(b, x)`](@ref). + +See also: [`forward(b, xs::AbstractBatch)`](@ref). +""" +forward(b, x) = forward_single(b, x) + +""" + forward_single(b, x) + +Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as a single input. + +Defaults to `ChangeOfVariables.with_logabsdet_jacobian(b, x)`. +""" +forward_single(b, x) = with_logabsdet_jacobian(b, x) + +""" + forward(b, xs::AbstractBatch) + +Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as +collection of inputs. -This defaults to the call above, but often one can re-use computation -in the computation of the forward pass and the computation of the -`logabsdetjac`. `forward` allows the user to take advantange of such -efficiencies, if they exist. +Alias for [`forward_multiple(b, xs)`](@ref). + +See also: [`forward(b, x)`](@ref). +""" +forward(b, xs::AbstractBatch) = forward_multiple(b, xs) + +""" + forward_multiple(b, xs) + +Return `(transform(b, xs), logabsdetjac(b, xs))` treating +`xs` as a batch, i.e. a collection inputs. + +See also: [`forward_single(b, x)`](@ref). +""" +function forward_multiple(b, xs) + # If `b` doesn't have its own definition of `forward_multiple` + # we just broadcast `forward_single`, resulting in a batch of `(y, logjac)` + # pairs which we then unwrap. + results = forward.(Ref(b), xs) + ys = map(first, results) + logjacs = map(last, results) + return batch_like(xs, ys, logjacs) +end + +# `logjac` as an argument doesn't make too much sense for `forward_single!` +# when the inputs have `eltype` `Real`. +""" + forward!(b, x[, y, logjac]) + +Compute `transform(b, x)` and `logabsdetjac(b, x)`, storing the result +in `y` and `logjac`, respetively. + +If `y` is not provided, then `x` will be used in its place. + +Alias for [`forward_single!(b, x, y, logjac)`](@ref). + +See also: [`forward!(b, xs::AbstractBatch, ys::AbstractBatch, logjacs::AbstractBatch)`](@ref). """ -forward(b, x) = (result = transform(b, x), logabsdetjac = logabsdetjac(b, x)) +forward!(b, x, y=x, logjac=zero(eltype(x))) = forward_single!(b, x, y, logjac) -function forward!(b, x, out) - y, logjac = forward(b, x) - out.result .= y - out.logabsdetjac .+= logjac +""" + forward_single!(b, x, y, logjac) + +Compute `transform(b, x)` and `logabsdetjac(b, x)`, storing the result +in `y` and `logjac`, respetively. - return out +Defaults to calling `forward(b, x)` and updating `y` and `logjac` with the result. +""" +function forward_single!(b, x, y, logjac) + y_, logjac_ = forward(b, x) + y .= y_ + return (y, logjac + logjac_) +end + +""" + forward!(b, xs[, ys, logjacs]) + +Compute `transform(b, x)` and `logabsdetjac(b, x)` for every `x` in the collection `xs`, +storing the results in `ys` and `logjacs`, respetively. + +If `ys` is not provided, then `xs` will be used in its place. + +Alias for [`forward_multiple!(b, xs, ys, logjacs)`](@ref). + +See also: [`forward!(b, x, y, logjac)`](@ref). +""" +function forward!( + b, + xs::AbstractBatch, + ys::AbstractBatch=xs, + logjacs::AbstractBatch=batch_like(xs, zeros(eltype(eltype(xs)), length(xs))) +) + return forward_multiple!(b, xs, ys, logjacs) +end + +""" + forward!(b, xs, ys, logjacs) + +Compute `transform(b, x)` and `logabsdetjac(b, x)` for every `x` in the collection `xs`, +storing the results in `ys` and `logjacs`, respetively. + +Defaults to iterating through the `xs` and calling [`forward_single(b, x)`](@ref) +for every `x` in `xs`. +""" +function forward_multiple!(b, xs, ys, logjacs) + for i in eachindex(ys) + res = forward_single(b, xs[i]) + ys[i] .= first(res) + logjacs[i] += last(res) + end + + return ys, logjacs end """ @@ -132,7 +378,7 @@ Most transformations have closed-form evaluations, but there are cases where this is not the case. For example the *inverse* evaluation of `PlanarLayer` requires an iterative procedure to evaluate. """ -isclosedform(b::Transform) = true +isclosedform(t::Transform) = true # Invertibility "trait". struct NotInvertible end @@ -144,11 +390,22 @@ Base.:+(::Invertible, ::NotInvertible) = NotInvertible() Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() Base.:+(::Invertible, ::Invertible) = Invertible() +""" + invertible(t) + +Return `Invertible()` if `t` is invertible, and `NotInvertible()` otherwise. +""" invertible(::Transform) = NotInvertible() + +""" + isinvertible(t) + +Return `true` if `t` is invertible, and `false` otherwise. +""" isinvertible(t::Transform) = invertible(t) isa Invertible """ - inv(b::Transform) + inverse(b::Transform) Inverse(b::Transform) A `Transform` representing the inverse transform of `b`. @@ -168,12 +425,12 @@ end Functors.@functor Inverse """ - inv(t::Transform[, ::Invertible]) + inverse(t::Transform[, ::Invertible]) Returns the inverse of transform `t`. """ -Base.inv(t::Transform) = Inverse(t) -Base.inv(ib::Inverse) = ib.orig +inverse(t::Transform) = Inverse(t) +inverse(ib::Inverse) = ib.orig invertible(ib::Inverse) = Invertible() @@ -197,157 +454,55 @@ logabsdetjacinv(b::Bijector, y) = logabsdetjac(inverse(b), y) ############################## # Example bijector: Identity # ############################## +# Here we don't need to separate between batched version and non-batched, and so +# we can just overload `transform`, etc. directly. +transform(::typeof(identity), x) = copy(x) +transform!(::typeof(identity), x, y) = copy!(y, x) -struct Identity <: Bijector end -inv(b::Identity) = b +logabsdetjac_single(::typeof(identity), x) = zero(eltype(x)) +logabsdetjac_multiple(::typeof(identity), x) = batch_like(x, zeros(eltype(eltype(x)), length(x))) -transform(::Identity, x) = copy(x) -transform!(::Identity, x, y) = (y .= x; return y) -logabsdetjac(::Identity, x) = zero(eltype(x)) -logabsdetjac!(::Identity, x, logjac) = logjac +logabsdetjac_single!(::typeof(identity), x, logjac) = logjac +logabsdetjac_multiple!(::typeof(identity), x, logjac) = logjac #################### # Batched versions # #################### # NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. # so we can actually reference them. Since we just did this for `Identity`, we're good. -Broadcast.broadcasted(b::Transform, xs::Batch) = transform_batch(b, xs) -Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_batch(b, xs) -Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_batch(b, xs) -Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_batch(b, xs) - - -""" - transform_batch(b, xs) - -Transform `xs` by `b`, treating `xs` as a "batch", i.e. a collection of independent inputs. - -See also: [`transform`](@ref) -""" -transform_batch(b, xs) = _transform_batch(b, xs) -# Default implementations uses private methods to avoid method ambiguity. -_transform_batch(b, xs::VectorBatch) = reconstruct(xs, map(b, value(xs))) -function _transform_batch(b, xs::ArrayBatch{2}) - # TODO: Check if we can avoid using these custom methods. - return Batch(eachcolmaphcat(b, value(xs))) -end -function _transform_batch(b, xs::ArrayBatch{N}) where {N} - res = reduce(map(b, eachslice(value(xs), Val{N}()))) do acc, x - cat(acc, x; dims = N) - end - return reconstruct(xs, res) -end - -""" - transform_batch!(b, xs, ys) - -Transform `xs` by `b` treating `xs` as a "batch", i.e. a collection of independent inputs, -and storing the result in `ys`. - -See also: [`transform!`](@ref) -""" -transform_batch!(b, xs, ys) = _transform_batch!(b, xs, ys) -function _transform_batch!(b, xs, ys) - for i = 1:length(xs) - if eltype(ys) <: Real - ys[i] = transform(b, xs[i]) - else - transform!(b, xs[i], ys[i]) - end - end - - return ys -end - -""" - logabsdetjac_batch(b, xs) - -Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs. - -See also: [`logabsdetjac`](@ref) -""" -logabsdetjac_batch(b, xs) = _logabsdetjac_batch(b, xs) -# Default implementations uses private methods to avoid method ambiguity. -_logabsdetjac_batch(b, xs::VectorBatch) = reconstruct(xs, map(x -> logabsdetjac(b, x), value(xs))) -function _logabsdetjac_batch(b, xs::ArrayBatch{2}) - return reconstruct(xs, map(x -> logabsdetjac(b, x), eachcol(value(xs)))) -end -function _logabsdetjac_batch(b, xs::ArrayBatch{N}) where {N} - return reconstruct(xs, map(x -> logabsdetjac(b, x), eachslice(value(xs), Val{N}()))) -end - -""" - logabsdetjac_batch!(b, xs, logjacs) - -Computes `logabsdetjac(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs, -accumulating the result in `logjacs`. - -See also: [`logabsdetjac!`](@ref) -""" -logabsdetjac_batch!(b, xs, logjacs) = _logabsdetjac_batch!(b, xs, logjacs) -function _logabsdetjac_batch!(b, xs, logjacs) - for i = 1:length(xs) - if eltype(logjacs) <: Real - logjacs[i] += logabsdetjac(b, xs[i]) - else - logabsdetjac!(b, xs[i], logjacs[i]) - end - end - - return logjacs -end - -""" - forward_batch(b, xs) - -Computes `forward(b, xs)`, treating `xs` as a "batch", i.e. a collection of independent inputs. - -See also: [`transform`](@ref) -""" -forward_batch(b, xs) = (result = transform_batch(b, xs), logabsdetjac = logabsdetjac_batch(b, xs)) - -""" - forward_batch!(b, xs, out) - -Computes `forward(b, xs)` in place, treating `xs` as a "batch", i.e. a collection of independent inputs. - -See also: [`forward!`](@ref) -""" -function forward_batch!(b, xs, out) - transform_batch!(b, xs, out.result) - logabsdetjac_batch!(b, xs, out.logabsdetjac) - - return out -end +# Broadcast.broadcasted(b::Transform, xs::Batch) = transform_multiple(b, xs) +# Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_multiple(b, xs) +# Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_multiple(b, xs) +# Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_multiple(b, xs) ###################### # Bijectors includes # ###################### # General -include("bijectors/adbijector.jl") +# include("bijectors/adbijector.jl") include("bijectors/composed.jl") -include("bijectors/stacked.jl") +# include("bijectors/stacked.jl") # Specific include("bijectors/exp_log.jl") include("bijectors/logit.jl") -include("bijectors/scale.jl") -include("bijectors/shift.jl") -include("bijectors/permute.jl") -include("bijectors/simplex.jl") -include("bijectors/pd.jl") -include("bijectors/corr.jl") -include("bijectors/truncated.jl") -include("bijectors/named_bijector.jl") -include("bijectors/ordered.jl") +# include("bijectors/scale.jl") +# include("bijectors/shift.jl") +# include("bijectors/permute.jl") +# include("bijectors/simplex.jl") +# include("bijectors/pd.jl") +# include("bijectors/corr.jl") +# include("bijectors/truncated.jl") +# include("bijectors/named_bijector.jl") +# include("bijectors/ordered.jl") # Normalizing flow related -include("bijectors/planar_layer.jl") -include("bijectors/radial_layer.jl") -include("bijectors/leaky_relu.jl") -include("bijectors/coupling.jl") -include("bijectors/normalise.jl") -include("bijectors/rational_quadratic_spline.jl") +# include("bijectors/planar_layer.jl") +# include("bijectors/radial_layer.jl") +# include("bijectors/leaky_relu.jl") +# include("bijectors/coupling.jl") +# include("bijectors/normalise.jl") +# include("bijectors/rational_quadratic_spline.jl") ################## # Other includes # From 306aa662f3fe54c7d73d93e2837f5580bb49ae3c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 24 Jan 2022 18:20:55 +0000 Subject: [PATCH 45/98] added docs --- docs/Project.toml | 3 + docs/make.jl | 15 +++++ docs/src/index.md | 144 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 docs/Project.toml create mode 100644 docs/make.jl create mode 100644 docs/src/index.md diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..7e7c8131 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,3 @@ +[deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 00000000..54fb95c5 --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,15 @@ +using Documenter +using Bijectors + +makedocs( + sitename = "Bijectors", + format = Documenter.HTML(), + modules = [Bijectors] +) + +# Documenter can also automatically deploy documentation to gh-pages. +# See "Hosting Documentation" and deploydocs() in the Documenter manual +# for more information. +#=deploydocs( + repo = "" +)=# diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 00000000..ce0f1b0a --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,144 @@ +# Bijectors.jl + +Documentation for Bijectors.jl + +## Usage + +A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function: +- The inverse of `exp` is `log`. +- The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`. + +```@repl usage +using Bijectors +transform(exp, 1.0) +logabsdetjac(exp, 1.0) +forward(exp, 1.0) +``` + +If you want to instead transform a collection of inputs, you can use the `batch` method from Batching.jl to inform Bijectors.jl that the input now represents a collection of inputs rather than a single input: + +```@repl usage +xs = batch(ones(2)); +transform(exp, xs) +logabsdetjac(exp, xs) +forward(exp, xs) +``` + +Some transformations is well-defined for different types of inputs, e.g. `exp` can also act elementwise on a `N`-dimensional `Array{<:Real,N}`. To specify that a transformation should be acting elementwise, we use the [`elementwise`](@ref) method: + +```@repl usage +x = ones(2, 2) +transform(elementwise(exp), x) +logabsdetjac(elementwise(exp), x) +forward(elementwise(exp), x) +``` + +And batched versions: + +```@repl usage +xs = batch(ones(2, 2, 3)); +transform(elementwise(exp), xs) +logabsdetjac(elementwise(exp), xs) +forward(elementwise(exp), xs) +``` + +These methods also work nicely for compositions of transformations: + +```@repl usage +transform(elementwise(log ∘ exp), xs) +``` + +Unlike `exp`, some transformations have parameters affecting the resulting transformation they represent, e.g. `Logit` has two parameters `a` and `b` representing the lower- and upper-bound, respectively, of its domain: + +```@repl usage +using Bijectors: Logit + +f = Logit(0.0, 1.0) +f(rand()) # takes us from `(0, 1)` to `(-∞, ∞)` +``` + +## User-facing methods + +Without mutation: + +```@docs +transform +logabsdetjac +forward(b, x) +``` + +With mutation: + +```@docs +transform! +logabsdetjac! +forward! +``` + +## Implementing a transformation + +Any callable can be made into a bijector by providing an implementation of [`forward(b, x)`](@ref), which is done by overloading + +```@docs +Bijectors.forward_single +``` + +where + +```@docs +Bijectors.transform_single +Bijectors.logabsdetjac_single +``` + +You can then optionally implement `transform` and `logabsdetjac` to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. + +Note that a _user_ of the bijector should generally be using [`forward(b, x)`](@ref) rather than calling [`forward_single`](@ref) directly. + +To implement "batched" versions of the above functionalities, i.e. methods which act on a _collection_ of inputs rather than a single input, you can overload the following method: + +```@docs +Bijectors.forward_multiple +``` + +And similarly, if you want to specialize [`transform`](@ref) and [`logabsdetjac`](@ref), you can implement + +```@docs +Bijectors.transform_multiple +Bijectors.logabsdetjac_multiple +``` + +### Mutability + +There are also _mutable_ versions of all of the above: + +```@docs +Bijectors.forward_single! +Bijectors.forward_multiple! +Bijectors.transform_single! +Bijectors.transform_multiple! +Bijectors.logabsdetjac_single! +Bijectors.logabsdetjac_multiple! +``` + +## Working with Distributions.jl + +```@docs +Bijectors.bijector +Bijectors.transformed(d::Distribution, b::Bijector) +``` + +## Utilities + +```@docs +Bijectors.elementwise +Bijectors.isinvertible +Bijectors.isclosedform(t::Bijectors.Transform) +``` + +## API + +```@docs +Bijectors.Transform +Bijectors.Bijector +Bijectors.Inverse +``` From 8cd371a75e6d8d3b833233d06d1083f752c5415c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Feb 2022 17:18:29 +0000 Subject: [PATCH 46/98] removed all batch related functionality --- Project.toml | 1 - src/Bijectors.jl | 6 +- src/bijectors/composed.jl | 40 +-- src/bijectors/exp_log.jl | 78 +----- src/bijectors/leaky_relu.jl | 48 ---- src/bijectors/logit.jl | 12 +- src/bijectors/named_bijector.jl | 32 +-- src/bijectors/permute.jl | 5 +- src/bijectors/radial_layer.jl | 15 - src/bijectors/rational_quadratic_spline.jl | 12 +- src/bijectors/scale.jl | 14 - src/bijectors/shift.jl | 8 +- src/bijectors/simplex.jl | 44 --- src/interface.jl | 309 +++------------------ 14 files changed, 76 insertions(+), 548 deletions(-) diff --git a/Project.toml b/Project.toml index 90b3abac..4d25f7cb 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ version = "0.10.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" -Batching = "a4738bf4-2ff1-4f03-87a2-28fa6f9d5d14" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 5ba75b9c..be2d71cf 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -36,9 +36,6 @@ using ConstructionBase using Base.Iterators: drop using LinearAlgebra: AbstractTriangular -using Batching -export batch - import ChangesOfVariables: with_logabsdet_jacobian import InverseFunctions: inverse @@ -260,7 +257,6 @@ function getlogp(d::InverseWishart, Xcf, X) end include("utils.jl") -# include("batch.jl") include("interface.jl") # include("chainrules.jl") @@ -271,6 +267,8 @@ Base.@deprecate forward(b::Transform, x) NamedTuple{(:rv,:logabsdetjac)}(with_lo inverse(b) end +Base.@deprecate NamedBijector(bs) NamedTransform(bs) + # Broadcasting here breaks Tracker for some reason maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index a3c07158..fc581cca 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -1,59 +1,31 @@ invertible(cb::ComposedFunction) = invertible(cb.inner) + invertible(cb.outer) isclosedform(cb::ComposedFunction) = isclosedform(cb.inner) && isclosedform(cb.outer) -transform_single(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) -transform_multiple(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) +transform(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) -function transform_single!(cb::ComposedFunction, x, y) +function transform!(cb::ComposedFunction, x, y) transform!(cb.inner, x, y) return transform!(cb.outer, y, y) end -function transform_multiple!(cb::ComposedFunction, x, y) - transform!(cb.inner, x, y) - return transform!(cb.outer, y, y) -end - -function logabsdetjac_single(cb::ComposedFunction, x) +function logabsdetjac(cb::ComposedFunction, x) y, logjac = forward(cb.inner, x) return logabsdetjac(cb.outer, y) + logjac end -function logabsdetjac_multiple(cb::ComposedFunction, x) - y, logjac = forward(cb.inner, x) - return logabsdetjac(cb.outer, y) + logjac -end - -function logabsdetjac_single!(cb::ComposedFunction, x, logjac) - y = similar(x) - forward!(cb.inner, x, y, logjac) - return logabdetjac!(cb.outer, y, y, logjac) -end - -function logabsdetjac_multiple!(cb::ComposedFunction, x, logjac) +function logabsdetjac!(cb::ComposedFunction, x, logjac) y = similar(x) forward!(cb.inner, x, y, logjac) return logabdetjac!(cb.outer, y, y, logjac) end -function forward_single(cb::ComposedFunction, x) +function forward(cb::ComposedFunction, x) y1, logjac1 = forward(cb.inner, x) y2, logjac2 = forward(cb.outer, y1) return y2, logjac1 + logjac2 end -function forward_multiple(cb::ComposedFunction, x) - y1, logjac1 = forward(cb.inner, x) - y2, logjac2 = forward(cb.outer, y1) - return y2, logjac1 + logjac2 -end - -function forward_single!(cb::ComposedFunction, x, y, logjac) - forward!(cb.inner, x, y, logjac) - return forward!(cb.outer, y, y, logjac) -end - -function forward_multiple!(cb::ComposedFunction, x, y, logjac) +function forward!(cb::ComposedFunction, x, y, logjac) forward!(cb.inner, x, y, logjac) return forward!(cb.outer, y, y, logjac) end diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 92197f80..74c81c71 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -1,83 +1,29 @@ -transform_single!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) +# TODO: Do we really need this? +Exp() = elementwise(exp) +Log() = elementwise(log) -transform_multiple(b::Union{typeof(log), typeof(exp)}, x) = b.(x) -transform_multiple!(b::Union{typeof(log), typeof(exp)}, x, y) = broadcast!(b, y, x) +transform!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) -logabsdetjac_single(b::typeof(exp), x::Real) = x -logabsdetjac_single(b::Elementwise{typeof(exp)}, x) = sum(x) +logabsdetjac(b::typeof(exp), x::Real) = x +logabsdetjac(b::Elementwise{typeof(exp)}, x) = sum(x) -logabsdetjac_single(b::typeof(log), x::Real) = -log(x) -logabsdetjac_single(b::Elementwise{typeof(log)}, x) = -sum(log, x) +logabsdetjac(b::typeof(log), x::Real) = -log(x) +logabsdetjac(b::Elementwise{typeof(log)}, x) = -sum(log, x) -logabsdetjac_multiple(b::typeof(exp), xs) = xs -logabsdetjac_multiple(b::Elementwise{typeof(exp)}, xs) = map(sum, xs) - -logabsdetjac_multiple(b::typeof(log), xs) = -map(log, xs) -logabsdetjac_multiple(b::Elementwise{typeof(log)}, xs) = -map(sum ∘ log, xs) - -function forward_single(b::typeof(exp), x::Real) +function forward(b::typeof(exp), x::Real) y = b(x) return y, x end -function forward_single(b::Elementwise{typeof(exp)}, x) +function forward(b::Elementwise{typeof(exp)}, x) y = b(x) return y, sum(x) end -function forward_multiple(b::typeof(exp), xs::AbstractBatch{<:Real}) - ys = transform(b, xs) - return ys, xs -end -function forward_multiple!( - b::typeof(exp), - xs::AbstractBatch{<:Real}, - ys::AbstractBatch{<:Real}, - logjacs::AbstractBatch{<:Real} -) - transform!(b, xs, ys) - logjacs += xs - return ys, logjacs -end -function forward_multiple(b::Elementwise{typeof(exp)}, xs) - ys = transform(b, xs) - return ys, map(sum, xs) -end -function forward_multiple!(b::Elementwise{typeof(exp)}, xs, ys, logjacs) - # Do this before `transform!` in case `xs === ys`. - logjacs += map(sum, xs) - transform!(b, xs, ys) - return ys, logjacs -end - -function forward_single(b::typeof(log), y::Real) +function forward(b::typeof(log), y::Real) x = transform(b, y) return x, -x end -function forward_single(b::Elementwise{typeof(log)}, y) +function forward(b::Elementwise{typeof(log)}, y) x = transform(b, y) return x, -sum(x) end - -function forward_multiple(b::typeof(log), ys::AbstractBatch{<:Real}) - xs = transform(b, ys) - return xs, -xs -end -function forward_multiple!( - b::typeof(log), - ys::AbstractBatch{<:Real}, - xs::AbstractBatch{<:Real}, - logjacs::AbstractBatch{<:Real} -) - transform!(b, ys, xs) - logjacs -= xs - return xs, logjacs -end -function forward_multiple(b::Elementwise{typeof(log)}, ys) - xs = transform(b, ys) - return xs, -map(sum, xs) -end -function forward_multiple!(b::Elementwise{typeof(log)}, ys, xs, logjacs) - transform!(b, ys, xs) - logjacs -= map(sum, xs) - return xs, logjacs -end diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index 51768978..ed499227 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -36,15 +36,6 @@ function forward(b::LeakyReLU, x::Real) return (result=J * x, logabsdetjac=log(abs(J))) end -function forward_batch(b::LeakyReLU, xs::Batch{<:AbstractVector}) - x = value(xs) - - J = let T = eltype(x), z = zero(T), o = one(T) - @. (x < z) * b.α + (x > z) * o - end - return (result=Batch(J .* x), logabsdetjac=Batch(log.(abs.(J)))) -end - # Array inputs. function transform(b::LeakyReLU, x::AbstractArray) return let z = zero(eltype(x)) @@ -52,27 +43,6 @@ function transform(b::LeakyReLU, x::AbstractArray) end end -function logabsdetjac(b::LeakyReLU, x::AbstractArray) - return sum(value(logabsdetjac_batch(b, Batch(x)))) -end - -function logabsdetjac_batch(b::LeakyReLU, xs::ArrayBatch{N}) where {N} - x = value(xs) - - # Is really diagonal of jacobian - J = let T = eltype(x), z = zero(T), o = one(T) - @. (x < z) * b.α + (x > z) * o - end - - logjac = if N ≤ 1 - sum(log ∘ abs, J) - else - vec(sum(map(log ∘ abs, J); dims = 1:N - 1)) - end - - return Batch(logjac) -end - # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. @@ -81,21 +51,3 @@ function forward(b::LeakyReLU, x::AbstractArray) return (result = value(y), logabsdetjac = sum(value(logjac))) end - -function forward_batch(b::LeakyReLU, xs::ArrayBatch{N}) where {N} - x = value(xs) - - # Is really diagonal of jacobian - J = let T = eltype(x), z = zero(T), o = one(T) - @. (x < z) * b.α + (x > z) * o - end - - logjac = if N ≤ 1 - sum(log ∘ abs, J) - else - vec(sum(map(log ∘ abs, J); dims = 1:N - 1)) - end - - y = J .* x - return (result=Batch(y), logabsdetjac=Batch(logjac)) -end diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index 17afd125..84d61c82 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -15,19 +15,13 @@ Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b # Evaluation _logit(x, a, b) = LogExpFunctions.logit((x - a) / (b - a)) -transform_single(b::Logit, x) = _logit.(x, b.a, b.b) -function transform_multiple(b::Logit, xs::Batching.ArrayBatch{<:Real}) - return batch_like(xs, _logit.(Batching.value(xs))) -end +transform(b::Logit, x) = _logit.(x, b.a, b.b) # Inverse _ilogit(y, a, b) = (b - a) * LogExpFunctions.logistic(y) + a -transform_single(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) -function transform_multiple(ib::Inverse{<:Logit}, ys::Batching.ArrayBatch) - return batch_like(ys, _ilogit.(Batching.value(ys))) -end +transform(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) # `logabsdetjac` logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) -logabsdetjac_single(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) +logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index f8dbe888..8f73386b 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,20 +1,18 @@ -abstract type AbstractNamedBijector <: Bijector end - -forward(b::AbstractNamedBijector, x) = (result = b(x), logabsdetjac = logabsdetjac(b, x)) +abstract type AbstractNamedTransform <: Transform end ####################### -### `NamedBijector` ### +### `NamedTransform` ### ####################### """ - NamedBijector <: AbstractNamedBijector + NamedTransform <: AbstractNamedTransform Wraps a `NamedTuple` of key -> `Bijector` pairs, implementing evaluation, inversion, etc. # Examples ```julia-repl -julia> using Bijectors: NamedBijector, Scale, Exp +julia> using Bijectors: NamedTransform, Scale, Exp -julia> b = NamedBijector((a = Scale(2.0), b = Exp())); +julia> b = NamedTransform((a = Scale(2.0), b = Exp())); julia> x = (a = 1., b = 0., c = 42.); @@ -25,21 +23,21 @@ julia> (a = 2 * x.a, b = exp(x.b), c = x.c) (a = 2.0, b = 1.0, c = 42.0) ``` """ -struct NamedBijector{names, Bs<:NamedTuple{names}} <: AbstractNamedBijector +struct NamedTransform{names, Bs<:NamedTuple{names}} <: AbstractNamedTransform bs::Bs end # fields contain nested numerical parameters -function Functors.functor(::Type{<:NamedBijector{names}}, x) where names +function Functors.functor(::Type{<:NamedTransform{names}}, x) where names function reconstruct_namedbijector(xs) - return NamedBijector{names,typeof(xs.bs)}(xs.bs) + return NamedTransform{names,typeof(xs.bs)}(xs.bs) end return (bs = x.bs,), reconstruct_namedbijector end -names_to_bijectors(b::NamedBijector) = b.bs +names_to_bijectors(b::NamedTransform) = b.bs -@generated function (b::NamedBijector{names1})( +@generated function (b::NamedTransform{names1})( x::NamedTuple{names2} ) where {names1, names2} exprs = [] @@ -55,11 +53,11 @@ names_to_bijectors(b::NamedBijector) = b.bs return :($(exprs...), ) end -@generated function inverse(b::NamedBijector{names}) where {names} - return :(NamedBijector(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) +@generated function inverse(b::NamedTransform{names}) where {names} + return :(NamedTransform(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) end -@generated function logabsdetjac(b::NamedBijector{names}, x::NamedTuple) where {names} +@generated function logabsdetjac(b::NamedTransform{names}, x::NamedTuple) where {names} exprs = [:(logabsdetjac(b.bs.$n, x.$n)) for n in names] return :(+($(exprs...))) end @@ -69,7 +67,7 @@ end ############################ # TODO: Add ref to `Coupling` or `CouplingLayer` once that's merged. """ - NamedCoupling{target, deps, F} <: AbstractNamedBijector + NamedCoupling{target, deps, F} <: AbstractNamedTransform Implements a coupling layer for named bijectors. @@ -89,7 +87,7 @@ julia> (a = x.a, b = (x.a + x.c) * x.b, c = x.c) (a = 1.0, b = 8.0, c = 3.0) ``` """ -struct NamedCoupling{target, deps, F} <: AbstractNamedBijector where {F, target} +struct NamedCoupling{target, deps, F} <: AbstractNamedTransform where {F, target} f::F end diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index f67a0562..7d8f5220 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -150,8 +150,7 @@ function Permute(n::Int, indices::Pair{Vector{Int}, Vector{Int}}...) end -@inline transform(b::Permute, x::AbstractVecOrMat) = b.A * x -@inline inv(b::Permute) = Permute(transpose(b.A)) +transform(b::Permute, x::AbstractVecOrMat) = b.A * x +inverse(b::Permute) = Permute(transpose(b.A)) logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x)) -logabsdetjac_batch(b::Permute, x::Batch) = zero(eltype(x), length(x)) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 6b6f0d94..8ba10270 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -49,8 +49,6 @@ end transform(b::RadialLayer, z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) transform(b::RadialLayer, z::AbstractMatrix{<:Real}) = _transform(b, z).transformed -transform_batch(b::RadialLayer, z::ArrayBatch{2}) = Batch(transform(b, value(z))) - function with_logabsdet_jacobian(flow::RadialLayer, z::AbstractVecOrMat) transformed, α, β_hat, r = _transform(flow, z) # Compute log_det_jacobian @@ -68,11 +66,6 @@ function with_logabsdet_jacobian(flow::RadialLayer, z::AbstractVecOrMat) return (result = transformed, logabsdetjac = log_det_jacobian) end -function forward_batch(b::RadialLayer, z::ArrayBatch{2}) - result, logjac = forward(b, value(z)) - return (result = Batch(result), logabsdetjac = Batch(logjac)) -end - function transform(ib::Inverse{<:RadialLayer}, y::AbstractVector{<:Real}) flow = ib.orig z0 = flow.z_0 @@ -103,10 +96,6 @@ function transform(ib::Inverse{<:RadialLayer}, y::AbstractMatrix{<:Real}) return z0 .+ γ .* y_minus_z0 end -function transform_batch(ib::Inverse{<:RadialLayer}, y::ArrayBatch{2}) - return Batch(transform(ib, value(y))) -end - """ compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) @@ -135,7 +124,3 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_test) end logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac - -function logabsdetjac_batch(flow::RadialLayer, x::ArrayBatch{2}) - return Batch(logabsdetjac(flow, value(x))) -end diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index ae89c0d8..72fa4076 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -178,9 +178,6 @@ end function transform(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_univariate(b.widths, b.heights, b.derivatives, x) end -function transform_batch(b::RationalQuadraticSpline{<:AbstractVector}, x::ArrayBatch{1}) - return Batch(transform.(b, value(x))) -end # multivariate # TODO: Improve. @@ -233,9 +230,6 @@ end function transform(ib::Inverse{<:RationalQuadraticSpline}, y::Real) return rqs_univariate_inverse(ib.orig.widths, ib.orig.heights, ib.orig.derivatives, y) end -function transform_batch(ib::Inverse{<:RationalQuadraticSpline}, y::AbstractVector) - return Batch(transform.(ib, value(y))) -end # TODO: Improve. function transform(ib::Inverse{<:RationalQuadraticSpline}, y::AbstractVector) @@ -314,9 +308,7 @@ end function logabsdetjac(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_logabsdetjac(b.widths, b.heights, b.derivatives, x) end -function logabsdetjac_batch(b::RationalQuadraticSpline{<:AbstractVector}, x::ArrayBatch{1}) - return Batch(logabsdetjac.(b, value(x))) -end + # TODO: Improve. function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) return sum([ @@ -376,6 +368,6 @@ function rqs_forward( return (y, logjac) end -function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index ad6df463..efeac48a 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -18,29 +18,15 @@ function logabsdetjac(b::Scale, x::AbstractArray{<:Real, N}) where {N} return _logabsdetjac_scale(b.a, x, Val(N)) end -function logabsdetjac_batch(b::Scale, x::ArrayBatch{N}) where {N} - return Batch(_logabsdetjac_scale(b.a, value(x), Val(N - 1))) -end - # Scalar: single input. _logabsdetjac_scale(a::Real, x::Real, ::Val{0}) = log(abs(a)) _logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{1}) = log(abs(a)) * length(x) _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{2}) = log(abs(a)) * length(x) -# Scalar: batch. -_logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) = fill(log(abs(a)), length(x)) -_logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{1}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) - # Vector: single input. _logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Val{1}) = sum(log ∘ abs, a) _logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{2}) = sum(log ∘ abs, a) -# Vector: batch. -_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Val{1}) = fill(sum(log ∘ abs, a), size(x, 2)) - # Matrix: single input. _logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Val{1}) = logabsdet(a)[1] _logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix, ::Val{2}) = logabsdet(a)[1] - -# Matrix: batch. -_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Val{1}) where {T} = logabsdet(a)[1] * ones(T, size(x, 2)) diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 85050e5c..2561d0f8 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -9,18 +9,14 @@ Base.:(==)(b1::Shift, b2::Shift) = b1.a == b2.a Functors.@functor Shift -transform(b::Shift, x) = b.a .+ x +inverse(b::Shift) = Shift(-b.a) -inv(b::Shift) = Shift(-b.a) +transform(b::Shift, x) = b.a .+ x # FIXME: implement custom adjoint to ensure we don't get tracking function logabsdetjac(b::Shift, x::Union{Real, AbstractArray{<:Real}}) return _logabsdetjac_shift(b.a, x) end -function logabsdetjac_batch(b::Shift, x::ArrayBatch) - return Batch(_logabsdetjac_shift_array_batch(b.a, value(x))) -end - _logabsdetjac_shift(a, x) = zero(eltype(x)) _logabsdetjac_shift_array_batch(a, x) = zeros(eltype(x), size(x, ndims(x))) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index 34677220..a453a489 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -35,17 +35,6 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where return y end -function transform_batch(b::SimplexBijector, X::ArrayBatch{2}) - Batch(_simplex_bijector(value(X), b)) -end -function transform_batch!( - b::SimplexBijector, - Y::Batch{<:AbstractMatrix{T}}, - X::Batch{<:AbstractMatrix{T}}, -) where {T} - Batch(_simplex_bijector!(value(Y), value(X), b)) -end - # Matrix implementation. function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj} K, N = size(X, 1), size(X, 2) @@ -107,16 +96,6 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) return x end -# Batched versions. -transform_batch(ib::Inverse{<:SimplexBijector}, Y::ArrayBatch{2}) = _simplex_inv_bijector(Y, ib.orig) -function transform_batch!( - ib::Inverse{<:SimplexBijector}, - X::Batch{<:AbstractMatrix{T}}, - Y::Batch{<:AbstractMatrix{T}}, -) where {T<:Real} - return Batch(_simplex_inv_bijector!(value(X), value(Y), ib.orig)) -end - function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj} K, N = size(Y, 1), size(Y, 2) @assert K > 1 "x needs to be of length greater than 1" @@ -192,29 +171,6 @@ function simplex_logabsdetjac_gradient(x::AbstractVector) end return g end -function logabsdetjac_batch(b::SimplexBijector, x::Batch{<:AbstractMatrix{T}}) where {T} - x = value(x) - - ϵ = _eps(T) - nlp = similar(x, T, size(x, 2)) - nlp .= zero(T) - - K = size(x, 1) - for col in 1:size(x, 2) - sum_tmp = zero(eltype(x)) - z = x[1,col] - nlp[col] -= log(max(z, ϵ)) + log(max(one(T) - z, ϵ)) - for k in 2:(K - 1) - sum_tmp += x[k-1,col] - z = x[k,col] / max(one(T) - sum_tmp, ϵ) - nlp[col] -= log(max(z, ϵ)) + log(max(one(T) - z, ϵ)) + log(max(one(T) - sum_tmp, ϵ)) - end - end - return Batch(nlp) -end -function logabsdetjac(b::SimplexBijector, x::AbstractMatrix) - return sum(value(logabsdetjac(b, Batch(x)))) -end function simplex_logabsdetjac_gradient(x::AbstractMatrix) T = eltype(x) diff --git a/src/interface.jl b/src/interface.jl index 8e96a27f..ec012635 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -66,16 +66,9 @@ If the `Transform` is also invertible: can exploit shared computation in the two. For the above methods, there are mutating versions which can _optionally_ be implemented: -- [`transform_single!`](@ref) -- [`logabsdetjac_single!`](@ref) -- [`forward_single!`](@ref) - -Finally, there are _batched_ versions of the above methods which can _optionally_ be implemented: -- [`transform_multiple`](@ref) -- [`logabsdetjac_multiple`](@ref) -- [`forward_multiple`](@ref) - -and similarly for the mutating versions. Default implementations depends on the type of `xs`. +- [`transform!`](@ref) +- [`logabsdetjac!`](@ref) +- [`forward!`](@ref) """ abstract type Transform end @@ -87,42 +80,8 @@ Broadcast.broadcastable(b::Transform) = Ref(b) transform(b, x) Transform `x` using `b`, treating `x` as a single input. - -Alias for [`transform_single(b, x)`](@ref). -""" -transform(b, x) = transform_single(b, x) - -""" - transform_single(b, x) - -Transform `x` using `b`. - -Defaults to `b(x)` if `b isa Function` and `first(forward(b, x))` otherwise. """ -transform_single(b, x) = first(forward(b, x)) -transform_single(b::Function, x) = b(x) - -""" - transform(b, xs::AbstractBatch) - -Transform `xs` using `b`, treating `xs` as a collection of inputs. - -Alias for [`transform_multiple(b, x)`](@ref). -""" -transform(b, xs::AbstractBatch) = transform_multiple(b, xs) - -""" - transform_multiple(b, xs) - -Transform `xs` using `b`, treating `xs` as a collection of inputs. - -Defaults to `map(Base.Fix1(transform, b), xs)`. -""" -transform_multiple(b, xs) = map(Base.Fix1(transform, b), xs) -function transform_multiple(b::Elementwise, x::Batching.ArrayBatch) - return batch_like(x, transform(b, Batching.value(x))) -end - +function transform end """ transform!(b, x[, y]) @@ -130,175 +89,34 @@ end Transform `x` using `b`, storing the result in `y`. If `y` is not provided, `x` is used as the output. - -Alias for [`transform_single!(b, x, y)`](@ref). -""" -transform!(b, x, y=x) = transform_single!(b, x, y) - -""" - transform_single!(b, x, y) - -Transform `x` using `b`, storing the result in `y`. """ -transform_single!(b, x, y) = (y .= transform(b, x)) - -""" - transform!(b, xs::AbstractBatch[, ys::AbstractBatch]) - -Transform `x` for `x` in `xs` using `b`, storing the result in `ys`. - -If `ys` is not provided, `xs` is used as the output. - -Alias for [`transform_multiple!(b, xs, ys)`](@ref). -""" -transform!(b, xs::AbstractBatch, ys::AbstractBatch=xs) = transform_multiple!(b, xs, ys) - -""" - transform_multiple!(b, xs::AbstractBatch[, ys::AbstractBatch]) - -Transform `x` for `x` in `xs` using `b`, storing the result in `ys`. -""" -transform_multiple!(b, xs, ys) = broadcast!(Base.Fix1(transform, b), xs, ys) -function transform_multiple!(b::Elementwise, x::Batching.ArrayBatch, y::Batching.ArrayBatch) - broadcast!(b, Batching.value(y), Batching.value(x)) - return y -end +transform!(b, x) = transform!(b, x, x) +transform!(b, x, y) = (y .= transform(b, x)) """ logabsdetjac(b, x) Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. - -Alias for [`logabsdetjac_single`](@ref). - -See also: [`logabsdetjac(b, xs::AbstractBatch)`](@ref). -""" -logabsdetjac(b, x) = logabsdetjac_single(b, x) - -""" - logabsdetjac_single(b, x) - -Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. - -Defaults to `last(forward(b, x))`. -""" -logabsdetjac_single(b, x) = last(forward(b, x)) - -""" - logabsdetjac(b, xs::AbstractBatch) - -Return a collection representing `log(abs(det(J(b, x))))` for every `x` in `xs`, -where `J(b, x)` is the jacobian of `b` at `x`. - -Alias for [`Bijectors.logabsdetjac_multiple`](@ref). - -See also: [`logabsdetjac(b, x)`](@ref). -""" -logabsdetjac(b, xs::AbstractBatch) = logabsdetjac_multiple(b, xs) - """ - logabsdetjac_multiple(b, xs) - -Return a collection representing `log(abs(det(J(b, x))))` for every `x` in `xs`, -where `J(b, x)` is the jacobian of `b` at `x`. - -Defaults to `map(Base.Fix1(logabsdetjac, b), xs)`. -""" -logabsdetjac_multiple(b, xs) = map(Base.Fix1(logabsdetjac, b), xs) +logabsdetjac(b, x) = last(forward(b, x)) """ logabsdetjac!(b, x[, logjac]) Compute `log(abs(det(J(b, x))))` and store the result in `logjac`, where `J(b, x)` is the jacobian of `b` at `x`. - -Alias for [`logabsdetjac_single!(b, x, logjac)`](@ref). """ -logabsdetjac!(b, x, logjac=zero(eltype(x))) = logabsdetjac_single!(b, x, logjac) - -""" - logabsdetjac_single!(b, x[, logjac]) - -Compute `log(abs(det(J(b, x))))` and accumulate the result in `logjac`, -where `J(b, x)` is the jacobian of `b` at `x`. -""" -logabsdetjac_single!(b, x, logjac) = (logjac += logabsdetjac(b, x)) - -""" - logabsdetjac!(b, xs::AbstractBatch[, logjacs::AbstractBatch]) - -Compute `log(abs(det(J(b, x))))` and store the result in `logjacs` for -every `x` in `xs`, where `J(b, x)` is the jacobian of `b` at `x`. - -Alias for [`logabsdetjac_single!(b, x, logjac)`](@ref). -""" -function logabsdetjac!( - b, - xs::AbstractBatch, - logjacs::AbstractBatch=batch_like(xs, zeros(eltype(eltype(xs)), length(xs))) -) - return logabsdetjac_multiple!(b, xs, logjacs) -end - -""" - logabsdetjac_multiple!(b, xs::AbstractBatch, logjacs::AbstractBatch) - -Compute `log(abs(det(J(b, x))))` and store the result in `logjacs` for -every `x` in `xs`, where `J(b, x)` is the jacobian of `b` at `x`. -""" -logabsdetjac_multiple!(b, xs, logjacs) = (logjacs .+= logabsdetjac(b, xs)) +logabsdetjac!(b, x) = logabsdetjac!(b, x, zero(eltype(x))) +logabsdetjac!(b, x, logjac) = (logjac += logabsdetjac(b, x)) """ forward(b, x) Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as single input. -Alias for [`forward_single(b, x)`](@ref). - -See also: [`forward(b, xs::AbstractBatch)`](@ref). -""" -forward(b, x) = forward_single(b, x) - -""" - forward_single(b, x) - -Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as a single input. - Defaults to `ChangeOfVariables.with_logabsdet_jacobian(b, x)`. """ -forward_single(b, x) = with_logabsdet_jacobian(b, x) - -""" - forward(b, xs::AbstractBatch) - -Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as -collection of inputs. - -Alias for [`forward_multiple(b, xs)`](@ref). - -See also: [`forward(b, x)`](@ref). -""" -forward(b, xs::AbstractBatch) = forward_multiple(b, xs) - -""" - forward_multiple(b, xs) - -Return `(transform(b, xs), logabsdetjac(b, xs))` treating -`xs` as a batch, i.e. a collection inputs. +forward(b, x) = with_logabsdet_jacobian(b, x) -See also: [`forward_single(b, x)`](@ref). -""" -function forward_multiple(b, xs) - # If `b` doesn't have its own definition of `forward_multiple` - # we just broadcast `forward_single`, resulting in a batch of `(y, logjac)` - # pairs which we then unwrap. - results = forward.(Ref(b), xs) - ys = map(first, results) - logjacs = map(last, results) - return batch_like(xs, ys, logjacs) -end - -# `logjac` as an argument doesn't make too much sense for `forward_single!` -# when the inputs have `eltype` `Real`. """ forward!(b, x[, y, logjac]) @@ -307,66 +125,16 @@ in `y` and `logjac`, respetively. If `y` is not provided, then `x` will be used in its place. -Alias for [`forward_single!(b, x, y, logjac)`](@ref). - -See also: [`forward!(b, xs::AbstractBatch, ys::AbstractBatch, logjacs::AbstractBatch)`](@ref). -""" -forward!(b, x, y=x, logjac=zero(eltype(x))) = forward_single!(b, x, y, logjac) - -""" - forward_single!(b, x, y, logjac) - -Compute `transform(b, x)` and `logabsdetjac(b, x)`, storing the result -in `y` and `logjac`, respetively. - Defaults to calling `forward(b, x)` and updating `y` and `logjac` with the result. """ -function forward_single!(b, x, y, logjac) +forward!(b, x) = forward!(b, x, x) +forward!(b, x, y) = forward!(b, x, y, zero(eltype(x))) +function forward!(b, x, y, logjac) y_, logjac_ = forward(b, x) y .= y_ return (y, logjac + logjac_) end -""" - forward!(b, xs[, ys, logjacs]) - -Compute `transform(b, x)` and `logabsdetjac(b, x)` for every `x` in the collection `xs`, -storing the results in `ys` and `logjacs`, respetively. - -If `ys` is not provided, then `xs` will be used in its place. - -Alias for [`forward_multiple!(b, xs, ys, logjacs)`](@ref). - -See also: [`forward!(b, x, y, logjac)`](@ref). -""" -function forward!( - b, - xs::AbstractBatch, - ys::AbstractBatch=xs, - logjacs::AbstractBatch=batch_like(xs, zeros(eltype(eltype(xs)), length(xs))) -) - return forward_multiple!(b, xs, ys, logjacs) -end - -""" - forward!(b, xs, ys, logjacs) - -Compute `transform(b, x)` and `logabsdetjac(b, x)` for every `x` in the collection `xs`, -storing the results in `ys` and `logjacs`, respetively. - -Defaults to iterating through the `xs` and calling [`forward_single(b, x)`](@ref) -for every `x` in `xs`. -""" -function forward_multiple!(b, xs, ys, logjacs) - for i in eachindex(ys) - res = forward_single(b, xs[i]) - ys[i] .= first(res) - logjacs[i] += last(res) - end - - return ys, logjacs -end - """ isclosedform(b::Transform)::bool isclosedform(b⁻¹::Inverse{<:Transform})::bool @@ -459,50 +227,37 @@ logabsdetjacinv(b::Bijector, y) = logabsdetjac(inverse(b), y) transform(::typeof(identity), x) = copy(x) transform!(::typeof(identity), x, y) = copy!(y, x) -logabsdetjac_single(::typeof(identity), x) = zero(eltype(x)) -logabsdetjac_multiple(::typeof(identity), x) = batch_like(x, zeros(eltype(eltype(x)), length(x))) - -logabsdetjac_single!(::typeof(identity), x, logjac) = logjac -logabsdetjac_multiple!(::typeof(identity), x, logjac) = logjac - -#################### -# Batched versions # -#################### -# NOTE: This needs to be after we've defined some `transform`, `logabsdetjac`, etc. -# so we can actually reference them. Since we just did this for `Identity`, we're good. -# Broadcast.broadcasted(b::Transform, xs::Batch) = transform_multiple(b, xs) -# Broadcast.broadcasted(::typeof(transform), b::Transform, xs::Batch) = transform_multiple(b, xs) -# Broadcast.broadcasted(::typeof(logabsdetjac), b::Transform, xs::Batch) = logabsdetjac_multiple(b, xs) -# Broadcast.broadcasted(::typeof(forward), b::Transform, xs::Batch) = forward_multiple(b, xs) +logabsdetjac(::typeof(identity), x) = zero(eltype(x)) +logabsdetjac!(::typeof(identity), x, logjac) = logjac ###################### # Bijectors includes # ###################### # General -# include("bijectors/adbijector.jl") +include("bijectors/adbijector.jl") include("bijectors/composed.jl") -# include("bijectors/stacked.jl") +include("bijectors/stacked.jl") # Specific include("bijectors/exp_log.jl") include("bijectors/logit.jl") -# include("bijectors/scale.jl") -# include("bijectors/shift.jl") -# include("bijectors/permute.jl") -# include("bijectors/simplex.jl") -# include("bijectors/pd.jl") -# include("bijectors/corr.jl") -# include("bijectors/truncated.jl") -# include("bijectors/named_bijector.jl") -# include("bijectors/ordered.jl") +include("bijectors/scale.jl") +include("bijectors/shift.jl") +include("bijectors/permute.jl") +include("bijectors/simplex.jl") +include("bijectors/pd.jl") +include("bijectors/corr.jl") +include("bijectors/truncated.jl") +include("bijectors/named_bijector.jl") +include("bijectors/ordered.jl") # Normalizing flow related -# include("bijectors/planar_layer.jl") -# include("bijectors/radial_layer.jl") -# include("bijectors/leaky_relu.jl") -# include("bijectors/coupling.jl") -# include("bijectors/normalise.jl") -# include("bijectors/rational_quadratic_spline.jl") +include("bijectors/planar_layer.jl") +include("bijectors/radial_layer.jl") +include("bijectors/leaky_relu.jl") +include("bijectors/coupling.jl") +include("bijectors/normalise.jl") +include("bijectors/rational_quadratic_spline.jl") ################## # Other includes # From 190f1a556cf304c43323de66dedd86a5b51461ee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 10:24:25 +0000 Subject: [PATCH 47/98] move bijectors over to with_logabsdet_jacobian and drop official batch support --- src/Bijectors.jl | 46 ++++++++--------- src/bijectors/adbijector.jl | 2 + src/bijectors/composed.jl | 16 ++---- src/bijectors/corr.jl | 2 + src/bijectors/coupling.jl | 24 +++++++++ src/bijectors/exp_log.jl | 17 +++++-- src/bijectors/leaky_relu.jl | 40 +++------------ src/bijectors/logit.jl | 3 ++ src/bijectors/named_bijector.jl | 59 +++++++++++++++------- src/bijectors/normalise.jl | 6 +-- src/bijectors/ordered.jl | 2 + src/bijectors/permute.jl | 2 + src/bijectors/radial_layer.jl | 4 +- src/bijectors/rational_quadratic_spline.jl | 4 ++ src/bijectors/scale.jl | 2 + src/bijectors/shift.jl | 2 + src/bijectors/simplex.jl | 2 + src/bijectors/stacked.jl | 17 ++++--- src/bijectors/truncated.jl | 2 + src/interface.jl | 43 ++++++++-------- src/transformed_distribution.jl | 28 +++++----- 21 files changed, 188 insertions(+), 135 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index be2d71cf..0941fcbe 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -56,10 +56,10 @@ export TransformDistribution, isclosedform, transform, transform!, + forward, with_logabsdet_jacobian, + with_logabsdet_jacobian!, inverse, - forward, - forward!, logabsdetjac, logabsdetjac!, logabsdetjacinv, @@ -258,9 +258,9 @@ end include("utils.jl") include("interface.jl") -# include("chainrules.jl") +include("chainrules.jl") -Base.@deprecate forward(b::Transform, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) +Base.@deprecate forward(b::Union{Transform,Function}, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) @noinline function Base.inv(b::Transform) Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv) @@ -273,24 +273,24 @@ Base.@deprecate NamedBijector(bs) NamedTransform(bs) maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) -# # optional dependencies -# function __init__() -# @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin -# function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) -# return copy(f.(x1, x...)) -# end -# function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) -# return copy(f.(x1, x2, x...)) -# end -# function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) -# return copy(f.(x1, x2, x3, x...)) -# end -# end -# @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl") -# @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") -# @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") -# @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl") -# @require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl") -# end +# optional dependencies +function __init__() + @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin + function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) + return copy(f.(x1, x...)) + end + function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...) + return copy(f.(x1, x2, x...)) + end + function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...) + return copy(f.(x1, x2, x3, x...)) + end + end + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl") + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") + @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl") + @require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl") +end end # module diff --git a/src/bijectors/adbijector.jl b/src/bijectors/adbijector.jl index 135f2bba..44d9420b 100644 --- a/src/bijectors/adbijector.jl +++ b/src/bijectors/adbijector.jl @@ -25,3 +25,5 @@ function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) fact = lu(jacobian(b, x), check=false) return issuccess(fact) ? logabsdet(fact)[1] : throw(SingularJacobianException(b)) end + +with_logabsdet_jacobian(b::ADBijector, x) = (b(x), logabsdetjac(b, x)) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index fc581cca..5e44c997 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -9,23 +9,17 @@ function transform!(cb::ComposedFunction, x, y) end function logabsdetjac(cb::ComposedFunction, x) - y, logjac = forward(cb.inner, x) + y, logjac = with_logabsdet_jacobian(cb.inner, x) return logabsdetjac(cb.outer, y) + logjac end function logabsdetjac!(cb::ComposedFunction, x, logjac) y = similar(x) - forward!(cb.inner, x, y, logjac) + logjac = last(with_logabsdet_jacobian!(cb.inner, x, y, logjac)) return logabdetjac!(cb.outer, y, y, logjac) end -function forward(cb::ComposedFunction, x) - y1, logjac1 = forward(cb.inner, x) - y2, logjac2 = forward(cb.outer, y1) - return y2, logjac1 + logjac2 -end - -function forward!(cb::ComposedFunction, x, y, logjac) - forward!(cb.inner, x, y, logjac) - return forward!(cb.outer, y, y, logjac) +function with_logabsdet_jacobian!(cb::ComposedFunction, x, y, logjac) + logjac = last(with_logabsdet_jacobian!(cb.inner, x, y, logjac)) + return with_logabsdet_jacobian!(cb.outer, y, y, logjac) end diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f5ac2af5..36b3e784 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -63,6 +63,8 @@ above the "manageable expression" directly, which is also described in above doc """ struct CorrBijector <: Bijector end +with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) + function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) w = cholesky(x).U # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 5b529cfe..8f25eff0 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -195,6 +195,30 @@ function couple(cl::Coupling, x::AbstractVector) return b end +function with_logabsdet_jacobian(cl::Coupling, x) + # partition vector using `cl.mask::PartitionMask` + x_1, x_2, x_3 = partition(cl.mask, x) + + # construct bijector `B` using θ(x₂) + b = cl.θ(x_2) + + y_1, logjac = with_logabsdet_jacobian(b, x_1) + return combine(cl.mask, y_1, x_2, x_3), logjac +end + +function with_logabsdet_jacobian(icl::Inverse{<:Coupling}, y) + cl = icl.orig + + # partition vector using `cl.mask::PartitionMask` + y_1, y_2, y_3 = partition(cl.mask, y) + + # construct bijector `B` using θ(y₂) + b = cl.θ(y_2) + + x_1, logjac = with_logabsdet_jacobian(inverse(b), y_1) + return combine(cl.mask, x_1, y_2, y_3), logjac +end + function transform(cl::Coupling, x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 74c81c71..fe350a74 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -2,6 +2,15 @@ Exp() = elementwise(exp) Log() = elementwise(log) +invertible(::typeof(exp)) = Invertible() +invertible(::Elementwise{typeof(exp)}) = Invertible() + +invertible(::typeof(log)) = Invertible() +invertible(::Elementwise{typeof(log)}) = Invertible() + +transform(b::Union{typeof(exp),typeof(log)}, x::Real) = b(x) +transform(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x) = b(x) + transform!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) logabsdetjac(b::typeof(exp), x::Real) = x @@ -10,20 +19,20 @@ logabsdetjac(b::Elementwise{typeof(exp)}, x) = sum(x) logabsdetjac(b::typeof(log), x::Real) = -log(x) logabsdetjac(b::Elementwise{typeof(log)}, x) = -sum(log, x) -function forward(b::typeof(exp), x::Real) +function with_logabsdet_jacobian(b::typeof(exp), x::Real) y = b(x) return y, x end -function forward(b::Elementwise{typeof(exp)}, x) +function with_logabsdet_jacobian(b::Elementwise{typeof(exp)}, x) y = b(x) return y, sum(x) end -function forward(b::typeof(log), y::Real) +function with_logabsdet_jacobian(b::typeof(log), y::Real) x = transform(b, y) return x, -x end -function forward(b::Elementwise{typeof(log)}, y) +function with_logabsdet_jacobian(b::Elementwise{typeof(log)}, y) x = transform(b, y) return x, -sum(x) end diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index ed499227..9d76bbb0 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -13,41 +13,17 @@ end Functors.@functor LeakyReLU -Base.inv(b::LeakyReLU) = LeakyReLU(inv.(b.α)) +inverse(b::LeakyReLU) = LeakyReLU(inv.(b.α)) -# (N=0) Univariate case -function transform(b::LeakyReLU, x::Real) +function with_logabsdet_jacobian(b::LeakyReLU, x::Real) mask = x < zero(x) - return mask * b.α * x + !mask * x -end - -function logabsdetjac(b::LeakyReLU, x::Real) - mask = x < zero(x) - J = mask * b.α + (1 - mask) * one(x) - return log(abs(J)) -end - -# We implement `with_logabsdet_jacobian` by hand since we can re-use the computation of -# the Jacobian of the transformation. This will lead to faster sampling -# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU, x::Real) - mask = x < zero(x) - J = mask * b.α + !mask * one(x) - return (result=J * x, logabsdetjac=log(abs(J))) + J = mask * b.α + !mask + return J * x, log(abs(J)) end # Array inputs. -function transform(b::LeakyReLU, x::AbstractArray) - return let z = zero(eltype(x)) - @. (x < z) * b.α * x + (x > z) * x - end -end - -# We implement `forward` by hand since we can re-use the computation of -# the Jacobian of the transformation. This will lead to faster sampling -# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU, x::AbstractArray) - y, logjac = forward_batch(b, Batch(x)) - - return (result = value(y), logabsdetjac = sum(value(logjac))) +function with_logabsdet_jacobian(b::LeakyReLU, x::AbstractArray) + mask = x .< zero(eltype(x)) + J = mask .* b.α .+ (!).(mask) + return J .* x, sum(log.(abs.(J))) end diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index 84d61c82..d3b90981 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -25,3 +25,6 @@ transform(ib::Inverse{<:Logit}, y) = _ilogit.(y, ib.orig.a, ib.orig.b) # `logabsdetjac` logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) + +# `with_logabsdet_jacobian` +with_logabsdet_jacobian(b::Logit, x) = _logit.(x, b.a, b.b), sum(logit_logabsdetjac.(x, b.a, b.b)) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 8f73386b..adaa4676 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -37,7 +37,12 @@ end names_to_bijectors(b::NamedTransform) = b.bs -@generated function (b::NamedTransform{names1})( +@generated function inverse(b::NamedTransform{names}) where {names} + return :(NamedTransform(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) +end + +@generated function transform( + b::NamedTransform{names1}, x::NamedTuple{names2} ) where {names1, names2} exprs = [] @@ -53,15 +58,36 @@ names_to_bijectors(b::NamedTransform) = b.bs return :($(exprs...), ) end -@generated function inverse(b::NamedTransform{names}) where {names} - return :(NamedTransform(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) -end - @generated function logabsdetjac(b::NamedTransform{names}, x::NamedTuple) where {names} exprs = [:(logabsdetjac(b.bs.$n, x.$n)) for n in names] return :(+($(exprs...))) end +@generated function with_logabsdet_jacobian( + b::NamedTransform{names1}, + x::NamedTuple{names2} +) where {names1, names2} + body_exprs = [] + logjac_expr = Expr(:call, :+) + val_expr = Expr(:tuple, ) + for n in names2 + if n in names1 + val_sym = Symbol("y_$n") + logjac_sym = Symbol("logjac_$n") + + push!(body_exprs, :(($val_sym, $logjac_sym) = with_logabsdet_jacobian(b.bs.$n, x.$n))) + push!(logjac_expr.args, logjac_sym) + push!(val_expr.args, :($n = $val_sym)) + else + push!(val_expr.args, :($n = x.$n)) + end + end + return quote + $(body_exprs...) + return NamedTuple{$names2}($val_expr), $logjac_expr + end +end + ############################ ### `NamedCouplingLayer` ### ############################ @@ -96,31 +122,26 @@ function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target, deps, F} return NamedCoupling{target, deps, F}(f) end +invertible(::NamedCoupling) = Invertible() + coupling(b::NamedCoupling) = b.f # For some reason trying to use the parameteric types doesn't always work # so we have to do this weird approach of extracting type and then index `parameters`. target(b::NamedCoupling{Target}) where {Target} = Target deps(b::NamedCoupling{<:Any, Deps}) where {Deps} = Deps -@generated function (nc::NamedCoupling{target, deps, F})(x::NamedTuple) where {target, deps, F} +@generated function with_logabsdet_jacobian(nc::NamedCoupling{target, deps, F}, x::NamedTuple) where {target, deps, F} return quote b = nc.f($([:(x.$d) for d in deps]...)) - return merge(x, ($target = b(x.$target), )) - end -end - -@generated function (ni::Inverse{<:NamedCoupling{target, deps, F}})( - x::NamedTuple -) where {target, deps, F} - return quote - b = ni.orig.f($([:(x.$d) for d in deps]...)) - return merge(x, ($target = inverse(b)(x.$target), )) + x_target, logjac = with_logabsdet_jacobian(b, x.$target) + return merge(x, ($target = x_target, )), logjac end end -@generated function logabsdetjac(nc::NamedCoupling{target, deps, F}, x::NamedTuple) where {target, deps, F} +@generated function with_logabsdet_jacobian(ni::Inverse{<:NamedCoupling{target, deps, F}}, x::NamedTuple) where {target, deps, F} return quote - b = nc.f($([:(x.$d) for d in deps]...)) - return logabsdetjac(b, x.$target) + ib = inverse(ni.orig.f($([:(x.$d) for d in deps]...))) + x_target, logjac = with_logabsdet_jacobian(ib, x.$target) + return merge(x, ($target = x_target, )), logjac end end diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 434f6f37..f9ef2601 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -79,8 +79,8 @@ function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) return (result=result, logabsdetjac=logabsdetjac) end -logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac -transform(bn::InvertibleBatchNorm, x) = forward(bn, x).result +logabsdetjac(bn::InvertibleBatchNorm, x) = last(with_logabsdet_jacobian(bn, x)) +transform(bn::InvertibleBatchNorm, x) = first(with_logabsdet_jacobian(bn, x)) function with_logabsdet_jacobian(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`with_logabsdet_jacobian(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -96,7 +96,7 @@ function with_logabsdet_jacobian(invbn::Inverse{<:InvertibleBatchNorm}, y) return (result=x, logabsdetjac=-logabsdetjac(bn, x)) end -transform(bn::Inverse{<:InvertibleBatchNorm}, y) = forward(bn, y).result +transform(bn::Inverse{<:InvertibleBatchNorm}, y) = first(with_logabsdet_jacobian(bn, y)) function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl index ec2b3ed8..f3af8f52 100644 --- a/src/bijectors/ordered.jl +++ b/src/bijectors/ordered.jl @@ -16,6 +16,8 @@ Return a `Distribution` whose support are ordered vectors, i.e., vectors with in """ ordered(d::ContinuousMultivariateDistribution) = Bijectors.transformed(d, OrderedBijector()) +with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x) + transform(b::OrderedBijector, y::AbstractVecOrMat) = _transform_ordered(y) function _transform_ordered(y::AbstractVector) diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index 7d8f5220..0c352886 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -154,3 +154,5 @@ transform(b::Permute, x::AbstractVecOrMat) = b.A * x inverse(b::Permute) = Permute(transpose(b.A)) logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x)) + +with_logabsdet_jacobian(b::Permute, x) = transform(b, x), logabsdetjac(b, x) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 8ba10270..7c344f54 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -116,11 +116,11 @@ where ``γ = \\|y_minus_z0\\|_2``. For details see appendix A.2 of the reference D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows. arXiv:1505.05770 """ -function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_test) +function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) γ = norm(y_minus_z0) a = α_plus_β_hat - γ r = (sqrt(a^2 + 4 * α * γ) - a) / 2 return r end -logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac +logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = last(with_logabsdet_jacobian(flow, x)) diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 72fa4076..9458f26f 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -371,3 +371,7 @@ end function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end + +function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) + return transform(b, x), logabsdetjac(b, x) +end diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index efeac48a..1e2c623f 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -6,6 +6,8 @@ Base.:(==)(b1::Scale, b2::Scale) = b1.a == b2.a Functors.@functor Scale +with_logabsdet_jacobian(b::Scale, x) = transform(b, x), logabsdetjac(b, x) + transform(b::Scale, x) = b.a .* x transform(b::Scale{<:AbstractMatrix}, x::AbstractVecOrMat) = b.a * x transform(ib::Inverse{<:Scale}, y) = transform(Scale(inv(ib.orig.a)), y) diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 2561d0f8..908815a6 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -20,3 +20,5 @@ end _logabsdetjac_shift(a, x) = zero(eltype(x)) _logabsdetjac_shift_array_batch(a, x) = zeros(eltype(x), size(x, ndims(x))) + +with_logabsdet_jacobian(b::Shift, x) = transform(b, x), logabsdetjac(b, x) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index a453a489..1fbc28d2 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -4,6 +4,8 @@ struct SimplexBijector{T} <: Bijector end SimplexBijector() = SimplexBijector{true}() +with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x) + transform(b::SimplexBijector, x) = _simplex_bijector(x, b) transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 2507ad86..e03e0f1f 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -28,6 +28,9 @@ end Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs))) Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) +# Avoid mixing tuples and arrays. +Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) + # define nested numerical parameters # TODO: replace with `Functors.@functor Stacked (bs,)` when # https://github.com/FluxML/Functors.jl/pull/7 is merged @@ -50,13 +53,13 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs) invertible(b::Stacked) = sum(map(invertible, b.bs)) -stack(bs::Bijector...) = Stacked(bs) +stack(bs...) = Stacked(bs) -# For some reason `inv.(sb.bs)` was unstable... This works though. -inv(sb::Stacked, ::Invertible) = Stacked(map(inv, sb.bs), sb.ranges) +# For some reason `inverse.(sb.bs)` was unstable... This works though. +inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) # map is not type stable for many stacked bijectors as a large tuple # hence the generated function -@generated function inv(sb::Stacked{A}, ::Invertible) where {A <: Tuple} +@generated function inverse(sb::Stacked{A}) where {A <: Tuple} exprs = [] for i = 1:length(A.parameters) push!(exprs, :(inverse(sb.bs[$i]))) @@ -64,18 +67,18 @@ inv(sb::Stacked, ::Invertible) = Stacked(map(inv, sb.bs), sb.ranges) :(Stacked(($(exprs...), ), sb.ranges)) end -@generated function _transform(x, rs::NTuple{N, UnitRange{Int}}, bs::Bijector...) where N +@generated function _transform(x, rs::NTuple{N, UnitRange{Int}}, bs...) where N exprs = [] for i = 1:N push!(exprs, :(bs[$i](x[rs[$i]]))) end return :(vcat($(exprs...))) end -function _transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector) +function _transform(x, rs::NTuple{1, UnitRange{Int}}, b) @assert rs[1] == 1:length(x) return b(x) end -function (sb::Stacked{<:Tuple,<:Tuple})(x::AbstractVector{<:Real}) +function transform(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real}) y = _transform(x, sb.ranges, sb.bs...) @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" return y diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index 26605861..180377ad 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -65,3 +65,5 @@ function truncated_logabsdetjac(x, a, b) return zero(x) end end + +with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x) diff --git a/src/interface.jl b/src/interface.jl index ec012635..5597451e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -62,11 +62,11 @@ If the `Transform` is also invertible: - `InverseFunctions.inverse(b::MyTransform)`: returns an existing `Transform`. - [`logabsdetjac`](@ref): computes the log-abs-det jacobian factor. - Optional: - - [`forward`](@ref): `transform` and `logabsdetjac` combined. Useful in cases where we + - [`with_logabsdet_jacobian`](@ref): `transform` and `logabsdetjac` combined. Useful in cases where we can exploit shared computation in the two. For the above methods, there are mutating versions which can _optionally_ be implemented: -- [`transform!`](@ref) +- [`with_logabsdet_jacobian!`](@ref) - [`logabsdetjac!`](@ref) - [`forward!`](@ref) """ @@ -81,7 +81,8 @@ Broadcast.broadcastable(b::Transform) = Ref(b) Transform `x` using `b`, treating `x` as a single input. """ -function transform end +transform(f::Function, x) = f(x) +transform(t::Transform, x) = first(with_logabsdet_jacobian(t, x)) """ transform!(b, x[, y]) @@ -98,7 +99,7 @@ transform!(b, x, y) = (y .= transform(b, x)) Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. """ -logabsdetjac(b, x) = last(forward(b, x)) +logabsdetjac(b, x) = last(with_logabsdet_jacobian(b, x)) """ logabsdetjac!(b, x[, logjac]) @@ -108,29 +109,22 @@ Compute `log(abs(det(J(b, x))))` and store the result in `logjac`, where `J(b, x logabsdetjac!(b, x) = logabsdetjac!(b, x, zero(eltype(x))) logabsdetjac!(b, x, logjac) = (logjac += logabsdetjac(b, x)) -""" - forward(b, x) - -Return `(transform(b, x), logabsdetjac(b, x))` treating `x` as single input. - -Defaults to `ChangeOfVariables.with_logabsdet_jacobian(b, x)`. -""" -forward(b, x) = with_logabsdet_jacobian(b, x) +# with_logabsdet_jacobian(b::Transform, x) = (transform(b, x), logabsdetjac(b, x)) """ - forward!(b, x[, y, logjac]) + with_logabsdet_jacobian!(b, x[, y, logjac]) Compute `transform(b, x)` and `logabsdetjac(b, x)`, storing the result in `y` and `logjac`, respetively. If `y` is not provided, then `x` will be used in its place. -Defaults to calling `forward(b, x)` and updating `y` and `logjac` with the result. +Defaults to calling `with_logabsdet_jacobian(b, x)` and updating `y` and `logjac` with the result. """ -forward!(b, x) = forward!(b, x, x) -forward!(b, x, y) = forward!(b, x, y, zero(eltype(x))) -function forward!(b, x, y, logjac) - y_, logjac_ = forward(b, x) +with_logabsdet_jacobian!(b, x) = with_logabsdet_jacobian!(b, x, x) +with_logabsdet_jacobian!(b, x, y) = with_logabsdet_jacobian!(b, x, y, zero(eltype(x))) +function with_logabsdet_jacobian!(b, x, y, logjac) + y_, logjac_ = with_logabsdet_jacobian(b, x) y .= y_ return (y, logjac + logjac_) end @@ -210,18 +204,27 @@ abstract type Bijector <: Transform end invertible(::Bijector) = Invertible() # Default implementation for inverse of a `Bijector`. -logabsdetjac(ib::Inverse{<:Bijector}, y) = -logabsdetjac(ib.orig, ib(y)) +logabsdetjac(ib::Inverse{<:Transform}, y) = -logabsdetjac(ib.orig, transform(ib, y)) + +function with_logabsdet_jacobian(ib::Inverse{<:Transform}, y) + x = transform(ib, y) + return x, -logabsdetjac(inverse(ib), x) +end """ logabsdetjacinv(b::Bijector, y) Just an alias for `logabsdetjac(inverse(b), y)`. """ -logabsdetjacinv(b::Bijector, y) = logabsdetjac(inverse(b), y) +logabsdetjacinv(b, y) = logabsdetjac(inverse(b), y) ############################## # Example bijector: Identity # ############################## +Identity() = identity + +invertible(::typeof(identity)) = Invertible() + # Here we don't need to separate between batched version and non-batched, and so # we can just overload `transform`, etc. directly. transform(::typeof(identity), x) = copy(x) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 02152f6a..cf3b8778 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -1,20 +1,20 @@ # Transformed distributions -struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector} +struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B} dist::D transform::B -end -TransformedDistribution(d::UnivariateDistribution, b::Bijector) = new{typeof(d), typeof(b), Univariate}(d, b) -TransformedDistribution(d::MultivariateDistribution, b::Bijector) = new{typeof(d), typeof(b), Multivariate}(d, b) -TransformedDistribution(d::MatrixDistribution, b::Bijector) = new{typeof(d), typeof(b), Matrixvariate}(d, b) + TransformedDistribution(d::UnivariateDistribution, b) = new{typeof(d), typeof(b), Univariate}(d, b) + TransformedDistribution(d::MultivariateDistribution, b) = new{typeof(d), typeof(b), Multivariate}(d, b) + TransformedDistribution(d::MatrixDistribution, b) = new{typeof(d), typeof(b), Matrixvariate}(d, b) +end # fields may contain nested numerical parameters Functors.@functor TransformedDistribution -const UnivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Univariate} -const MultivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Multivariate} +const UnivariateTransformed = TransformedDistribution{<:Distribution,<:Any,Univariate} +const MultivariateTransformed = TransformedDistribution{<:Distribution,<:Any,Multivariate} const MvTransformed = MultivariateTransformed -const MatrixTransformed = TransformedDistribution{<:Distribution, <:Bijector, Matrixvariate} +const MatrixTransformed = TransformedDistribution{<:Distribution,<:Any,Matrixvariate} const Transformed = TransformedDistribution @@ -27,7 +27,7 @@ Couples distribution `d` with the bijector `b` by returning a `TransformedDistri If no bijector is provided, i.e. `transformed(d)` is called, then `transformed(d, bijector(d))` is returned. """ -transformed(d::Distribution, b::Bijector) = TransformedDistribution(d, b) +transformed(d::Distribution, b) = TransformedDistribution(d, b) transformed(d) = transformed(d, bijector(d)) """ @@ -36,8 +36,8 @@ transformed(d) = transformed(d, bijector(d)) Returns the constrained-to-unconstrained bijector for distribution `d`. """ bijector(td::TransformedDistribution) = bijector(td.dist) ∘ inverse(td.transform) -bijector(d::DiscreteUnivariateDistribution) = Identity{0}() -bijector(d::DiscreteMultivariateDistribution) = Identity{1}() +bijector(d::DiscreteUnivariateDistribution) = Identity() +bijector(d::DiscreteMultivariateDistribution) = Identity() bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) bijector(d::Product{Discrete}) = Identity() function bijector(d::Product{Continuous}) @@ -218,7 +218,7 @@ end const GLOBAL_RNG = Distributions.GLOBAL_RNG function _forward(d::UnivariateDistribution, x) - y, logjac = with_logabsdet_jacobian(Identity{0}(), x) + y, logjac = with_logabsdet_jacobian(Identity(), x) return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf.(d, x)) end @@ -227,7 +227,7 @@ function forward(rng::AbstractRNG, d::Distribution, num_samples::Int) return _forward(d, rand(rng, d, num_samples)) end function _forward(d::Distribution, x) - y, logjac = with_logabsdet_jacobian(Identity{length(size(d))}(), x) + y, logjac = with_logabsdet_jacobian(Identity(), x) return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf(d, x)) end @@ -244,7 +244,7 @@ function forward(rng::AbstractRNG, td::Transformed) return _forward(td, rand(rng, td.dist)) end function forward(rng::AbstractRNG, td::Transformed, num_samples::Int) - return _forward(td, rand(rng, td.dist, num_samples)) + return [_forward(td, rand(rng, td.dist)) for _ = 1:num_samples] end """ From 52c8ed7b13e696aeab84fc288172c66eb388583f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 10:24:43 +0000 Subject: [PATCH 48/98] updated compat --- src/compat/distributionsad.jl | 34 ++++++++-------- src/compat/reversediff.jl | 39 +++++------------- src/compat/tracker.jl | 75 ++++++++--------------------------- src/compat/zygote.jl | 40 +++---------------- 4 files changed, 49 insertions(+), 139 deletions(-) diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl index 85ef72ad..215f9d7f 100644 --- a/src/compat/distributionsad.jl +++ b/src/compat/distributionsad.jl @@ -9,28 +9,28 @@ using Distributions: AbstractMvLogNormal bijector(::TuringDirichlet) = SimplexBijector() bijector(::TuringWishart) = PDBijector() bijector(::TuringInverseWishart) = PDBijector() -bijector(::TuringScalMvNormal) = Identity{1}() -bijector(::TuringDiagMvNormal) = Identity{1}() -bijector(::TuringDenseMvNormal) = Identity{1}() +bijector(::TuringScalMvNormal) = Identity() +bijector(::TuringDiagMvNormal) = Identity() +bijector(::TuringDenseMvNormal) = Identity() -bijector(d::FillVectorOfUnivariate{Continuous}) = up1(bijector(d.v.value)) -bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(up1(bijector(d.dists.value))) -bijector(d::MatrixOfUnivariate{Discrete}) = Identity{2}() -bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector{2}(_minmax(d.dists)...) -bijector(d::VectorOfMultivariate{Discrete}) = Identity{2}() +bijector(d::FillVectorOfUnivariate{Continuous}) = bijector(d.v.value) +bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(bijector(d.dists.value)) +bijector(d::MatrixOfUnivariate{Discrete}) = Identity() +bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...) +bijector(d::VectorOfMultivariate{Discrete}) = Identity() for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) @eval begin - bijector(d::$T{Continuous, <:MvNormal}) = Identity{2}() - bijector(d::$T{Continuous, <:TuringScalMvNormal}) = Identity{2}() - bijector(d::$T{Continuous, <:TuringDiagMvNormal}) = Identity{2}() - bijector(d::$T{Continuous, <:TuringDenseMvNormal}) = Identity{2}() - bijector(d::$T{Continuous, <:MvNormalCanon}) = Identity{2}() - bijector(d::$T{Continuous, <:AbstractMvLogNormal}) = Log{2}() - bijector(d::$T{Continuous, <:SimplexDistribution}) = SimplexBijector{2}() - bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector{2}() + bijector(d::$T{Continuous, <:MvNormal}) = Identity() + bijector(d::$T{Continuous, <:TuringScalMvNormal}) = Identity() + bijector(d::$T{Continuous, <:TuringDiagMvNormal}) = Identity() + bijector(d::$T{Continuous, <:TuringDenseMvNormal}) = Identity() + bijector(d::$T{Continuous, <:MvNormalCanon}) = Identity() + bijector(d::$T{Continuous, <:AbstractMvLogNormal}) = Log() + bijector(d::$T{Continuous, <:SimplexDistribution}) = SimplexBijector() + bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector() end end -bijector(d::FillVectorOfMultivariate{Continuous}) = up1(bijector(d.dists.value)) +bijector(d::FillVectorOfMultivariate{Continuous}) = bijector(d.dists.value) isdirichlet(::VectorOfMultivariate{Continuous, <:Dirichlet}) = true isdirichlet(::VectorOfMultivariate{Continuous, <:TuringDirichlet}) = true diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 116d8531..86130418 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -4,7 +4,7 @@ using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVecto TrackedMatrix using Requires, LinearAlgebra -using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian, +using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian, simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector, ReverseDiffAD, Inverse import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, @@ -49,13 +49,10 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end end -logabsdetjac(b::Log{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) -@grad function logabsdetjac(b::Log{1}, x::AbstractVector) +logabsdetjac(b::Elementwise{typeof(log)}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) +@grad function logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractVector) return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x)) end -@grad function logabsdetjac(b::Log{1}, x::AbstractMatrix) - return -vec(sum(log, value(x); dims = 1)), Δ -> (nothing, .- Δ' ./ value(x)) -end function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0}) return track(_logabsdetjac_scale, a, value(x), Val(0)) end @@ -100,30 +97,22 @@ end Jᵀ = repeat(inv.(da), 1, size(x, 2)) return _logabsdetjac_scale(da, value(x), Val(1)), Δ -> (Jᵀ * Δ, nothing, nothing) end -function _simplex_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector{1}) +function _simplex_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector) return track(_simplex_bijector, X, b) end -@grad function _simplex_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@grad function _simplex_bijector(Y::AbstractVector, b::SimplexBijector) Yd = value(Y) return _simplex_bijector(Yd, b), Δ -> (simplex_link_jacobian(Yd)' * Δ, nothing) end -@grad function _simplex_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) - Yd = value(Y) - return _simplex_bijector(Yd, b), Δ -> begin - maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 - simplex_link_jacobian(c1)' * c2 - end, nothing - end -end -function _simplex_inv_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector{1}) +function _simplex_inv_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector) return track(_simplex_inv_bijector, X, b) end -@grad function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@grad function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) Yd = value(Y) return _simplex_inv_bijector(Yd, b), Δ -> (simplex_invlink_jacobian(Yd)' * Δ, nothing) end -@grad function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) +@grad function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector) Yd = value(Y) return _simplex_inv_bijector(Yd, b), Δ -> begin maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 @@ -154,21 +143,13 @@ replace_diag(::typeof(exp), X::TrackedMatrix) = track(replace_diag, exp, X) end end -logabsdetjac(b::SimplexBijector{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) -@grad function logabsdetjac(b::SimplexBijector{1}, x::AbstractVector) +logabsdetjac(b::SimplexBijector, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) +@grad function logabsdetjac(b::SimplexBijector, x::AbstractVector) xd = value(x) return logabsdetjac(b, xd), Δ -> begin (nothing, simplex_logabsdetjac_gradient(xd) * Δ) end end -@grad function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix) - xd = value(x) - return logabsdetjac(b, xd), Δ -> begin - (nothing, maphcat(eachcol(xd), Δ) do c, g - simplex_logabsdetjac_gradient(c) * g - end) - end -end getpd(X::TrackedMatrix) = track(getpd, X) @grad function getpd(X::AbstractMatrix) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index eed09521..53727813 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,8 +12,8 @@ using ..Tracker: Tracker, param import ..Bijectors -using ..Bijectors: Log, SimplexBijector, ADBijector, - TrackerAD, Inverse, Stacked, Exp +using ..Bijectors: Elementwise, SimplexBijector, ADBijector, + TrackerAD, Inverse, Stacked import ChainRulesCore import LogExpFunctions @@ -91,13 +91,10 @@ end # Log bijector -@grad function Bijectors.logabsdetjac(b::Log{1}, x::AbstractVector) +@grad function Bijectors.logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractVector) return -sum(log, data(x)), Δ -> (nothing, -Δ ./ data(x)) end -@grad function Bijectors.logabsdetjac(b::Log{1}, x::AbstractMatrix) - return -vec(sum(log, data(x); dims = 1)), Δ -> (nothing, .- Δ' ./ data(x)) -end -@grad function Bijectors.logabsdetjac(b::Log{2}, x::AbstractMatrix) +@grad function Bijectors.logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractMatrix) return -sum(log, data(x)), Δ -> (nothing, -Δ ./ data(x)) end @@ -154,49 +151,23 @@ end # @grad function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Val{1}) # throw # end -# implementations for Stacked bijector -function Bijectors.logabsdetjac(b::Stacked, x::TrackedMatrix{<:Real}) - return map(eachcol(x)) do c - Bijectors.logabsdetjac(b, c) - end -end -# TODO: implement custom adjoint since we can exploit block-diagonal nature of `Stacked` -function (sb::Stacked)(x::TrackedMatrix{<:Real}) - return Bijectors.eachcolmaphcat(sb, x) -end + # Simplex adjoints -function Bijectors._simplex_bijector(X::TrackedVecOrMat, b::SimplexBijector{1}) +function Bijectors._simplex_bijector(X::TrackedVecOrMat, b::SimplexBijector) return track(Bijectors._simplex_bijector, X, b) end -function Bijectors._simplex_inv_bijector(Y::TrackedVecOrMat, b::SimplexBijector{1}) +function Bijectors._simplex_inv_bijector(Y::TrackedVecOrMat, b::SimplexBijector) return track(Bijectors._simplex_inv_bijector, Y, b) end -@grad function Bijectors._simplex_bijector(X::AbstractVector, b::SimplexBijector{1}) +@grad function Bijectors._simplex_bijector(X::AbstractVector, b::SimplexBijector) Xd = data(X) return Bijectors._simplex_bijector(Xd, b), Δ -> (Bijectors.simplex_link_jacobian(Xd)' * Δ, nothing) end -@grad function Bijectors._simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@grad function Bijectors._simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) Yd = data(Y) return Bijectors._simplex_inv_bijector(Yd, b), Δ -> (Bijectors.simplex_invlink_jacobian(Yd)' * Δ, nothing) end -@grad function Bijectors._simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) - Xd = data(X) - return Bijectors._simplex_bijector(Xd, b), Δ -> begin - Bijectors.maphcat(eachcol(Xd), eachcol(Δ)) do c1, c2 - Bijectors.simplex_link_jacobian(c1)' * c2 - end, nothing - end -end -@grad function Bijectors._simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) - Yd = data(Y) - return Bijectors._simplex_inv_bijector(Yd, b), Δ -> begin - Bijectors.maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 - Bijectors.simplex_invlink_jacobian(c1)' * c2 - end, nothing - end -end - Bijectors.replace_diag(::typeof(log), X::TrackedMatrix) = track(Bijectors.replace_diag, log, X) @grad function Bijectors.replace_diag(::typeof(log), X) Xd = data(X) @@ -219,21 +190,13 @@ Bijectors.replace_diag(::typeof(exp), X::TrackedMatrix) = track(Bijectors.replac end end -Bijectors.logabsdetjac(b::SimplexBijector{1}, x::TrackedVecOrMat) = track(Bijectors.logabsdetjac, b, x) -@grad function Bijectors.logabsdetjac(b::SimplexBijector{1}, x::AbstractVector) +Bijectors.logabsdetjac(b::SimplexBijector, x::TrackedVecOrMat) = track(Bijectors.logabsdetjac, b, x) +@grad function Bijectors.logabsdetjac(b::SimplexBijector, x::AbstractVector) xd = data(x) return Bijectors.logabsdetjac(b, xd), Δ -> begin (nothing, Bijectors.simplex_logabsdetjac_gradient(xd) * Δ) end end -@grad function Bijectors.logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix) - xd = data(x) - return Bijectors.logabsdetjac(b, xd), Δ -> begin - (nothing, Bijectors.maphcat(eachcol(xd), Δ) do c, g - Bijectors.simplex_logabsdetjac_gradient(c) * g - end) - end -end for header in [ (:(α_::TrackedReal), :β, :z_0, :(z::AbstractVector)), @@ -327,18 +290,12 @@ function vectorof(::Type{TrackedReal{T}}) where {T<:Real} return TrackedArray{T,1,Vector{T}} end -(b::Exp{0})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) -(b::Exp{1})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) -(b::Exp{1})(x::TrackedMatrix) = exp.(x)::matrixof(float(eltype(x))) -(b::Exp{2})(x::TrackedMatrix) = exp.(x)::matrixof(float(eltype(x))) - -(b::Log{0})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) -(b::Log{1})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) -(b::Log{1})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) -(b::Log{2})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) +(b::Elementwise{typeof(exp)})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) +(b::Elementwise{typeof(exp)})(x::TrackedVector) = exp.(x)::vectorof(float(eltype(x))) +(b::Elementwise{typeof(exp)})(x::TrackedMatrix) = exp.(x)::matrixof(float(eltype(x))) -Bijectors.logabsdetjac(b::Log{0}, x::TrackedVector) = .-log.(x)::vectorof(float(eltype(x))) -Bijectors.logabsdetjac(b::Log{1}, x::TrackedMatrix) = - vec(sum(log.(x); dims = 1)) +(b::Elementwise{typeof(log)})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) +(b::Elementwise{typeof(log)})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) @grad function Bijectors.getpd(X::AbstractMatrix) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index c971597b..af200c98 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -28,12 +28,9 @@ end return pullback(g, f, x1, x2) end -@adjoint function logabsdetjac(b::Log{1}, x::AbstractVector) +@adjoint function logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractVector) return -sum(log, x), Δ -> (nothing, -Δ ./ x) end -@adjoint function logabsdetjac(b::Log{1}, x::AbstractMatrix) - return -vec(sum(log, x; dims = 1)), Δ -> (nothing, .- Δ' ./ x) -end # AD implementations function jacobian( @@ -119,21 +116,21 @@ end # Simplex adjoints -@adjoint function _simplex_bijector(X::AbstractVector, b::SimplexBijector{1}) +@adjoint function _simplex_bijector(X::AbstractVector, b::SimplexBijector) return _simplex_bijector(X, b), Δ -> (simplex_link_jacobian(X)' * Δ, nothing) end -@adjoint function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@adjoint function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) return _simplex_inv_bijector(Y, b), Δ -> (simplex_invlink_jacobian(Y)' * Δ, nothing) end -@adjoint function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) +@adjoint function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector) return _simplex_bijector(X, b), Δ -> begin maphcat(eachcol(X), eachcol(Δ)) do c1, c2 simplex_link_jacobian(c1)' * c2 end, nothing end end -@adjoint function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) +@adjoint function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector) return _simplex_inv_bijector(Y, b), Δ -> begin maphcat(eachcol(Y), eachcol(Δ)) do c1, c2 simplex_invlink_jacobian(c1)' * c2 @@ -141,18 +138,11 @@ end end end -@adjoint function logabsdetjac(b::SimplexBijector{1}, x::AbstractVector) +@adjoint function logabsdetjac(b::SimplexBijector, x::AbstractVector) return logabsdetjac(b, x), Δ -> begin (nothing, simplex_logabsdetjac_gradient(x) * Δ) end end -@adjoint function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix) - return logabsdetjac(b, x), Δ -> begin - (nothing, maphcat(eachcol(x), Δ) do c, g - simplex_logabsdetjac_gradient(c) * g - end) - end -end # LocationScale fix @@ -290,21 +280,3 @@ end return z, pullback_link_chol_lkj end - -# Otherwise Zygote complains. -Zygote.@adjoint function Batch(x) - return Batch(x), function(_Δ) - # Sometimes `Batch` has been extraced using `value`, in which case - # we get back a `NamedTuple`. - # Other times the value was extracted by iterating over `Batch`, - # in which case we don't get a `NamedTuple`. - Δ = _Δ isa NamedTuple{(:value, )} ? _Δ.value : _Δ - return if Δ isa AbstractArray{<:Real} - (Δ, ) - else - Δ_new = similar(x, eltype(Δ[1])) - Δ_new[:] .= vcat(Δ...) - (Δ_new, ) - end - end -end From 99a421e43da2ab5c13e9e6abd5fbb0e8672211ba Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 10:24:56 +0000 Subject: [PATCH 49/98] updated tests --- test/bijectors/coupling.jl | 10 +- test/bijectors/leaky_relu.jl | 96 ++---- test/bijectors/named_bijector.jl | 18 +- test/bijectors/ordered.jl | 10 +- test/bijectors/rational_quadratic_spline.jl | 12 +- test/bijectors/utils.jl | 245 +++----------- test/interface.jl | 355 ++------------------ test/runtests.jl | 7 +- test/transform.jl | 60 +--- 9 files changed, 136 insertions(+), 677 deletions(-) diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index 38ebd763..5d6c6607 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -54,8 +54,12 @@ using Bijectors: # With `Scale` cl = Coupling(x -> Scale(x[1]), m) - x = hcat([-1., -2., -3.], [1., 2., 3.]) - y = hcat([2., -2., -3.], [2., 2., 3.]) - test_bijector(cl, x, y, log.([2., 2.])) + x = [-1., -2., -3.] + y = [2., -2., -3.] + test_bijector(cl, x; y=y, logjac=log(2)) + + x = [1., 2., 3.] + y = [2., 2., 3.] + test_bijector(cl, x; y=y, logjac=log(2)) end end diff --git a/test/bijectors/leaky_relu.jl b/test/bijectors/leaky_relu.jl index e110a046..a51f33e5 100644 --- a/test/bijectors/leaky_relu.jl +++ b/test/bijectors/leaky_relu.jl @@ -3,84 +3,48 @@ using Test using Bijectors using Bijectors: LeakyReLU -using LinearAlgebra -using ForwardDiff - -true_logabsdetjac(b::Bijector{0}, x::Real) = (log ∘ abs)(ForwardDiff.derivative(b, x)) -true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log ∘ abs).(ForwardDiff.derivative.(b, x)) -true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1] -true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs)) - @testset "0-dim parameter, 0-dim input" begin - b = LeakyReLU(0.1; dim=Val(0)) - x = 1. - @test inverse(b)(b(x)) == x - @test inverse(b)(b(-x)) == -x - - # Mixing of types - # 1. Changes in input-type - @assert eltype(b(Float32(1.))) == Float64 - @assert eltype(b(Float64(1.))) == Float64 + b = LeakyReLU(0.1) - # 2. Changes in parameter-type - b = LeakyReLU(Float32(0.1); dim=Val(0)) - @assert eltype(b(Float32(1.))) == Float32 - @assert eltype(b(Float64(1.))) == Float64 + # < 0 + x = -1.0 + test_bijector(b, x) - # logabsdetjac - @test logabsdetjac(b, x) == true_logabsdetjac(b, x) - @test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x) + # ≥ 0 + x = 1.0 + test_bijector(b, x; test_not_identity=false, test_types=true) - # Batch - xs = randn(10) - @test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) - @test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) + # Float32 + b = LeakyReLU(Float32(b.α)) - @test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) - @test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) + # < 0 + x = -1f0 + test_bijector(b, x) - # Forward - f = with_logabsdet_jacobian(b, xs) - @test f[2] ≈ logabsdetjac(b, xs) - @test f[1] ≈ b(xs) - - f = with_logabsdet_jacobian(b, Float32.(xs)) - @test f[2] == logabsdetjac(b, Float32.(xs)) - @test f[1] ≈ b(Float32.(xs)) + # ≥ 0 + x = 1f0 + test_bijector(b, x; test_not_identity=false, test_types=true) end @testset "0-dim parameter, 1-dim input" begin d = 2 + b = LeakyReLU(0.1) - b = LeakyReLU(0.1; dim=Val(1)) - x = ones(d) - @test inverse(b)(b(x)) == x - @test inverse(b)(b(-x)) == -x + # < 0 + x = -ones(d) + test_bijector(b, x) - # Batch - xs = randn(d, 10) - @test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) - @test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) - - @test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) - @test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) - - # Forward - f = with_logabsdet_jacobian(b, xs) - @test f[2] ≈ logabsdetjac(b, xs) - @test f[1] ≈ b(xs) - - f = with_logabsdet_jacobian(b, Float32.(xs)) - @test f[2] == logabsdetjac(b, Float32.(xs)) - @test f[1] ≈ b(Float32.(xs)) + # ≥ 0 + x = ones(d) + test_bijector(b, x; test_not_identity=false) - # Mixing of types - # 1. Changes in input-type - @assert eltype(b(ones(Float32, 2))) == Float64 - @assert eltype(b(ones(Float64, 2))) == Float64 + # Float32 + b = LeakyReLU(Float32.(b.α)) + # < 0 + x = -ones(Float32, d) + test_bijector(b, x; test_types=true) - # 2. Changes in parameter-type - b = LeakyReLU(Float32(0.1); dim=Val(1)) - @assert eltype(b(ones(Float32, 2))) == Float32 - @assert eltype(b(ones(Float64, 2))) == Float64 + # ≥ 0 + x = ones(Float32, d) + test_bijector(b, x; test_not_identity=false, test_types=true) end diff --git a/test/bijectors/named_bijector.jl b/test/bijectors/named_bijector.jl index a7248fae..9dae0365 100644 --- a/test/bijectors/named_bijector.jl +++ b/test/bijectors/named_bijector.jl @@ -1,26 +1,12 @@ using Test using Bijectors -using Bijectors: Exp, Log, Logit, AbstractNamedBijector, NamedBijector, NamedInverse, NamedCoupling, NamedComposition, Shift +using Bijectors: Exp, Log, Logit, AbstractNamedTransform, NamedBijector, NamedCoupling, Shift @testset "NamedBijector" begin b = NamedBijector((a = Exp(), b = Log())) @test b((a = 0.0, b = exp(1.0))) == (a = 1.0, b = 1.0) -end - -@testset "NamedComposition" begin - b = NamedBijector((a = Exp(), )) - x = (a = 0., b = 1.) - - nc1 = NamedComposition((b, b)) - @test nc1(x) == b(b(x)) - @test logabsdetjac(nc1, x) ≈ logabsdetjac(b, x) + logabsdetjac(b, b(x)) - - nc2 = b ∘ b - @test nc1 == nc2 - inc2 = inverse(nc2) - @test (inc2 ∘ nc2)(x) == x - @test logabsdetjac((inc2 ∘ nc2), x) ≈ 0.0 + with_logabsdet_jacobian(b, (a = 0.0, b = exp(1.0))) end @testset "NamedCoupling" begin diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index 1bf53931..058ee77d 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -5,18 +5,12 @@ import Bijectors: OrderedBijector # Length 1 x = randn(1) - y = b(x) - test_bijector(b, hcat(x, x), hcat(y, y), zeros(2)) + test_bijector(b, x; test_not_identity=false) # Larger x = randn(5) - xs = hcat(x, x) - test_bijector(b, xs) + test_bijector(b, x) y = b(x) @test sort(y) == y - - ys = b(xs) - @test sort(ys[:, 1]) == ys[:, 1] - @test sort(ys[:, 2]) == ys[:, 2] end diff --git a/test/bijectors/rational_quadratic_spline.jl b/test/bijectors/rational_quadratic_spline.jl index a80c4bc3..fe0ddc31 100644 --- a/test/bijectors/rational_quadratic_spline.jl +++ b/test/bijectors/rational_quadratic_spline.jl @@ -38,23 +38,23 @@ using Bijectors: RationalQuadraticSpline # Inside of domain x = 0.5 - test_bijector(b, [-x, x]) + test_bijector(b, -x) + test_bijector(b, x) # Outside of domain - x = 5. - test_bijector(b, [-x, x], [-x, x], [0., 0.]) + x = 5.0 + test_bijector(b, -x; y=-x, logjac=0) + test_bijector(b, x; y=x, logjac=0) # multivariate b = b_mv # Inside of domain x = [-0.5, 0.5] - x = hcat(x, -x, x) # batch test_bijector(b, x) # Outside of domain x = [-5., 5.] - x = hcat(x, -x, x) # batch - test_bijector(b, x, x, zeros(size(x, 2))) + test_bijector(b, x; y=x, logjac=zero(eltype(x))) end end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index 82a34a5f..cf1283dd 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -1,195 +1,72 @@ -import Bijectors: AbstractBatch +# Allows us to run `ChangesOfVariables.test_with_logabsdet_jacobian` +include(joinpath(dirname(pathof(ChangesOfVariables)), "..", "test", "getjacobian.jl")) +test_bijector(b, x; kwargs...) = test_bijector(b, x, getjacobian; kwargs...) -function test_bijector_single( - b::Bijector, - x_true, - y_true, - logjac_true; - isequal = true, - tol = 1e-6 +# TODO: Should we move this into `src/`? +function test_bijector( + b, + x, + getjacobian; + y=nothing, + logjac=nothing, + test_not_identity=isnothing(y) && isnothing(logjac), + test_types=false, + compare=isapprox, + kwargs... ) + # Ensure that everything is type-stable. ib = @inferred inverse(b) - y = @inferred b(x_true) - logjac = @inferred logabsdetjac(b, x_true) - ilogjac = @inferred logabsdetjac(ib, y_true) - res = @inferred with_logabsdet_jacobian(b, x_true) - - # If `isequal` is false, then we use the computed `y`, - # but if it's true, we use the true `y`. - ires = isequal ? @inferred(with_logabsdet_jacobian(inverse(b), y_true)) : @inferred(with_logabsdet_jacobian(inverse(b), y)) + logjac_test = @inferred logabsdetjac(b, x) + res = @inferred with_logabsdet_jacobian(b, x) - # Always want the following to hold - @test ires[1] ≈ x_true atol=tol - @test ires[2] ≈ -logjac atol=tol - - if isequal - @test y ≈ y_true atol=tol # forward - @test (@inferred ib(y_true)) ≈ x_true atol=tol # inverse - @test logjac ≈ logjac_true # logjac forward - @test res[1] ≈ y_true atol=tol # forward using `forward` - @test res[2] ≈ logjac_true atol=tol # logjac using `forward` + y_test = @inferred b(x) + ilogjac_test = !isnothing(y) ? @inferred(logabsdetjac(ib, y)) : @inferred(logabsdetjac(ib, y_test)) + ires = if !isnothing(y) + @inferred(with_logabsdet_jacobian(inverse(b), y)) else - @test y ≠ y_true # forward - @test (@inferred ib(y)) ≈ x_true atol=tol # inverse - @test logjac ≠ logjac_true # logjac forward - @test res[1] ≠ y_true # forward using `forward` - @test res[2] ≠ logjac_true # logjac using `forward` + @inferred(with_logabsdet_jacobian(inverse(b), y_test)) end -end - -function test_bijector_batch( - b::Bijector, - xs_true::AbstractBatch, - ys_true::AbstractBatch, - logjacs_true; - isequal = true, - tol = 1e-6 -) - ib = @inferred inverse(b) - ys = @inferred b(xs_true) - logjacs = @inferred logabsdetjac(b, xs_true) - res = @inferred with_logabsdet_jacobian(b, xs_true) - # If `isequal` is false, then we use the computed `y`, - # but if it's true, we use the true `y`. - ires = isequal ? @inferred(with_logabsdet_jacobian(inverse(b), ys_true)) : @inferred(with_logabsdet_jacobian(inverse(b), ys)) - - # always want the following to hold - @test ys isa typeof(ys_true) - @test logjacs isa typeof(logjacs_true) - @test mean(abs, ires[1] - xs_true) ≤ tol - @test mean(abs, ires[2] + logjacs) ≤ tol - - if isequal - @test mean(abs, ys - ys_true) ≤ tol # forward - @test mean(abs, (ib(ys_true)) - xs_true) ≤ tol # inverse - @test mean(abs, logjacs - logjacs_true) ≤ tol # logjac forward - @test mean(abs, res[1] - ys_true) ≤ tol # forward using `forward` - @test mean(abs, res[2] - logjacs_true) ≤ tol # logjac `forward` - @test mean(abs, ires[2] + logjacs_true) ≤ tol # inverse logjac `forward` - else - # Don't want the following to be equal to their "true" values - @test mean(norm, ys - ys_true) > tol # forward - @test mean(abs, logjacs - logjacs_true) > tol # logjac forward - @test mean(abs, res[1] - ys_true) > tol # forward using `forward` - - # Still want the following to be equal to the COMPUTED values - @test mean(abs, ib(ys) - xs_true) ≤ tol # inverse - @test mean(abs, res[2] - logjacs) ≤ tol # logjac forward using `forward` - end -end -""" - test_bijector(b::Bijector, xs::Array, ys::Array, logjacs::Array; kwargs...) - -Tests the bijector `b` on the inputs `xs` against the, optionally, provided `ys` -and `logjacs`. - -If `ys` and `logjacs` are NOT provided, `isequal` will be set to `false` and -`ys` and `logjacs` will be set to `zeros`. These `ys` and `logjacs` will be -treated as "counter-examples", i.e. values NOT to match. - -# Arguments -- `b::Bijector`: the bijector to test -- `xs`: inputs (has to be several!!!)(has to be several, i.e. a batch!!!) to test -- `ys`: outputs (has to be several, i.e. a batch!!!) to test against -- `logjacs`: `logabsdetjac` outputs (has to be several!!!)(has to be several, i.e. - a batch!!!) to test against - -# Keywords -- `isequal = true`: if `false`, it will be assumed that the given values are - provided as "counter-examples" in the sense that the inputs `xs` should NOT map - to the given outputs. This is useful in cases where one might not know the expected - output, but still wants to test that the evaluation, etc. works. - This is set to `true` by default if `ys` and `logjacs` are not provided. -- `tol = 1e-6`: the absolute tolerance used for the checks. This is also used to check - arrays where we check that the L1-norm is sufficiently small. -""" -function test_bijector( - b::Bijector, - xs_true::AbstractBatch, - ys_true::AbstractBatch, - logjacs_true::AbstractBatch; - kwargs... -) - ib = inverse(b) + # ChangesOfVariables.jl + ChangesOfVariables.test_with_logabsdet_jacobian(b, x, getjacobian; compare=compare, kwargs...) + ChangesOfVariables.test_with_logabsdet_jacobian(ib, isnothing(y) ? y_test : y, getjacobian; compare=compare, kwargs...) - # Batch - test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) + # InverseFunctions.jl + InverseFunctions.test_inverse(b, x; compare, kwargs...) + InverseFunctions.test_inverse(ib, isnothing(y) ? y_test : y; compare=compare, kwargs...) - # Test `logabsdetjac` against jacobians - test_logabsdetjac(b, xs_true) - - if Bijectors.isinvertible(b) - ib = inv(b) - test_logabsdetjac(ib, ys_true) + # Always want the following to hold + @test compare(ires[1], x; kwargs...) + @test compare(ires[2], -logjac_test; kwargs...) + + # Verify values. + if !isnothing(y) + @test compare(y_test, y; kwargs...) + @test compare((@inferred ib(y)), x; kwargs...) # inverse + @test compare(res[1], y; kwargs...) # forward using `forward` end - - for (x_true, y_true, logjac_true) in zip(xs_true, ys_true, logjacs_true) - # Test validity of single input. - test_bijector_single(b, x_true, y_true, logjac_true; kwargs...) - - # Test AD wrt. inputs. - test_ad(x -> sum(b(x)), x_true) - test_ad(x -> logabsdetjac(b, x), x_true) - if Bijectors.isinvertible(b) - y = b(x_true) - test_ad(x -> sum(ib(x)), y) - end + if !isnothing(logjac) + # We've already checked `ires[2]` against `res[2]`, so if `res[2]` is correct, then so is `ires[2]`. + @test compare(logjac_test, logjac; kwargs...) # logjac forward + @test compare(res[2], logjac; kwargs...) # logjac using `forward` end - # Test AD wrt. parameters. - test_bijector_parameter_gradient(b, xs_true[1], ys_true[1]) - - # Test validity of collection of inputs. - test_bijector_batch(b, xs_true, ys_true, logjacs_true; kwargs...) - - # AD testing for batch. - f, arg = make_gradient_function(x -> sum(sum(b.(x))), xs_true) - test_ad(f, arg) - f, arg = make_gradient_function(x -> sum(logabsdetjac.(b, x)), xs_true) - test_ad(f, arg) - - if Bijectors.isinvertible(b) - ys = b.(xs_true) - f, arg = make_gradient_function(y -> sum(sum(ib.(y))), ys) - test_ad(f, arg) + # Useful for testing when you don't know the true outputs but know that + # `b` is definitively not identity. + if test_not_identity + @test y_test ≠ x + @test logjac_test ≠ zero(eltype(x)) + @test res[2] ≠ zero(eltype(x)) end -end - -function make_gradient_function(f, xs::ArrayBatch) - s = size(Bijectors.value(xs)) - -function test_bijector( - b::Bijector{1}, - xs_true::AbstractMatrix{<:Real}, - ys_true::AbstractMatrix{<:Real}, - logjacs_true::AbstractVector{<:Real}; - kwargs... -) - ib = inverse(b) - - return g, vec(Bijectors.value(xs)) -end - -function make_gradient_function(f, xs::VectorBatch{<:AbstractArray{<:Real}}) - xs_new = vcat(map(vec, Bijectors.value(xs))) - n = length(xs_new) - - s = size(Bijectors.value(xs[1])) - stride = n ÷ length(xs) - - function g(x) - x_vec = map(1:stride:n) do i - reshape(x[i:i + stride - 1], s) - end - x_batch = Bijectors.reconstruct(xs, x_vec) - return f(x_batch) + if test_types + @test typeof(first(res)) === typeof(x) + @test typeof(res) === typeof(ires) + @test typeof(y_test) === typeof(x) + @test typeof(logjac_test) === typeof(ilogjac_test) end - - return g, xs_new end make_jacobian_function(f, xs::AbstractVector) = f, xs @@ -204,22 +81,6 @@ function make_jacobian_function(f, xs::AbstractArray) return g, xs_new end -function test_logabsdetjac(b::Transform, xs::Batch{<:Any, <:AbstractArray}; tol=1e-6) - f, _ = make_jacobian_function(b, xs[1]) - logjac_ad = map(xs) do x - first(logabsdet(ForwardDiff.jacobian(f, x))) - end - - @test mean(collect(logabsdetjac.(b, xs)) - logjac_ad) ≤ tol -end - -function test_logabsdetjac(b::Transform, xs::Batch{<:Any, <:Real}; tol=1e-6) - logjac_ad = map(xs) do x - log(abs(ForwardDiff.derivative(b, x))) - end - @test mean(collect(logabsdetjac.(b, xs)) - logjac_ad) ≤ tol -end - # Check if `Functors.functor` works properly function test_functor(x, xs) _xs, re = Functors.functor(x) @@ -227,7 +88,7 @@ function test_functor(x, xs) @test _xs == xs end -function test_bijector_parameter_gradient(b::Transform, x, y = b(x)) +function test_bijector_parameter_gradient(b::Bijectors.Transform, x, y = b(x)) args, re = Functors.functor(b) recon(k, param) = re(merge(args, NamedTuple{(k, )}((param, )))) diff --git a/test/interface.jl b/test/interface.jl index 11fc27f6..b9f9b88c 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,3 +1,6 @@ +# using Pkg; Pkg.activate("..") +# using TestEnv; TestEnv.activate() + using Test using Random using LinearAlgebra @@ -11,19 +14,20 @@ using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Per Random.seed!(123) -struct MyADBijector{AD, N, B <: Bijector{N}} <: ADBijector{AD, N} +struct MyADBijector{AD,B} <: ADBijector{AD} b::B end MyADBijector(d::Distribution) = MyADBijector{Bijectors.ADBackend()}(d) MyADBijector{AD}(d::Distribution) where {AD} = MyADBijector{AD}(bijector(d)) -MyADBijector{AD}(b::B) where {AD, N, B <: Bijector{N}} = MyADBijector{AD, N, B}(b) +MyADBijector{AD}(b) where {AD} = MyADBijector{AD, typeof(b)}(b) (b::MyADBijector)(x) = b.b(x) -(b::Inverse{<:MyADBijector})(x) = inverse(b.orig.b)(x) +Bijectors.transform(b::MyADBijector, x) = b.b(x) +Bijectors.transform(b::Inverse{<:MyADBijector}, x) = inverse(b.orig.b)(x) -struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end +struct NonInvertibleBijector{AD} <: ADBijector{AD} end contains(predicate::Function, b::Bijector) = predicate(b) -contains(predicate::Function, b::Composed) = any(contains.(predicate, b.ts)) +contains(predicate::Function, b::ComposedFunction) = any(contains.(predicate, b.ts)) contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) # Scalar tests @@ -142,234 +146,6 @@ end end end -@testset "Batch computation" begin - bs_xs = [ - (Scale(2.0), randn(3)), - (Scale([1.0, 2.0]), randn(2, 3)), - (Shift(2.0), randn(3)), - (Shift([1.0, 2.0]), randn(2, 3)), - (Log{0}(), exp.(randn(3))), - (Log{1}(), exp.(randn(2, 3))), - (Exp{0}(), randn(3)), - (Exp{1}(), randn(2, 3)), - (Log{1}() ∘ Exp{1}(), randn(2, 3)), - (inverse(Logit(-1.0, 1.0)), randn(3)), - (Identity{0}(), randn(3)), - (Identity{1}(), randn(2, 3)), - (PlanarLayer(2), randn(2, 3)), - (RadialLayer(2), randn(2, 3)), - (PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), - (Exp{1}() ∘ PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), - (SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)), - (stack(Exp{0}(), Scale(2.0)), randn(2, 3)), - (Stacked((Exp{1}(), SimplexBijector()), (1:1, 2:3)), - mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)), - (RationalQuadraticSpline(randn(3), randn(3), randn(3 - 1), 2.), [-0.5, 0.5]), - (LeakyReLU(0.1), randn(3)), - (LeakyReLU(Float32(0.1)), randn(3)), - (LeakyReLU(0.1; dim = Val(1)), randn(2, 3)), - ] - - for (b, xs) in bs_xs - @testset "$b" begin - D = @inferred Bijectors.dimension(b) - ib = @inferred inverse(b) - - @test Bijectors.dimension(ib) == D - - x = D == 0 ? xs[1] : xs[:, 1] - - y = @inferred b(x) - ys = @inferred b(xs) - @inferred(b(param(xs))) - - x_ = @inferred ib(y) - xs_ = @inferred ib(ys) - @inferred(ib(param(ys))) - - result = @inferred with_logabsdet_jacobian(b, x) - results = @inferred with_logabsdet_jacobian(b, xs) - - iresult = @inferred with_logabsdet_jacobian(ib, y) - iresults = @inferred with_logabsdet_jacobian(ib, ys) - - # Sizes - @test size(y) == size(x) - @test size(ys) == size(xs) - - @test size(x_) == size(x) - @test size(xs_) == size(xs) - - @test size(result[1]) == size(x) - @test size(results[1]) == size(xs) - - @test size(iresult[1]) == size(y) - @test size(iresults[1]) == size(ys) - - # Values - @test ys ≈ hcat([b(xs[:, i]) for i = 1:size(xs, 2)]...) - @test ys ≈ results[1] - - if D == 0 - # Sizes - @test y == ys[1] - - @test length(logabsdetjac(b, xs)) == length(xs) - @test length(logabsdetjac(ib, ys)) == length(xs) - - @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} - @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} - - @test size(results[2]) == size(xs, ) - @test size(iresults[2]) == size(ys, ) - - # Values - b_logjac_ad = [(log ∘ abs)(ForwardDiff.derivative(b, xs[i])) for i = 1:length(xs)] - ib_logjac_ad = [(log ∘ abs)(ForwardDiff.derivative(ib, ys[i])) for i = 1:length(ys)] - @test logabsdetjac.(b, xs) == @inferred(logabsdetjac(b, xs)) - @test @inferred(logabsdetjac(b, xs)) ≈ b_logjac_ad atol=1e-9 - @test logabsdetjac.(ib, ys) == @inferred(logabsdetjac(ib, ys)) - @test @inferred(logabsdetjac(ib, ys)) ≈ ib_logjac_ad atol=1e-9 - - @test logabsdetjac.(b, param(xs)) == @inferred(logabsdetjac(b, param(xs))) - @test logabsdetjac.(ib, param(ys)) == @inferred(logabsdetjac(ib, param(ys))) - - @test results[2] ≈ vec(logabsdetjac.(b, xs)) - @test iresults[2] ≈ vec(logabsdetjac.(ib, ys)) - elseif D == 1 - @test y == ys[:, 1] - # Comparing sizes instead of lengths ensures we catch errors s.t. - # length(x) == 3 when size(x) == (1, 3). - # Sizes - @test size(logabsdetjac(b, xs)) == (size(xs, 2), ) - @test size(logabsdetjac(ib, ys)) == (size(xs, 2), ) - - @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} - @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} - - @test size(results[2]) == (size(xs, 2), ) - @test size(iresults[2]) == (size(ys, 2), ) - - # Test all values - @test @inferred(logabsdetjac(b, xs)) ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) - @test @inferred(logabsdetjac(ib, ys)) ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) - - @test results[2] ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) - @test iresults[2] ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) - - # FIXME: `SimplexBijector` results in ∞ gradient if not in the domain - if !contains(t -> t isa SimplexBijector, b) - b_logjac_ad = [logabsdet(ForwardDiff.jacobian(b, xs[:, i]))[1] for i = 1:size(xs, 2)] - @test logabsdetjac(b, xs) ≈ b_logjac_ad atol=1e-9 - - ib_logjac_ad = [logabsdet(ForwardDiff.jacobian(ib, ys[:, i]))[1] for i = 1:size(ys, 2)] - @test logabsdetjac(ib, ys) ≈ ib_logjac_ad atol=1e-9 - end - else - error("tests not implemented yet") - end - end - end - - @testset "Composition" begin - @test_throws DimensionMismatch (Exp{1}() ∘ Log{0}()) - - # Check that type-stable composition stays type-stable - cb1 = Composed((Exp(), Log())) ∘ Exp() - @test cb1 isa Composed{<:Tuple} - cb2 = Exp() ∘ Composed((Exp(), Log())) - @test cb2 isa Composed{<:Tuple} - cb3 = cb1 ∘ cb2 - @test cb3 isa Composed{<:Tuple} - - @test logabsdetjac(cb1, 1.) isa Real - @test logabsdetjac(cb1, 1.) == 1. - - @test inverse(cb1) isa Composed{<:Tuple} - @test inverse(cb2) isa Composed{<:Tuple} - @test inverse(cb3) isa Composed{<:Tuple} - - # Check that type-unstable composition stays type-unstable - cb1 = Composed([Exp(), Log()]) ∘ Exp() - @test cb1 isa Composed{<:AbstractArray} - cb2 = Exp() ∘ Composed([Exp(), Log()]) - @test cb2 isa Composed{<:AbstractArray} - cb3 = cb1 ∘ cb2 - @test cb3 isa Composed{<:AbstractArray} - - @test logabsdetjac(cb1, 1.) isa Real - @test logabsdetjac(cb1, 1.) == 1. - - @test inverse(cb1) isa Composed{<:AbstractArray} - @test inverse(cb2) isa Composed{<:AbstractArray} - @test inverse(cb3) isa Composed{<:AbstractArray} - - # combining the two - @test_throws ErrorException (Log() ∘ Exp()) ∘ cb1 - @test_throws ErrorException cb1 ∘ (Log() ∘ Exp()) - end - - @testset "Batch-computation with Tracker.jl" begin - @testset "Scale" begin - # 0-dim with `Real` parameter - b = Scale(param(2.0)) - lj = logabsdetjac(b, 1.0) - Tracker.back!(lj, 1.0) - @test Tracker.extract_grad!(b.a) == 0.5 - - # 0-dim with `Real` parameter for batch-computation - lj = logabsdetjac(b, [1.0, 2.0, 3.0]) - Tracker.back!(lj, [1.0, 1.0, 1.0]) - @test Tracker.extract_grad!(b.a) == sum([0.5, 0.5, 0.5]) - - - # 1-dim with `Vector` parameter - x = [3.0, 4.0, 5.0] - xs = [3.0 4.0; 4.0 7.0; 5.0 8.0] - a = [2.0, 3.0, 5.0] - - b = Scale(param(a)) - lj = logabsdetjac(b, x) - Tracker.back!(lj) - @test Tracker.extract_grad!(b.a) == ForwardDiff.gradient(a -> logabsdetjac(Scale(a), x), a) - - # batch - lj = logabsdetjac(b, xs) - Tracker.back!(mean(lj), 1.0) - @test Tracker.extract_grad!(b.a) == ForwardDiff.gradient(a -> mean(logabsdetjac(Scale(a), xs)), a) - - # Forward when doing a composition - y, logjac = logabsdetjac(b, xs) - Tracker.back!(mean(logjac), 1.0) - @test Tracker.extract_grad!(b.a) == ForwardDiff.gradient(a -> mean(logabsdetjac(Scale(a), xs)), a) - end - - @testset "Shift" begin - b = Shift(param(1.0)) - lj = logabsdetjac(b, 1.0) - Tracker.back!(lj, 1.0) - @test Tracker.extract_grad!(b.a) == 0.0 - - # 0-dim with `Real` parameter for batch-computation - lj = logabsdetjac(b, [1.0, 2.0, 3.0]) - @test lj isa TrackedArray - Tracker.back!(lj, [1.0, 1.0, 1.0]) - @test Tracker.extract_grad!(b.a) == 0.0 - - # 1-dim with `Vector` parameter - b = Shift(param([2.0, 3.0, 5.0])) - lj = logabsdetjac(b, [3.0, 4.0, 5.0]) - Tracker.back!(lj) - @test Tracker.extract_grad!(b.a) == zeros(3) - - lj = logabsdetjac(b, [3.0 4.0 5.0; 6.0 7.0 8.0]) - @test lj isa TrackedArray - Tracker.back!(lj, [1.0, 1.0, 1.0]) - @test Tracker.extract_grad!(b.a) == zeros(3) - end - end -end - @testset "Truncated" begin d = truncated(Normal(), -1, 1) b = bijector(d) @@ -426,12 +202,6 @@ end @test lp ≈ logpdf(td, y) @test logjac ≈ logabsdetjacinv(td.transform, y) - # multi-sample - y = rand(td, 10) - x = inverse(td.transform)(y) - @test inverse(td.transform)(param(y)) isa TrackedArray - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - # forward f = forward(td) @test f.x ≈ inverse(td.transform)(f.y) @@ -442,7 +212,7 @@ end # verify against AD # similar to what we do in test/transform.jl for Dirichlet if dist isa Dirichlet - b = Bijectors.SimplexBijector{1, false}() + b = Bijectors.SimplexBijector{false}() x = rand(dist) y = b(x) @test b(param(x)) isa TrackedArray @@ -490,80 +260,10 @@ end # lp, logjac = logpdf_with_jac(td, y) # @test lp ≈ logpdf(td, y) # @test logjac ≈ logabsdetjacinv(td.transform, y) - - # multi-sample - y = rand(td, 10) - x = inverse(td.transform)(y) - @test inverse(td.transform)(param.(y)) isa Vector{<:TrackedArray} - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) end end end -@testset "Composition <: Bijector" begin - d = Beta() - td = transformed(d) - - x = rand(d) - y = td.transform(x) - - b = @inferred Bijectors.composel(td.transform, Bijectors.Identity{0}()) - ib = @inferred inverse(b) - - @test with_logabsdet_jacobian(b, x) == with_logabsdet_jacobian(td.transform, x) - @test with_logabsdet_jacobian(ib, y) == with_logabsdet_jacobian(inverse(td.transform), y) - - @test with_logabsdet_jacobian(b, x) == with_logabsdet_jacobian(Bijectors.composer(b.ts...), x) - - # inverse works fine for composition - cb = @inferred b ∘ ib - @test cb(x) ≈ x - - cb2 = @inferred cb ∘ cb - @test cb(x) ≈ x - - # ensures that the `logabsdetjac` is correct - x = rand(d) - b = inverse(bijector(d)) - @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) - - # order of composed evaluation - b1 = MyADBijector(d) - b2 = MyADBijector(Gamma()) - - cb = inverse(b1) ∘ b2 - @test cb(x) ≈ inverse(b1)(b2(x)) - - # contrived example - b = bijector(d) - cb = @inferred inverse(b) ∘ b - cb = @inferred cb ∘ cb - @test @inferred(cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x - - # forward for tuple and array - d = Beta() - b = @inferred inverse(bijector(d)) - b⁻¹ = @inferred inverse(b) - x = rand(d) - - cb_t = b⁻¹ ∘ b⁻¹ - f_t = with_logabsdet_jacobian(cb_t, x) - - cb_a = Composed([b⁻¹, b⁻¹]) - f_a = with_logabsdet_jacobian(cb_a, x) - - @test f_t == f_a - - # `composer` and `composel` - cb_l = Bijectors.composel(b⁻¹, b⁻¹, b) - cb_r = Bijectors.composer(reverse(cb_l.ts)...) - y = cb_l(x) - @test y == Bijectors.composel(cb_r.ts...)(x) - - k = length(cb_l.ts) - @test all([cb_l.ts[i] == cb_r.ts[i] for i = 1:k]) -end - @testset "Stacked <: Bijector" begin # `logabsdetjac` withOUT AD d = Beta() @@ -712,7 +412,6 @@ end @test td isa Distribution{Multivariate, Continuous} # check that wrong ranges fails - @test_throws MethodError stack(ibs...) sb = Stacked(ibs) x = rand(d) @test_throws AssertionError sb(x) @@ -801,15 +500,9 @@ end @testset "Equality" begin bs = [ - Identity{0}(), - Identity{1}(), - Identity{2}(), - Exp{0}(), - Exp{1}(), - Exp{2}(), - Log{0}(), - Log{1}(), - Log{2}(), + Identity(), + Exp(), + Log(), Scale(2.0), Scale(3.0), Scale(rand(2,2)), @@ -832,14 +525,12 @@ end RadialLayer(2), RadialLayer(3), SimplexBijector(), - Stacked((Exp{0}(), Log{0}())), - Stacked((Log{0}(), Exp{0}())), - Stacked([Exp{0}(), Log{0}()]), - Stacked([Log{0}(), Exp{0}()]), - Composed((Exp{0}(), Log{0}())), - Composed((Log{0}(), Exp{0}())), - # Composed([Exp{0}(), Log{0}()]), - # Composed([Log{0}(), Exp{0}()]), + Stacked((Exp(), Log())), + Stacked((Log(), Exp())), + Stacked([Exp(), Log()]), + Stacked([Log(), Exp()]), + Exp() ∘ Log(), + Log() ∘ Exp(), TruncatedBijector(1.0, 2.0), TruncatedBijector(1.0, 3.0), TruncatedBijector(0.0, 2.0), @@ -855,16 +546,16 @@ end end @testset "test_inverse and test_with_logabsdet_jacobian" begin - b = Bijectors.Scale{Float64,0}(4.2) + b = Bijectors.Scale{Float64,}(4.2) x = 0.3 - test_inverse(b, x) - test_with_logabsdet_jacobian(b, x, (f::Bijectors.Scale, x) -> f.a) + InverseFunctions.test_inverse(b, x) + ChangesOfVariables.test_with_logabsdet_jacobian(b, x, (f::Bijectors.Scale, x) -> f.a) end @testset "deprecations" begin - b = Bijectors.Exp() + b = Bijectors.Logit(0.0, 1.0) x = 0.3 @test @test_deprecated(forward(b, x)) == NamedTuple{(:rv, :logabsdetjac)}(with_logabsdet_jacobian(b, x)) diff --git a/test/runtests.jl b/test/runtests.jl index 65000d5f..0c25dd3d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,11 +16,10 @@ using Random, LinearAlgebra, Test using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector -using ChangesOfVariables: test_with_logabsdet_jacobian -using InverseFunctions: test_inverse +using ChangesOfVariables: ChangesOfVariables +using InverseFunctions: InverseFunctions -using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal, - TuringPoissonBinomial +using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringPoissonBinomial const GROUP = get(ENV, "GROUP", "All") diff --git a/test/transform.jl b/test/transform.jl index e119c1b2..de595d67 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -1,6 +1,6 @@ using Test using Bijectors -using ForwardDiff: derivative, jacobian +using ForwardDiff: ForwardDiff using LinearAlgebra: logabsdet, I, norm using Random @@ -67,21 +67,6 @@ function single_sample_tests(dist) @test typeof(x) == typeof(y) end -# Standard tests for all distributions involving multiple samples. xs should be whatever -# the appropriate repeated version of x is for the distribution in question. ie. for -# univariate distributions, just a vector of identical values. For vector-valued -# distributions, a matrix whose columns are identical. -function multi_sample_tests(dist, x, xs, N) - ys = @inferred(link(dist, copy(xs))) - @test @inferred(invlink(dist, link(dist, copy(xs)))) ≈ xs atol=1e-9 - @test @inferred(link(dist, invlink(dist, copy(ys)))) ≈ ys atol=1e-9 - @test logpdf_with_trans(dist, xs, true) == fill(logpdf_with_trans(dist, x, true), N) - @test logpdf_with_trans(dist, xs, false) == fill(logpdf_with_trans(dist, x, false), N) - - # This is a quirk of the current implementation, of which it would be nice to be rid. - @test typeof(xs) == typeof(ys) -end - # Scalar tests @testset "scalar" begin let @@ -116,13 +101,7 @@ let ] for dist in uni_dists - single_sample_tests(dist, derivative) - - # specialised multi-sample tests. - N = 10 - x = rand(dist) - xs = fill(x, N) - multi_sample_tests(dist, x, xs, N) + single_sample_tests(dist, ForwardDiff.derivative) end end end @@ -155,7 +134,7 @@ let ϵ = eps(Float64) end logpdf_turing = logpdf_with_trans(dist, x, true) - J = jacobian(x->link(dist, x, Val(false)), x) + J = ForwardDiff.jacobian(x->link(dist, x, Val(false)), x) @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing # Issue #12 @@ -164,14 +143,8 @@ let ϵ = eps(Float64) x = [logpdf_with_trans(dist, invlink(dist, link(dist, rand(dist)) .+ randn(dim) .* stepsize), true) for _ in 1:1_000] @test !any(isinf, x) && !any(isnan, x) else - single_sample_tests(dist, jacobian) + single_sample_tests(dist, ForwardDiff.jacobian) end - - # Multi-sample tests. Columns are observations due to Distributions.jl conventions. - N = 10 - x = rand(dist) - xs = repeat(x, 1, N) - multi_sample_tests(dist, x, xs, N) end end end @@ -191,15 +164,9 @@ let lowerinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[1] >= I[2]] upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1]] logpdf_turing = logpdf_with_trans(dist, x, true) - J = jacobian(x->link(dist, x), x) + J = ForwardDiff.jacobian(x->link(dist, x), x) J = J[lowerinds, upperinds] @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing - - # Multi-sample tests comprising vectors of matrices. - N = 10 - x = rand(dist) - xs = [x for _ in 1:N] - multi_sample_tests(dist, x, xs, N) end end end @@ -216,17 +183,10 @@ end x = d .* x .* d' upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]] - J = jacobian(x->link(dist, x), x) + J = ForwardDiff.jacobian(x->link(dist, x), x) J = J[upperinds, upperinds] logpdf_turing = logpdf_with_trans(dist, x, true) @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing - - # Multi-sample tests comprising vectors of matrices. - N = 10 - x = rand(dist) - xs = [x for _ in 1:N] - multi_sample_tests(dist, x, xs, N) - end ################################## Miscelaneous old tests ################################## @@ -279,10 +239,10 @@ end g1 = y -> invlink(dist, y, Val(true)) g2 = y -> invlink(dist, y, Val(false)) - @test @aeq jacobian(f1, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(true))) - @test @aeq jacobian(f2, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(false))) - @test @aeq jacobian(g1, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(true))) - @test @aeq jacobian(g2, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(false))) + @test @aeq ForwardDiff.jacobian(f1, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(true))) + @test @aeq ForwardDiff.jacobian(f2, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(false))) + @test @aeq ForwardDiff.jacobian(g1, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(true))) + @test @aeq ForwardDiff.jacobian(g2, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(false))) @test @aeq Bijectors.simplex_link_jacobian(x, Val(false)) * Bijectors.simplex_invlink_jacobian(y, Val(false)) I end for i in 1:4 From a28c9b19da1a91aaf21731916e7cfa1b7dae99cd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 11:06:27 +0000 Subject: [PATCH 50/98] updated docs --- docs/src/index.md | 73 ++++++----------------------------------------- src/interface.jl | 2 +- 2 files changed, 9 insertions(+), 66 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index ce0f1b0a..e24fa768 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,7 +1,5 @@ # Bijectors.jl -Documentation for Bijectors.jl - ## Usage A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function: @@ -12,16 +10,7 @@ A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable tran using Bijectors transform(exp, 1.0) logabsdetjac(exp, 1.0) -forward(exp, 1.0) -``` - -If you want to instead transform a collection of inputs, you can use the `batch` method from Batching.jl to inform Bijectors.jl that the input now represents a collection of inputs rather than a single input: - -```@repl usage -xs = batch(ones(2)); -transform(exp, xs) -logabsdetjac(exp, xs) -forward(exp, xs) +with_logabsdet_jacobian(exp, 1.0) ``` Some transformations is well-defined for different types of inputs, e.g. `exp` can also act elementwise on a `N`-dimensional `Array{<:Real,N}`. To specify that a transformation should be acting elementwise, we use the [`elementwise`](@ref) method: @@ -30,22 +19,13 @@ Some transformations is well-defined for different types of inputs, e.g. `exp` c x = ones(2, 2) transform(elementwise(exp), x) logabsdetjac(elementwise(exp), x) -forward(elementwise(exp), x) -``` - -And batched versions: - -```@repl usage -xs = batch(ones(2, 2, 3)); -transform(elementwise(exp), xs) -logabsdetjac(elementwise(exp), xs) -forward(elementwise(exp), xs) +with_logabsdet_jacobian(elementwise(exp), x) ``` These methods also work nicely for compositions of transformations: ```@repl usage -transform(elementwise(log ∘ exp), xs) +transform(elementwise(log ∘ exp), x) ``` Unlike `exp`, some transformations have parameters affecting the resulting transformation they represent, e.g. `Logit` has two parameters `a` and `b` representing the lower- and upper-bound, respectively, of its domain: @@ -64,7 +44,7 @@ Without mutation: ```@docs transform logabsdetjac -forward(b, x) +with_logabsdet_jacobian ``` With mutation: @@ -72,53 +52,16 @@ With mutation: ```@docs transform! logabsdetjac! -forward! +with_logabsdet_jacobian! ``` ## Implementing a transformation -Any callable can be made into a bijector by providing an implementation of [`forward(b, x)`](@ref), which is done by overloading - -```@docs -Bijectors.forward_single -``` +Any callable can be made into a bijector by providing an implementation of `with_logabsdet_jacobian(b, x)`. -where - -```@docs -Bijectors.transform_single -Bijectors.logabsdetjac_single -``` +You can also optionally implement [`transform`](@ref) and [`logabsdetjac`](@ref) to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. -You can then optionally implement `transform` and `logabsdetjac` to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. - -Note that a _user_ of the bijector should generally be using [`forward(b, x)`](@ref) rather than calling [`forward_single`](@ref) directly. - -To implement "batched" versions of the above functionalities, i.e. methods which act on a _collection_ of inputs rather than a single input, you can overload the following method: - -```@docs -Bijectors.forward_multiple -``` - -And similarly, if you want to specialize [`transform`](@ref) and [`logabsdetjac`](@ref), you can implement - -```@docs -Bijectors.transform_multiple -Bijectors.logabsdetjac_multiple -``` - -### Mutability - -There are also _mutable_ versions of all of the above: - -```@docs -Bijectors.forward_single! -Bijectors.forward_multiple! -Bijectors.transform_single! -Bijectors.transform_multiple! -Bijectors.logabsdetjac_single! -Bijectors.logabsdetjac_multiple! -``` +Similarly with the mutable versions [`with_logabsdet_jacobian`](@ref), [`transform!`](@ref), and [`logabsdetjac!`]. ## Working with Distributions.jl diff --git a/src/interface.jl b/src/interface.jl index 5597451e..37218f1a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -56,7 +56,7 @@ A subtype of `Transform` of should at least implement [`transform(b, x)`](@ref). If the `Transform` is also invertible: - Required: - - [`invertible`](@ref): should return [`Invertible`](@ref). + - [`invertible`](@ref): should return instance of [`Invertible`](@ref) or [`NotInvertible`](@ref). - _Either_ of the following: - `transform(::Inverse{<:MyTransform}, x)`: the `transform` for its inverse. - `InverseFunctions.inverse(b::MyTransform)`: returns an existing `Transform`. From b553b084e3ee3b5738957f61b613f9fb13d4124a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 11:19:03 +0000 Subject: [PATCH 51/98] removed reundndat dep --- Project.toml | 1 - src/Bijectors.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 4d25f7cb..15a1887d 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 0941fcbe..7aaf80e3 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -32,7 +32,6 @@ using Reexport, Requires @reexport using Distributions using LinearAlgebra using MappedArrays -using ConstructionBase using Base.Iterators: drop using LinearAlgebra: AbstractTriangular From 77fbdb688cd07a4f2c11e78f55f6ef53af986b07 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 11:19:39 +0000 Subject: [PATCH 52/98] remove batch --- src/batch.jl | 88 ---------------------------------------------------- 1 file changed, 88 deletions(-) delete mode 100644 src/batch.jl diff --git a/src/batch.jl b/src/batch.jl deleted file mode 100644 index 0089618a..00000000 --- a/src/batch.jl +++ /dev/null @@ -1,88 +0,0 @@ -abstract type AbstractBatch{T} <: AbstractVector{T} end - -""" - value(x) - -Returns the underlying storage used for the entire batch. - -If `x` is not `AbstractBatch`, then this is the identity function. -""" -value(x) = x - -struct Batch{V, T} <: AbstractBatch{T} - value::V -end - -value(x::Batch) = x.value - -# Convenient aliases -const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}} -const VectorBatch{T} = Batch{<:AbstractVector{T}} - -# Constructor for `ArrayBatch`. -Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x) -function Batch(x::AbstractArray{<:Real}) - V = typeof(x) - # HACK: This assumes the batch is non-empty. - T = typeof(getindex_for_last(x, 1)) - return Batch{V, T}(x) -end - -# Constructor for `VectorBatch`. -Batch(x::AbstractVector{<:AbstractArray}) = Batch{typeof(x), eltype(x)}(x) - -# `AbstractVector` interface. -Base.size(batch::Batch{<:AbstractArray{<:Any, N}}) where {N} = (size(value(batch), N), ) - -# Since impl inherited from `AbstractVector` doesn't do exactly what we want. -Base.similar(b::AbstractBatch) = reconstruct(b, similar(value(b))) - -# For `VectorBatch` -Base.getindex(batch::VectorBatch, i::Int) = value(batch)[i] -Base.getindex(batch::VectorBatch, i::CartesianIndex{1}) = value(batch)[i] -Base.getindex(batch::VectorBatch, i) = Batch(value(batch)[i]) -function Base.setindex!(batch::VectorBatch, v, i) - # `v` can also be a `Batch`. - Base.setindex!(value(batch), value(v), i) - return batch -end - -# For `ArrayBatch` -@generated function getindex_for_last(x::AbstractArray{<:Any, N}, inds) where {N} - e = Expr(:call) - push!(e.args, :(Base.view)) - push!(e.args, :x) - - for i = 1:N - 1 - push!(e.args, :(:)) - end - - push!(e.args, :(inds)) - - return e -end - -@generated function setindex_for_last!(out::AbstractArray{<:Any, N}, x, inds) where {N} - e = Expr(:call) - push!(e.args, :(Base.setindex!)) - push!(e.args, :out) - push!(e.args, :x) - - for i = 1:N - 1 - push!(e.args, :(:)) - end - - push!(e.args, :(inds)) - - return e -end - -# General arrays. -Base.getindex(batch::ArrayBatch, i::Int) = getindex_for_last(value(batch), i) -Base.getindex(batch::ArrayBatch, i::CartesianIndex{1}) = getindex_for_last(value(batch), i) -Base.getindex(batch::ArrayBatch, i) = Batch(getindex_for_last(value(batch), i)) -function Base.setindex!(batch::ArrayBatch, v, i) - # `v` can also be a `Batch`. - setindex_for_last!(value(batch), value(v), i) - return batch -end From 3e8d65da984458d41d9b891b70bd22134ebc170f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 11:21:28 +0000 Subject: [PATCH 53/98] remove redundant defs of transform --- src/bijectors/exp_log.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index fe350a74..f7d5e73a 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -8,9 +8,6 @@ invertible(::Elementwise{typeof(exp)}) = Invertible() invertible(::typeof(log)) = Invertible() invertible(::Elementwise{typeof(log)}) = Invertible() -transform(b::Union{typeof(exp),typeof(log)}, x::Real) = b(x) -transform(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x) = b(x) - transform!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) logabsdetjac(b::typeof(exp), x::Real) = x From fa469b8bec97b5f206fb9d3d1e28b5e84b639f8b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 11:27:30 +0000 Subject: [PATCH 54/98] removed unnecessary impls of with_logabsdet_jacobian --- src/bijectors/exp_log.jl | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index f7d5e73a..9c4a1085 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -15,21 +15,3 @@ logabsdetjac(b::Elementwise{typeof(exp)}, x) = sum(x) logabsdetjac(b::typeof(log), x::Real) = -log(x) logabsdetjac(b::Elementwise{typeof(log)}, x) = -sum(log, x) - -function with_logabsdet_jacobian(b::typeof(exp), x::Real) - y = b(x) - return y, x -end -function with_logabsdet_jacobian(b::Elementwise{typeof(exp)}, x) - y = b(x) - return y, sum(x) -end - -function with_logabsdet_jacobian(b::typeof(log), y::Real) - x = transform(b, y) - return x, -x -end -function with_logabsdet_jacobian(b::Elementwise{typeof(log)}, y) - x = transform(b, y) - return x, -sum(x) -end From 0a5d55e01718b06a4370dcad8fc4ec5796152ca0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Feb 2022 12:10:57 +0000 Subject: [PATCH 55/98] remove usage of Exp and Log in tests --- src/Bijectors.jl | 3 +++ src/bijectors/exp_log.jl | 4 ---- src/transformed_distribution.jl | 8 ++++---- test/bijectors/named_bijector.jl | 6 +++--- test/interface.jl | 28 ++++++++++++++-------------- 5 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 7aaf80e3..fad28f7f 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -268,6 +268,9 @@ end Base.@deprecate NamedBijector(bs) NamedTransform(bs) +Base.@deprecate Exp() elementwise(exp) false +Base.@deprecate Log() elementwise(log) false + # Broadcasting here breaks Tracker for some reason maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 9c4a1085..899bde31 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -1,7 +1,3 @@ -# TODO: Do we really need this? -Exp() = elementwise(exp) -Log() = elementwise(log) - invertible(::typeof(exp)) = Invertible() invertible(::Elementwise{typeof(exp)}) = Invertible() diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index cf3b8778..3964d119 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -54,14 +54,14 @@ end bijector(d::Normal) = Identity() bijector(d::Distributions.AbstractMvNormal) = Identity() -bijector(d::Distributions.AbstractMvLogNormal) = Log() -bijector(d::PositiveDistribution) = Log() +bijector(d::Distributions.AbstractMvLogNormal) = elementwise(log) +bijector(d::PositiveDistribution) = elementwise(log) bijector(d::SimplexDistribution) = SimplexBijector() bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d))) bijector_bounded(d, a=minimum(d), b=maximum(d)) = Logit(a, b) -bijector_lowerbounded(d, a=minimum(d)) = Log() ∘ Shift(-a) -bijector_upperbounded(d, b=maximum(d)) = Log() ∘ Shift(b) ∘ Scale(- one(typeof(b))) +bijector_lowerbounded(d, a=minimum(d)) = elementwise(log) ∘ Shift(-a) +bijector_upperbounded(d, b=maximum(d)) = elementwise(log) ∘ Shift(b) ∘ Scale(- one(typeof(b))) const BoundedDistribution = Union{ Arcsine, Biweight, Cosine, Epanechnikov, Beta, NoncentralBeta diff --git a/test/bijectors/named_bijector.jl b/test/bijectors/named_bijector.jl index 9dae0365..8cf57a95 100644 --- a/test/bijectors/named_bijector.jl +++ b/test/bijectors/named_bijector.jl @@ -1,9 +1,9 @@ using Test using Bijectors -using Bijectors: Exp, Log, Logit, AbstractNamedTransform, NamedBijector, NamedCoupling, Shift +using Bijectors: Exp, Log, Logit, AbstractNamedTransform, NamedTransform, NamedCoupling, Shift -@testset "NamedBijector" begin - b = NamedBijector((a = Exp(), b = Log())) +@testset "NamedTransform" begin + b = NamedTransform((a = elementwise(exp), b = elementwise(log))) @test b((a = 0.0, b = exp(1.0))) == (a = 1.0, b = 1.0) with_logabsdet_jacobian(b, (a = 0.0, b = exp(1.0))) diff --git a/test/interface.jl b/test/interface.jl index b9f9b88c..11077875 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -10,7 +10,7 @@ using Tracker using DistributionsAD using Bijectors -using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, RationalQuadraticSpline, LeakyReLU +using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, RationalQuadraticSpline, LeakyReLU Random.seed!(123) @@ -309,7 +309,7 @@ end # value-test x = ones(3) - sb = @inferred stack(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) + sb = @inferred stack(elementwise(exp), elementwise(log), Shift(5.0)) res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @@ -319,7 +319,7 @@ end # TODO: change when we have dimensionality in the type - sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), (1:1, 2:3)) + sb = @inferred Stacked((elementwise(exp), SimplexBijector()), (1:1, 2:3)) x = ones(3) ./ 3.0 res = @inferred with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @@ -332,7 +332,7 @@ end @test_throws AssertionError sb(x) # Array-version - sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], [1:1, 2:3]) + sb = Stacked([elementwise(exp), SimplexBijector()], [1:1, 2:3]) x = ones(3) ./ 3.0 res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @@ -346,7 +346,7 @@ end # Mixed versions # Tuple, Array - sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], (1:1, 2:3)) + sb = Stacked([elementwise(exp), SimplexBijector()], (1:1, 2:3)) x = ones(3) ./ 3.0 res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @@ -359,7 +359,7 @@ end @test_throws AssertionError sb(x) # Array, Tuple - sb = Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3]) + sb = Stacked((elementwise(exp), SimplexBijector()), [1:1, 2:3]) x = ones(3) ./ 3.0 res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @@ -501,8 +501,8 @@ end @testset "Equality" begin bs = [ Identity(), - Exp(), - Log(), + elementwise(exp), + elementwise(log), Scale(2.0), Scale(3.0), Scale(rand(2,2)), @@ -525,12 +525,12 @@ end RadialLayer(2), RadialLayer(3), SimplexBijector(), - Stacked((Exp(), Log())), - Stacked((Log(), Exp())), - Stacked([Exp(), Log()]), - Stacked([Log(), Exp()]), - Exp() ∘ Log(), - Log() ∘ Exp(), + Stacked((elementwise(exp), elementwise(log))), + Stacked((elementwise(log), elementwise(exp))), + Stacked([elementwise(exp), elementwise(log)]), + Stacked([elementwise(log), elementwise(exp)]), + elementwise(exp) ∘ elementwise(log), + elementwise(log) ∘ elementwise(exp), TruncatedBijector(1.0, 2.0), TruncatedBijector(1.0, 3.0), TruncatedBijector(0.0, 2.0), From d63c07efc231f83cebd92fa90104cc0e0106b876 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 19 Jul 2022 10:41:13 +0100 Subject: [PATCH 56/98] fixed docs --- docs/src/index.md | 8 +++++++- src/interface.jl | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index e24fa768..1bbf0826 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -44,6 +44,9 @@ Without mutation: ```@docs transform logabsdetjac +``` + +```julia with_logabsdet_jacobian ``` @@ -61,7 +64,7 @@ Any callable can be made into a bijector by providing an implementation of `with You can also optionally implement [`transform`](@ref) and [`logabsdetjac`](@ref) to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. -Similarly with the mutable versions [`with_logabsdet_jacobian`](@ref), [`transform!`](@ref), and [`logabsdetjac!`]. +Similarly with the mutable versions [`with_logabsdet_jacobian!`](@ref), [`transform!`](@ref), and [`logabsdetjac!`](@ref). ## Working with Distributions.jl @@ -76,6 +79,9 @@ Bijectors.transformed(d::Distribution, b::Bijector) Bijectors.elementwise Bijectors.isinvertible Bijectors.isclosedform(t::Bijectors.Transform) +Bijectors.invertible +Bijectors.NotInvertible +Bijectors.Invertible ``` ## API diff --git a/src/interface.jl b/src/interface.jl index 37218f1a..d8fbaf77 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -62,13 +62,13 @@ If the `Transform` is also invertible: - `InverseFunctions.inverse(b::MyTransform)`: returns an existing `Transform`. - [`logabsdetjac`](@ref): computes the log-abs-det jacobian factor. - Optional: - - [`with_logabsdet_jacobian`](@ref): `transform` and `logabsdetjac` combined. Useful in cases where we + - `with_logabsdet_jacobian`: `transform` and `logabsdetjac` combined. Useful in cases where we can exploit shared computation in the two. For the above methods, there are mutating versions which can _optionally_ be implemented: - [`with_logabsdet_jacobian!`](@ref) - [`logabsdetjac!`](@ref) -- [`forward!`](@ref) +- [`with_logabsdet_jacobian!`](@ref) """ abstract type Transform end @@ -143,7 +143,17 @@ requires an iterative procedure to evaluate. isclosedform(t::Transform) = true # Invertibility "trait". +""" + Invertible + +Represents the trait of being, well, non-invertible. +""" struct NotInvertible end +""" + Invertible + +Represents the trait of being, well, invertible. +""" struct Invertible end # Useful for checking if compositions, etc. are invertible or not. From c00b9f2fb345a6385773c67e1c717e1c8a2d8c68 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 19 Jul 2022 10:52:19 +0100 Subject: [PATCH 57/98] added bijectors with docs to docs --- docs/src/index.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index 1bbf0826..511c83dc 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -91,3 +91,16 @@ Bijectors.Transform Bijectors.Bijector Bijectors.Inverse ``` + +## Bijectors + +```@docs +Bijectors.CorrBijector +Bijectors.LeakyReLU +Bijectors.Stacked +Bijectors.RationalQuadraticSpline +Bijectors.Coupling +Bijectors.OrderedBijector +Bijectors.NamedTransform +Bijectors.NamedCoupling +``` From 0a0858d4bcef5442dc3151c97e09f9f5f8a1ce89 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Jul 2022 08:43:08 +0100 Subject: [PATCH 58/98] small change to docs --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 511c83dc..562c7dcd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -60,7 +60,7 @@ with_logabsdet_jacobian! ## Implementing a transformation -Any callable can be made into a bijector by providing an implementation of `with_logabsdet_jacobian(b, x)`. +Any callable can be made into a bijector by providing an implementation of `ChangeOfVariables.with_logabsdet_jacobian(b, x)`. You can also optionally implement [`transform`](@ref) and [`logabsdetjac`](@ref) to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. From a53f971189f01c638b28274e12d83d0ab9327a41 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Jul 2022 08:43:17 +0100 Subject: [PATCH 59/98] fixed bug in computation of logabsdetjac of truncated --- src/bijectors/truncated.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index 180377ad..52c09f4d 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -50,7 +50,7 @@ end function logabsdetjac(b::TruncatedBijector, x) a, b = b.lb, b.ub - return truncated_logabsdetjac.(_clamp.(x, a, b), a, b) + return sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b)) end function truncated_logabsdetjac(x, a, b) From 99765f34bf23bd0cbb0d104eb34b892a3a06c5f5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 20 Jul 2022 08:43:31 +0100 Subject: [PATCH 60/98] bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 162f651e..e7d0f937 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.10.3" +version = "0.11.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From 21d21fc086b4a428e410e10bcde1dfcc08ea2df3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Aug 2022 15:01:07 +0100 Subject: [PATCH 61/98] run GH actions on Julia 1.6, which is the new LTS, instead of 1.3 --- .github/workflows/AD.yml | 2 +- .github/workflows/Interface.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index ebbd9179..541ccdd9 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - - '1.3' + - '1.6' - '1' os: - ubuntu-latest diff --git a/.github/workflows/Interface.yml b/.github/workflows/Interface.yml index d670855d..ef1f4dc7 100644 --- a/.github/workflows/Interface.yml +++ b/.github/workflows/Interface.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: version: - - '1.3' + - '1.6' - '1' os: - ubuntu-latest From 34bb350098a790f1cfdf9f4c39307eff1c387249 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Aug 2022 15:59:50 +0100 Subject: [PATCH 62/98] added Github actions for making docs, etc. --- .github/workflows/Docs.yml | 32 ++++++++++++++++++++++++ .github/workflows/DocsPreviewCleanup.yml | 26 +++++++++++++++++++ docs/Project.toml | 4 ++- docs/make.jl | 6 ++--- 4 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/Docs.yml create mode 100644 .github/workflows/DocsPreviewCleanup.yml diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml new file mode 100644 index 00000000..e5f79a0f --- /dev/null +++ b/.github/workflows/Docs.yml @@ -0,0 +1,32 @@ +name: Documentation + +on: + push: + branches: + # Build the master branch. + - master + tags: '*' + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: '1' + - name: Install dependencies + run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + - name: Build and deploy + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key + JULIA_DEBUG: Documenter # Print `@debug` statements (https://github.com/JuliaDocs/Documenter.jl/issues/955) + run: julia --project=docs/ docs/make.jl diff --git a/.github/workflows/DocsPreviewCleanup.yml b/.github/workflows/DocsPreviewCleanup.yml new file mode 100644 index 00000000..4f57bc46 --- /dev/null +++ b/.github/workflows/DocsPreviewCleanup.yml @@ -0,0 +1,26 @@ +name: DocsPreviewCleanup + +on: + pull_request: + types: [closed] + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Checkout gh-pages branch + uses: actions/checkout@v2 + with: + ref: gh-pages + - name: Delete preview and history + push changes + run: | + if [ -d "previews/PR$PRNUM" ]; then + git config user.name "Documenter.jl" + git config user.email "documenter@juliadocs.github.io" + git rm -rf "previews/PR$PRNUM" + git commit -m "delete preview" + git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) + git push --force origin gh-pages-new:gh-pages + fi + env: + PRNUM: ${{ github.event.number }} diff --git a/docs/Project.toml b/docs/Project.toml index 7e7c8131..3a52a5db 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,5 @@ [deps] -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" + +[compat] +Documenter = "0.27" diff --git a/docs/make.jl b/docs/make.jl index 54fb95c5..b3aafc3d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -10,6 +10,6 @@ makedocs( # Documenter can also automatically deploy documentation to gh-pages. # See "Hosting Documentation" and deploydocs() in the Documenter manual # for more information. -#=deploydocs( - repo = "" -)=# +deploydocs( + repo = "github.com/TuringLang/Bijectors.jl.git", +) From c5046f5dfc77a45ea774e7ff637ff6fab24f161d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Aug 2022 16:06:05 +0100 Subject: [PATCH 63/98] removed left-overs from batch impls --- src/utils.jl | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index dcb99eb7..8203e1b4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,20 +6,3 @@ aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b) # flatten arrays with fallback for scalars _vec(x::AbstractArray{<:Real}) = vec(x) _vec(x::Real) = x - -# Useful for reconstructing objects. -reconstruct(b, args...) = constructorof(typeof(b))(args...) - -# Despite kwargs using `NamedTuple` in Julia 1.6, I'm still running -# into type-instability issues when using `eachslice`. So we define our own. -# https://github.com/JuliaLang/julia/issues/39639 -# TODO: Check if this is the case in Julia 1.7. -# Adapted from https://github.com/JuliaLang/julia/blob/ef673537f8622fcaea92ac85e07962adcc17745b/base/abstractarraymath.jl#L506-L513. -# FIXME: It seems to break Zygote though. Which is weird because normal `eachslice` does not. -@inline function Base.eachslice(A::AbstractArray, ::Val{N}) where {N} - dim = N - dim <= ndims(A) || throw(DimensionMismatch("A doesn't have $dim dimensions")) - inds_before = ntuple(_ -> Colon(), dim-1) - inds_after = ntuple(_ -> Colon(), ndims(A)-dim) - return (view(A, inds_before..., i, inds_after...) for i in axes(A, dim)) -end From 7956d05c48bf657585f6496178cc9e196ce0ea1a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Aug 2022 16:17:26 +0100 Subject: [PATCH 64/98] removed redundant comment --- src/bijectors/logit.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index d3b90981..1df73514 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -11,8 +11,6 @@ Functors.@functor Logit # For equality of Logit with Float64 fields to one with Duals Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b -# TODO: Implement `forward` and batched versions. - # Evaluation _logit(x, a, b) = LogExpFunctions.logit((x - a) / (b - a)) transform(b::Logit, x) = _logit.(x, b.a, b.b) From 82f8ba83a21f18f0e55d1818b32b803aa1da4807 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Aug 2022 16:31:59 +0100 Subject: [PATCH 65/98] dont return NamedTuple from with_logabsdet_jacobian --- src/bijectors/normalise.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index f9ef2601..b04a10fd 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -76,7 +76,7 @@ function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) logabsdetjac = ( fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims)) ) - return (result=result, logabsdetjac=logabsdetjac) + return (result, logabsdetjac) end logabsdetjac(bn::InvertibleBatchNorm, x) = last(with_logabsdet_jacobian(bn, x)) @@ -93,7 +93,7 @@ function with_logabsdet_jacobian(invbn::Inverse{<:InvertibleBatchNorm}, y) v = reshape(bn.v, as...) x = (y .- b) ./ s .* sqrt.(v .+ bn.eps) .+ m - return (result=x, logabsdetjac=-logabsdetjac(bn, x)) + return (x, -logabsdetjac(bn, x)) end transform(bn::Inverse{<:InvertibleBatchNorm}, y) = first(with_logabsdet_jacobian(bn, y)) From d5d2274a536a3bc1df2782bbb06ea07692ff1c84 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Aug 2022 18:08:00 +0100 Subject: [PATCH 66/98] remove unnused methods --- src/transformed_distribution.jl | 133 -------------------------------- test/ad/flows.jl | 12 ++- test/interface.jl | 25 ------ test/norm_flows.jl | 2 +- 4 files changed, 9 insertions(+), 163 deletions(-) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 3964d119..e515d65d 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -151,123 +151,6 @@ function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real}) x .= td.transform(x) end -############################################################# -# Additional useful functions for `TransformedDistribution` # -############################################################# -""" - logpdf_with_jac(td::UnivariateTransformed, y::Real) - logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) - logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - -Makes use of the `forward` method to potentially re-use computation -and returns a tuple `(logpdf, logabsdetjac)`. -""" -function logpdf_with_jac(td::UnivariateTransformed, y::Real) - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return (logpdf(td.dist, x) + logjac, logjac) -end - -# TODO: implement more efficiently for flows in the case of `Matrix` -function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return (logpdf(td.dist, x) + logjac, logjac) -end - -function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return (logpdf(td.dist, x) + logjac, logjac) -end - -function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) - T = eltype(y) - ϵ = _eps(T) - - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac - return (lp, logjac) -end - -# TODO: should eventually drop using `logpdf_with_trans` -function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return (logpdf_with_trans(td.dist, x, true), logjac) -end - -""" - logpdf_forward(td::Transformed, x) - logpdf_forward(td::Transformed, x, logjac) - -Computes the `logpdf` using the forward pass of the bijector rather than using -the inverse transform to compute the necessary `logabsdetjac`. - -This is similar to `logpdf_with_trans`. -""" -# TODO: implement more efficiently for flows in the case of `Matrix` -logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) - logjac -logpdf_forward(td::Transformed, x) = logpdf_forward(td, x, logabsdetjac(td.transform, x)) - -function logpdf_forward(td::MvTransformed{<:Dirichlet}, x, logjac) - T = eltype(x) - ϵ = _eps(T) - - return logpdf(td.dist, mappedarray(z->z+ϵ, x)) - logjac -end - - -# forward function -const GLOBAL_RNG = Distributions.GLOBAL_RNG - -function _forward(d::UnivariateDistribution, x) - y, logjac = with_logabsdet_jacobian(Identity(), x) - return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf.(d, x)) -end - -forward(rng::AbstractRNG, d::Distribution) = _forward(d, rand(rng, d)) -function forward(rng::AbstractRNG, d::Distribution, num_samples::Int) - return _forward(d, rand(rng, d, num_samples)) -end -function _forward(d::Distribution, x) - y, logjac = with_logabsdet_jacobian(Identity(), x) - return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf(d, x)) -end - -function _forward(td::Transformed, x) - y, logjac = with_logabsdet_jacobian(td.transform, x) - return ( - x = x, - y = y, - logabsdetjac = logjac, - logpdf = logpdf_forward(td, x, logjac) - ) -end -function forward(rng::AbstractRNG, td::Transformed) - return _forward(td, rand(rng, td.dist)) -end -function forward(rng::AbstractRNG, td::Transformed, num_samples::Int) - return [_forward(td, rand(rng, td.dist)) for _ = 1:num_samples] -end - -""" - forward(d::Distribution) - forward(d::Distribution, num_samples::Int) - -Returns a `NamedTuple` with fields `x`, `y`, `logabsdetjac` and `logpdf`. - -In the case where `d isa TransformedDistribution`, this means -- `x = rand(d.dist)` -- `y = d.transform(x)` -- `logabsdetjac` is the logabsdetjac of the "forward" transform. -- `logpdf` is the logpdf of `y`, not `x` - -In the case where `d isa Distribution`, this means -- `x = rand(d)` -- `y = x` -- `logabsdetjac = 0.0` -- `logpdf` is logpdf of `x` -""" -forward(d::Distribution) = forward(GLOBAL_RNG, d) -forward(d::Distribution, num_samples::Int) = forward(GLOBAL_RNG, d, num_samples) - # utility stuff Distributions.params(td::Transformed) = Distributions.params(td.dist) function Base.maximum(td::UnivariateTransformed) @@ -281,19 +164,3 @@ function Base.minimum(td::UnivariateTransformed) return max < min ? max : min end -# logabsdetjac for distributions -logabsdetjacinv(d::UnivariateDistribution, x::T) where T <: Real = zero(T) -logabsdetjacinv(d::MultivariateDistribution, x::AbstractVector{T}) where {T<:Real} = zero(T) - - -""" - logabsdetjacinv(td::UnivariateTransformed, y::Real) - logabsdetjacinv(td::MultivariateTransformed, y::AbstractVector{<:Real}) - -Computes the `logabsdetjac` of the _inverse_ transformation, since `rand(td)` returns -the _transformed_ random variable. -""" -logabsdetjacinv(td::UnivariateTransformed, y::Real) = logabsdetjac(inverse(td.transform), y) -function logabsdetjacinv(td::MvTransformed, y::AbstractVector{<:Real}) - return logabsdetjac(inverse(td.transform), y) -end diff --git a/test/ad/flows.jl b/test/ad/flows.jl index 335f6333..bfcbaacc 100644 --- a/test/ad/flows.jl +++ b/test/ad/flows.jl @@ -3,23 +3,27 @@ test_ad(randn(7)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) flow = transformed(MvNormal(zeros(2), I), layer) - return logpdf_forward(flow, θ[6:7]) + x = θ[6:7] + return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x) end test_ad(randn(11)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) flow = transformed(MvNormal(zeros(2), I), layer) - return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :))) + x = reshape(θ[6:end], 2, :) + return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)) end # logpdf of a flow with the inverse of a planar layer and two-dimensional inputs test_ad(randn(7)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) flow = transformed(MvNormal(zeros(2), I), inverse(layer)) - return logpdf_forward(flow, θ[6:7]) + x = θ[6:7] + return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x) end test_ad(randn(11)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) flow = transformed(MvNormal(zeros(2), I), inverse(layer)) - return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :))) + x = reshape(θ[6:end], 2, :) + return sum(logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)) end end diff --git a/test/interface.jl b/test/interface.jl index 11077875..7bcfcc5a 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -81,11 +81,6 @@ end @test y ≈ @inferred td.transform(x) @test @inferred(logpdf(td, y)) ≈ @inferred(logpdf_with_trans(dist, x, true)) - # logpdf_with_jac - lp, logjac = logpdf_with_jac(td, y) - @test lp ≈ logpdf(td, y) - @test logjac ≈ logabsdetjacinv(td.transform, y) - # multi-sample y = @inferred rand(td, 10) x = inverse(td.transform).(y) @@ -99,14 +94,6 @@ end @test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) @test logpdf(d, x) - logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) - # forward - f = @inferred forward(td) - @test f.x ≈ inverse(td.transform)(f.y) - @test f.y ≈ td.transform(f.x) - @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) - @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) - @test f.logpdf ≈ logpdf(td.dist, f.x) - f.logabsdetjac - # verify against AD d = dist b = bijector(d) @@ -197,18 +184,6 @@ end @test td.transform(param(x)) isa TrackedArray @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - # logpdf_with_jac - lp, logjac = logpdf_with_jac(td, y) - @test lp ≈ logpdf(td, y) - @test logjac ≈ logabsdetjacinv(td.transform, y) - - # forward - f = forward(td) - @test f.x ≈ inverse(td.transform)(f.y) - @test f.y ≈ td.transform(f.x) - @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) - @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) - # verify against AD # similar to what we do in test/transform.jl for Dirichlet if dist isa Dirichlet diff --git a/test/norm_flows.jl b/test/norm_flows.jl index fdf79676..eb07488d 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -103,7 +103,7 @@ end x = rand(d) y = flow.transform(x) res = with_logabsdet_jacobian(flow.transform, x) - lp = logpdf_forward(flow, x, res[2]) + lp = logpdf(d, x) - res[2] @test res[1] ≈ y @test logpdf(flow, y) ≈ lp rtol=0.1 From 39746db5b6fafb294dae9d556f7bc810fbd653a2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 24 Aug 2022 11:21:53 +0100 Subject: [PATCH 67/98] remove old deprecation warnings --- src/Bijectors.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index b8291c9e..e01b3be4 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -258,13 +258,6 @@ include("utils.jl") include("interface.jl") include("chainrules.jl") -Base.@deprecate forward(b::Union{Transform,Function}, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) - -@noinline function Base.inv(b::Transform) - Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv) - inverse(b) -end - Base.@deprecate NamedBijector(bs) NamedTransform(bs) Base.@deprecate Exp() elementwise(exp) false From 859a6ba20c26dece1542ca213c55748d361e0cb7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 24 Aug 2022 13:57:17 +0100 Subject: [PATCH 68/98] fix exports --- src/Bijectors.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index e01b3be4..7a8e2d67 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -55,7 +55,6 @@ export TransformDistribution, isclosedform, transform, transform!, - forward, with_logabsdet_jacobian, with_logabsdet_jacobian!, inverse, @@ -65,8 +64,6 @@ export TransformDistribution, Bijector, ADBijector, Inverse, - Composed, - compose, Stacked, stack, Identity, @@ -74,13 +71,10 @@ export TransformDistribution, transformed, UnivariateTransformed, MultivariateTransformed, - logpdf_with_jac, - logpdf_forward, PlanarLayer, RadialLayer, - CouplingLayer, + Coupling, InvertibleBatchNorm, - Elementwise, elementwise if VERSION < v"1.1" From 0f5f9f1c11c827156ad16b8166d3981f08db81ea Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 24 Aug 2022 13:57:29 +0100 Subject: [PATCH 69/98] updated tests for deprecations --- test/interface.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 7bcfcc5a..c9037b19 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -533,6 +533,8 @@ end b = Bijectors.Logit(0.0, 1.0) x = 0.3 - @test @test_deprecated(forward(b, x)) == NamedTuple{(:rv, :logabsdetjac)}(with_logabsdet_jacobian(b, x)) - @test @test_deprecated(inv(b)) == inverse(b) + @test @test_deprecated(Bijectors.Exp()) == elementwise(exp) + @test @test_deprecated(Bijectors.Log()) == elementwise(log) + + @test @test_deprecated(Bijectors.NamedBijector((x = b, ))) == Bijectors.NamedBijector((x = b, )) end From 6717172a3e01f3ffb58c8589c50600881cfe701f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 24 Aug 2022 14:07:05 +0100 Subject: [PATCH 70/98] completed some random TODOs --- src/bijectors/normalise.jl | 10 +--------- src/bijectors/stacked.jl | 10 +--------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index b04a10fd..6b131d27 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -38,15 +38,7 @@ function InvertibleBatchNorm( ) end -# define numerical parameters -# TODO: replace with `Functors.@functor InvertibleBatchNorm (b, logs)` when -# https://github.com/FluxML/Functors.jl/pull/7 is merged -function Functors.functor(::Type{<:InvertibleBatchNorm}, x) - function reconstruct_invertiblebatchnorm(xs) - return InvertibleBatchNorm(xs.b, xs.logs, x.m, x.v, x.eps, x.mtm) - end - return (b = x.b, logs = x.logs), reconstruct_invertiblebatchnorm -end +Functors.@functor InvertibleBatchNorm (b, logs) function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) dims = ndims(x) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index e03e0f1f..b974afb9 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -31,15 +31,7 @@ Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) # Avoid mixing tuples and arrays. Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) -# define nested numerical parameters -# TODO: replace with `Functors.@functor Stacked (bs,)` when -# https://github.com/FluxML/Functors.jl/pull/7 is merged -function Functors.functor(::Type{<:Stacked}, x) - function reconstruct_stacked(xs) - return Stacked(xs.bs, x.ranges) - end - return (bs = x.bs,), reconstruct_stacked -end +Functors.@functor Stacked (bs,) function Base.:(==)(b1::Stacked, b2::Stacked) bs1, bs2 = b1.bs, b2.bs From fe2b5e9b1ec238bfcd4c0a9fc782f1a5a085db8c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 6 Oct 2022 23:51:25 +0100 Subject: [PATCH 71/98] fix SimplexBijector tests --- test/interface.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index c9037b19..19c4b1b9 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -188,7 +188,11 @@ end # similar to what we do in test/transform.jl for Dirichlet if dist isa Dirichlet b = Bijectors.SimplexBijector{false}() - x = rand(dist) + # HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]` + # which in turn will lead to differences between `ForwardDiff.jacobian` + # and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`. + # We therefore test the realizations _on_ the boundary rather if we're near the boundary. + x = any(rand(dist) .> 0.9999) ? [0.0, 1.0][sortperm(rand(dist))] : rand(dist) y = b(x) @test b(param(x)) isa TrackedArray @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) From a9be9c94d7a944cf69fa1ca7dc52fb73cb0ea177 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 00:50:54 +0100 Subject: [PATCH 72/98] removed whitespace --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 36b3e784..4139fc46 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -65,7 +65,7 @@ struct CorrBijector <: Bijector end with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) -function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) +function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) w = cholesky(x).U # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) return r + zero(x) From a30d56dd8317b3282d657b2e490b73396080ea49 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 01:03:21 +0100 Subject: [PATCH 73/98] made some docstrings into doctests --- src/bijectors/named_bijector.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index adaa4676..94a02332 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -9,10 +9,10 @@ abstract type AbstractNamedTransform <: Transform end Wraps a `NamedTuple` of key -> `Bijector` pairs, implementing evaluation, inversion, etc. # Examples -```julia-repl -julia> using Bijectors: NamedTransform, Scale, Exp +```jldoctest +julia> using Bijectors: NamedTransform, Scale -julia> b = NamedTransform((a = Scale(2.0), b = Exp())); +julia> b = NamedTransform((a = Scale(2.0), b = exp)); julia> x = (a = 1., b = 0., c = 42.); @@ -98,11 +98,10 @@ end Implements a coupling layer for named bijectors. # Examples -```julia-repl +```jldoctest julia> using Bijectors: NamedCoupling, Scale -julia> b = NamedCoupling(:b, (:a, :c), (a, c) -> Scale(a + c)) -NamedCoupling{:b,(:a, :c),var"#3#4"}(var"#3#4"()) +julia> b = NamedCoupling(:b, (:a, :c), (a, c) -> Scale(a + c)); julia> x = (a = 1., b = 2., c = 3.); From c0f11f4dab40e912cd2c285b4f2b7deeb9d56b9f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 01:03:32 +0100 Subject: [PATCH 74/98] removed unnused method --- src/bijectors/named_bijector.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 94a02332..5b1b22d6 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -35,8 +35,7 @@ function Functors.functor(::Type{<:NamedTransform{names}}, x) where names return (bs = x.bs,), reconstruct_namedbijector end -names_to_bijectors(b::NamedTransform) = b.bs - +# TODO: Use recursion instead of `@generated`? @generated function inverse(b::NamedTransform{names}) where {names} return :(NamedTransform(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) end From 5397d331f7336820ed17c6c1d27802518123b7b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 01:12:21 +0100 Subject: [PATCH 75/98] improved show for scale and shift --- src/bijectors/scale.jl | 2 ++ src/bijectors/shift.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index 1e2c623f..bff549e5 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -6,6 +6,8 @@ Base.:(==)(b1::Scale, b2::Scale) = b1.a == b2.a Functors.@functor Scale +Base.show(io::IO, b::Scale) = print(io, "Scale($(b.a))") + with_logabsdet_jacobian(b::Scale, x) = transform(b, x), logabsdetjac(b, x) transform(b::Scale, x) = b.a .* x diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 908815a6..22aed9d2 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -9,6 +9,8 @@ Base.:(==)(b1::Shift, b2::Shift) = b1.a == b2.a Functors.@functor Shift +Base.show(io::IO, b::Shift) = print(io, "Shift($(b.a))") + inverse(b::Shift) = Shift(-b.a) transform(b::Shift, x) = b.a .+ x From 901a6ef6831e063de3953f6a2ca0a39360f5d4b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 01:12:43 +0100 Subject: [PATCH 76/98] converted example for Coupling into doctest --- src/bijectors/coupling.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 8f25eff0..500246db 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -134,25 +134,23 @@ Partitions `x` into 3 disjoint subvectors. Implements a coupling-layer as defined in [1]. # Examples -```julia-repl -julia> m = PartitionMask(3, [1], [2]) # <= going to use x[2] to parameterize transform of x[1] -PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( - [1, 1] = 1.0, - [2, 1] = 1.0, - [3, 1] = 1.0) +```jldoctest +julia> using Bijectors: Shift, Coupling, PartitionMask, coupling, couple -julia> cl = Coupling(θ -> Shift(θ[1]), m) # <= will do `y[1:1] = x[1:1] + x[2:2]`; +julia> m = PartitionMask(3, [1], [2]); # <= going to use x[2] to parameterize transform of x[1] + +julia> cl = Coupling(Shift, m) # <= will do `y[1:1] = x[1:1] + x[2:2]`; julia> x = [1., 2., 3.]; julia> cl(x) -3-element Array{Float64,1}: +3-element Vector{Float64}: 3.0 2.0 3.0 julia> inverse(cl)(cl(x)) -3-element Array{Float64,1}: +3-element Vector{Float64}: 1.0 2.0 3.0 @@ -161,7 +159,10 @@ julia> coupling(cl) # get the `Bijector` map `θ -> b(⋅, θ)` Shift julia> couple(cl, x) # get the `Bijector` resulting from `x` -Shift{Array{Float64,1},1}([2.0]) +Shift([2.0]) + +julia> logabsdetjac(cl, x) +0.0 ``` # References From 557f826b12eb0539d3fd1b790b81a6bbf899affe Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 01:12:55 +0100 Subject: [PATCH 77/98] added reference to Coupling bijector for NamedCoupling --- src/bijectors/named_bijector.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 5b1b22d6..bbff828e 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -96,6 +96,8 @@ end Implements a coupling layer for named bijectors. +See also: [`Coupling`](@ref) + # Examples ```jldoctest julia> using Bijectors: NamedCoupling, Scale From 92651e865e2d3675f3c4c9ef139fb2913811d6e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 01:23:21 +0100 Subject: [PATCH 78/98] fixed docstring --- src/bijectors/coupling.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 500246db..9aaaf829 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -139,7 +139,7 @@ julia> using Bijectors: Shift, Coupling, PartitionMask, coupling, couple julia> m = PartitionMask(3, [1], [2]); # <= going to use x[2] to parameterize transform of x[1] -julia> cl = Coupling(Shift, m) # <= will do `y[1:1] = x[1:1] + x[2:2]`; +julia> cl = Coupling(Shift, m); # <= will do `y[1:1] = x[1:1] + x[2:2]`; julia> x = [1., 2., 3.]; @@ -161,8 +161,8 @@ Shift julia> couple(cl, x) # get the `Bijector` resulting from `x` Shift([2.0]) -julia> logabsdetjac(cl, x) -0.0 +julia> with_logabsdet_jacobian(cl, x) +([3.0, 2.0, 3.0], 0.0) ``` # References From 3719f33d0fb5c4d0d2ec03c53fe4d0710e8ec80a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 01:25:16 +0100 Subject: [PATCH 79/98] fixed documentation setup --- docs/make.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/make.jl b/docs/make.jl index b3aafc3d..5c3f4d6b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,9 @@ using Documenter using Bijectors +# Doctest setup +DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true) + makedocs( sitename = "Bijectors", format = Documenter.HTML(), From f26d5a646fdc351741b3162ef0e7ef6c57740f3f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 01:25:35 +0100 Subject: [PATCH 80/98] nvm, now I fixed documentation setup --- docs/make.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 5c3f4d6b..7a65b16a 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,7 +2,7 @@ using Documenter using Bijectors # Doctest setup -DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true) +DocMeta.setdocmeta!(Bijectors, :DocTestSetup, :(using Bijectors); recursive=true) makedocs( sitename = "Bijectors", From ad3ecc9eed4bfbcdefb26c4655d0b505c321b79f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 17:19:55 +0100 Subject: [PATCH 81/98] removed references to dimensionality in code --- src/bijectors/adbijector.jl | 2 +- src/bijectors/corr.jl | 2 +- src/bijectors/permute.jl | 2 +- src/bijectors/stacked.jl | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/bijectors/adbijector.jl b/src/bijectors/adbijector.jl index 44d9420b..7c62c4ac 100644 --- a/src/bijectors/adbijector.jl +++ b/src/bijectors/adbijector.jl @@ -1,5 +1,5 @@ """ -Abstract type for a `Bijector{N}` making use of auto-differentation (AD) to +Abstract type for a `Bijector` making use of auto-differentation (AD) to implement `jacobian` and, by impliciation, `logabsdetjac`. """ abstract type ADBijector{AD} <: Bijector end diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 4139fc46..29f5aa59 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -1,5 +1,5 @@ """ - CorrBijector <: Bijector{2} + CorrBijector <: Bijector A bijector implementation of Stan's parametrization method for Correlation matrix: https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index 0c352886..a2b49aa7 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -2,7 +2,7 @@ using SparseArrays using ArgCheck """ - Permute{A} <: Bijector{1} + Permute{A} <: Bijector A bijector implementation of a permutation. The permutation is performed using a matrix of type `A`. There are a couple of different ways to construct `Permute`: diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 7ec75398..6c7c6e9d 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -1,7 +1,7 @@ """ Stacked(bs) Stacked(bs, ranges) - stack(bs::Bijector{0}...) # where `0` means 0-dim `Bijector` + stack(bs::Bijector...) # where `0` means 0-dim `Bijector` A `Bijector` which stacks bijectors together which can then be applied to a vector where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. From 95ff4b6b91a1529ca61869e2a695f29beda74233 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 17:20:06 +0100 Subject: [PATCH 82/98] fixed typo --- src/bijectors/composed.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 5e44c997..79ffe9c7 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -16,7 +16,7 @@ end function logabsdetjac!(cb::ComposedFunction, x, logjac) y = similar(x) logjac = last(with_logabsdet_jacobian!(cb.inner, x, y, logjac)) - return logabdetjac!(cb.outer, y, y, logjac) + return logabsdetjac!(cb.outer, y, y, logjac) end function with_logabsdet_jacobian!(cb::ComposedFunction, x, y, logjac) From c8c3bdca62a5755e40c3df3f3abd6318360b68e0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 7 Oct 2022 17:20:41 +0100 Subject: [PATCH 83/98] add impl of invertible for Elementwise --- src/interface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/interface.jl b/src/interface.jl index d8fbaf77..4f4ddc27 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -168,6 +168,7 @@ Base.:+(::Invertible, ::Invertible) = Invertible() Return `Invertible()` if `t` is invertible, and `NotInvertible()` otherwise. """ invertible(::Transform) = NotInvertible() +invertible(f::Elementwise) = invertible(f.x) """ isinvertible(t) From 43d204f73ac374514f71d1c643a4ed575b93198d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 9 Oct 2022 19:44:04 +0100 Subject: [PATCH 84/98] added transforms and distributions as separate pages in docs --- docs/make.jl | 12 ++-- docs/src/distributions.md | 70 ++++++++++++++++++++ docs/src/index.md | 134 +++++++++----------------------------- docs/src/transforms.md | 105 +++++++++++++++++++++++++++++ 4 files changed, 210 insertions(+), 111 deletions(-) create mode 100644 docs/src/distributions.md create mode 100644 docs/src/transforms.md diff --git a/docs/make.jl b/docs/make.jl index 7a65b16a..706cfd02 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,12 +7,10 @@ DocMeta.setdocmeta!(Bijectors, :DocTestSetup, :(using Bijectors); recursive=true makedocs( sitename = "Bijectors", format = Documenter.HTML(), - modules = [Bijectors] + modules = [Bijectors], + pages = ["Home" => "index.md", "Transforms" => "transforms.md", "Distributions.jl integration" => "distributions.md"], + strict=false, + checkdocs=:exports, ) -# Documenter can also automatically deploy documentation to gh-pages. -# See "Hosting Documentation" and deploydocs() in the Documenter manual -# for more information. -deploydocs( - repo = "github.com/TuringLang/Bijectors.jl.git", -) +deploydocs(repo = "github.com/TuringLang/Bijectors.jl.git", push_preview=true) diff --git a/docs/src/distributions.md b/docs/src/distributions.md new file mode 100644 index 00000000..7f960d66 --- /dev/null +++ b/docs/src/distributions.md @@ -0,0 +1,70 @@ +# Basic usage +Other than the `logpdf_with_trans` methods, the package also provides a more composable interface through the `Bijector` types. Consider for example the one from above with `Beta(2, 2)`. + +```julia +julia> using Random; Random.seed!(42); + +julia> using Bijectors; using Bijectors: Logit + +julia> dist = Beta(2, 2) +Beta{Float64}(α=2.0, β=2.0) + +julia> x = rand(dist) +0.36888689965963756 + +julia> b = bijector(dist) # bijection (0, 1) → ℝ +Logit{Float64}(0.0, 1.0) + +julia> y = b(x) +-0.5369949942509267 +``` + +In this case we see that `bijector(d::Distribution)` returns the corresponding constrained-to-unconstrained bijection for `Beta`, which indeed is a `Logit` with `a = 0.0` and `b = 1.0`. The resulting `Logit <: Bijector` has a method `(b::Logit)(x)` defined, allowing us to call it just like any other function. Comparing with the above example, `b(x) ≈ link(dist, x)`. Just to convince ourselves: + +```julia +julia> b(x) ≈ link(dist, x) +true +``` + +## Transforming distributions + +```@setup transformed-dist-simple +using Bijectors +``` + +We can create a _transformed_ `Distribution`, i.e. a `Distribution` defined by sampling from a given `Distribution` and then transforming using a given transformation: + +```@repl transformed-dist-simple +dist = Beta(2, 2) # support on (0, 1) +tdist = transformed(dist) # support on ℝ + +tdist isa UnivariateDistribution +``` + +We can the then compute the `logpdf` for the resulting distribution: + +```@repl transformed-dist-simple +# Some example values +x = rand(dist) +y = tdist.transform(x) + +logpdf(tdist, y) +``` + +When computing `logpdf(tdist, y)` where `tdist` is the _transformed_ distribution corresponding to `Beta(2, 2)`, it makes more semantic sense to compute the pdf of the _transformed_ variable `y` rather than using the "un-transformed" variable `x` to do so, as we do in `logpdf_with_trans`. With that being said, we can also do + +```julia +logpdf_forward(tdist, x) +``` + +We can of course also sample from `tdist`: + +```julia +julia> y = rand(td) # ∈ ℝ +0.999166054552483 + +julia> x = inverse(td.transform)(y) # transform back to interval [0, 1] +0.7308945834125756 +``` + + diff --git a/docs/src/index.md b/docs/src/index.md index 562c7dcd..7f7ed0a4 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,106 +1,32 @@ # Bijectors.jl -## Usage - -A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function: -- The inverse of `exp` is `log`. -- The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`. - -```@repl usage -using Bijectors -transform(exp, 1.0) -logabsdetjac(exp, 1.0) -with_logabsdet_jacobian(exp, 1.0) -``` - -Some transformations is well-defined for different types of inputs, e.g. `exp` can also act elementwise on a `N`-dimensional `Array{<:Real,N}`. To specify that a transformation should be acting elementwise, we use the [`elementwise`](@ref) method: - -```@repl usage -x = ones(2, 2) -transform(elementwise(exp), x) -logabsdetjac(elementwise(exp), x) -with_logabsdet_jacobian(elementwise(exp), x) -``` - -These methods also work nicely for compositions of transformations: - -```@repl usage -transform(elementwise(log ∘ exp), x) -``` - -Unlike `exp`, some transformations have parameters affecting the resulting transformation they represent, e.g. `Logit` has two parameters `a` and `b` representing the lower- and upper-bound, respectively, of its domain: - -```@repl usage -using Bijectors: Logit - -f = Logit(0.0, 1.0) -f(rand()) # takes us from `(0, 1)` to `(-∞, ∞)` -``` - -## User-facing methods - -Without mutation: - -```@docs -transform -logabsdetjac -``` - -```julia -with_logabsdet_jacobian -``` - -With mutation: - -```@docs -transform! -logabsdetjac! -with_logabsdet_jacobian! -``` - -## Implementing a transformation - -Any callable can be made into a bijector by providing an implementation of `ChangeOfVariables.with_logabsdet_jacobian(b, x)`. - -You can also optionally implement [`transform`](@ref) and [`logabsdetjac`](@ref) to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. - -Similarly with the mutable versions [`with_logabsdet_jacobian!`](@ref), [`transform!`](@ref), and [`logabsdetjac!`](@ref). - -## Working with Distributions.jl - -```@docs -Bijectors.bijector -Bijectors.transformed(d::Distribution, b::Bijector) -``` - -## Utilities - -```@docs -Bijectors.elementwise -Bijectors.isinvertible -Bijectors.isclosedform(t::Bijectors.Transform) -Bijectors.invertible -Bijectors.NotInvertible -Bijectors.Invertible -``` - -## API - -```@docs -Bijectors.Transform -Bijectors.Bijector -Bijectors.Inverse -``` - -## Bijectors - -```@docs -Bijectors.CorrBijector -Bijectors.LeakyReLU -Bijectors.Stacked -Bijectors.RationalQuadraticSpline -Bijectors.Coupling -Bijectors.OrderedBijector -Bijectors.NamedTransform -Bijectors.NamedCoupling -``` +This package implements a set of functions for transforming constrained random variables (e.g. simplexes, intervals) to Euclidean space. The 3 main functions implemented in this package are the `link`, `invlink` and `logpdf_with_trans` for a number of distributions. The distributions supported are: +1. `RealDistribution`: `Union{Cauchy, Gumbel, Laplace, Logistic, NoncentralT, Normal, NormalCanon, TDist}`, +2. `PositiveDistribution`: `Union{BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet, Gamma, InverseGamma, InverseGaussian, Kolmogorov, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull}`, +3. `UnitDistribution`: `Union{Beta, KSOneSided, NoncentralBeta}`, +4. `SimplexDistribution`: `Union{Dirichlet}`, +5. `PDMatDistribution`: `Union{InverseWishart, Wishart}`, and +6. `TransformDistribution`: `Union{T, Truncated{T}} where T<:ContinuousUnivariateDistribution`. + +All exported names from the [Distributions.jl](https://github.com/TuringLang/Bijectors.jl) package are reexported from `Bijectors`. + +Bijectors.jl also provides a nice interface for working with these maps: composition, inversion, etc. +The following table lists mathematical operations for a bijector and the corresponding code in Bijectors.jl. + +| Operation | Method | Automatic | +|:------------------------------------:|:-----------------:|:-----------:| +| `b ↦ b⁻¹` | `inverse(b)` | ✓ | +| `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` | ✓ | +| `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` | ✓ | +| `x ↦ b(x)` | `b(x)` | × | +| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × | +| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD | +| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` | ✓ | +| `p ↦ q := b_* p` | `q = transformed(p, b)` | ✓ | +| `y ∼ q` | `y = rand(q)` | ✓ | +| `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` | ✓ | +| `(x ∼ p, b(x), log|det J(b, x)|, log q(y))` | `forward(q)` | ✓ | + +In this table, `b` denotes a `Bijector`, `J(b, x)` denotes the Jacobian of `b` evaluated at `x`, `b_*` denotes the [push-forward](https://www.wikiwand.com/en/Pushforward_measure) of `p` by `b`, and `x ∼ p` denotes `x` sampled from the distribution with density `p`. + +The "Automatic" column in the table refers to whether or not you are required to implement the feature for a custom `Bijector`. "AD" refers to the fact that it can be implemented "automatically" using automatic differentiation. diff --git a/docs/src/transforms.md b/docs/src/transforms.md new file mode 100644 index 00000000..e2a4c915 --- /dev/null +++ b/docs/src/transforms.md @@ -0,0 +1,105 @@ +## Usage + +A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function: +- The inverse of `exp` is `log`. +- The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`. + +```@repl usage +using Bijectors +transform(exp, 1.0) +logabsdetjac(exp, 1.0) +with_logabsdet_jacobian(exp, 1.0) +``` + +Some transformations is well-defined for different types of inputs, e.g. `exp` can also act elementwise on a `N`-dimensional `Array{<:Real,N}`. To specify that a transformation should be acting elementwise, we use the [`elementwise`](@ref) method: + +```@repl usage +x = ones(2, 2) +transform(elementwise(exp), x) +logabsdetjac(elementwise(exp), x) +with_logabsdet_jacobian(elementwise(exp), x) +``` + +These methods also work nicely for compositions of transformations: + +```@repl usage +transform(elementwise(log ∘ exp), x) +``` + +Unlike `exp`, some transformations have parameters affecting the resulting transformation they represent, e.g. `Logit` has two parameters `a` and `b` representing the lower- and upper-bound, respectively, of its domain: + +```@repl usage +using Bijectors: Logit + +f = Logit(0.0, 1.0) +f(rand()) # takes us from `(0, 1)` to `(-∞, ∞)` +``` + +## User-facing methods + +Without mutation: + +```@docs +transform +logabsdetjac +``` + +```julia +with_logabsdet_jacobian +``` + +With mutation: + +```@docs +transform! +logabsdetjac! +with_logabsdet_jacobian! +``` + +## Implementing a transformation + +Any callable can be made into a bijector by providing an implementation of `ChangeOfVariables.with_logabsdet_jacobian(b, x)`. + +You can also optionally implement [`transform`](@ref) and [`logabsdetjac`](@ref) to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. + +Similarly with the mutable versions [`with_logabsdet_jacobian!`](@ref), [`transform!`](@ref), and [`logabsdetjac!`](@ref). + +## Working with Distributions.jl + +```@docs +Bijectors.bijector +Bijectors.transformed(d::Distribution, b::Bijector) +``` + +## Utilities + +```@docs +Bijectors.elementwise +Bijectors.isinvertible +Bijectors.isclosedform(t::Bijectors.Transform) +Bijectors.invertible +Bijectors.NotInvertible +Bijectors.Invertible +``` + +## API + +```@docs +Bijectors.Transform +Bijectors.Bijector +Bijectors.Inverse +``` + +## Bijectors + +```@docs +Bijectors.CorrBijector +Bijectors.LeakyReLU +Bijectors.Stacked +Bijectors.RationalQuadraticSpline +Bijectors.Coupling +Bijectors.OrderedBijector +Bijectors.NamedTransform +Bijectors.NamedCoupling +``` + From 29a0b59e96485e53f0c9d7969cc826499c01252d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Oct 2022 02:09:10 +0100 Subject: [PATCH 85/98] removed all the unnecessary stuff in README --- README.md | 252 +----------------------------------------------------- 1 file changed, 1 insertion(+), 251 deletions(-) diff --git a/README.md b/README.md index 14fba13a..34abd42b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Bijectors.jl +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://turinglang.github.io/Bijectors.jl/stable) [![Interface tests](https://github.com/TuringLang/Bijectors.jl/workflows/Interface%20tests/badge.svg?branch=master)](https://github.com/TuringLang/Bijectors.jl/actions?query=workflow%3A%22Interface+tests%22+branch%3Amaster) [![AD tests](https://github.com/TuringLang/Bijectors.jl/workflows/AD%20tests/badge.svg?branch=master)](https://github.com/TuringLang/Bijectors.jl/actions?query=workflow%3A%22AD+tests%22+branch%3Amaster) @@ -135,19 +136,6 @@ true Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inverse(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inverse(Exp()) isa Log` is true. -#### Dimensionality -One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`: - -```julia -julia> Bijectors.dimension(b) -0 - -julia> Bijectors.dimension(Exp{1}()) -1 -``` - -In most cases specification of the dimensionality is unnecessary as a `Bijector{N}` is usually only defined for a particular value of `N`, e.g. `Logit isa Bijector{0}` since it only makes sense to apply `Logit` to a real number (or a vector of reals if you're doing batch-computation). As a user, you'll rarely have to deal with this dimensionality specification. Unfortunately there are exceptions, e.g. `Exp` which can be applied to both real numbers and a vector of real numbers, in both cases treating it as a single input. This means that when `Exp` receives a vector input `x` as input, it's ambiguous whether or not to treat `x` as a *batch* of 0-dim inputs or as a single 1-dim input. As a result, to support batch-computation it is necessary to know the expected dimensionality of the input and output. Notice that we assume the dimensionality of the input and output to be the *same*. This is a reasonable assumption considering we're working with *bijections*. - #### Composition Also, we can _compose_ bijectors: @@ -491,244 +479,6 @@ julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns a This method is for example useful when computing quantities such as the _expected lower bound (ELBO)_ between this transformed distribution and some other joint density. If no analytical expression is available, we have to approximate the ELBO by a Monte Carlo estimate. But one term in the ELBO is the entropy of the base density, which we _do_ know analytically in this case. Using the analytical expression for the entropy and then using a monte carlo estimate for the rest of the terms in the ELBO gives an estimate with lower variance than if we used the monte carlo estimate for the entire expectation. -### Normalizing flows with bounded support - - -## Implementing your own `Bijector` -There's mainly two ways you can implement your own `Bijector`, and which way you choose mainly depends on the following question: are you bothered enough to manually implement `logabsdetjac`? If the answer is "Yup!", then you subtype from `Bijector`, if "Naaaah" then you subtype `ADBijector`. - -### `<:Bijector` -Here's a simple example taken from the source code, the `Identity`: - -```julia -import Bijectors: logabsdetjac - -struct Identity{N} <: Bijector{N} end -(::Identity)(x) = x # transform itself, "forward" -(::Inverse{<: Identity})(y) = y # inverse tramsform, "backward" - -# see the proper implementation for `logabsdetjac` in general -logabsdetjac(::Identity{0}, y::Real) = zero(eltype(y)) # ∂ₓid(x) = ∂ₓ x = 1 → log(abs(1)) = log(1) = 0 -``` - -A slightly more complex example is `Logit`: - -```julia -using LogExpFunctions: logit, logistic - -struct Logit{T<:Real} <: Bijector{0} - a::T - b::T -end - -(b::Logit)(x::Real) = logit((x - b.a) / (b.b - b.a)) -(b::Logit)(x) = map(b, x) -# `orig` contains the `Bijector` which was inverted -(ib::Inverse{<:Logit})(y::Real) = (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a -(ib::Inverse{<:Logit})(y) = map(ib, y) - -logabsdetjac(b::Logit, x::Real) = - log((x - b.a) * (b.b - x) / (b.b - b.a)) -logabsdetjac(b::Logit, x) = map(logabsdetjac, x) -``` - -(Batch computation is not fully supported by all bijectors yet (see issue #35), but is actively worked on. In the particular case of `Logit` there's only one thing that makes sense, which is elementwise application. Therefore we've added `@.` to the implementation above, thus this works for any `AbstractArray{<:Real}`.) - -Then - -```julia -julia> b = Logit(0.0, 1.0) -Logit{Float64}(0.0, 1.0) - -julia> b(0.6) -0.4054651081081642 - -julia> inverse(b)(y) -Tracked 2-element Array{Float64,1}: - 0.3078149833748082 - 0.72380041667891 - -julia> logabsdetjac(b, 0.6) -1.4271163556401458 - -julia> logabsdetjac(inverse(b), y) # defaults to `- logabsdetjac(b, inverse(b)(x))` -Tracked 2-element Array{Float64,1}: - -1.546158373866469 - -1.6098711387913573 - -julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))` -(0.4054651081081642, 1.4271163556401458) -``` - -For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`: - -```julia -julia> using Bijectors: Logit - -julia> import Bijectors: with_logabsdet_jacobian - -julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x) - totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not - y = logit.(totally_worth_saving) - logjac = @. - log((b.b - x) * totally_worth_saving) - return (y, logjac) - end -forward (generic function with 16 methods) - -julia> with_logabsdet_jacobian(b, 0.6) -(0.4054651081081642, 1.4271163556401458) - -julia> @which with_logabsdet_jacobian(b, 0.6) -with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2 -``` - -As you can see it's a very contrived example, but you get the idea. - -### `<:ADBijector` - -We could also have implemented `Logit` as an `ADBijector`: - -```julia -using LogExpFunctions: logit, logistic -using Bijectors: ADBackend - -struct ADLogit{T, AD} <: ADBijector{AD, 0} - a::T - b::T -end - -# ADBackend() returns ForwardDiffAD, which means we use ForwardDiff.jl for AD -ADLogit(a::T, b::T) where {T<:Real} = ADLogit{T, ADBackend()}(a, b) - -(b::ADLogit)(x) = @. logit((x - b.a) / (b.b - b.a)) -(ib::Inverse{<:ADLogit{<:Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a -``` - -No implementation of `logabsdetjac`, but: - -```julia -julia> b_ad = ADLogit(0.0, 1.0) -ADLogit{Float64,Bijectors.ForwardDiffAD}(0.0, 1.0) - -julia> logabsdetjac(b_ad, 0.6) -1.4271163556401458 - -julia> y = b_ad(0.6) -0.4054651081081642 - -julia> inverse(b_ad)(y) -0.6 - -julia> logabsdetjac(inverse(b_ad), y) --1.4271163556401458 -``` - -Neat! And just to verify that everything works: - -```julia -julia> b = Logit(0.0, 1.0) -Logit{Float64}(0.0, 1.0) - -julia> logabsdetjac(b, 0.6) -1.4271163556401458 - -julia> logabsdetjac(b_ad, 0.6) ≈ logabsdetjac(b, 0.6) -true -``` - -We can also use Tracker.jl for the AD, rather than ForwardDiff.jl: - -```julia -julia> Bijectors.setadbackend(:reversediff) -:reversediff - -julia> b_ad = ADLogit(0.0, 1.0) -ADLogit{Float64,Bijectors.TrackerAD}(0.0, 1.0) - -julia> logabsdetjac(b_ad, 0.6) -1.4271163556401458 -``` - - -### Reference -Most of the methods and types mention below will have docstrings with more elaborate explanation and examples, e.g. -```julia -help?> Bijectors.Composed - Composed(ts::A) - - ∘(b1::Bijector{N}, b2::Bijector{N})::Composed{<:Tuple} - composel(ts::Bijector{N}...)::Composed{<:Tuple} - composer(ts::Bijector{N}...)::Composed{<:Tuple} - - where A refers to either - - • Tuple{Vararg{<:Bijector{N}}}: a tuple of bijectors of dimensionality N - - • AbstractArray{<:Bijector{N}}: an array of bijectors of dimensionality N - - A Bijector representing composition of bijectors. composel and composer results in a Composed for which application occurs from left-to-right and right-to-left, respectively. - - Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methods, e.g. inverse. - - If you want to use an Array as the container instead you can do - - Composed([b1, b2, ...]) - - In general this is not advised since you lose type-stability, but there might be cases where this is desired, e.g. if you have a insanely large number of bijectors to compose. - - Examples - ≡≡≡≡≡≡≡≡≡≡ - - It's important to note that ∘ does what is expected mathematically, which means that the bijectors are applied to the input right-to-left, e.g. first applying b2 and then b1: - - (b1 ∘ b2)(x) == b1(b2(x)) # => true - - But in the Composed struct itself, we store the bijectors left-to-right, so that - - cb1 = b1 ∘ b2 # => Composed.ts == (b2, b1) - cb2 = composel(b2, b1) # => Composed.ts == (b2, b1) - cb1(x) == cb2(x) == b1(b2(x)) # => true -``` -If anything is lacking or not clear in docstrings, feel free to open an issue or PR. - -#### Types -The following are the bijectors available: -- Abstract: - - `Bijector`: super-type of all bijectors. - - `ADBijector{AD} <: Bijector`: subtypes of this only require the user to implement `(b::UserBijector)(x)` and `(ib::Inverse{<:UserBijector})(y)`. Automatic differentation will be used to compute the `jacobian(b, x)` and thus `logabsdetjac(b, x). -- Concrete: - - `Composed`: represents a composition of bijectors. - - `Stacked`: stacks univariate and multivariate bijectors - - `Identity`: does what it says, i.e. nothing. - - `Logit` - - `Exp` - - `Log` - - `Scale`: scaling by scalar value, though at the moment only well-defined `logabsdetjac` for univariate. - - `Shift`: shifts by a scalar value. - - `Permute`: permutes the input array using matrix multiplication - - `SimplexBijector`: mostly used as the constrained-to-unconstrained bijector for `SimplexDistribution`, e.g. `Dirichlet`. - - `PlanarLayer`: §4.1 Eq. (10) in [1] - - `RadialLayer`: §4.1 Eq. (14) in [1] - -The distribution interface consists of: -- `TransformedDistribution <: Distribution`: implements the `Distribution` interface from Distributions.jl. This means `rand` and `logpdf` are provided at the moment. - -#### Methods -The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`. -- `(b::Bijector)(x)`: implements the transform of the `Bijector` -- `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`. -- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))). -- `with_logabsdet_jacobian(b::Bijector, x)`: returns the tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner. -- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation. -- `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency. -- `dimension(b::Bijector)`: returns the dimensionality of `b`. -- `isclosedform(b::Bijector)`: returns `true` or `false` depending on whether or not `b(x)` has a closed-form implementation. - -For `TransformedDistribution`, together with default implementations for `Distribution`, we have the following methods: -- `bijector(d::Distribution)`: returns the default constrained-to-unconstrained bijector for `d` -- `transformed(d::Distribution)`, `transformed(d::Distribution, b::Bijector)`: constructs a `TransformedDistribution` from `d` and `b`. -- `logpdf_forward(d::Distribution, x)`, `logpdf_forward(d::Distribution, x, logjac)`: computes the `logpdf(td, td.transform(x))` using the forward pass, which is potentially faster depending on the transform at hand. -- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inverse(b), b(x))` depending on which is most efficient. - # Bibliography 1. Rezende, D. J., & Mohamed, S. (2015). Variational Inference With Normalizing Flows. [arXiv:1505.05770](https://arxiv.org/abs/1505.05770v6). 2. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2016). Automatic Differentiation Variational Inference. [arXiv:1603.00788](https://arxiv.org/abs/1603.00788v1). From 871874e7883915a4b2c210cfe19c491fe3bb01df Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Oct 2022 02:09:41 +0100 Subject: [PATCH 86/98] added examples to docs --- docs/Project.toml | 3 + docs/make.jl | 2 +- docs/src/examples.md | 163 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 docs/src/examples.md diff --git a/docs/Project.toml b/docs/Project.toml index 3a52a5db..265d975f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,8 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Documenter = "0.27" diff --git a/docs/make.jl b/docs/make.jl index 706cfd02..e1138577 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ makedocs( sitename = "Bijectors", format = Documenter.HTML(), modules = [Bijectors], - pages = ["Home" => "index.md", "Transforms" => "transforms.md", "Distributions.jl integration" => "distributions.md"], + pages = ["Home" => "index.md", "Transforms" => "transforms.md", "Distributions.jl integration" => "distributions.md", "Examples" => "examples.md"], strict=false, checkdocs=:exports, ) diff --git a/docs/src/examples.md b/docs/src/examples.md new file mode 100644 index 00000000..fff55f28 --- /dev/null +++ b/docs/src/examples.md @@ -0,0 +1,163 @@ +```@setup advi +using Bijectors +``` + +## Univariate ADVI example +But the real utility of `TransformedDistribution` becomes more apparent when using `transformed(dist, b)` for any bijector `b`. To get the transformed distribution corresponding to the `Beta(2, 2)`, we called `transformed(dist)` before. This is simply an alias for `transformed(dist, bijector(dist))`. Remember `bijector(dist)` returns the constrained-to-constrained bijector for that particular `Distribution`. But we can of course construct a `TransformedDistribution` using different bijectors with the same `dist`. This is particularly useful in something called _Automatic Differentiation Variational Inference (ADVI)_.[2] An important part of ADVI is to approximate a constrained distribution, e.g. `Beta`, as follows: +1. Sample `x` from a `Normal` with parameters `μ` and `σ`, i.e. `x ~ Normal(μ, σ)`. +2. Transform `x` to `y` s.t. `y ∈ support(Beta)`, with the transform being a differentiable bijection with a differentiable inverse (a "bijector") + +This then defines a probability density with same _support_ as `Beta`! Of course, it's unlikely that it will be the same density, but it's an _approximation_. Creating such a distribution becomes trivial with `Bijector` and `TransformedDistribution`: + +```@repl advi +using StableRNGs: StableRNG +rng = StableRNG(42); +dist = Beta(2, 2) +b = bijector(dist) # (0, 1) → ℝ +b⁻¹ = inverse(b) # ℝ → (0, 1) +td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1) + x = rand(rng, td) # ∈ (0, 1) +``` + +It's worth noting that `support(Beta)` is the _closed_ interval `[0, 1]`, while the constrained-to-unconstrained bijection, `Logit` in this case, is only well-defined as a map `(0, 1) → ℝ` for the _open_ interval `(0, 1)`. This is of course not an implementation detail. `ℝ` is itself open, thus no continuous bijection exists from a _closed_ interval to `ℝ`. But since the boundaries of a closed interval has what's known as measure zero, this doesn't end up affecting the resulting density with support on the entire real line. In practice, this means that + +```@repl advi +td = transformed(Beta()) +inverse(td.transform)(rand(rng, td)) +``` + +will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._ + +## Multivariate ADVI example +We can also do _multivariate_ ADVI using the `Stacked` bijector. `Stacked` gives us a way to combine univariate and/or multivariate bijectors into a singe multivariate bijector. Say you have a vector `x` of length 2 and you want to transform the first entry using `Exp` and the second entry using `Log`. `Stacked` gives you an easy and efficient way of representing such a bijector. + +```@repl advi +using Bijectors: SimplexBijector + +# Original distributions +dists = ( + Beta(), + InverseGamma(), + Dirichlet(2, 3) +); + +# Construct the corresponding ranges +ranges = []; +idx = 1; + +for i = 1:length(dists) + d = dists[i] + push!(ranges, idx:idx + length(d) - 1) + + global idx + idx += length(d) +end; + +ranges + +# Base distribution; mean-field normal +num_params = ranges[end][end] + +d = MvNormal(zeros(num_params), ones(num_params)); + +# Construct the transform +bs = bijector.(dists); # constrained-to-unconstrained bijectors for dists +ibs = inverse.(bs); # invert, so we get unconstrained-to-constrained +sb = Stacked(ibs, ranges) # => Stacked <: Bijector + +# Mean-field normal with unconstrained-to-constrained stacked bijector +td = transformed(d, sb); +y = rand(td) +0.0 ≤ y[1] ≤ 1.0 +0.0 < y[2] +sum(y[3:4]) ≈ 1.0 +``` + +## Normalizing flows +A very interesting application is that of _normalizing flows_.[1] Usually this is done by sampling from a multivariate normal distribution, and then transforming this to a target distribution using invertible neural networks. Currently there are two such transforms available in Bijectors.jl: `PlanarLayer` and `RadialLayer`. Let's create a flow with a single `PlanarLayer`: + +```@setup normalizing-flows +using Bijectors +using StableRNGs: StableRNG +rng = StableRNG(42); +``` + +```@repl normalizing-flows +d = MvNormal(zeros(2), ones(2)); +b = PlanarLayer(2) +flow = transformed(d, b) +flow isa MultivariateDistribution +``` + +That's it. Now we can sample from it using `rand` and compute the `logpdf`, like any other `Distribution`. + +```@repl normalizing-flows +y = rand(rng, flow) +logpdf(flow, y) # uses inverse of `b` +``` + +Similarily to the multivariate ADVI example, we could use `Stacked` to get a _bounded_ flow: + +```@repl normalizing-flows +d = MvNormal(zeros(2), ones(2)); +ibs = inverse.(bijector.((InverseGamma(2, 3), Beta()))); +sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)] +b = sb ∘ PlanarLayer(2) +td = transformed(d, b); +y = rand(rng, td) +0 < y[1] +0 ≤ y[2] ≤ 1 +``` + +Want to fit the flow? + +```@repl normalizing-flows +using Zygote + +# Construct the flow. +b = PlanarLayer(2) + +# Convenient for extracting parameters and reconstructing the flow. +using Functors +θs, reconstruct = Functors.functor(b); + +# Make the objective a `struct` to avoid capturing global variables. +struct NLLObjective{R,D,T} + reconstruct::R + basedist::D + data::T +end + +function (obj::NLLObjective)(θs...) + transformed_dist = transformed(obj.basedist, obj.reconstruct(θs)) + return -sum(Base.Fix1(logpdf, transformed_dist), eachcol(obj.data)) +end + +# Some random data to estimate the density of. +xs = randn(2, 1000); + +# Construct the objective. +f = NLLObjective(reconstruct, MvNormal(2, 1), xs); + +# Initial loss. +@info "Initial loss: $(f(θs...))" + +# Train using gradient descent. +ε = 1e-3; +for i = 1:100 + ∇s = Zygote.gradient(f, θs...) + θs = map(θs, ∇s) do θ, ∇ + θ - ε .* ∇ + end +end + +# Final loss +@info "Finall loss: $(f(θs...))" + +# Very simple check to see if we learned something useful. +samples = rand(transformed(f.basedist, f.reconstruct(θs)), 1000); +mean(eachcol(samples)) # ≈ [0, 0] +cov(samples; dims=2) # ≈ I +``` + +We can easily create more complex flows by simply doing `PlanarLayer(10) ∘ PlanarLayer(10) ∘ RadialLayer(10)` and so on. From 84d6863744639de9ed6d06da12151e0678f1a1e6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Oct 2022 02:10:08 +0100 Subject: [PATCH 87/98] added some show methods for certain bijectors --- src/bijectors/planar_layer.jl | 2 ++ src/bijectors/radial_layer.jl | 2 ++ src/bijectors/stacked.jl | 6 ++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 7b321d33..be46de3a 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -29,6 +29,8 @@ end # all fields are numerical parameters Functors.@functor PlanarLayer +Base.show(io::IO, b::PlanarLayer) = print(io, "PlanarLayer(w = $(b.w), u = $(b.u), b = $(b.b))") + """ get_u_hat(u::AbstractVector{<:Real}, w::AbstractVector{<:Real}) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 7c344f54..d4156f01 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -27,6 +27,8 @@ end # all fields are numerical parameters Functors.@functor RadialLayer +Base.show(io::IO, b::RadialLayer) = print(io, "RadialLayer(α_ = $(b.α_), β = $(b.β), z_0 = $(b.z_0))") + h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) #dh(α, r) = .- (1 ./ (α .+ r)) .^ 2 # for radial flow; derivative of h() diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 6c7c6e9d..470c0aa6 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -1,7 +1,7 @@ """ Stacked(bs) Stacked(bs, ranges) - stack(bs::Bijector...) # where `0` means 0-dim `Bijector` + stack(bs::Bijector...) A `Bijector` which stacks bijectors together which can then be applied to a vector where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. @@ -16,7 +16,7 @@ where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. # Examples ``` b1 = Logit(0.0, 1.0) -b2 = Identity{0}() +b2 = Identity() b = stack(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` @@ -33,6 +33,8 @@ Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) Functors.@functor Stacked (bs,) +Base.show(io::IO, b::Stacked) = print(io, "Stacked($(b.bs), $(b.ranges))") + function Base.:(==)(b1::Stacked, b2::Stacked) bs1, bs2 = b1.bs, b2.bs if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector) From 80bea9476848b0e4d9fb43ebb66d1179997ddd81 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Oct 2022 02:11:39 +0100 Subject: [PATCH 88/98] added compat entries to docs --- docs/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 265d975f..9c81e994 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,3 +6,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Documenter = "0.27" +Functors = "0.3" +StableRNGs = "1" +Zygote = "0.6" \ No newline at end of file From e0e1792083cb19abda8c0cb3ec8449e78ef9eaf4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Oct 2022 02:11:45 +0100 Subject: [PATCH 89/98] updated docstring for RationalQuadraticSpline --- src/bijectors/rational_quadratic_spline.jl | 32 ++++++++++------------ 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 9458f26f..6c0dd601 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -28,46 +28,44 @@ There are two constructors for `RationalQuadraticSpline`: # Examples ## Univariate -```julia-repl +```jldoctest +julia> using StableRNGs: StableRNG; rng = StableRNG(42); # For reproducibility. + julia> using Bijectors: RationalQuadraticSpline julia> K = 3; B = 2; julia> # Monotonic spline on '[-B, B]' with `K` intermediate knots/"connection points". - b = RationalQuadraticSpline(randn(K), randn(K), randn(K - 1), B); + b = RationalQuadraticSpline(randn(rng, K), randn(rng, K), randn(rng, K - 1), B); julia> b(0.5) # inside of `[-B, B]` → transformed -1.412300607463467 +1.1943325397834206 julia> b(5.) # outside of `[-B, B]` → not transformed 5.0 -``` -Or we can use the constructor with the parameters correctly constrained: -```julia-repl + julia> b = RationalQuadraticSpline(b.widths, b.heights, b.derivatives); julia> b(0.5) # inside of `[-B, B]` → transformed -1.412300607463467 -``` -## Multivariate -```julia-repl +1.1943325397834206 + julia> d = 2; K = 3; B = 2; -julia> b = RationalQuadraticSpline(randn(d, K), randn(d, K), randn(d, K - 1), B); +julia> b = RationalQuadraticSpline(randn(rng, d, K), randn(rng, d, K), randn(rng, d, K - 1), B); julia> b([-1., 1.]) -2-element Array{Float64,1}: - -1.2568224171342797 - 0.5537259740554675 +2-element Vector{Float64}: + -1.5660106244288925 + 0.5384702734738573 julia> b([-5., 5.]) -2-element Array{Float64,1}: +2-element Vector{Float64}: -5.0 5.0 julia> b([-1., 5.]) -2-element Array{Float64,1}: - -1.2568224171342797 +2-element Vector{Float64}: + -1.5660106244288925 5.0 ``` From fc476336ee9e6e225f01d5771ebdf603b3a7cd89 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Oct 2022 02:12:03 +0100 Subject: [PATCH 90/98] removed commented code --- src/interface.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 4f4ddc27..5195275d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -109,8 +109,6 @@ Compute `log(abs(det(J(b, x))))` and store the result in `logjac`, where `J(b, x logabsdetjac!(b, x) = logabsdetjac!(b, x, zero(eltype(x))) logabsdetjac!(b, x, logjac) = (logjac += logabsdetjac(b, x)) -# with_logabsdet_jacobian(b::Transform, x) = (transform(b, x), logabsdetjac(b, x)) - """ with_logabsdet_jacobian!(b, x[, y, logjac]) From 161b6f17925bdda0c223091f4877b32bae3b2ea5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 10 Oct 2022 22:35:06 +0100 Subject: [PATCH 91/98] remove reference to logpdf_forward --- docs/src/distributions.md | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/docs/src/distributions.md b/docs/src/distributions.md index 7f960d66..514fc4bf 100644 --- a/docs/src/distributions.md +++ b/docs/src/distributions.md @@ -1,4 +1,4 @@ -# Basic usage +## Basic usage Other than the `logpdf_with_trans` methods, the package also provides a more composable interface through the `Bijector` types. Consider for example the one from above with `Beta(2, 2)`. ```julia @@ -50,21 +50,3 @@ y = tdist.transform(x) logpdf(tdist, y) ``` - -When computing `logpdf(tdist, y)` where `tdist` is the _transformed_ distribution corresponding to `Beta(2, 2)`, it makes more semantic sense to compute the pdf of the _transformed_ variable `y` rather than using the "un-transformed" variable `x` to do so, as we do in `logpdf_with_trans`. With that being said, we can also do - -```julia -logpdf_forward(tdist, x) -``` - -We can of course also sample from `tdist`: - -```julia -julia> y = rand(td) # ∈ ℝ -0.999166054552483 - -julia> x = inverse(td.transform)(y) # transform back to interval [0, 1] -0.7308945834125756 -``` - - From ecb54d34fd1e35112a3cb5f22ce9709258985964 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 11 Oct 2022 12:49:28 +0100 Subject: [PATCH 92/98] remove enforcement of type of input and output being the same in tests --- test/transform.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/transform.jl b/test/transform.jl index de595d67..7be147d3 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -63,8 +63,6 @@ function single_sample_tests(dist) @test logpdf(dist, x) == logpdf_with_trans(dist, x, false) @test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x)) for _ in 1:100])) end - # This is a quirk of the current implementation, of which it would be nice to be rid. - @test typeof(x) == typeof(y) end # Scalar tests From 269f11a9ede5c44901289bb353e6e6cd2af1ec45 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 31 Jan 2023 17:37:42 +0000 Subject: [PATCH 93/98] make logpdf_with_trans compatible with logpdf when it comes to handling batches --- src/Bijectors.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 7a8e2d67..033acbff 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -126,6 +126,19 @@ end link(d::Distribution, x) = bijector(d)(x) invlink(d::Distribution, y) = inverse(bijector(d))(y) + +# To still allow `logpdf_with_trans` to work with "batches" in a similar way +# as `logpdf` can. +_logabsdetjac_dist(d::UnivariateDistribution, x::Real) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::UnivariateDistribution, x::AbstractArray) = logabsdetjac.((bijector(d),), x) + +_logabsdetjac_dist(d::MultivariateDistribution, x::AbstractVector) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::MultivariateDistribution, x::AbstractMatrix) = logabsdetjac.((bijector(d),), eachcol(x)) + +_logabsdetjac_dist(d::MatrixDistribution, x::AbstractMatrix) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractMatrix}) = logabsdetjac.((bijector(d),), x) + + function logpdf_with_trans(d::Distribution, x, transform::Bool) if ispd(d) return pd_logpdf_with_trans(d, x, transform) @@ -135,7 +148,7 @@ function logpdf_with_trans(d::Distribution, x, transform::Bool) l = logpdf(d, x) end if transform - return l - logabsdetjac(bijector(d), x) + return l - _logabsdetjac_dist(d, x) else return l end From 09fe97d0e957491c2ded0999d42c2d4f7900123f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 1 Feb 2023 10:29:19 +0000 Subject: [PATCH 94/98] Apply suggestions from code review Co-authored-by: David Widmann --- docs/src/transforms.md | 3 ++- src/bijectors/corr.jl | 4 ++-- src/bijectors/named_bijector.jl | 4 +--- src/bijectors/ordered.jl | 2 +- src/bijectors/pd.jl | 2 +- src/interface.jl | 2 +- 6 files changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/src/transforms.md b/docs/src/transforms.md index e2a4c915..50ccfbea 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -11,7 +11,8 @@ logabsdetjac(exp, 1.0) with_logabsdet_jacobian(exp, 1.0) ``` -Some transformations is well-defined for different types of inputs, e.g. `exp` can also act elementwise on a `N`-dimensional `Array{<:Real,N}`. To specify that a transformation should be acting elementwise, we use the [`elementwise`](@ref) method: +Some transformations are well-defined for different types of inputs, e.g. `exp` can also act elementwise on an `N`-dimensional `Array{<:Real,N}`. +To specify that a transformation should act elementwise, we use the [`elementwise`](@ref) method: ```@repl usage x = ones(2, 2) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 29f5aa59..252ecc68 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -73,12 +73,12 @@ function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 end -function transform(ib::Inverse{<:CorrBijector}, y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) w = _inv_link_chol_lkj(y) return w' * w end -function logabsdetjac(::Inverse{<:CorrBijector}, y::AbstractMatrix{<:Real}) +function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) K = LinearAlgebra.checksquare(y) result = float(zero(eltype(y))) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index bbff828e..6ff514fe 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -36,9 +36,7 @@ function Functors.functor(::Type{<:NamedTransform{names}}, x) where names end # TODO: Use recursion instead of `@generated`? -@generated function inverse(b::NamedTransform{names}) where {names} - return :(NamedTransform(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) -end +inverse(t::NamedTransform) = NamedTransform(map(inverse, t.bs)) @generated function transform( b::NamedTransform{names1}, diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl index f3af8f52..2a43a661 100644 --- a/src/bijectors/ordered.jl +++ b/src/bijectors/ordered.jl @@ -47,7 +47,7 @@ function _transform_ordered(y::AbstractMatrix) return x end -transform(ib::Inverse{<:OrderedBijector}, x::AbstractVecOrMat) = _transform_inverse_ordered(x) +transform(ib::Inverse{OrderedBijector}, x::AbstractVecOrMat) = _transform_inverse_ordered(x) function _transform_inverse_ordered(x::AbstractVector) y = similar(x) @assert !isempty(y) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index deac1bd6..5b57f55b 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -14,7 +14,7 @@ function pd_link(X) end lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) -function transform(ib::Inverse{<:PDBijector}, Y::AbstractMatrix{<:Real}) +function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) return getpd(X) end diff --git a/src/interface.jl b/src/interface.jl index 5195275d..2e368eb1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -92,7 +92,7 @@ Transform `x` using `b`, storing the result in `y`. If `y` is not provided, `x` is used as the output. """ transform!(b, x) = transform!(b, x, x) -transform!(b, x, y) = (y .= transform(b, x)) +transform!(b, x, y) = copyto!(y, transform(b, x)) """ logabsdetjac(b, x) From de849ee5e5f4b3a87d7bba2eb7174adec27cb3bc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 1 Feb 2023 11:20:33 +0000 Subject: [PATCH 95/98] remove usage of invertible, etc. and use InverseFunctions.NoInverse instead --- docs/src/transforms.md | 3 --- src/Bijectors.jl | 2 ++ src/bijectors/composed.jl | 2 +- src/bijectors/exp_log.jl | 6 ----- src/bijectors/named_bijector.jl | 8 +++++-- src/bijectors/stacked.jl | 2 +- src/interface.jl | 41 ++++----------------------------- 7 files changed, 14 insertions(+), 50 deletions(-) diff --git a/docs/src/transforms.md b/docs/src/transforms.md index 50ccfbea..cf9223aa 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -78,9 +78,6 @@ Bijectors.transformed(d::Distribution, b::Bijector) Bijectors.elementwise Bijectors.isinvertible Bijectors.isclosedform(t::Bijectors.Transform) -Bijectors.invertible -Bijectors.NotInvertible -Bijectors.Invertible ``` ## API diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 033acbff..42ba29e7 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -35,6 +35,8 @@ using MappedArrays using Base.Iterators: drop using LinearAlgebra: AbstractTriangular +using InverseFunctions: InverseFunctions + import ChangesOfVariables: with_logabsdet_jacobian import InverseFunctions: inverse diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 79ffe9c7..ee103a0e 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -1,4 +1,4 @@ -invertible(cb::ComposedFunction) = invertible(cb.inner) + invertible(cb.outer) +isinvertible(cb::ComposedFunction) = isinvertible(cb.inner) && isinvertible(cb.outer) isclosedform(cb::ComposedFunction) = isclosedform(cb.inner) && isclosedform(cb.outer) transform(cb::ComposedFunction, x) = transform(cb.outer, transform(cb.inner, x)) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 899bde31..7236a74e 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -1,9 +1,3 @@ -invertible(::typeof(exp)) = Invertible() -invertible(::Elementwise{typeof(exp)}) = Invertible() - -invertible(::typeof(log)) = Invertible() -invertible(::Elementwise{typeof(log)}) = Invertible() - transform!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) logabsdetjac(b::typeof(exp), x::Real) = x diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 6ff514fe..d147afc6 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,4 +1,5 @@ abstract type AbstractNamedTransform <: Transform end +abstract type AbstractNamedBijector <: Transform end ####################### ### `NamedTransform` ### @@ -37,6 +38,9 @@ end # TODO: Use recursion instead of `@generated`? inverse(t::NamedTransform) = NamedTransform(map(inverse, t.bs)) +# NOTE: Need explicit definition, since `inverse(::NamedTransform)` will +# end up wrapping a potential `NoInverse` in `NamedTransform`. +isinvertible(t::NamedTransform) = all(isinvertible, t.bs) @generated function transform( b::NamedTransform{names1}, @@ -111,7 +115,7 @@ julia> (a = x.a, b = (x.a + x.c) * x.b, c = x.c) (a = 1.0, b = 8.0, c = 3.0) ``` """ -struct NamedCoupling{target, deps, F} <: AbstractNamedTransform where {F, target} +struct NamedCoupling{target, deps, F} <: AbstractNamedBijector where {F, target} f::F end @@ -120,7 +124,7 @@ function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target, deps, F} return NamedCoupling{target, deps, F}(f) end -invertible(::NamedCoupling) = Invertible() +isinvertible(::NamedCoupling) = true coupling(b::NamedCoupling) = b.f # For some reason trying to use the parameteric types doesn't always work diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 470c0aa6..eec82eab 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -45,7 +45,7 @@ end isclosedform(b::Stacked) = all(isclosedform, b.bs) -invertible(b::Stacked) = sum(map(invertible, b.bs)) +isinvertible(b::Stacked) = all(isinvertible, b.bs) stack(bs...) = Stacked(bs) diff --git a/src/interface.jl b/src/interface.jl index 2e368eb1..5bd41132 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -56,7 +56,6 @@ A subtype of `Transform` of should at least implement [`transform(b, x)`](@ref). If the `Transform` is also invertible: - Required: - - [`invertible`](@ref): should return instance of [`Invertible`](@ref) or [`NotInvertible`](@ref). - _Either_ of the following: - `transform(::Inverse{<:MyTransform}, x)`: the `transform` for its inverse. - `InverseFunctions.inverse(b::MyTransform)`: returns an existing `Transform`. @@ -140,40 +139,12 @@ requires an iterative procedure to evaluate. """ isclosedform(t::Transform) = true -# Invertibility "trait". -""" - Invertible - -Represents the trait of being, well, non-invertible. -""" -struct NotInvertible end -""" - Invertible - -Represents the trait of being, well, invertible. -""" -struct Invertible end - -# Useful for checking if compositions, etc. are invertible or not. -Base.:+(::NotInvertible, ::Invertible) = NotInvertible() -Base.:+(::Invertible, ::NotInvertible) = NotInvertible() -Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() -Base.:+(::Invertible, ::Invertible) = Invertible() - -""" - invertible(t) - -Return `Invertible()` if `t` is invertible, and `NotInvertible()` otherwise. -""" -invertible(::Transform) = NotInvertible() -invertible(f::Elementwise) = invertible(f.x) - """ isinvertible(t) Return `true` if `t` is invertible, and `false` otherwise. """ -isinvertible(t::Transform) = invertible(t) isa Invertible +isinvertible(t) = inverse(t) !== InverseFunctions.NoInverse() """ inverse(b::Transform) @@ -196,21 +167,19 @@ end Functors.@functor Inverse """ - inverse(t::Transform[, ::Invertible]) + inverse(t::Transform) Returns the inverse of transform `t`. """ inverse(t::Transform) = Inverse(t) inverse(ib::Inverse) = ib.orig -invertible(ib::Inverse) = Invertible() - Base.:(==)(b1::Inverse, b2::Inverse) = b1.orig == b2.orig "Abstract type of a bijector, i.e. differentiable bijection with differentiable inverse." abstract type Bijector <: Transform end -invertible(::Bijector) = Invertible() +isinvertible(::Bijector) = true # Default implementation for inverse of a `Bijector`. logabsdetjac(ib::Inverse{<:Transform}, y) = -logabsdetjac(ib.orig, transform(ib, y)) @@ -221,7 +190,7 @@ function with_logabsdet_jacobian(ib::Inverse{<:Transform}, y) end """ - logabsdetjacinv(b::Bijector, y) + logabsdetjacinv(b, y) Just an alias for `logabsdetjac(inverse(b), y)`. """ @@ -232,8 +201,6 @@ logabsdetjacinv(b, y) = logabsdetjac(inverse(b), y) ############################## Identity() = identity -invertible(::typeof(identity)) = Invertible() - # Here we don't need to separate between batched version and non-batched, and so # we can just overload `transform`, etc. directly. transform(::typeof(identity), x) = copy(x) From 2ed35b3de53b85903fcf5dbeebbd89072e8eefdf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 1 Feb 2023 11:22:46 +0000 Subject: [PATCH 96/98] specialze transform on Function --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index 5bd41132..5487d01a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -80,7 +80,7 @@ Broadcast.broadcastable(b::Transform) = Ref(b) Transform `x` using `b`, treating `x` as a single input. """ -transform(f::Function, x) = f(x) +transform(f::F, x) where {F<:Function} = f(x) transform(t::Transform, x) = first(with_logabsdet_jacobian(t, x)) """ From 6ba5a0b98fec53dad3fbe46c4e7bd47feb972ae6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 1 Feb 2023 11:49:42 +0000 Subject: [PATCH 97/98] removed unnecessary show and deprecation warnings --- src/Bijectors.jl | 5 ----- src/bijectors/shift.jl | 2 -- 2 files changed, 7 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 42ba29e7..0bb967fb 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -267,11 +267,6 @@ include("utils.jl") include("interface.jl") include("chainrules.jl") -Base.@deprecate NamedBijector(bs) NamedTransform(bs) - -Base.@deprecate Exp() elementwise(exp) false -Base.@deprecate Log() elementwise(log) false - # Broadcasting here breaks Tracker for some reason maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 22aed9d2..908815a6 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -9,8 +9,6 @@ Base.:(==)(b1::Shift, b2::Shift) = b1.a == b2.a Functors.@functor Shift -Base.show(io::IO, b::Shift) = print(io, "Shift($(b.a))") - inverse(b::Shift) = Shift(-b.a) transform(b::Shift, x) = b.a .+ x From fb979a7437248cc0df6bff00f555c4781486bd48 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 1 Feb 2023 12:18:54 +0000 Subject: [PATCH 98/98] remove references to Log and Exp --- test/bijectors/named_bijector.jl | 2 +- test/interface.jl | 11 ----------- test/runtests.jl | 2 +- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/test/bijectors/named_bijector.jl b/test/bijectors/named_bijector.jl index 8cf57a95..fbdce0ec 100644 --- a/test/bijectors/named_bijector.jl +++ b/test/bijectors/named_bijector.jl @@ -1,6 +1,6 @@ using Test using Bijectors -using Bijectors: Exp, Log, Logit, AbstractNamedTransform, NamedTransform, NamedCoupling, Shift +using Bijectors: Logit, AbstractNamedTransform, NamedTransform, NamedCoupling, Shift @testset "NamedTransform" begin b = NamedTransform((a = elementwise(exp), b = elementwise(log))) diff --git a/test/interface.jl b/test/interface.jl index 19c4b1b9..e975e486 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -531,14 +531,3 @@ end InverseFunctions.test_inverse(b, x) ChangesOfVariables.test_with_logabsdet_jacobian(b, x, (f::Bijectors.Scale, x) -> f.a) end - - -@testset "deprecations" begin - b = Bijectors.Logit(0.0, 1.0) - x = 0.3 - - @test @test_deprecated(Bijectors.Exp()) == elementwise(exp) - @test @test_deprecated(Bijectors.Log()) == elementwise(log) - - @test @test_deprecated(Bijectors.NamedBijector((x = b, ))) == Bijectors.NamedBijector((x = b, )) -end diff --git a/test/runtests.jl b/test/runtests.jl index e3f1323f..7fcd6dfc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,7 +13,7 @@ using Zygote using Random, LinearAlgebra, Test -using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, +using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector using ChangesOfVariables: ChangesOfVariables