Skip to content

Rewrite: removing dimensionality and allow non-bijective transformations #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
1d9a1a3
renamed rv to result in forward
torfjelde Jun 5, 2021
0717e3e
added abstrac type Transform and removed dimensionality from Bijector
torfjelde Jun 5, 2021
81a2ed6
updated Composed to new interface
torfjelde Jun 5, 2021
251ab9c
updated Exp and Log to new interface
torfjelde Jun 5, 2021
f1ef968
updated Logit to new interface
torfjelde Jun 5, 2021
45ff364
removed something that shouldnt be there
torfjelde Jun 5, 2021
eb94e00
removed false statement in docstring of Transform
torfjelde Jun 5, 2021
0d8783f
fixed a typo in implementation of logabsdetjac_batch
torfjelde Jun 5, 2021
8f9988e
added types for representing batches
torfjelde Jun 5, 2021
9fa37d1
make it possible to use broadcasting for working with batches
torfjelde Jun 5, 2021
168dd43
updated SimplexBijector to new interface, I think
torfjelde Jun 5, 2021
d44cf42
updated PDBijector to new interface
torfjelde Jun 5, 2021
c719b07
use transform_batch rather than broadcasting
torfjelde Jun 5, 2021
0962f06
Merge branch 'master' into tor/rewrite
torfjelde Jun 5, 2021
0a62e96
added default implementations for batches
torfjelde Jun 5, 2021
21a66ab
updated ADBijector to new interface
torfjelde Jun 5, 2021
0b04d14
updated CorrBijector to new interface
torfjelde Jun 5, 2021
2cfd24b
updated Coupling to new interface
torfjelde Jun 5, 2021
c272cd4
updated LeakyReLU to new interface
torfjelde Jun 5, 2021
5e2a585
updated NamedBijector to new interface
torfjelde Jun 5, 2021
8e41a50
updated BatchNormalisation to new interface
torfjelde Jun 5, 2021
9f45b16
updated Permute to new interface
torfjelde Jun 5, 2021
f793dad
updated PlanarLayer to new interface
torfjelde Jun 5, 2021
19d1ef1
updated RadialLayer to new interface
torfjelde Jun 5, 2021
5015551
updated RationalQuadraticSpline to new interface
torfjelde Jun 5, 2021
3b75526
updated Scale to new interface
torfjelde Jun 5, 2021
195a107
updated Shift to new interface
torfjelde Jun 5, 2021
f56ed6a
updated Stacked to new interface
torfjelde Jun 5, 2021
72d68b8
updated TruncatedBijector to new interface
torfjelde Jun 5, 2021
214aa92
added ConstructionBase as dependency
torfjelde Jun 5, 2021
836e152
fixed a bunch of small typos and errors from previous commits
torfjelde Jun 5, 2021
ff6b756
forgot to wrap some in Batch
torfjelde Jun 5, 2021
989aaa8
allow inverses of non-bijectors
torfjelde Jun 6, 2021
4d23882
relax definition of VectorBatch so Vector{<:Real} is covered
torfjelde Jun 6, 2021
0f9d334
just perform invertibility check in Inverse rather than inv
torfjelde Jun 6, 2021
af0b24b
moved some code arround
torfjelde Jun 6, 2021
0777fab
added docstrings and default impls for mutating batched methods
torfjelde Jun 6, 2021
42839c3
add elementype to VectorBatch
torfjelde Jun 6, 2021
2ce74d4
simplify Shift bijector
torfjelde Jun 6, 2021
926ef27
added rrules for logabsdetjac_shift
torfjelde Jun 6, 2021
c3745d7
use type-stable implementation of eachslice
torfjelde Jun 6, 2021
1600986
initial work on adding proper testing
torfjelde Jun 6, 2021
2f4d328
make Batch compatible with Zygote
torfjelde Jun 6, 2021
6da2498
Merge branch 'master' into tor/rewrite
torfjelde Aug 1, 2021
e5439f5
updated OrderedBijector
torfjelde Aug 1, 2021
dbf06d9
Merge branch 'master' into tor/rewrite
torfjelde Dec 25, 2021
5681358
temporary stuff
torfjelde Jan 24, 2022
306aa66
added docs
torfjelde Jan 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ version = "0.10.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Batching = "a4738bf4-2ff1-4f03-87a2-28fa6f9d5d14"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
15 changes: 15 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Documenter
using Bijectors

