Skip to content

Composable scalar transforms #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 64 commits into from
Apr 28, 2025
Merged
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
488c780
Took a stab at this idea. Needs testing
Ickaser Mar 6, 2025
efbb5e4
The crucial step: composite!
Ickaser Mar 6, 2025
09bb188
don't nest composite transforms if possible
Ickaser Mar 6, 2025
8e0d88e
Properly refer to composite operator
Ickaser Mar 14, 2025
ab6e1ff
Rename elemental scalar transformations
Ickaser Mar 14, 2025
d5d244a
Export newly renamed transforms
Ickaser Mar 14, 2025
b1d9bdc
Rename and export Logistic
Ickaser Mar 14, 2025
d5e0e13
Make indexing directly on CompositeScalarTransform pass through to un…
Ickaser Mar 14, 2025
92479ea
get indexing right
Ickaser Mar 14, 2025
30b327e
Make types inferrable for transform_and_logjac
Ickaser Mar 14, 2025
f333507
Logistic, not logit
Ickaser Mar 14, 2025
677579d
Add inverses
Ickaser Mar 14, 2025
17b69fe
Add consistency tests parallel to the existing ones
Ickaser Mar 14, 2025
f813763
Add a TVNeg transformation for negative
Ickaser Mar 19, 2025
1c0d71b
Add test to make sure scalar transformations at least compose
Ickaser Mar 19, 2025
3e492db
More composing in test
Ickaser Mar 19, 2025
45aabf6
Test consistency for a bunch of arbitrary compositions
Ickaser Mar 20, 2025
a133017
Add inverse_and_logjac for new types, consistent transform_and_logjac…
Ickaser Mar 25, 2025
54b601b
Commit to new style, mark a few tests as broken for now
Ickaser Mar 25, 2025
d87fa2a
Provide some default show methods
Ickaser Mar 25, 2025
683d924
Add some documentation for composable scalar transforms
Ickaser Mar 26, 2025
f5d9740
Widen types allowed by scalar inverses to Number, and TVScale to anyt…
Ickaser Mar 26, 2025
04f3a33
Try adding Unitful tests: commented out because not working yet
Ickaser Mar 26, 2025
4107271
Some docs explaining Unitful transform
Ickaser Mar 26, 2025
3c7b1a5
Add tests to cover more constructions of composition
Ickaser Mar 26, 2025
07ef76f
Don't test Jacobian for transforms that add units
Ickaser Mar 28, 2025
74da494
Remove log-Jacobian functionality for Unitful-type transforms
Ickaser Mar 28, 2025
6dff1f9
Fix ill-formed test
Ickaser Mar 28, 2025
d106ef7
Improve test coverage
Ickaser Mar 28, 2025
e3747e3
Improve pretty printing for common transforms to keep printing behavi…
Ickaser Apr 2, 2025
14bba55
Add a non-unicode alias for composition operator
Ickaser Apr 2, 2025
5ff6832
More alias tests in scalar show
Ickaser Apr 2, 2025
67f1eb4
Fix log-Jacobian of inverse of Exp transform; move inverse_and_logjac…
Ickaser Apr 4, 2025
12137f9
important typo in docs
Ickaser Apr 4, 2025
e5bb567
make compose a direct constructor for CompositeScalarTransform, rathe…
Ickaser Apr 4, 2025
4e7a085
Make sure only ScalarTransforms get passed to CompositeScalarTransform
Ickaser Apr 4, 2025
dbeb48a
Move Unitful-related tests to be closer together
Ickaser Apr 8, 2025
a5ff605
Remove unnecessary test
Ickaser Apr 8, 2025
9b2938f
Add some internal docs
Ickaser Apr 8, 2025
5b1f055
Get inverse_and_logjac tested for Identity()
Ickaser Apr 8, 2025
6cf09c1
make test look nicer
Ickaser Apr 8, 2025
2dd2bb4
Remove argcheck on inverse(Exp)
Ickaser Apr 9, 2025
ad49fd0
Remove argcheck on inverse(Logistic)
Ickaser Apr 9, 2025
05754ac
Cleaner typing
Ickaser Apr 9, 2025
e7c231e
Better zeros for inverse_and_logjac
Ickaser Apr 9, 2025
0604ddd
more logjac_zero
Ickaser Apr 9, 2025
593c2f8
Clean up foldl/foldr operation
Ickaser Apr 9, 2025
6ebb9ab
Remove indexing into CompositeTransform
Ickaser Apr 9, 2025
f467e63
Merge branch 'composite-scalar' of https://github.com/Ickaser/Transfo…
Ickaser Apr 9, 2025
de68e4d
Restrict cases where asR+, etc. get string printed
Ickaser Apr 9, 2025
0af7406
Clearly better suggestions
Ickaser Apr 9, 2025
0935932
Use CompositionsBase to provide non-Unicode compose
Ickaser Apr 9, 2025
d7d2398
Nice little elision
Ickaser Apr 9, 2025
bf0142d
Merge branch 'composite-scalar' of https://github.com/Ickaser/Transfo…
Ickaser Apr 9, 2025
a214583
Get scalar show tests passing again, for now
Ickaser Apr 9, 2025
7996d6f
Actually import compose, redo test
Ickaser Apr 9, 2025
f65fe05
Remove ShiftedExp,etc. from new tests
Ickaser Apr 9, 2025
4172d02
Amend asR+ an asI to be individual scalar transforms, not single-elem…
Ickaser Apr 16, 2025
13034d4
Fully remove ShiftedExp and ScaledShiftedLogistic
Ickaser Apr 16, 2025
090b08f
Add another note to docs about extending scaling for custom number types
Ickaser Apr 16, 2025
f0e3d90
Update docs to clarify use of non-Real numbers
Ickaser Apr 23, 2025
7daeb1e
Trim a whitespace
Ickaser Apr 24, 2025
9018fb0
Catch CompositeScalarTransforms inside the Vararg composite method
Ickaser Apr 24, 2025
8d168fe
Merge branch 'composite-scalar' of https://github.com/Ickaser/Transfo…
Ickaser Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ version = "0.8.14"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -23,6 +24,7 @@ InverseFunctionsExt = "InverseFunctions"
[compat]
ArgCheck = "1, 2"
ChangesOfVariables = "0.1"
CompositionsBase = "0.1.2"
DocStringExtensions = "0.8, 0.9"
ForwardDiff = "0.10"
InverseFunctions = "0.1"
35 changes: 35 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -98,6 +98,41 @@ asℝ₋
as𝕀
```

For more granular control than the `as(Real, a, b)`, scalar transformations can be built from individual elements with the composition operator `∘` (typed as `\circ<tab>`):

```@docs
TVExp
TVLogistic
TVScale
TVShift
TVNeg
```

Consistent with common notation, transforms are applied right-to-left; for example, `as(Real, ∞, 3)` is equivalent to `TVShift(3) ∘ TVNeg() ∘ TVExp()`.
If you are working in an editor where typing Unicode is difficult, `TransformVariables.compose` is also available, as in `TransformVariables.compose(TVScale(5.0), TVNeg(), TVExp())`.

This composition works with any scalar transform in any order, so `TVScale(4) ∘ as(Real, 2, ∞) ∘ TVShift(1e3)` is a valid transform.
This is useful especially for making sure that values near 0, when transformed, yield usefully-scaled values for a given variable.

In addition, the `TVScale` transform accepts arbitrary types. It can be used as the outermost transform (so leftmost in the composition) to add, for example, `Unitful` units to a number (or to create other exotic number types which can be constructed by multiplying, such as a `ForwardDiff.Dual`).

However, note that calculating log Jacobian determinants may error for types that are not real numbers.
For example,

```julia
using Unitful
t = TVScale(5u"m") ∘ TVExp()
```
produces positive quantities with the dimension of length.
!!! note
Because the log-Jacobian of a transform that adds units is not defined, `transform_and_logjac` and `inverse_and_logjac`
only have methods defined for `TVScale{T} where {T<:Real}`.
!!! note
The inverse transform of `TVScale(scale)` divides by `scale`, which is the correct inverse for adding units to a number, but may be inappropriate for other custom number types. A transform that doesn't just multiply or an inverse that extracts a float from an exotic number type could be defined by adding methods to `transform` and `inverse` like the following:
```
transform(t::TVScale{T}, x) where T<:MyCustomNumberType = MyCustomNumberType(x)
inverse(t::TVScale{T}, x) where T<:MyCustomNumberType = get_the_float_part(x)```

## Special arrays

```@docs
2 changes: 0 additions & 2 deletions docs/src/internals.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,6 @@ These are not part of the API, use the `as` constructor or one of the predefined

