-
Notifications
You must be signed in to change notification settings - Fork 38
Rewrite: removing dimensionality and allow non-bijective transformations #183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
1d9a1a3
renamed rv to result in forward
torfjelde 0717e3e
added abstrac type Transform and removed dimensionality from Bijector
torfjelde 81a2ed6
updated Composed to new interface
torfjelde 251ab9c
updated Exp and Log to new interface
torfjelde f1ef968
updated Logit to new interface
torfjelde 45ff364
removed something that shouldnt be there
torfjelde eb94e00
removed false statement in docstring of Transform
torfjelde 0d8783f
fixed a typo in implementation of logabsdetjac_batch
torfjelde 8f9988e
added types for representing batches
torfjelde 9fa37d1
make it possible to use broadcasting for working with batches
torfjelde 168dd43
updated SimplexBijector to new interface, I think
torfjelde d44cf42
updated PDBijector to new interface
torfjelde c719b07
use transform_batch rather than broadcasting
torfjelde 0962f06
Merge branch 'master' into tor/rewrite
torfjelde 0a62e96
added default implementations for batches
torfjelde 21a66ab
updated ADBijector to new interface
torfjelde 0b04d14
updated CorrBijector to new interface
torfjelde 2cfd24b
updated Coupling to new interface
torfjelde c272cd4
updated LeakyReLU to new interface
torfjelde 5e2a585
updated NamedBijector to new interface
torfjelde 8e41a50
updated BatchNormalisation to new interface
torfjelde 9f45b16
updated Permute to new interface
torfjelde f793dad
updated PlanarLayer to new interface
torfjelde 19d1ef1
updated RadialLayer to new interface
torfjelde 5015551
updated RationalQuadraticSpline to new interface
torfjelde 3b75526
updated Scale to new interface
torfjelde 195a107
updated Shift to new interface
torfjelde f56ed6a
updated Stacked to new interface
torfjelde 72d68b8
updated TruncatedBijector to new interface
torfjelde 214aa92
added ConstructionBase as dependency
torfjelde 836e152
fixed a bunch of small typos and errors from previous commits
torfjelde ff6b756
forgot to wrap some in Batch
torfjelde 989aaa8
allow inverses of non-bijectors
torfjelde 4d23882
relax definition of VectorBatch so Vector{<:Real} is covered
torfjelde 0f9d334
just perform invertibility check in Inverse rather than inv
torfjelde af0b24b
moved some code arround
torfjelde 0777fab
added docstrings and default impls for mutating batched methods
torfjelde 42839c3
add elementype to VectorBatch
torfjelde 2ce74d4
simplify Shift bijector
torfjelde 926ef27
added rrules for logabsdetjac_shift
torfjelde c3745d7
use type-stable implementation of eachslice
torfjelde 1600986
initial work on adding proper testing
torfjelde 2f4d328
make Batch compatible with Zygote
torfjelde 6da2498
Merge branch 'master' into tor/rewrite
torfjelde e5439f5
updated OrderedBijector
torfjelde dbf06d9
Merge branch 'master' into tor/rewrite
torfjelde 5681358
temporary stuff
torfjelde 306aa66
added docs
torfjelde File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[deps] | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
using Documenter | ||
using Bijectors | ||
|
||
makedocs( | ||
sitename = "Bijectors", | ||
format = Documenter.HTML(), | ||
modules = [Bijectors] | ||
) | ||
|
||
# Documenter can also automatically deploy documentation to gh-pages. | ||
# See "Hosting Documentation" and deploydocs() in the Documenter manual | ||
# for more information. | ||
#=deploydocs( | ||
repo = "<repository url>" | ||
)=# |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# Bijectors.jl | ||
|
||
Documentation for Bijectors.jl | ||
|
||
## Usage | ||
|
||
A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function: | ||
- The inverse of `exp` is `log`. | ||
- The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`. | ||
|
||
```@repl usage | ||
using Bijectors | ||
transform(exp, 1.0) | ||
logabsdetjac(exp, 1.0) | ||
forward(exp, 1.0) | ||
``` | ||
|
||
If you want to instead transform a collection of inputs, you can use the `batch` method from Batching.jl to inform Bijectors.jl that the input now represents a collection of inputs rather than a single input: | ||
|
||
```@repl usage | ||
xs = batch(ones(2)); | ||
transform(exp, xs) | ||
logabsdetjac(exp, xs) | ||
forward(exp, xs) | ||
``` | ||
|
||
Some transformations is well-defined for different types of inputs, e.g. `exp` can also act elementwise on a `N`-dimensional `Array{<:Real,N}`. To specify that a transformation should be acting elementwise, we use the [`elementwise`](@ref) method: | ||
|
||
```@repl usage | ||
x = ones(2, 2) | ||
transform(elementwise(exp), x) | ||
logabsdetjac(elementwise(exp), x) | ||
forward(elementwise(exp), x) | ||
``` | ||
|
||
And batched versions: | ||
|
||
```@repl usage | ||
xs = batch(ones(2, 2, 3)); | ||
transform(elementwise(exp), xs) | ||
logabsdetjac(elementwise(exp), xs) | ||
forward(elementwise(exp), xs) | ||
``` | ||
|
||
These methods also work nicely for compositions of transformations: | ||
|
||
```@repl usage | ||
transform(elementwise(log ∘ exp), xs) | ||
``` | ||
|
||
Unlike `exp`, some transformations have parameters affecting the resulting transformation they represent, e.g. `Logit` has two parameters `a` and `b` representing the lower- and upper-bound, respectively, of its domain: | ||
|
||
```@repl usage | ||
using Bijectors: Logit | ||
|
||
f = Logit(0.0, 1.0) | ||
f(rand()) # takes us from `(0, 1)` to `(-∞, ∞)` | ||
``` | ||
|
||
## User-facing methods | ||
|
||
Without mutation: | ||
|
||
```@docs | ||
transform | ||
logabsdetjac | ||
forward(b, x) | ||
``` | ||
|
||
With mutation: | ||
|
||
```@docs | ||
transform! | ||
logabsdetjac! | ||
forward! | ||
``` | ||
|
||
## Implementing a transformation | ||
|
||
Any callable can be made into a bijector by providing an implementation of [`forward(b, x)`](@ref), which is done by overloading | ||
|
||
```@docs | ||
Bijectors.forward_single | ||
``` | ||
|
||
where | ||
|
||
```@docs | ||
Bijectors.transform_single | ||
Bijectors.logabsdetjac_single | ||
``` | ||
|
||
You can then optionally implement `transform` and `logabsdetjac` to avoid redundant computations. This is usually only worth it if you expect `transform` or `logabsdetjac` to be used heavily without the other. | ||
|
||
Note that a _user_ of the bijector should generally be using [`forward(b, x)`](@ref) rather than calling [`forward_single`](@ref) directly. | ||
|
||
To implement "batched" versions of the above functionalities, i.e. methods which act on a _collection_ of inputs rather than a single input, you can overload the following method: | ||
|
||
```@docs | ||
Bijectors.forward_multiple | ||
``` | ||
|
||
And similarly, if you want to specialize [`transform`](@ref) and [`logabsdetjac`](@ref), you can implement | ||
|
||
```@docs | ||
Bijectors.transform_multiple | ||
Bijectors.logabsdetjac_multiple | ||
``` | ||
|
||
### Mutability | ||
|
||
There are also _mutable_ versions of all of the above: | ||
|
||
```@docs | ||
Bijectors.forward_single! | ||
Bijectors.forward_multiple! | ||
Bijectors.transform_single! | ||
Bijectors.transform_multiple! | ||
Bijectors.logabsdetjac_single! | ||
Bijectors.logabsdetjac_multiple! | ||
``` | ||
|
||
## Working with Distributions.jl | ||
|
||
```@docs | ||
Bijectors.bijector | ||
Bijectors.transformed(d::Distribution, b::Bijector) | ||
``` | ||
|
||
## Utilities | ||
|
||
```@docs | ||
Bijectors.elementwise | ||
Bijectors.isinvertible | ||
Bijectors.isclosedform(t::Bijectors.Transform) | ||
``` | ||
|
||
## API | ||
|
||
```@docs | ||
Bijectors.Transform | ||
Bijectors.Bijector | ||
Bijectors.Inverse | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
abstract type AbstractBatch{T} <: AbstractVector{T} end | ||
|
||
""" | ||
value(x) | ||
|
||
Returns the underlying storage used for the entire batch. | ||
|
||
If `x` is not `AbstractBatch`, then this is the identity function. | ||
""" | ||
value(x) = x | ||
|
||
struct Batch{V, T} <: AbstractBatch{T} | ||
value::V | ||
end | ||
|
||
value(x::Batch) = x.value | ||
|
||
# Convenient aliases | ||
const ArrayBatch{N} = Batch{<:AbstractArray{<:Real, N}} | ||
const VectorBatch{T} = Batch{<:AbstractVector{T}} | ||
|
||
# Constructor for `ArrayBatch`. | ||
Batch(x::AbstractVector{<:Real}) = Batch{typeof(x), eltype(x)}(x) | ||
function Batch(x::AbstractArray{<:Real}) | ||
V = typeof(x) | ||
# HACK: This assumes the batch is non-empty. | ||
T = typeof(getindex_for_last(x, 1)) | ||
return Batch{V, T}(x) | ||
end | ||
|
||
# Constructor for `VectorBatch`. | ||
Batch(x::AbstractVector{<:AbstractArray}) = Batch{typeof(x), eltype(x)}(x) | ||
|
||
# `AbstractVector` interface. | ||
Base.size(batch::Batch{<:AbstractArray{<:Any, N}}) where {N} = (size(value(batch), N), ) | ||
|
||
# Since impl inherited from `AbstractVector` doesn't do exactly what we want. | ||
Base.similar(b::AbstractBatch) = reconstruct(b, similar(value(b))) | ||
|
||
# For `VectorBatch` | ||
Base.getindex(batch::VectorBatch, i::Int) = value(batch)[i] | ||
Base.getindex(batch::VectorBatch, i::CartesianIndex{1}) = value(batch)[i] | ||
Base.getindex(batch::VectorBatch, i) = Batch(value(batch)[i]) | ||
function Base.setindex!(batch::VectorBatch, v, i) | ||
# `v` can also be a `Batch`. | ||
Base.setindex!(value(batch), value(v), i) | ||
return batch | ||
end | ||
|
||
# For `ArrayBatch` | ||
@generated function getindex_for_last(x::AbstractArray{<:Any, N}, inds) where {N} | ||
e = Expr(:call) | ||
push!(e.args, :(Base.view)) | ||
push!(e.args, :x) | ||
|
||
for i = 1:N - 1 | ||
push!(e.args, :(:)) | ||
end | ||
|
||
push!(e.args, :(inds)) | ||
|
||
return e | ||
end | ||
|
||
@generated function setindex_for_last!(out::AbstractArray{<:Any, N}, x, inds) where {N} | ||
e = Expr(:call) | ||
push!(e.args, :(Base.setindex!)) | ||
push!(e.args, :out) | ||
push!(e.args, :x) | ||
|
||
for i = 1:N - 1 | ||
push!(e.args, :(:)) | ||
end | ||
|
||
push!(e.args, :(inds)) | ||
|
||
return e | ||
end | ||
|
||
# General arrays. | ||
Base.getindex(batch::ArrayBatch, i::Int) = getindex_for_last(value(batch), i) | ||
Base.getindex(batch::ArrayBatch, i::CartesianIndex{1}) = getindex_for_last(value(batch), i) | ||
Base.getindex(batch::ArrayBatch, i) = Batch(getindex_for_last(value(batch), i)) | ||
function Base.setindex!(batch::ArrayBatch, v, i) | ||
# `v` can also be a `Batch`. | ||
setindex_for_last!(value(batch), value(v), i) | ||
return batch | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@devmotion I'm curious what you think of this design:) You can ignore the code for now, the main thing is just whether or not you're happy with the idea of doing
We can then introduce a
Batch
type which generalizesColVecs
etc., and doMy plan is to split this into two PRs:
forward_single
, etc., justforward
.forward_single
instead offorward
.But ideally we'd fully adopt the ChangesOfVariables.jl interface, i.e. replace
forward
withwith_logabsdet_jacobian
. The issue here is that we of course can no longer doetc. which makes me want to keep
forward
or ChangesOfVariables decides to take on a similar interface and encourage people to instead implementwith_logabsdet_jacobian_single
.Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will be a huge improvement if batching is declared more explicitly. Is there a specific reason why one needs
forward_single
andforward_multiple
instead of justforward(::MyBijector, x)
andforward(::MyBijector, x::AbstractBatch)
? More functions means more entry points and hence more possible confusion for developers. If there's no dedicated_single
function anymore, it would also work better with the CoV API, I assume?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is method ambiguity 😕 If only
forward
orwith_logabsdet_jacobian
is the entry-point, then we cannot provide sane defaults, e.g.forward(b, x::AbstractBatch)
without every other implementation offorward
being very explicit.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is annoying - but there are multiple different options I think. One would be e.g. to implement (possibly different) default implementations but not as
forward(...)
but someforward_batch_style1(..)
(whatever) and then to "activate" it by definingforward(b::MyBijector, x::AbstractBatch) = forward_batch_style1(b, x)
etc. I.e., batch support would be enabled manually but without much effort. And one could reduce the amount of code even more with some helper macro.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that works of course, but I think this will be easy to forget vs. always just implemented
*_single
and batching is guaranteed to at least work.Both approaches feel sub-optimal 😕