makedocs(
sitename = "Bijectors",
format = Documenter.HTML(),
modules = [Bijectors]
)

# Documenter can also automatically deploy documentation to gh-pages.
# See "Hosting Documentation" and deploydocs() in the Documenter manual
# for more information.
#=deploydocs(
repo = "<repository url>"
)=#
144 changes: 144 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Bijectors.jl
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I'm curious what you think of this design:) You can ignore the code for now, the main thing is just whether or not you're happy with the idea of doing

forward(b, x) = forward_single(b, x)

# Non-batched version
forward_single(b, x) = ...

# Batched version
forward_multiple(b, x) = ...

We can then introduce a Batch type which generalizes ColVecs etc., and do

forward(b, x::AbstractBatch) = forward_multiple(b, x)

My plan is to split this into two PRs:

  1. Remove all "official" support for batched computation, thus always assuming that the input given represents a single input (some bijectors might still support it, but there's not going to be a "official" support for it). In this PR there is no forward_single, etc., just forward.
  2. Add support for batching and overload forward_single instead of forward.

But ideally we'd fully adopt the ChangesOfVariables.jl interface, i.e. replace forward with with_logabsdet_jacobian. The issue here is that we of course can no longer do

with_logabsdet_jacobian(b, x) = with_logabsdet_jacobian_single(b, x)

etc. which makes me want to keep forward or ChangesOfVariables decides to take on a similar interface and encourage people to instead implement with_logabsdet_jacobian_single.

Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be a huge improvement if batching is declared more explicitly. Is there a specific reason why one needs forward_single and forward_multiple instead of just forward(::MyBijector, x) and forward(::MyBijector, x::AbstractBatch)? More functions means more entry points and hence more possible confusion for developers. If there's no dedicated _single function anymore, it would also work better with the CoV API, I assume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is method ambiguity 😕 If only forward or with_logabsdet_jacobian is the entry-point, then we cannot provide sane defaults, e.g. forward(b, x::AbstractBatch) without every other implementation of forward being very explicit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is annoying - but there are multiple different options I think. One would be e.g. to implement (possibly different) default implementations but not as forward(...) but some forward_batch_style1(..) (whatever) and then to "activate" it by defining forward(b::MyBijector, x::AbstractBatch) = forward_batch_style1(b, x) etc. I.e., batch support would be enabled manually but without much effort. And one could reduce the amount of code even more with some helper macro.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that works of course, but I think this will be easy to forget vs. always just implemented *_single and batching is guaranteed to at least work.

Both approaches feel sub-optimal 😕


Documentation for Bijectors.jl

## Usage

A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function:
- The inverse of `exp` is `log`.
- The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`.

```@repl usage
using Bijectors
transform(exp, 1.0)
logabsdetjac(exp, 1.0)
forward(exp, 1.0)
```

If you want to instead transform a collection of inputs, you can use the `batch` method from Batching.jl to inform Bijectors.jl that the input now represents a collection of inputs rather than a single input:

```@repl usage
xs = batch(ones(2));
transform(exp, xs)
logabsdetjac(exp, xs)
forward(exp, xs)
```

Some transformations is well-defined for different types of inputs, e.g. `exp` can also act elementwise on a `N`-dimensional `Array{<:Real,N}`. To specify that a transformation should be acting elementwise, we use the [`elementwise`](@ref) method:

```@repl usage
x = ones(2, 2)
transform(elementwise(exp), x)
logabsdetjac(elementwise(exp), x)
forward(elementwise(exp), x)
```

And batched versions:

```@repl usage
xs = batch(ones(2, 2, 3));
transform(elementwise(exp), xs)
logabsdetjac(elementwise(exp), xs)
forward(elementwise(exp), xs)
```

These methods also work nicely for compositions of transformations:

```@repl usage
transform(elementwise(log ∘ exp), xs)
```

Unlike `exp`, some transformations have parameters affecting the resulting transformation they represent, e.g. `Logit` has two parameters `a` and `b` representing the lower- and upper-bound, respectively, of its domain:

```@repl usage
using Bijectors: Logit

f = Logit(0.0, 1.0)
f(rand()) # takes us from `(0, 1)` to `(-∞, ∞)`
```

## User-facing methods

Without mutation:

```@docs
transform
logabsdetjac
forward(b, x)
```

With mutation:

```@docs
transform!
logabsdetjac!
forward!
```

## Implementing a transformation

Any callable can be made into a bijector by providing an implementation of [`forward(b, x)`](@ref), which is done by overloading

```@docs
Bijectors.forward_single
```

where

```@docs
Bijectors.transform_single
Bijectors.logabsdetjac_single
```

You can then optionally implement `transform` and `logabsdetjac` to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other.

Note that a _user_ of the bijector should generally be using [`forward(b, x)`](@ref) rather than calling [`forward_single`](@ref) directly.

To implement "batched" versions of the above functionalities, i.e. methods which act on a _collection_ of inputs rather than a single input, you can overload the following method:

```@docs
Bijectors.forward_multiple
```

And similarly, if you want to specialize [`transform`](@ref) and [`logabsdetjac`](@ref), you can implement

```@docs
Bijectors.transform_multiple
Bijectors.logabsdetjac_multiple
```

### Mutability

There are also _mutable_ versions of all of the above:

```@docs
Bijectors.forward_single!
Bijectors.forward_multiple!
Bijectors.transform_single!
Bijectors.transform_multiple!
Bijectors.logabsdetjac_single!
Bijectors.logabsdetjac_multiple!
```

## Working with Distributions.jl

```@docs
Bijectors.bijector
Bijectors.transformed(d::Distribution, b::Bijector)
```

## Utilities

```@docs
Bijectors.elementwise
Bijectors.isinvertible
Bijectors.isclosedform(t::Bijectors.Transform)
```

## API

```@docs
Bijectors.Transform
Bijectors.Bijector
Bijectors.Inverse
```
56 changes: 33 additions & 23 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ using Reexport, Requires
@reexport using Distributions
using LinearAlgebra
using MappedArrays
using ConstructionBase
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

using Batching
export batch

import ChangesOfVariables: with_logabsdet_jacobian
import InverseFunctions: inverse

Expand All @@ -54,10 +58,13 @@ export TransformDistribution,
logpdf_with_trans,
isclosedform,
transform,
transform!,
with_logabsdet_jacobian,
inverse,
forward,
forward!,
logabsdetjac,
logabsdetjac!,
logabsdetjacinv,
Bijector,
ADBijector,
Expand All @@ -76,7 +83,9 @@ export TransformDistribution,
PlanarLayer,
RadialLayer,
CouplingLayer,
InvertibleBatchNorm
InvertibleBatchNorm,
Elementwise,
elementwise

if VERSION < v"1.1"
using Compat: eachcol
Expand Down Expand Up @@ -251,12 +260,13 @@ function getlogp(d::InverseWishart, Xcf, X)
end

include("utils.jl")
# include("batch.jl")
include("interface.jl")
include("chainrules.jl")
# include("chainrules.jl")

Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x))
Base.@deprecate forward(b::Transform, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x))

@noinline function Base.inv(b::AbstractBijector)
@noinline function Base.inv(b::Transform)
Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv)
inverse(b)
end
Expand All @@ -265,24 +275,24 @@ end
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
maporbroadcast(f, x::AbstractArray...) = f.(x...)

# optional dependencies
function __init__()
@require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...)
return copy(f.(x1, x...))
end
function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...)
return copy(f.(x1, x2, x...))
end
function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...)
return copy(f.(x1, x2, x3, x...))
end
end
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl")
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl")
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl")
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl")
@require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl")
end
# # optional dependencies
# function __init__()
# @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
# function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...)
# return copy(f.(x1, x...))
# end
# function maporbroadcast(f, x1, x2::LazyArrays.BroadcastArray, x...)
# return copy(f.(x1, x2, x...))
# end
# function maporbroadcast(f, x1, x2, x3::LazyArrays.BroadcastArray, x...)
# return copy(f.(x1, x2, x3, x...))
# end
# end
# @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl")
# @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl")
# @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl")
# @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl")
# @require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl")
# end

end # module
88 changes: 88 additions & 0 deletions src/batch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
abstract type AbstractBatch{T} <: AbstractVector{T} end

"""
value(x)

Returns the underlying storage used for the entire batch.

If `x` is not `AbstractBatch`, then this is the identity function.
"""
value(x) = x

struct Batch{V, T} <: AbstractBatch{T}
value::V
end

value(x::Batch) = x.value

# Convenient aliases
const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}}
const VectorBatch{T} = Batch{<:AbstractVector{T}}

# Constructor for `ArrayBatch`.
Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x)
function Batch(x::AbstractArray{<:Real})
V = typeof(x)
# HACK: This assumes the batch is non-empty.
T = typeof(getindex_for_last(x, 1))
return Batch{V, T}(x)
end

# Constructor for `VectorBatch`.
Batch(x::AbstractVector{<:AbstractArray}) = Batch{typeof(x), eltype(x)}(x)

# `AbstractVector` interface.
Base.size(batch::Batch{<:AbstractArray{<:Any, N}}) where {N} = (size(value(batch), N), )

# Since impl inherited from `AbstractVector` doesn't do exactly what we want.
Base.similar(b::AbstractBatch) = reconstruct(b, similar(value(b)))

# For `VectorBatch`
Base.getindex(batch::VectorBatch, i::Int) = value(batch)[i]
Base.getindex(batch::VectorBatch, i::CartesianIndex{1}) = value(batch)[i]
Base.getindex(batch::VectorBatch, i) = Batch(value(batch)[i])
function Base.setindex!(batch::VectorBatch, v, i)
# `v` can also be a `Batch`.
Base.setindex!(value(batch), value(v), i)
return batch
end

# For `ArrayBatch`
@generated function getindex_for_last(x::AbstractArray{<:Any, N}, inds) where {N}
e = Expr(:call)
push!(e.args, :(Base.view))
push!(e.args, :x)

for i = 1:N - 1
push!(e.args, :(:))
end

push!(e.args, :(inds))

return e
end

@generated function setindex_for_last!(out::AbstractArray{<:Any, N}, x, inds) where {N}
e = Expr(:call)
push!(e.args, :(Base.setindex!))
push!(e.args, :out)
push!(e.args, :x)

for i = 1:N - 1
push!(e.args, :(:))
end

push!(e.args, :(inds))

return e
end

# General arrays.
Base.getindex(batch::ArrayBatch, i::Int) = getindex_for_last(value(batch), i)
Base.getindex(batch::ArrayBatch, i::CartesianIndex{1}) = getindex_for_last(value(batch), i)
Base.getindex(batch::ArrayBatch, i) = Batch(getindex_for_last(value(batch), i))
function Base.setindex!(batch::ArrayBatch, v, i)
# `v` can also be a `Batch`.
setindex_for_last!(value(batch), value(v), i)
return batch
end
4 changes: 2 additions & 2 deletions src/bijectors/adbijector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Abstract type for a `Bijector{N}` making use of auto-differentation (AD) to
implement `jacobian` and, by impliciation, `logabsdetjac`.
"""
abstract type ADBijector{AD, N} <: Bijector{N} end
abstract type ADBijector{AD} <: Bijector end

struct SingularJacobianException{B<:Bijector} <: Exception
b::B
Expand All @@ -24,4 +24,4 @@ end
function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real})
fact = lu(jacobian(b, x), check=false)
return issuccess(fact) ? logabsdet(fact)[1] : throw(SingularJacobianException(b))
end
end
Loading