```@docs
TransformVariables.Identity
TransformVariables.ScaledShiftedLogistic
TransformVariables.ShiftedExp
```

### Aggregating transformations
1 change: 1 addition & 0 deletions src/TransformVariables.jl
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ using LogExpFunctions
using LinearAlgebra: UpperTriangular, logabsdet
using Random: AbstractRNG, GLOBAL_RNG
using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst
using CompositionsBase

include("utilities.jl")
include("generic.jl")
198 changes: 124 additions & 74 deletions src/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export TVExp, TVScale, TVShift, TVLogistic, TVNeg
export ∞, asℝ, asℝ₊, asℝ₋, as𝕀, as_real, as_positive_real, as_negative_real,
as_unit_interval

@@ -50,90 +51,131 @@ struct Identity <: ScalarTransform end

transform(::Identity, x::Real) = x

transform_and_logjac(::Identity, x::Real) = x, zero(x)
transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))

inverse(::Identity, x::Real) = x
inverse(::Identity, x::Number) = x

inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))

####
#### shifted exponential
#### elementary scalar transforms
####

"""
$(TYPEDEF)

Shifted exponential. When `D::Bool == true`, maps to `(shift, ∞)` using `x ↦
shift + eˣ`, otherwise to `(-∞, shift)` using `x ↦ shift - eˣ`.
Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive reals.
"""
struct ShiftedExp{D, T <: Real} <: ScalarTransform
shift::T
function ShiftedExp{D,T}(shift::T) where {D, T <: Real}
@argcheck D isa Bool
new(shift)
end
struct TVExp <: ScalarTransform
end
transform(::TVExp, x::Real) = exp(x)
transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x

ShiftedExp(D::Bool, shift::T) where {T <: Real} = ShiftedExp{D,T}(shift)
function inverse(::TVExp, x::Number)
log(x)
end
inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), -log(x)

transform(t::ShiftedExp{D}, x::Real) where D =
D ? t.shift + exp(x) : t.shift - exp(x)
"""
$(TYPEDEF)

transform_and_logjac(t::ShiftedExp, x::Real) = transform(t, x), x
Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1).
"""
struct TVLogistic <: ScalarTransform
end
transform(::TVLogistic, x::Real) = logistic(x)
transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x)

function inverse(t::ShiftedExp{D}, x::Real) where D
(; shift) = t
if D
@argcheck x > shift DomainError
log(x - shift)
else
@argcheck x < shift DomainError
log(shift - x)
end
function inverse(::TVLogistic, x::Number)
logit(x)
end
inverse_and_logjac(t::TVLogistic, x::Number) = inverse(t, x), logit_logjac(x)

####
#### scaled and shifted logistic
####
"""
$(TYPEDEF)

Shift transformation `x ↦ x + shift`.
"""
struct TVShift{T <: Real} <: ScalarTransform
shift::T
end
transform(t::TVShift, x::Real) = x + t.shift
transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse(t::TVShift, x::Number) = x - t.shift
inverse_and_logjac(t::TVShift, x::Number) = inverse(t, x), logjac_zero(LogJac(), typeof(x))

"""
$(TYPEDEF)

Maps to `(scale, shift + scale)` using `logistic(x) * scale + shift`.
Scale transformation `x ↦ scale * x`.
"""
struct ScaledShiftedLogistic{T <: Real} <: ScalarTransform
struct TVScale{T} <: ScalarTransform
scale::T
shift::T
function ScaledShiftedLogistic{T}(scale::T, shift::T) where {T <: Real}
@argcheck scale > 0
new(scale, shift)
function TVScale{T}(scale::T) where {T}
@argcheck scale > zero(scale) DomainError
new(scale)
end
end
TVScale(scale::T) where {T} = TVScale{T}(scale)

ScaledShiftedLogistic(scale::T, shift::T) where {T <: Real} =
ScaledShiftedLogistic{T}(scale, shift)
transform(t::TVScale, x::Real) = t.scale * x
transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)

ScaledShiftedLogistic(scale::Real, shift::Real) =
ScaledShiftedLogistic(promote(scale, shift)...)
inverse(t::TVScale, x::Number) = x / t.scale
inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale)

# Switch to muladd and now it does have a DiffRule defined
transform(t::ScaledShiftedLogistic, x::Real) = muladd(logistic(x), t.scale, t.shift)
"""
$(TYPEDEF)

transform_and_logjac(t::ScaledShiftedLogistic, x) =
transform(t, x), log(t.scale) + logistic_logjac(x)
Negative transformation `x ↦ -x`.
"""
struct TVNeg <: ScalarTransform
end

transform(::TVNeg, x::Real) = -x
transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse(::TVNeg, x::Number) = -x
inverse_and_logjac(::TVNeg, x::Number) = -x, logjac_zero(LogJac(), typeof(x))

function inverse(t::ScaledShiftedLogistic, y)
@argcheck y > t.shift DomainError
@argcheck y < t.scale + t.shift DomainError
logit((y - t.shift)/t.scale)
####
#### composite scalar transforms
####
"""
$(TYPEDEF)

A composite scalar transformation, i.e. a sequence of scalar transformations.
"""
struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform
transforms::Ts
function CompositeScalarTransform(transforms::Ts) where {Ts <: Tuple{ScalarTransform,Vararg{ScalarTransform}}}
new{Ts}(transforms)
end
end

# NOTE: inverse_and_logjac interface experimental and sporadically implemented for now
function inverse_and_logjac(t::ScaledShiftedLogistic, y)
@argcheck y > t.shift DomainError
@argcheck y < t.scale + t.shift DomainError
z = (y - t.shift) / t.scale
logit(z), logit_logjac(z) - log(t.scale)
transform(t::CompositeScalarTransform, x) = foldr(transform, t.transforms, init=x)
function transform_and_logjac(ts::CompositeScalarTransform, x)
foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac)
nx, nlogjac = transform_and_logjac(t, x)
(nx, logjac + nlogjac)
end
end

inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x)
function inverse_and_logjac(ts::CompositeScalarTransform, x)
foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t
nx, nlogjac = inverse_and_logjac(t, x)
(nx, logjac + nlogjac)
end
end

Base.:∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s))
Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...))
Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t))
Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...))
Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = foldl(∘, tt; init=t)

####
#### to_interval interface
####
@@ -173,21 +215,22 @@ as(::Type{Real}, left, right) =

as(::Type{Real}, ::Infinity{false}, ::Infinity{true}) = Identity()

as(::Type{Real}, left::Real, ::Infinity{true}) = ShiftedExp(true, left)
as(::Type{Real}, left::Real, ::Infinity{true}) = TVShift(left) ∘ TVExp()

as(::Type{Real}, ::Infinity{false}, right::Real) = ShiftedExp(false, right)
as(::Type{Real}, ::Infinity{false}, right::Real) = TVShift(right) ∘ TVNeg() ∘ TVExp()

function as(::Type{Real}, left::Real, right::Real)
@argcheck left < right "the interval ($(left), $(right)) is empty"
ScaledShiftedLogistic(right - left, left)
shift, scale = promote(left, right - left)
TVShift(shift) ∘ TVScale(scale) ∘ TVLogistic()
end

"""
Transform to a positive real number. See [`as`](@ref).

`asℝ₊` and `as_positive_real` are equivalent alternatives.
"""
const asℝ₊ = as(Real, 0, ∞)
const asℝ₊ = TVExp()

const as_positive_real = asℝ₊

@@ -196,7 +239,7 @@ Transform to a negative real number. See [`as`](@ref).

`asℝ₋` and `as_negative_real` are equivalent alternatives.
"""
const asℝ₋ = as(Real, -∞, 0)
const asℝ₋ = TVNeg() ∘ TVExp()

const as_negative_real = asℝ₋

@@ -205,7 +248,7 @@ Transform to the unit interval `(0, 1)`. See [`as`](@ref).

`as𝕀` and `as_unit_interval` are equivalent alternatives.
"""
const as𝕀 = as(Real, 0, 1)
const as𝕀 = TVLogistic()

const as_unit_interval = as𝕀

@@ -218,24 +261,31 @@ const asℝ = as(Real, -∞, ∞)

const as_real = asℝ

function Base.show(io::IO, t::ShiftedExp)
if t === asℝ₊
print(io, "asℝ₊")
elseif t === asℝ₋
print(io, "asℝ₋")
elseif t isa ShiftedExp{true}
print(io, "as(Real, ", t.shift, ", ∞)")
else
print(io, "as(Real, -∞, ", t.shift, ")")
end
# Single scalar transforms
Base.show(io::IO, ::Identity) = print(io, "asℝ")
Base.show(io::IO, ::TVExp) = print(io, "asℝ₊")
Base.show(io::IO, ::TVLogistic) = print(io, "as𝕀")
function Base.show(io::IO, t::TVScale)
print(io, "TVScale(", t.scale, ")")
end
function Base.show(io::IO, t::TVShift)
print(io, "TVShift(", t.shift, ")")
end

function Base.show(io::IO, t::ScaledShiftedLogistic)
if t === as𝕀
print(io, "as𝕀")
else
print(io, "as(Real, ", t.shift, ", ", t.shift + t.scale, ")")
end
# Fallback method: print all transforms in order
Base.show(io::IO, ct::CompositeScalarTransform) = join(io, ct.transforms, " ∘ ")

# Special cases which are constructed by as(Real, ...)
function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVExp}}) where T
print(io, "as(Real, ", ct.transforms[1].shift, ", ∞)")
end
function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVNeg, TVExp}}) where T
print(io, "as(Real, -∞, ", ct.transforms[1].shift, ")")
end
function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T1}, TVScale{T2}, TVLogistic}}) where {T1, T2}
print(io, "as(Real, ", ct.transforms[1].shift, ", ", ct.transforms[1].shift +
ct.transforms[2].scale, ")")
end

Base.show(io::IO, t::Identity) = print(io, "asℝ")
# Special case for asR-
Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVNeg, TVExp}}) = print(io, "asℝ₋")
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -13,3 +13,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TransformedLogDensities = "f9bc47f6-f3f8-4f3b-ab21-f8bc73906f26"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Loading