From 488c7809481261fef40e8ad1d85ef1922011d8e2 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Thu, 6 Mar 2025 21:30:34 +0100 Subject: [PATCH 01/61] Took a stab at this idea. Needs testing --- src/scalar.jl | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/src/scalar.jl b/src/scalar.jl index db3503e..daf6e9b 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -54,6 +54,77 @@ transform_and_logjac(::Identity, x::Real) = x, zero(x) inverse(::Identity, x::Real) = x +#### +#### composite scalar transforms +#### +""" +$(TYPEDEF) + +A composite scalar transformation, i.e. a sequence of scalar transformations. +""" +struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform + transforms::Ts +end + +transform(t::CompositeScalarTransform, x) = foldr((t, x) -> transform(t, x), t.transforms, init=x) +function transform_and_logjac(ts::CompositeScalarTransform, x) + logjac = zero(x) + for t in ts[end:begin] + nx, nlogjac = transform_and_logjac(t, x) + x = nx + logjac += nlogjac + end + return x, logjac +end + +#### +#### elementary scalar transforms +#### + +""" +$(TYPEDEF) + +Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive reals. +""" +struct Exp <: ScalarTransform +end +transform(::Exp, x::Real) = exp(x) +transform_and_logjac(::Exp, x::Real) = transform(Exp(), x), x + +""" +$(TYPEDEF) + +Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1). +""" +struct Logistic <: ScalarTransform +end +transform(::Logistic, x::Real) = logit(x) +transform_and_logjac(::Logistic, x::Real) = transform(Logistic(), x), logit_logjac(x) + +""" +$(TYPEDEF) + +Shift transformation `x ↦ x + shift`. +""" +struct Shift{T <: Real} <: ScalarTransform + shift::T +end +transform(t::Shift, x::Real) = x + t.shift +transform_and_logjac(t::Shift, x::Real) = transform(t, x), zero(x) + +""" +$(TYPEDEF) + +Scale transformation `x ↦ scale * x`. +""" +struct Scale{T <: Real} <: ScalarTransform + scale::T +end + +transform(t::Scale, x::Real) = t.scale * x +transform_and_logjac(t::Scale, x::Real) = transform(t, x), log(abs(t.scale)) #???? need to think about this abs + + #### #### shifted exponential #### From efbb5e48c5fc6b597ecfca63b1ff22c0b0c13430 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Thu, 6 Mar 2025 21:34:50 +0100 Subject: [PATCH 02/61] The crucial step: composite! --- src/scalar.jl | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index daf6e9b..995447e 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -54,28 +54,6 @@ transform_and_logjac(::Identity, x::Real) = x, zero(x) inverse(::Identity, x::Real) = x -#### -#### composite scalar transforms -#### -""" -$(TYPEDEF) - -A composite scalar transformation, i.e. a sequence of scalar transformations. -""" -struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform - transforms::Ts -end - -transform(t::CompositeScalarTransform, x) = foldr((t, x) -> transform(t, x), t.transforms, init=x) -function transform_and_logjac(ts::CompositeScalarTransform, x) - logjac = zero(x) - for t in ts[end:begin] - nx, nlogjac = transform_and_logjac(t, x) - x = nx - logjac += nlogjac - end - return x, logjac -end #### #### elementary scalar transforms @@ -124,6 +102,32 @@ end transform(t::Scale, x::Real) = t.scale * x transform_and_logjac(t::Scale, x::Real) = transform(t, x), log(abs(t.scale)) #???? need to think about this abs +#### +#### composite scalar transforms +#### +""" +$(TYPEDEF) + +A composite scalar transformation, i.e. a sequence of scalar transformations. +""" +struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform + transforms::Ts +end + +transform(t::CompositeScalarTransform, x) = foldr((t, x) -> transform(t, x), t.transforms, init=x) +function transform_and_logjac(ts::CompositeScalarTransform, x) + logjac = zero(x) + for t in ts[end:begin] + nx, nlogjac = transform_and_logjac(t, x) + x = nx + logjac += nlogjac + end + return x, logjac +end + +Base.∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s)) +Base.∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) + #### #### shifted exponential From 09bb1888fd2a40a305f4793292f616187bc2de53 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Thu, 6 Mar 2025 21:42:39 +0100 Subject: [PATCH 03/61] don't nest composite transforms if possible --- src/scalar.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/scalar.jl b/src/scalar.jl index 995447e..b254b42 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -126,6 +126,9 @@ function transform_and_logjac(ts::CompositeScalarTransform, x) end Base.∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s)) +Base.∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...)) +Base.∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) +Base.∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) Base.∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) From 8e0d88e16cbcc3408ff259f2d1aa9de3f9907dd8 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 11:02:30 +0100 Subject: [PATCH 04/61] Properly refer to composite operator --- src/scalar.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index b254b42..2342b6a 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -125,11 +125,11 @@ function transform_and_logjac(ts::CompositeScalarTransform, x) return x, logjac end -Base.∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s)) -Base.∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...)) -Base.∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) -Base.∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) -Base.∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) +Base.:∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s)) +Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...)) +Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) +Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) +Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) #### From ab6e1ff2dab82d68163de013c6c7c90f21961494 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 11:20:50 +0100 Subject: [PATCH 05/61] Rename elemental scalar transformations --- src/scalar.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 2342b6a..fc86f5b 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -64,10 +64,10 @@ $(TYPEDEF) Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive reals. """ -struct Exp <: ScalarTransform +struct TVExp <: ScalarTransform end -transform(::Exp, x::Real) = exp(x) -transform_and_logjac(::Exp, x::Real) = transform(Exp(), x), x +transform(::TVExp, x::Real) = exp(x) +transform_and_logjac(::TVExp, x::Real) = transform(TVExp(), x), x """ $(TYPEDEF) @@ -84,23 +84,23 @@ $(TYPEDEF) Shift transformation `x ↦ x + shift`. """ -struct Shift{T <: Real} <: ScalarTransform +struct TVShift{T <: Real} <: ScalarTransform shift::T end -transform(t::Shift, x::Real) = x + t.shift -transform_and_logjac(t::Shift, x::Real) = transform(t, x), zero(x) +transform(t::TVShift, x::Real) = x + t.shift +transform_and_logjac(t::TVShift, x::Real) = transform(t, x), zero(x) """ $(TYPEDEF) Scale transformation `x ↦ scale * x`. """ -struct Scale{T <: Real} <: ScalarTransform +struct TVScale{T <: Real} <: ScalarTransform scale::T end -transform(t::Scale, x::Real) = t.scale * x -transform_and_logjac(t::Scale, x::Real) = transform(t, x), log(abs(t.scale)) #???? need to think about this abs +transform(t::TVScale, x::Real) = t.scale * x +transform_and_logjac(t::TVScale, x::Real) = transform(t, x), log(abs(t.scale)) #???? need to think about this abs #### #### composite scalar transforms From d5d244ac70c2349152dd8ce9f0debe1576417428 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 11:21:48 +0100 Subject: [PATCH 06/61] Export newly renamed transforms --- src/scalar.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/scalar.jl b/src/scalar.jl index fc86f5b..8980760 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -1,3 +1,4 @@ +export TVExp, TVScale, TVShift export ∞, asℝ, asℝ₊, asℝ₋, as𝕀, as_real, as_positive_real, as_negative_real, as_unit_interval From b1d9bdc310583195addc2720b34eb8ccfc8260e1 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 11:27:33 +0100 Subject: [PATCH 07/61] Rename and export Logistic --- src/scalar.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 8980760..6b928e5 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -1,4 +1,4 @@ -export TVExp, TVScale, TVShift +export TVExp, TVScale, TVShift, TVLogistic export ∞, asℝ, asℝ₊, asℝ₋, as𝕀, as_real, as_positive_real, as_negative_real, as_unit_interval @@ -75,10 +75,10 @@ $(TYPEDEF) Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1). """ -struct Logistic <: ScalarTransform +struct TVLogistic <: ScalarTransform end -transform(::Logistic, x::Real) = logit(x) -transform_and_logjac(::Logistic, x::Real) = transform(Logistic(), x), logit_logjac(x) +transform(::TVLogistic, x::Real) = logit(x) +transform_and_logjac(::TVLogistic, x::Real) = transform(TVLogistic(), x), logit_logjac(x) """ $(TYPEDEF) From d5e0e13f68da7bdd5be8d1d8b8a971d83116e2e1 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 11:31:35 +0100 Subject: [PATCH 08/61] Make indexing directly on CompositeScalarTransform pass through to underlying tuple --- src/scalar.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/scalar.jl b/src/scalar.jl index 6b928e5..1c0eeed 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -115,6 +115,10 @@ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform transforms::Ts end +Base.getindex(t::CompositeScalarTransform, i) = t.transforms[i] +Base.firstindex(t::CompositeScalarTransform) = lastindex(t.transforms) +Base.lastindex(t::CompositeScalarTransform) = lastindex(t.transforms) + transform(t::CompositeScalarTransform, x) = foldr((t, x) -> transform(t, x), t.transforms, init=x) function transform_and_logjac(ts::CompositeScalarTransform, x) logjac = zero(x) From 92479ea96d4f502d4b20ed0304c46aa51520da8c Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 11:39:57 +0100 Subject: [PATCH 09/61] get indexing right --- src/scalar.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 1c0eeed..0b499d7 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -116,13 +116,14 @@ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform end Base.getindex(t::CompositeScalarTransform, i) = t.transforms[i] -Base.firstindex(t::CompositeScalarTransform) = lastindex(t.transforms) +Base.firstindex(t::CompositeScalarTransform) = firstindex(t.transforms) Base.lastindex(t::CompositeScalarTransform) = lastindex(t.transforms) transform(t::CompositeScalarTransform, x) = foldr((t, x) -> transform(t, x), t.transforms, init=x) function transform_and_logjac(ts::CompositeScalarTransform, x) logjac = zero(x) - for t in ts[end:begin] + for t in ts[end:-1:begin] + @info "check" t x nx, nlogjac = transform_and_logjac(t, x) x = nx logjac += nlogjac From 30b327e43f1b94ca77ef2cf9998fba6eac3d3e7b Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 11:51:34 +0100 Subject: [PATCH 10/61] Make types inferrable for transform_and_logjac --- src/scalar.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 0b499d7..f07ad83 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -121,14 +121,12 @@ Base.lastindex(t::CompositeScalarTransform) = lastindex(t.transforms) transform(t::CompositeScalarTransform, x) = foldr((t, x) -> transform(t, x), t.transforms, init=x) function transform_and_logjac(ts::CompositeScalarTransform, x) - logjac = zero(x) - for t in ts[end:-1:begin] - @info "check" t x + foldr(ts.transforms, init=(x, zero(x))) do t, (x, logjac) nx, nlogjac = transform_and_logjac(t, x) x = nx logjac += nlogjac + (x, logjac) end - return x, logjac end Base.:∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s)) From f333507bb011a4eb7e0f1ef7c1bcb7034cfa01d9 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 11:56:07 +0100 Subject: [PATCH 11/61] Logistic, not logit --- src/scalar.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index f07ad83..112d058 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -77,8 +77,8 @@ Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1). """ struct TVLogistic <: ScalarTransform end -transform(::TVLogistic, x::Real) = logit(x) -transform_and_logjac(::TVLogistic, x::Real) = transform(TVLogistic(), x), logit_logjac(x) +transform(::TVLogistic, x::Real) = logistic(x) +transform_and_logjac(::TVLogistic, x::Real) = transform(TVLogistic(), x), logistic_logjac(x) """ $(TYPEDEF) From 677579da6c16506b06df310a191054d8067f8d3c Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 12:04:37 +0100 Subject: [PATCH 12/61] Add inverses --- src/scalar.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/scalar.jl b/src/scalar.jl index 112d058..276045a 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -70,6 +70,11 @@ end transform(::TVExp, x::Real) = exp(x) transform_and_logjac(::TVExp, x::Real) = transform(TVExp(), x), x +function inverse(::TVExp, x::Real) + @argcheck x > 0 DomainError + log(x) +end + """ $(TYPEDEF) @@ -80,6 +85,11 @@ end transform(::TVLogistic, x::Real) = logistic(x) transform_and_logjac(::TVLogistic, x::Real) = transform(TVLogistic(), x), logistic_logjac(x) +function inverse(::TVLogistic, x::Real) + @argcheck 0 < x < 1 DomainError + logit(x) +end + """ $(TYPEDEF) @@ -91,6 +101,8 @@ end transform(t::TVShift, x::Real) = x + t.shift transform_and_logjac(t::TVShift, x::Real) = transform(t, x), zero(x) +inverse(t::TVShift, x::Real) = x - t.shift + """ $(TYPEDEF) @@ -103,6 +115,8 @@ end transform(t::TVScale, x::Real) = t.scale * x transform_and_logjac(t::TVScale, x::Real) = transform(t, x), log(abs(t.scale)) #???? need to think about this abs +inverse(t::TVScale, x::Real) = x / t.scale + #### #### composite scalar transforms #### @@ -119,7 +133,7 @@ Base.getindex(t::CompositeScalarTransform, i) = t.transforms[i] Base.firstindex(t::CompositeScalarTransform) = firstindex(t.transforms) Base.lastindex(t::CompositeScalarTransform) = lastindex(t.transforms) -transform(t::CompositeScalarTransform, x) = foldr((t, x) -> transform(t, x), t.transforms, init=x) +transform(t::CompositeScalarTransform, x) = foldr((t, y) -> transform(t, y), t.transforms, init=x) function transform_and_logjac(ts::CompositeScalarTransform, x) foldr(ts.transforms, init=(x, zero(x))) do t, (x, logjac) nx, nlogjac = transform_and_logjac(t, x) @@ -129,6 +143,8 @@ function transform_and_logjac(ts::CompositeScalarTransform, x) end end +inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x) + Base.:∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s)) Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...)) Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) From 17b69fe49cfb4c548d5e29c3ddead1068a31059b Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 14 Mar 2025 12:04:59 +0100 Subject: [PATCH 13/61] Add consistency tests parallel to the existing ones --- test/runtests.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 73b3e22..6c79286 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,13 @@ end test_transformation(as(Real, a, ∞), y -> y > a) b = a + 0.5 + rand(Float64) + exp(randn() * 10) test_transformation(as(Real, a, b), y -> a < y < b) + + posexp = TVShift(a) ∘ TVExp() + negexp = TVShift(a) ∘ TVScale(-1) ∘ TVExp() + finite = TVShift(a) ∘ TVScale(b-a) ∘ TVLogistic() + test_transformation(posexp, y -> y > a) + test_transformation(negexp, y -> y < a) + test_transformation(finite, y -> a < y < b) end test_transformation(as(Real, -∞, ∞), _ -> true) end From f81376336dd1cd6c1e658a03428cbda8c316259a Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 19 Mar 2025 16:02:27 +0100 Subject: [PATCH 14/61] Add a TVNeg transformation for negative --- src/scalar.jl | 22 ++++++++++++++++++++-- test/runtests.jl | 2 +- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 276045a..ad1699f 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -1,4 +1,4 @@ -export TVExp, TVScale, TVShift, TVLogistic +export TVExp, TVScale, TVShift, TVLogistic, TVNeg export ∞, asℝ, asℝ₊, asℝ₋, as𝕀, as_real, as_positive_real, as_negative_real, as_unit_interval @@ -110,13 +110,31 @@ Scale transformation `x ↦ scale * x`. """ struct TVScale{T <: Real} <: ScalarTransform scale::T + function TVScale{T}(scale::T) where {T <: Real} + @argcheck scale > 0 DomainError + new(scale) + end end +TVScale(scale::T) where {T<:Real} = TVScale{T}(scale) transform(t::TVScale, x::Real) = t.scale * x -transform_and_logjac(t::TVScale, x::Real) = transform(t, x), log(abs(t.scale)) #???? need to think about this abs +transform_and_logjac(t::TVScale, x::Real) = transform(t, x), log(t.scale) inverse(t::TVScale, x::Real) = x / t.scale +""" +$(TYPEDEF) + +Negative transformation `x ↦ -x`. +""" +struct TVNeg <: ScalarTransform +end + +transform(::TVNeg, x::Real) = -x +transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), zero(x) + +inverse(::TVNeg, x::Real) = -x + #### #### composite scalar transforms #### diff --git a/test/runtests.jl b/test/runtests.jl index 6c79286..b5943b6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,7 +55,7 @@ end test_transformation(as(Real, a, b), y -> a < y < b) posexp = TVShift(a) ∘ TVExp() - negexp = TVShift(a) ∘ TVScale(-1) ∘ TVExp() + negexp = TVShift(a) ∘ TVNeg() ∘ TVExp() finite = TVShift(a) ∘ TVScale(b-a) ∘ TVLogistic() test_transformation(posexp, y -> y > a) test_transformation(negexp, y -> y < a) From 1c0d71b99fdca88b1cad5d73b668ffe28ecd3686 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 19 Mar 2025 16:21:06 +0100 Subject: [PATCH 15/61] Add test to make sure scalar transformations at least compose --- test/runtests.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index b5943b6..90681e0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,6 +64,23 @@ end test_transformation(as(Real, -∞, ∞), _ -> true) end +@testset "composite scalar transformations" begin + all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg(), + as(Real, 1.0, 3.0), as(Real, 5.0, ∞), as(Real, -∞, 4.5)] + for t1 in all_transforms + for t2 in all_transforms + for t3 in all_transforms + t = t1 ∘ t2 ∘ t3 + @test t isa TransformVariables.CompositeScalarTransform + # x = randn() + # y = transform(t, x) + # x2 = inverse(t, y) + # @test x ≈ x2 + end + end + end +end + @testset "scalar transformation corner cases" begin @test_throws ArgumentError as(Real, "a fish", 9) @test as(Real, 1, 4.0) == as(Real, 1.0, 4.0) From 3e492db4f2d1697691cfaf3ab21f4b5c406119ce Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 19 Mar 2025 16:26:33 +0100 Subject: [PATCH 16/61] More composing in test --- test/runtests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 90681e0..bbe0934 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,7 +71,8 @@ end for t2 in all_transforms for t3 in all_transforms t = t1 ∘ t2 ∘ t3 - @test t isa TransformVariables.CompositeScalarTransform + @test t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3)}} + @test t ∘ t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3), typeof(t1), typeof(t2), typeof(t3)}} # x = randn() # y = transform(t, x) # x2 = inverse(t, y) From 45aabf614354fccf28c2cf46fc7eea811ad9b01e Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Thu, 20 Mar 2025 11:08:52 +0100 Subject: [PATCH 17/61] Test consistency for a bunch of arbitrary compositions --- test/runtests.jl | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index bbe0934..6e10839 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,22 +66,29 @@ end @testset "composite scalar transformations" begin all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg(), - as(Real, 1.0, 3.0), as(Real, 5.0, ∞), as(Real, -∞, 4.5)] + as(Real, 1.0, 3.0), as(Real, -5.0, ∞), as(Real, -∞, 4.5)] for t1 in all_transforms for t2 in all_transforms for t3 in all_transforms t = t1 ∘ t2 ∘ t3 @test t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3)}} @test t ∘ t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3), typeof(t1), typeof(t2), typeof(t3)}} - # x = randn() - # y = transform(t, x) - # x2 = inverse(t, y) - # @test x ≈ x2 end end end end +@testset "semiarbitrary compositions" begin + same_domain_transforms = [TVShift(3.0), TVScale(2.0), TVNeg()] + new_domain_transforms = [TVExp(), TVLogistic(), as(Real, 1.0, 3.0), as(Real, -5.0, ∞), as(Real, -∞, 4.5)] + for s1 in same_domain_transforms, s2 in same_domain_transforms, n in new_domain_transforms + for s3 in same_domain_transforms, s4 in same_domain_transforms + # don't worry about valid output here, let inverse check that + test_transformation(s1 ∘ s2 ∘ n ∘ s3 ∘ s4, _ -> true; N=5) + end + end +end + @testset "scalar transformation corner cases" begin @test_throws ArgumentError as(Real, "a fish", 9) @test as(Real, 1, 4.0) == as(Real, 1.0, 4.0) From a1330177c2538e282fa77a798fb18b237750cc7d Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Tue, 25 Mar 2025 16:44:34 +0100 Subject: [PATCH 18/61] Add inverse_and_logjac for new types, consistent transform_and_logjac style --- src/scalar.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index ad1699f..bbf40a6 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -68,12 +68,13 @@ Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive real struct TVExp <: ScalarTransform end transform(::TVExp, x::Real) = exp(x) -transform_and_logjac(::TVExp, x::Real) = transform(TVExp(), x), x +transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x function inverse(::TVExp, x::Real) @argcheck x > 0 DomainError log(x) end +inverse_and_logjac(t::TVExp, x::Real) = inverse(t, x), -x """ $(TYPEDEF) @@ -83,12 +84,13 @@ Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1). struct TVLogistic <: ScalarTransform end transform(::TVLogistic, x::Real) = logistic(x) -transform_and_logjac(::TVLogistic, x::Real) = transform(TVLogistic(), x), logistic_logjac(x) +transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x) function inverse(::TVLogistic, x::Real) @argcheck 0 < x < 1 DomainError logit(x) end +inverse_and_logjac(t::TVLogistic, x::Real) = inverse(t, x), logit_logjac(x) """ $(TYPEDEF) @@ -102,6 +104,7 @@ transform(t::TVShift, x::Real) = x + t.shift transform_and_logjac(t::TVShift, x::Real) = transform(t, x), zero(x) inverse(t::TVShift, x::Real) = x - t.shift +inverse_and_logjac(t::TVShift, x::Real) = inverse(t, x), zero(x) """ $(TYPEDEF) @@ -121,6 +124,7 @@ transform(t::TVScale, x::Real) = t.scale * x transform_and_logjac(t::TVScale, x::Real) = transform(t, x), log(t.scale) inverse(t::TVScale, x::Real) = x / t.scale +inverse_and_logjac(t::TVScale, x::Real) = inverse(t, x), -log(t.scale) """ $(TYPEDEF) @@ -134,6 +138,7 @@ transform(::TVNeg, x::Real) = -x transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), zero(x) inverse(::TVNeg, x::Real) = -x +inverse_and_logjac(::TVNeg, x::Real) = -x, zero(x) #### #### composite scalar transforms @@ -162,6 +167,14 @@ function transform_and_logjac(ts::CompositeScalarTransform, x) end inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x) +function inverse_and_logjac(ts::CompositeScalarTransform, x) + foldl(ts.transforms, init=(x, zero(x))) do (x, logjac), t + nx, nlogjac = inverse_and_logjac(t, x) + x = nx + logjac += nlogjac + (x, logjac) + end +end Base.:∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s)) Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...)) From 54b601b8122dba9c3e7af57eb8ecb6e10b54a562 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Tue, 25 Mar 2025 17:27:26 +0100 Subject: [PATCH 19/61] Commit to new style, mark a few tests as broken for now --- src/scalar.jl | 12 ++++++------ test/runtests.jl | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index bbf40a6..9edf592 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -302,13 +302,13 @@ as(::Type{Real}, left, right) = as(::Type{Real}, ::Infinity{false}, ::Infinity{true}) = Identity() -as(::Type{Real}, left::Real, ::Infinity{true}) = ShiftedExp(true, left) +as(::Type{Real}, left::Real, ::Infinity{true}) = TVShift(left) ∘ TVExp() -as(::Type{Real}, ::Infinity{false}, right::Real) = ShiftedExp(false, right) +as(::Type{Real}, ::Infinity{false}, right::Real) = TVShift(right) ∘ TVNeg() ∘ TVExp() function as(::Type{Real}, left::Real, right::Real) @argcheck left < right "the interval ($(left), $(right)) is empty" - ScaledShiftedLogistic(right - left, left) + TVShift(left) ∘ TVScale(right - left) ∘ TVLogistic() end """ @@ -316,7 +316,7 @@ Transform to a positive real number. See [`as`](@ref). `asℝ₊` and `as_positive_real` are equivalent alternatives. """ -const asℝ₊ = as(Real, 0, ∞) +const asℝ₊ = TVExp() const as_positive_real = asℝ₊ @@ -325,7 +325,7 @@ Transform to a negative real number. See [`as`](@ref). `asℝ₋` and `as_negative_real` are equivalent alternatives. """ -const asℝ₋ = as(Real, -∞, 0) +const asℝ₋ = TVNeg() ∘ TVExp() const as_negative_real = asℝ₋ @@ -334,7 +334,7 @@ Transform to the unit interval `(0, 1)`. See [`as`](@ref). `as𝕀` and `as_unit_interval` are equivalent alternatives. """ -const as𝕀 = as(Real, 0, 1) +const as𝕀 = TVLogistic() const as_unit_interval = as𝕀 diff --git a/test/runtests.jl b/test/runtests.jl index 6e10839..27a644f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,7 @@ using LogDensityProblemsAD using TransformVariables: AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation, unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, - NOLOGJAC, transform_with + NOLOGJAC, transform_with, ShiftedExp, ScaledShiftedLogistic import ChangesOfVariables, InverseFunctions using Enzyme: autodiff, ReverseWithPrimal, Active, Const @@ -66,7 +66,7 @@ end @testset "composite scalar transformations" begin all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg(), - as(Real, 1.0, 3.0), as(Real, -5.0, ∞), as(Real, -∞, 4.5)] + ScaledShiftedLogistic(2.0, 1.0), ShiftedExp(true, -5), ShiftedExp(false, 4.5)] for t1 in all_transforms for t2 in all_transforms for t3 in all_transforms @@ -80,7 +80,7 @@ end @testset "semiarbitrary compositions" begin same_domain_transforms = [TVShift(3.0), TVScale(2.0), TVNeg()] - new_domain_transforms = [TVExp(), TVLogistic(), as(Real, 1.0, 3.0), as(Real, -5.0, ∞), as(Real, -∞, 4.5)] + new_domain_transforms = [TVExp(), TVLogistic(), ScaledShiftedLogistic(2.0, 1.0), ShiftedExp(true, -5), ShiftedExp(false, 4.5)] for s1 in same_domain_transforms, s2 in same_domain_transforms, n in new_domain_transforms for s3 in same_domain_transforms, s4 in same_domain_transforms # don't worry about valid output here, let inverse check that @@ -91,7 +91,7 @@ end @testset "scalar transformation corner cases" begin @test_throws ArgumentError as(Real, "a fish", 9) - @test as(Real, 1, 4.0) == as(Real, 1.0, 4.0) + @test_broken as(Real, 1, 4.0) == as(Real, 1.0, 4.0) @test_throws ArgumentError as(Real, 3.0, -4.0) t = as(Real, 1.0, ∞) @@ -575,12 +575,12 @@ end @testset "scalar show" begin @test string(asℝ) == "asℝ" - @test string(asℝ₊) == "asℝ₊" - @test string(asℝ₋) == "asℝ₋" - @test string(as𝕀) == "as𝕀" - @test string(as(Real, 0.0, 2.0)) == "as(Real, 0.0, 2.0)" - @test string(as(Real, 1.0, ∞)) == "as(Real, 1.0, ∞)" - @test string(as(Real, -∞, 1.0)) == "as(Real, -∞, 1.0)" + @test_broken string(asℝ₊) == "asℝ₊" + @test_broken string(asℝ₋) == "asℝ₋" + @test_broken string(as𝕀) == "as𝕀" + @test_broken string(as(Real, 0.0, 2.0)) == "as(Real, 0.0, 2.0)" + @test_broken string(as(Real, 1.0, ∞)) == "as(Real, 1.0, ∞)" + @test_broken string(as(Real, -∞, 1.0)) == "as(Real, -∞, 1.0)" end @testset "sum dimensions allocations" begin From d87fa2ae27a9b3f5c0991fbd2d5e14f4d04f108b Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Tue, 25 Mar 2025 17:27:45 +0100 Subject: [PATCH 20/61] Provide some default show methods --- src/scalar.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/scalar.jl b/src/scalar.jl index 9edf592..f682153 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -347,6 +347,30 @@ const asℝ = as(Real, -∞, ∞) const as_real = asℝ +function Base.show(io::IO, ct::CompositeScalarTransform) + # if ct === asℝ₋ + # print(io, "asℝ₋") + # else + str = string(ct.transforms[1]) + for ti in ct.transforms[begin+1:end] + str *= " ∘ "*string(ti) + end + print(io, str) + # end +end +function Base.show(io::IO, t::TVScale) + print(io, "TVScale(", t.scale, ")") +end +function Base.show(io::IO, t::TVShift) + print(io, "TVShift(", t.shift, ")") +end +# function Base.show(io::IO, ::TVExp) +# print(io, "asℝ₊") +# end +# function Base.show(io::IO, ::TVLogistic) +# print(io, "as𝕀") +# end + function Base.show(io::IO, t::ShiftedExp) if t === asℝ₊ print(io, "asℝ₊") From 683d924f3ef87e99c1d8237360cc5bce4323594a Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 26 Mar 2025 12:03:10 +0100 Subject: [PATCH 21/61] Add some documentation for composable scalar transforms --- docs/src/index.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index 515bdaf..89f0b3b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -98,6 +98,21 @@ asℝ₋ as𝕀 ``` +For more granular control than the `as(Real, a, b)`, scalar transformations can be built from individual elements with the composition operator `∘` (typed as `\circ`): + +```@docs +TVExp +TVLogistic +TVScale +TVShift +TVNeg +``` + +Consistent with common notation, transforms are applied right-to-left; for example, `as(Real, ∞, 3)` is equivalent to `TVShift(3) ∘ TVNeg() ∘ TVExp()`. + +This composition works with any scalar transform in any order, so `TVScale(4) ∘ as(Real, 2, ∞) ∘ TVShift(1e3)` is a valid transform. +This is useful especially for making sure that values near 0, when transformed, yield usefully-scaled values for a given variable. + ## Special arrays ```@docs From f5d9740533422c2f3943ced231be231088f8edc9 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 26 Mar 2025 12:18:48 +0100 Subject: [PATCH 22/61] Widen types allowed by scalar inverses to Number, and TVScale to anything --- src/scalar.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index f682153..fd6dee6 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -53,7 +53,7 @@ transform(::Identity, x::Real) = x transform_and_logjac(::Identity, x::Real) = x, zero(x) -inverse(::Identity, x::Real) = x +inverse(::Identity, x::Number) = x #### @@ -70,11 +70,11 @@ end transform(::TVExp, x::Real) = exp(x) transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x -function inverse(::TVExp, x::Real) +function inverse(::TVExp, x::Number) @argcheck x > 0 DomainError log(x) end -inverse_and_logjac(t::TVExp, x::Real) = inverse(t, x), -x +inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), -x """ $(TYPEDEF) @@ -86,11 +86,11 @@ end transform(::TVLogistic, x::Real) = logistic(x) transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x) -function inverse(::TVLogistic, x::Real) +function inverse(::TVLogistic, x::Number) @argcheck 0 < x < 1 DomainError logit(x) end -inverse_and_logjac(t::TVLogistic, x::Real) = inverse(t, x), logit_logjac(x) +inverse_and_logjac(t::TVLogistic, x::Number) = inverse(t, x), logit_logjac(x) """ $(TYPEDEF) @@ -103,28 +103,28 @@ end transform(t::TVShift, x::Real) = x + t.shift transform_and_logjac(t::TVShift, x::Real) = transform(t, x), zero(x) -inverse(t::TVShift, x::Real) = x - t.shift -inverse_and_logjac(t::TVShift, x::Real) = inverse(t, x), zero(x) +inverse(t::TVShift, x::Number) = x - t.shift +inverse_and_logjac(t::TVShift, x::Number) = inverse(t, x), zero(x) """ $(TYPEDEF) Scale transformation `x ↦ scale * x`. """ -struct TVScale{T <: Real} <: ScalarTransform +struct TVScale{T} <: ScalarTransform scale::T - function TVScale{T}(scale::T) where {T <: Real} - @argcheck scale > 0 DomainError + function TVScale{T}(scale::T) where {T} + @argcheck scale > zero(scale) DomainError new(scale) end end -TVScale(scale::T) where {T<:Real} = TVScale{T}(scale) +TVScale(scale::T) where {T} = TVScale{T}(scale) transform(t::TVScale, x::Real) = t.scale * x transform_and_logjac(t::TVScale, x::Real) = transform(t, x), log(t.scale) -inverse(t::TVScale, x::Real) = x / t.scale -inverse_and_logjac(t::TVScale, x::Real) = inverse(t, x), -log(t.scale) +inverse(t::TVScale, x::Number) = x / t.scale +inverse_and_logjac(t::TVScale, x::Number) = inverse(t, x), -log(t.scale) """ $(TYPEDEF) @@ -137,8 +137,8 @@ end transform(::TVNeg, x::Real) = -x transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), zero(x) -inverse(::TVNeg, x::Real) = -x -inverse_and_logjac(::TVNeg, x::Real) = -x, zero(x) +inverse(::TVNeg, x::Number) = -x +inverse_and_logjac(::TVNeg, x::Number) = -x, zero(x) #### #### composite scalar transforms From 04f3a3399afad52102063395646317e2799d0a6b Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 26 Mar 2025 12:29:50 +0100 Subject: [PATCH 23/61] Try adding Unitful tests: commented out because not working yet --- test/Project.toml | 1 + test/runtests.jl | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index bfaad77..e26b067 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,3 +13,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TransformedLogDensities = "f9bc47f6-f3f8-4f3b-ab21-f8bc73906f26" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/test/runtests.jl b/test/runtests.jl index 27a644f..1e29677 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ using TransformVariables: NOLOGJAC, transform_with, ShiftedExp, ScaledShiftedLogistic import ChangesOfVariables, InverseFunctions using Enzyme: autodiff, ReverseWithPrimal, Active, Const +using Unitful: @u_str, ustrip, uconvert const CIENV = get(ENV, "CI", "") == "true" @@ -46,6 +47,7 @@ end #### scalar transformations correctness checks #### + @testset "scalar transformations consistency" begin for _ in 1:100 a = randn() * 100 @@ -64,6 +66,17 @@ end test_transformation(as(Real, -∞, ∞), _ -> true) end +# @testset "scalar non-Real (Unitful) consistency" begin +# for _ in 1:10 +# a = randn() * 100 +# b = a + 0.5 + rand(Float64) + exp(randn() * 10) +# t1 = TVScale(2u"m") ∘ TVShift(a) ∘ TVExp() +# test_transformation(t1, y -> y > 2u"m") +# t2 = TVScale(1u"s") ∘ TVShift(a) ∘ TVScale(b-a) ∘ TVLogistic() +# test_transformation(t2, y -> a*u"s" < y < b*u"s") +# end +# end + @testset "composite scalar transformations" begin all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg(), ScaledShiftedLogistic(2.0, 1.0), ShiftedExp(true, -5), ShiftedExp(false, 4.5)] @@ -122,6 +135,7 @@ end @test transform(as𝕀, a) isa Float32 end + #### #### special array transformation correctness checks #### From 4107271b06c0c8449d67f28120a37add0ea1f249 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 26 Mar 2025 12:30:22 +0100 Subject: [PATCH 24/61] Some docs explaining Unitful transform --- docs/src/index.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index 89f0b3b..835833e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -113,6 +113,15 @@ Consistent with common notation, transforms are applied right-to-left; for examp This composition works with any scalar transform in any order, so `TVScale(4) ∘ as(Real, 2, ∞) ∘ TVShift(1e3)` is a valid transform. This is useful especially for making sure that values near 0, when transformed, yield usefully-scaled values for a given variable. +In addition, the `TVScale` transform accepts arbitrary types. It can be used as the outermost transform (so leftmost in the composition), to add Unitful units to a number (or to create other exotic number types which can be constructed by multiplying, such as a `ForwardDiff.Dual`). +For example, + +```julia +using Unitful +t = TVScale(5u"m") ∘ TVExp() +``` +produces positive quantities with the dimension of length. + ## Special arrays ```@docs From 3c7b1a5ddae56fbbba579d853033ad7e9edcacff Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 26 Mar 2025 14:32:54 +0100 Subject: [PATCH 25/61] Add tests to cover more constructions of composition --- test/runtests.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 1e29677..d9d959c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -84,8 +84,13 @@ end for t2 in all_transforms for t3 in all_transforms t = t1 ∘ t2 ∘ t3 + # Basic functionality @test t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3)}} + # Other constructions @test t ∘ t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3), typeof(t1), typeof(t2), typeof(t3)}} + @test t == t1 ∘ (t2 ∘ t3) + @test t == ∘(t1, t2, t3) + @test all([t[1] == t1, t[2] == t2, t[3] == t3]) end end end From 07ef76fa79e815d77df96a0cb9507729f17f276a Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 28 Mar 2025 13:15:57 +0100 Subject: [PATCH 26/61] Don't test Jacobian for transforms that add units --- test/runtests.jl | 20 ++++++++++---------- test/utilities.jl | 16 +++++++++------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d9d959c..f44393d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,16 +66,16 @@ end test_transformation(as(Real, -∞, ∞), _ -> true) end -# @testset "scalar non-Real (Unitful) consistency" begin -# for _ in 1:10 -# a = randn() * 100 -# b = a + 0.5 + rand(Float64) + exp(randn() * 10) -# t1 = TVScale(2u"m") ∘ TVShift(a) ∘ TVExp() -# test_transformation(t1, y -> y > 2u"m") -# t2 = TVScale(1u"s") ∘ TVShift(a) ∘ TVScale(b-a) ∘ TVLogistic() -# test_transformation(t2, y -> a*u"s" < y < b*u"s") -# end -# end +@testset "scalar non-Real (Unitful) consistency" begin + for _ in 1:10 + a = randn() * 100 + b = a + 0.5 + rand(Float64) + exp(randn() * 10) + t1 = TVScale(2u"m") ∘ TVShift(a) ∘ TVExp() + test_transformation(t1, y -> y > 2u"m", jac=false) + t2 = TVScale(1u"s") ∘ TVShift(a) ∘ TVScale(b-a) ∘ TVLogistic() + test_transformation(t2, y -> a*u"s" < y < b*u"s", jac=false) + end +end @testset "composite scalar transformations" begin all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg(), diff --git a/test/utilities.jl b/test/utilities.jl index 18549b8..f1b3fa7 100644 --- a/test/utilities.jl +++ b/test/utilities.jl @@ -31,19 +31,21 @@ automatic differentiation. `test_inverse` determines whether the inverse is tested. """ function test_transformation(t::AbstractTransform, is_valid_y; - vec_y = identity, N = 1000, test_inverse = true) + vec_y = identity, N = 1000, test_inverse = true, jac=true) for _ in 1:N x = random_arg(t) x isa ScalarTransform && @test dimension(x) == 1 y = @inferred transform(t, x) @test is_valid_y(y) @test transform(t, x) == y - y2, lj = @inferred transform_and_logjac(t, x) - @test y2 == y - if t isa ScalarTransform - @test lj ≈ AD_logjac(t, x) - else - @test lj ≈ AD_logjac(t, x, vec_y) + if jac + y2, lj = @inferred transform_and_logjac(t, x) + @test y2 == y + if t isa ScalarTransform + @test lj ≈ AD_logjac(t, x) + else + @test lj ≈ AD_logjac(t, x, vec_y) + end end if test_inverse x2 = inverse(t, y) From 74da4941323057224cce147d58051e43b4619de2 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 28 Mar 2025 13:25:54 +0100 Subject: [PATCH 27/61] Remove log-Jacobian functionality for Unitful-type transforms --- docs/src/index.md | 5 ++++- src/scalar.jl | 4 ++-- test/runtests.jl | 5 +++++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 835833e..53c7772 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -120,7 +120,10 @@ For example, using Unitful t = TVScale(5u"m") ∘ TVExp() ``` -produces positive quantities with the dimension of length. +produces positive quantities with the dimension of length. +!!! note + Because the log-Jacobian of a transform that adds units is not defined, `transform_and_logjac` and `inverse_and_logjac` + only have methods defined for `TVScale{T} where {T<:Real}`. ## Special arrays diff --git a/src/scalar.jl b/src/scalar.jl index fd6dee6..a4273c9 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -121,10 +121,10 @@ end TVScale(scale::T) where {T} = TVScale{T}(scale) transform(t::TVScale, x::Real) = t.scale * x -transform_and_logjac(t::TVScale, x::Real) = transform(t, x), log(t.scale) +transform_and_logjac(t::TVScale{T}, x::Real) where {T<:Real} = transform(t, x), log(t.scale) inverse(t::TVScale, x::Number) = x / t.scale -inverse_and_logjac(t::TVScale, x::Number) = inverse(t, x), -log(t.scale) +inverse_and_logjac(t::TVScale{T}, x::Number) where {T<:Real} = inverse(t, x), -log(t.scale) """ $(TYPEDEF) diff --git a/test/runtests.jl b/test/runtests.jl index f44393d..ebc11a0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -123,6 +123,11 @@ end @test_throws DomainError inverse(t, 11.0) @test_throws DomainError inverse_and_logjac(t, 0.5) @test_throws DomainError inverse_and_logjac(t, 11.0) + + t = TVScale(5.0u"m") ∘ as(Real, 1.0, 10.0) + @test_throws MethodError transform_and_logjac(t, 0.5) + y = transform(t, 0.5) + @test_throws MethodError inverse_and_logjac(t, y) end @testset "scalar alternatives" begin From 6dff1f989c572890e0b22b7dfd6c71f93941ec38 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 28 Mar 2025 13:50:26 +0100 Subject: [PATCH 28/61] Fix ill-formed test --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ebc11a0..1b9185e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,9 +71,9 @@ end a = randn() * 100 b = a + 0.5 + rand(Float64) + exp(randn() * 10) t1 = TVScale(2u"m") ∘ TVShift(a) ∘ TVExp() - test_transformation(t1, y -> y > 2u"m", jac=false) + test_transformation(t1, y -> y > a*2u"m", jac=false) t2 = TVScale(1u"s") ∘ TVShift(a) ∘ TVScale(b-a) ∘ TVLogistic() - test_transformation(t2, y -> a*u"s" < y < b*u"s", jac=false) + test_transformation(t2, y -> (a*u"s" < y < b*u"s"), jac=false) end end From d106ef735a3d7476e6e331fe9bef5d2792c93c59 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 28 Mar 2025 13:54:03 +0100 Subject: [PATCH 29/61] Improve test coverage --- test/runtests.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 1b9185e..86af51e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -91,6 +91,8 @@ end @test t == t1 ∘ (t2 ∘ t3) @test t == ∘(t1, t2, t3) @test all([t[1] == t1, t[2] == t2, t[3] == t3]) + @test all([t[begin] == t1, t[end] == t3]) + @test t[begin:end] == t[:] end end end @@ -531,6 +533,18 @@ end x2, lj2 = TransformVariables.inverse_and_logjac(t, y) @test x2 ≈ x @test lj2 ≈ -lj + + t = as(Real, a, ∞) + y, lj = transform_and_logjac(t, x) + x2, lj2 = TransformVariables.inverse_and_logjac(t, y) + @test x2 ≈ x + @test lj2 ≈ -lj + + t = as(Real, -∞, a) + y, lj = transform_and_logjac(t, x) + x2, lj2 = TransformVariables.inverse_and_logjac(t, y) + @test x2 ≈ x + @test lj2 ≈ -lj end end From e3747e39eedce3b986f7d4ade6380fd0155d92b7 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 2 Apr 2025 20:30:23 +0200 Subject: [PATCH 30/61] Improve pretty printing for common transforms to keep printing behavior; add new scalar show tests --- src/scalar.jl | 56 +++++++++++++++++++++++++++++++++--------------- test/runtests.jl | 30 +++++++++++++++++--------- 2 files changed, 59 insertions(+), 27 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index a4273c9..76df8f6 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -316,7 +316,7 @@ Transform to a positive real number. See [`as`](@ref). `asℝ₊` and `as_positive_real` are equivalent alternatives. """ -const asℝ₊ = TVExp() +const asℝ₊ = ∘(TVExp()) const as_positive_real = asℝ₊ @@ -334,7 +334,7 @@ Transform to the unit interval `(0, 1)`. See [`as`](@ref). `as𝕀` and `as_unit_interval` are equivalent alternatives. """ -const as𝕀 = TVLogistic() +const as𝕀 = ∘(TVLogistic()) const as_unit_interval = as𝕀 @@ -347,29 +347,51 @@ const asℝ = as(Real, -∞, ∞) const as_real = asℝ +# Fallback method: print all transforms in order function Base.show(io::IO, ct::CompositeScalarTransform) - # if ct === asℝ₋ - # print(io, "asℝ₋") - # else - str = string(ct.transforms[1]) - for ti in ct.transforms[begin+1:end] - str *= " ∘ "*string(ti) - end - print(io, str) - # end + str = string(ct.transforms[1]) + for ti in ct.transforms[begin+1:end] + str *= " ∘ "*string(ti) + end + print(io, str) end + +# If equivalent to asℝ₊, print as such. Two ways to achieve this +function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVExp}}) where T + if ct[1].shift == 0 + print(io, "asℝ₊") + else + print(io, "as(Real, ", ct[1].shift, ", ∞)") + end +end +Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVExp}}) = print(io, "asℝ₊") + +# If equivalent to asℝ₋, print as such. Two ways to achieve this +function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVNeg, TVExp}}) where T + if ct[1].shift == 0 + print(io, "asℝ₋") + else + print(io, "as(Real, -∞, ", ct[1].shift, ")") + end +end +Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVNeg, TVExp}}) = print(io, "asℝ₋") + +# If equivalent to as𝕀, print as such. Two ways to achieve this +function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T1}, TVScale{T2}, TVLogistic}}) where {T1, T2} + if ct[1].shift == 0 && ct[2].scale == 1 + print(io, "as𝕀") + else + print(io, "as(Real, ", ct[1].shift, ", ", ct[1].shift + ct[2].scale, ")") + end +end +Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVLogistic}}) = print(io, "as𝕀") + function Base.show(io::IO, t::TVScale) print(io, "TVScale(", t.scale, ")") end function Base.show(io::IO, t::TVShift) print(io, "TVShift(", t.shift, ")") end -# function Base.show(io::IO, ::TVExp) -# print(io, "asℝ₊") -# end -# function Base.show(io::IO, ::TVLogistic) -# print(io, "as𝕀") -# end function Base.show(io::IO, t::ShiftedExp) if t === asℝ₊ diff --git a/test/runtests.jl b/test/runtests.jl index 86af51e..ebde9df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,26 @@ end #### scalar transformations correctness checks #### +@testset "scalar show" begin + @test string(asℝ) == "asℝ" + @test string(asℝ₊) == "asℝ₊" + @test string(asℝ₋) == "asℝ₋" + @test string(as𝕀) == "as𝕀" + @test string(as(Real, 0.0, 2.0)) == "as(Real, 0.0, 2.0)" + @test string(as(Real, 1.0, ∞)) == "as(Real, 1.0, ∞)" + @test string(as(Real, -∞, 1.0)) == "as(Real, -∞, 1.0)" + + @test string(TVShift(4.0)) == "TVShift(4.0)" + @test string(TVScale(4.0)) == "TVScale(4.0)" + @test string(TVExp()) == "TVExp()" + @test string(TVLogistic()) == "TVLogistic()" + @test string(TVNeg()) == "TVNeg()" + + @test string(TVShift(0) ∘ TVNeg() ∘ TVExp()) == "asℝ₋" + @test string(TVShift(0) ∘ TVExp()) == "asℝ₊" + @test string(TVScale(2.0) ∘ TVNeg() ∘ TVExp()) == "TVScale(2.0) ∘ TVNeg() ∘ TVExp()" + @test string(TVScale(5.0u"m") ∘ TVExp()) == "TVScale(5.0 m) ∘ TVExp()" +end @testset "scalar transformations consistency" begin for _ in 1:100 @@ -611,16 +631,6 @@ end #### show #### -@testset "scalar show" begin - @test string(asℝ) == "asℝ" - @test_broken string(asℝ₊) == "asℝ₊" - @test_broken string(asℝ₋) == "asℝ₋" - @test_broken string(as𝕀) == "as𝕀" - @test_broken string(as(Real, 0.0, 2.0)) == "as(Real, 0.0, 2.0)" - @test_broken string(as(Real, 1.0, ∞)) == "as(Real, 1.0, ∞)" - @test_broken string(as(Real, -∞, 1.0)) == "as(Real, -∞, 1.0)" -end - @testset "sum dimensions allocations" begin shifted = TransformVariables.ShiftedExp{true,Float64}(0.0) tr = (a = shifted, b = TransformVariables.Identity(), c = shifted, d = shifted, e = shifted, f = shifted) From 14bba55139d6c73a6d9ccf0c6aa0353d4fdc59f1 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 2 Apr 2025 20:48:28 +0200 Subject: [PATCH 31/61] Add a non-unicode alias for composition operator --- docs/src/index.md | 3 ++- src/scalar.jl | 1 + test/runtests.jl | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 53c7772..2e8134c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -109,11 +109,12 @@ TVNeg ``` Consistent with common notation, transforms are applied right-to-left; for example, `as(Real, ∞, 3)` is equivalent to `TVShift(3) ∘ TVNeg() ∘ TVExp()`. +If you are working in an editor where typing Unitful is difficult, `TransformVariables.compose` (not exported) is provided as an alias for `∘`, as in `TransformVariables.compose(TVScale(5.0), TVExp(), TVNeg())`. This composition works with any scalar transform in any order, so `TVScale(4) ∘ as(Real, 2, ∞) ∘ TVShift(1e3)` is a valid transform. This is useful especially for making sure that values near 0, when transformed, yield usefully-scaled values for a given variable. -In addition, the `TVScale` transform accepts arbitrary types. It can be used as the outermost transform (so leftmost in the composition), to add Unitful units to a number (or to create other exotic number types which can be constructed by multiplying, such as a `ForwardDiff.Dual`). +In addition, the `TVScale` transform accepts arbitrary types. It can be used as the outermost transform (so leftmost in the composition) to add Unitful units to a number (or to create other exotic number types which can be constructed by multiplying, such as a `ForwardDiff.Dual`). For example, ```julia diff --git a/src/scalar.jl b/src/scalar.jl index 76df8f6..30b92e7 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -181,6 +181,7 @@ Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTra Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) +const compose = Base.:∘ #### diff --git a/test/runtests.jl b/test/runtests.jl index ebde9df..6915804 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -110,6 +110,7 @@ end @test t ∘ t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3), typeof(t1), typeof(t2), typeof(t3)}} @test t == t1 ∘ (t2 ∘ t3) @test t == ∘(t1, t2, t3) + @test t == TransformVariables.compose(t1, t2, t3) @test all([t[1] == t1, t[2] == t2, t[3] == t3]) @test all([t[begin] == t1, t[end] == t3]) @test t[begin:end] == t[:] From 5ff6832046294699d9a6bafddf09961cb49796d8 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 2 Apr 2025 20:55:44 +0200 Subject: [PATCH 32/61] More alias tests in scalar show --- test/runtests.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6915804..3adbab8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,12 @@ end @test string(asℝ₊) == "asℝ₊" @test string(asℝ₋) == "asℝ₋" @test string(as𝕀) == "as𝕀" + @test string(TVShift(0) ∘ TVNeg() ∘ TVExp()) == "asℝ₋" + @test string(TVShift(0) ∘ TVExp()) == "asℝ₊" + @test string(TVShift(0) ∘ TVScale(1) ∘ TVLogistic()) == "as𝕀" + @test string(TVNeg() ∘ TVExp()) == "asℝ₋" + @test string(∘(TVExp())) == "asℝ₊" + @test string(∘(TVLogistic())) == "as𝕀" @test string(as(Real, 0.0, 2.0)) == "as(Real, 0.0, 2.0)" @test string(as(Real, 1.0, ∞)) == "as(Real, 1.0, ∞)" @test string(as(Real, -∞, 1.0)) == "as(Real, -∞, 1.0)" @@ -62,8 +68,6 @@ end @test string(TVLogistic()) == "TVLogistic()" @test string(TVNeg()) == "TVNeg()" - @test string(TVShift(0) ∘ TVNeg() ∘ TVExp()) == "asℝ₋" - @test string(TVShift(0) ∘ TVExp()) == "asℝ₊" @test string(TVScale(2.0) ∘ TVNeg() ∘ TVExp()) == "TVScale(2.0) ∘ TVNeg() ∘ TVExp()" @test string(TVScale(5.0u"m") ∘ TVExp()) == "TVScale(5.0 m) ∘ TVExp()" end From 67f1eb40508820a53b14ee089f7b39b5d70df825 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 4 Apr 2025 11:15:50 +0200 Subject: [PATCH 33/61] Fix log-Jacobian of inverse of Exp transform; move inverse_and_logjac tests --- src/scalar.jl | 3 ++- test/runtests.jl | 47 +++++++++++++++++++++++------------------------ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 30b92e7..c8e41f8 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -55,6 +55,7 @@ transform_and_logjac(::Identity, x::Real) = x, zero(x) inverse(::Identity, x::Number) = x +inverse_and_logjac(::Identity, x::Real) = x, zero(x) #### #### elementary scalar transforms @@ -74,7 +75,7 @@ function inverse(::TVExp, x::Number) @argcheck x > 0 DomainError log(x) end -inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), -x +inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), log(1/x) """ $(TYPEDEF) diff --git a/test/runtests.jl b/test/runtests.jl index 3adbab8..1643dd0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -172,6 +172,29 @@ end @test transform(as𝕀, a) isa Float32 end +@testset "inverse_and_logjac" begin + for _ in 1:100 + x = randn() + a = randn() + t = as(Real, a, a + abs(randn()) + 0.1) + y, lj = transform_and_logjac(t, x) + x2, lj2 = TransformVariables.inverse_and_logjac(t, y) + @test x2 ≈ x + @test lj2 ≈ -lj + + t = as(Real, a, ∞) + y, lj = transform_and_logjac(t, x) + x2, lj2 = TransformVariables.inverse_and_logjac(t, y) + @test x2 ≈ x + @test lj2 ≈ -lj + + t = as(Real, -∞, a) + y, lj = transform_and_logjac(t, x) + x2, lj2 = TransformVariables.inverse_and_logjac(t, y) + @test x2 ≈ x + @test lj2 ≈ -lj + end +end #### #### special array transformation correctness checks @@ -548,30 +571,6 @@ end # end # end -@testset "inverse_and_logjac" begin - # WIP, test separately until integrated - for _ in 1:100 - x = randn() - a = randn() - t = as(Real, a, a + abs(randn()) + 0.1) - y, lj = transform_and_logjac(t, x) - x2, lj2 = TransformVariables.inverse_and_logjac(t, y) - @test x2 ≈ x - @test lj2 ≈ -lj - - t = as(Real, a, ∞) - y, lj = transform_and_logjac(t, x) - x2, lj2 = TransformVariables.inverse_and_logjac(t, y) - @test x2 ≈ x - @test lj2 ≈ -lj - - t = as(Real, -∞, a) - y, lj = transform_and_logjac(t, x) - x2, lj2 = TransformVariables.inverse_and_logjac(t, y) - @test x2 ≈ x - @test lj2 ≈ -lj - end -end @testset "inference of nested tuples" begin # An MWE adapted from a real-life problem From 12137f9d2d600362145b73da20adb88c62a7316d Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 4 Apr 2025 11:17:58 +0200 Subject: [PATCH 34/61] important typo in docs --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 2e8134c..d26a40a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -109,7 +109,7 @@ TVNeg ``` Consistent with common notation, transforms are applied right-to-left; for example, `as(Real, ∞, 3)` is equivalent to `TVShift(3) ∘ TVNeg() ∘ TVExp()`. -If you are working in an editor where typing Unitful is difficult, `TransformVariables.compose` (not exported) is provided as an alias for `∘`, as in `TransformVariables.compose(TVScale(5.0), TVExp(), TVNeg())`. +If you are working in an editor where typing Unicode is difficult, `TransformVariables.compose` (not exported) is provided as an alias for `∘`, as in `TransformVariables.compose(TVScale(5.0), TVExp(), TVNeg())`. This composition works with any scalar transform in any order, so `TVScale(4) ∘ as(Real, 2, ∞) ∘ TVShift(1e3)` is a valid transform. This is useful especially for making sure that values near 0, when transformed, yield usefully-scaled values for a given variable. From e5bb567a5a7e80bf525a69f8c48299e5e1ee3ed2 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 4 Apr 2025 11:21:56 +0200 Subject: [PATCH 35/61] make compose a direct constructor for CompositeScalarTransform, rather than alias for composition operator --- docs/src/index.md | 2 +- src/scalar.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index d26a40a..271a888 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -109,7 +109,7 @@ TVNeg ``` Consistent with common notation, transforms are applied right-to-left; for example, `as(Real, ∞, 3)` is equivalent to `TVShift(3) ∘ TVNeg() ∘ TVExp()`. -If you are working in an editor where typing Unicode is difficult, `TransformVariables.compose` (not exported) is provided as an alias for `∘`, as in `TransformVariables.compose(TVScale(5.0), TVExp(), TVNeg())`. +If you are working in an editor where typing Unicode is difficult, `TransformVariables.compose` is also available, as in `TransformVariables.compose(TVScale(5.0), TVNeg(), TVExp())`. This composition works with any scalar transform in any order, so `TVScale(4) ∘ as(Real, 2, ∞) ∘ TVShift(1e3)` is a valid transform. This is useful especially for making sure that values near 0, when transformed, yield usefully-scaled values for a given variable. diff --git a/src/scalar.jl b/src/scalar.jl index c8e41f8..1756dbc 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -182,7 +182,7 @@ Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTra Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) -const compose = Base.:∘ +compose(args...) = CompositeScalarTransform(args) #### From 4e7a08597f00c0a9d97b59ea3f4c401c48fe95c0 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Fri, 4 Apr 2025 11:47:32 +0200 Subject: [PATCH 36/61] Make sure only ScalarTransforms get passed to CompositeScalarTransform --- src/scalar.jl | 5 +++++ test/runtests.jl | 3 +++ 2 files changed, 8 insertions(+) diff --git a/src/scalar.jl b/src/scalar.jl index 1756dbc..3f45233 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -151,6 +151,11 @@ A composite scalar transformation, i.e. a sequence of scalar transformations. """ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform transforms::Ts + function CompositeScalarTransform(transforms::Ts) where {Ts <: Tuple} + @argcheck length(transforms) > 0 DomainError + @argcheck all(t -> t isa ScalarTransform, transforms) DomainError + new{Ts}(transforms) + end end Base.getindex(t::CompositeScalarTransform, i) = t.transforms[i] diff --git a/test/runtests.jl b/test/runtests.jl index 1643dd0..3c0bd5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -155,6 +155,9 @@ end @test_throws MethodError transform_and_logjac(t, 0.5) y = transform(t, 0.5) @test_throws MethodError inverse_and_logjac(t, y) + + @test_throws DomainError TransformVariables.compose(TVExp(), 5) + @test_throws DomainError TransformVariables.compose() end @testset "scalar alternatives" begin From dbeb48a5a7b8ea861b87a2fe3ee3db0512f1085d Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Tue, 8 Apr 2025 11:02:29 +0200 Subject: [PATCH 37/61] Move Unitful-related tests to be closer together --- test/runtests.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3c0bd5f..214e043 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -98,6 +98,13 @@ end test_transformation(t1, y -> y > a*2u"m", jac=false) t2 = TVScale(1u"s") ∘ TVShift(a) ∘ TVScale(b-a) ∘ TVLogistic() test_transformation(t2, y -> (a*u"s" < y < b*u"s"), jac=false) + + @test_throws MethodError transform_and_logjac(t1, 0.5) + @test_throws MethodError transform_and_logjac(t2, 0.5) + y1 = transform(t1, 0.5) + y2 = transform(t1, 0.5) + @test_throws MethodError inverse_and_logjac(t1, y1) + @test_throws MethodError inverse_and_logjac(t2, y2) end end @@ -151,11 +158,6 @@ end @test_throws DomainError inverse_and_logjac(t, 0.5) @test_throws DomainError inverse_and_logjac(t, 11.0) - t = TVScale(5.0u"m") ∘ as(Real, 1.0, 10.0) - @test_throws MethodError transform_and_logjac(t, 0.5) - y = transform(t, 0.5) - @test_throws MethodError inverse_and_logjac(t, y) - @test_throws DomainError TransformVariables.compose(TVExp(), 5) @test_throws DomainError TransformVariables.compose() end From a5ff6053f16757bc9c381f6bb5b4dc77df8d205a Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Tue, 8 Apr 2025 11:03:18 +0200 Subject: [PATCH 38/61] Remove unnecessary test --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 214e043..626b596 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -143,7 +143,6 @@ end @testset "scalar transformation corner cases" begin @test_throws ArgumentError as(Real, "a fish", 9) - @test_broken as(Real, 1, 4.0) == as(Real, 1.0, 4.0) @test_throws ArgumentError as(Real, 3.0, -4.0) t = as(Real, 1.0, ∞) From 9b2938f5c8236ab2bdf819cf695f1310a68f966c Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Tue, 8 Apr 2025 11:10:37 +0200 Subject: [PATCH 39/61] Add some internal docs --- src/scalar.jl | 2 ++ test/utilities.jl | 3 +++ 2 files changed, 5 insertions(+) diff --git a/src/scalar.jl b/src/scalar.jl index 3f45233..4160358 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -148,6 +148,8 @@ inverse_and_logjac(::TVNeg, x::Number) = -x, zero(x) $(TYPEDEF) A composite scalar transformation, i.e. a sequence of scalar transformations. +The component transforms can be accessed indexing into the `CompositeScalarTransform`, +as in `ct[i]` or `ct[end]`. """ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform transforms::Ts diff --git a/test/utilities.jl b/test/utilities.jl index f1b3fa7..cf0076a 100644 --- a/test/utilities.jl +++ b/test/utilities.jl @@ -29,6 +29,9 @@ Test transformation `t` with random values, `N` times. automatic differentiation. `test_inverse` determines whether the inverse is tested. + +`jac` determines whether `transform_and_logjac` is tested against the log +Jacobian from AD, true by default. The Jacobian is not defined for Unitful scaling. """ function test_transformation(t::AbstractTransform, is_valid_y; vec_y = identity, N = 1000, test_inverse = true, jac=true) From 5b1f05579271ed56eb2ba25eb1e2c03405ace5e8 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Tue, 8 Apr 2025 11:41:43 +0200 Subject: [PATCH 40/61] Get inverse_and_logjac tested for Identity() --- test/runtests.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 626b596..bbaf7ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -197,6 +197,13 @@ end x2, lj2 = TransformVariables.inverse_and_logjac(t, y) @test x2 ≈ x @test lj2 ≈ -lj + + # For completeness sake, + t = as(Real, -∞, ∞) + y, lj = transform_and_logjac(t, x) + x2, lj2 = TransformVariables.inverse_and_logjac(t, y) + @test x2 ≈ x + @test lj2 ≈ -lj end end From 6cf09c1dde4629bfcde38562ee29627b1dd57777 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Tue, 8 Apr 2025 11:50:13 +0200 Subject: [PATCH 41/61] make test look nicer --- test/runtests.jl | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index bbaf7ba..cdd4dcd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -111,22 +111,18 @@ end @testset "composite scalar transformations" begin all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg(), ScaledShiftedLogistic(2.0, 1.0), ShiftedExp(true, -5), ShiftedExp(false, 4.5)] - for t1 in all_transforms - for t2 in all_transforms - for t3 in all_transforms - t = t1 ∘ t2 ∘ t3 - # Basic functionality - @test t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3)}} - # Other constructions - @test t ∘ t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3), typeof(t1), typeof(t2), typeof(t3)}} - @test t == t1 ∘ (t2 ∘ t3) - @test t == ∘(t1, t2, t3) - @test t == TransformVariables.compose(t1, t2, t3) - @test all([t[1] == t1, t[2] == t2, t[3] == t3]) - @test all([t[begin] == t1, t[end] == t3]) - @test t[begin:end] == t[:] - end - end + for t1 in all_transforms, t2 in all_transforms, t3 in all_transforms + t = t1 ∘ t2 ∘ t3 + # Basic functionality + @test t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3)}} + # Other constructions + @test t ∘ t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3), typeof(t1), typeof(t2), typeof(t3)}} + @test t == t1 ∘ (t2 ∘ t3) + @test t == ∘(t1, t2, t3) + @test t == TransformVariables.compose(t1, t2, t3) + @test all([t[1] == t1, t[2] == t2, t[3] == t3]) + @test all([t[begin] == t1, t[end] == t3]) + @test t[begin:end] == t[:] end end @@ -136,7 +132,7 @@ end for s1 in same_domain_transforms, s2 in same_domain_transforms, n in new_domain_transforms for s3 in same_domain_transforms, s4 in same_domain_transforms # don't worry about valid output here, let inverse check that - test_transformation(s1 ∘ s2 ∘ n ∘ s3 ∘ s4, _ -> true; N=5) + test_transformation(s1 ∘ s2 ∘ n ∘ s3 ∘ s4, _ -> true; N=100) end end end From 2dd2bb455723334c5942707dc730da459edcf4bf Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Wed, 9 Apr 2025 16:19:18 +0200 Subject: [PATCH 42/61] Remove argcheck on inverse(Exp) Co-authored-by: David Widmann --- src/scalar.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/scalar.jl b/src/scalar.jl index 4160358..499460b 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -72,7 +72,6 @@ transform(::TVExp, x::Real) = exp(x) transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x function inverse(::TVExp, x::Number) - @argcheck x > 0 DomainError log(x) end inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), log(1/x) From ad49fd07e5ab73fcde1220ec065a3e4b20e76b4c Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Wed, 9 Apr 2025 16:19:57 +0200 Subject: [PATCH 43/61] Remove argcheck on inverse(Logistic) Co-authored-by: David Widmann --- src/scalar.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/scalar.jl b/src/scalar.jl index 499460b..fb34aee 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -87,7 +87,6 @@ transform(::TVLogistic, x::Real) = logistic(x) transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x) function inverse(::TVLogistic, x::Number) - @argcheck 0 < x < 1 DomainError logit(x) end inverse_and_logjac(t::TVLogistic, x::Number) = inverse(t, x), logit_logjac(x) From 05754ace11dae5711b8adff822e5d52349534e00 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Wed, 9 Apr 2025 16:24:40 +0200 Subject: [PATCH 44/61] Cleaner typing Co-authored-by: David Widmann --- src/scalar.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index fb34aee..a6cd280 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -120,10 +120,10 @@ end TVScale(scale::T) where {T} = TVScale{T}(scale) transform(t::TVScale, x::Real) = t.scale * x -transform_and_logjac(t::TVScale{T}, x::Real) where {T<:Real} = transform(t, x), log(t.scale) +transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale) inverse(t::TVScale, x::Number) = x / t.scale -inverse_and_logjac(t::TVScale{T}, x::Number) where {T<:Real} = inverse(t, x), -log(t.scale) +inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale) """ $(TYPEDEF) @@ -151,9 +151,7 @@ as in `ct[i]` or `ct[end]`. """ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform transforms::Ts - function CompositeScalarTransform(transforms::Ts) where {Ts <: Tuple} - @argcheck length(transforms) > 0 DomainError - @argcheck all(t -> t isa ScalarTransform, transforms) DomainError + function CompositeScalarTransform(transforms::Ts) where {Ts <: Tuple{ScalarTransform,Vararg{ScalarTransform}}} new{Ts}(transforms) end end From e7c231e99b82b898317dfc976f57e54936bbd19d Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Wed, 9 Apr 2025 16:26:58 +0200 Subject: [PATCH 45/61] Better zeros for inverse_and_logjac Co-authored-by: David Widmann --- src/scalar.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index a6cd280..cdabd4b 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -55,7 +55,7 @@ transform_and_logjac(::Identity, x::Real) = x, zero(x) inverse(::Identity, x::Number) = x -inverse_and_logjac(::Identity, x::Real) = x, zero(x) +inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) #### #### elementary scalar transforms @@ -74,7 +74,7 @@ transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x function inverse(::TVExp, x::Number) log(x) end -inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), log(1/x) +inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), -log(x) """ $(TYPEDEF) @@ -100,10 +100,10 @@ struct TVShift{T <: Real} <: ScalarTransform shift::T end transform(t::TVShift, x::Real) = x + t.shift -transform_and_logjac(t::TVShift, x::Real) = transform(t, x), zero(x) +transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) inverse(t::TVShift, x::Number) = x - t.shift -inverse_and_logjac(t::TVShift, x::Number) = inverse(t, x), zero(x) +inverse_and_logjac(t::TVShift, x::Number) = inverse(t, x), logjac_zero(LogJac(), typeof(x)) """ $(TYPEDEF) @@ -162,7 +162,7 @@ Base.lastindex(t::CompositeScalarTransform) = lastindex(t.transforms) transform(t::CompositeScalarTransform, x) = foldr((t, y) -> transform(t, y), t.transforms, init=x) function transform_and_logjac(ts::CompositeScalarTransform, x) - foldr(ts.transforms, init=(x, zero(x))) do t, (x, logjac) + foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac) nx, nlogjac = transform_and_logjac(t, x) x = nx logjac += nlogjac @@ -172,7 +172,7 @@ end inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x) function inverse_and_logjac(ts::CompositeScalarTransform, x) - foldl(ts.transforms, init=(x, zero(x))) do (x, logjac), t + foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t nx, nlogjac = inverse_and_logjac(t, x) x = nx logjac += nlogjac From 0604ddd0c4501c7887615d88601e10cfad402de7 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 9 Apr 2025 16:37:09 +0200 Subject: [PATCH 46/61] more logjac_zero --- src/scalar.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index cdabd4b..b363d89 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -51,7 +51,7 @@ struct Identity <: ScalarTransform end transform(::Identity, x::Real) = x -transform_and_logjac(::Identity, x::Real) = x, zero(x) +transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) inverse(::Identity, x::Number) = x @@ -134,10 +134,10 @@ struct TVNeg <: ScalarTransform end transform(::TVNeg, x::Real) = -x -transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), zero(x) +transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) inverse(::TVNeg, x::Number) = -x -inverse_and_logjac(::TVNeg, x::Number) = -x, zero(x) +inverse_and_logjac(::TVNeg, x::Number) = -x, logjac_zero(LogJac(), typeof(x)) #### #### composite scalar transforms From 593c2f8f7e3e0992d4256a07e1f6326a1cae5185 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Wed, 9 Apr 2025 16:39:14 +0200 Subject: [PATCH 47/61] Clean up foldl/foldr operation Co-authored-by: David Widmann --- src/scalar.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index b363d89..fc275c3 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -164,9 +164,7 @@ transform(t::CompositeScalarTransform, x) = foldr((t, y) -> transform(t, y), t.t function transform_and_logjac(ts::CompositeScalarTransform, x) foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac) nx, nlogjac = transform_and_logjac(t, x) - x = nx - logjac += nlogjac - (x, logjac) + (nx, logjac + nlogjac) end end @@ -174,9 +172,7 @@ inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.tra function inverse_and_logjac(ts::CompositeScalarTransform, x) foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t nx, nlogjac = inverse_and_logjac(t, x) - x = nx - logjac += nlogjac - (x, logjac) + (nx, logjac + nlogjac) end end From 6ebb9ab31c6e289f83593f1769490eeb1b6a26e6 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 9 Apr 2025 16:51:47 +0200 Subject: [PATCH 48/61] Remove indexing into CompositeTransform --- src/scalar.jl | 6 ------ test/runtests.jl | 3 --- 2 files changed, 9 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index b363d89..35f191f 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -146,8 +146,6 @@ inverse_and_logjac(::TVNeg, x::Number) = -x, logjac_zero(LogJac(), typeof(x)) $(TYPEDEF) A composite scalar transformation, i.e. a sequence of scalar transformations. -The component transforms can be accessed indexing into the `CompositeScalarTransform`, -as in `ct[i]` or `ct[end]`. """ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform transforms::Ts @@ -156,10 +154,6 @@ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform end end -Base.getindex(t::CompositeScalarTransform, i) = t.transforms[i] -Base.firstindex(t::CompositeScalarTransform) = firstindex(t.transforms) -Base.lastindex(t::CompositeScalarTransform) = lastindex(t.transforms) - transform(t::CompositeScalarTransform, x) = foldr((t, y) -> transform(t, y), t.transforms, init=x) function transform_and_logjac(ts::CompositeScalarTransform, x) foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac) diff --git a/test/runtests.jl b/test/runtests.jl index cdd4dcd..fce29e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -120,9 +120,6 @@ end @test t == t1 ∘ (t2 ∘ t3) @test t == ∘(t1, t2, t3) @test t == TransformVariables.compose(t1, t2, t3) - @test all([t[1] == t1, t[2] == t2, t[3] == t3]) - @test all([t[begin] == t1, t[end] == t3]) - @test t[begin:end] == t[:] end end From de68e4de880e379d9d81d17aa763d33655c8d115 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 9 Apr 2025 16:53:27 +0200 Subject: [PATCH 49/61] Restrict cases where asR+, etc. get string printed --- src/scalar.jl | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index a0dbf35..e612fdf 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -343,43 +343,25 @@ const asℝ = as(Real, -∞, ∞) const as_real = asℝ # Fallback method: print all transforms in order -function Base.show(io::IO, ct::CompositeScalarTransform) - str = string(ct.transforms[1]) - for ti in ct.transforms[begin+1:end] - str *= " ∘ "*string(ti) - end - print(io, str) -end +Base.show(io::IO, ct::CompositeScalarTransform) = join(io, ct.transforms, " ∘ ") -# If equivalent to asℝ₊, print as such. Two ways to achieve this function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVExp}}) where T - if ct[1].shift == 0 - print(io, "asℝ₊") - else - print(io, "as(Real, ", ct[1].shift, ", ∞)") - end + print(io, "as(Real, ", ct[1].shift, ", ∞)") end -Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVExp}}) = print(io, "asℝ₊") +# If equivalent to asℝ₊, print as such. +Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVExp}}) = print(io, "asℝ₊") -# If equivalent to asℝ₋, print as such. Two ways to achieve this +# If equivalent to asℝ₋, print as such. function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVNeg, TVExp}}) where T - if ct[1].shift == 0 - print(io, "asℝ₋") - else - print(io, "as(Real, -∞, ", ct[1].shift, ")") - end + print(io, "as(Real, -∞, ", ct[1].shift, ")") end -Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVNeg, TVExp}}) = print(io, "asℝ₋") +Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVNeg, TVExp}}) = print(io, "asℝ₋") # If equivalent to as𝕀, print as such. Two ways to achieve this function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T1}, TVScale{T2}, TVLogistic}}) where {T1, T2} - if ct[1].shift == 0 && ct[2].scale == 1 - print(io, "as𝕀") - else - print(io, "as(Real, ", ct[1].shift, ", ", ct[1].shift + ct[2].scale, ")") - end + print(io, "as(Real, ", ct[1].shift, ", ", ct[1].shift + ct[2].scale, ")") end -Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVLogistic}}) = print(io, "as𝕀") +Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVLogistic}}) = print(io, "as𝕀") function Base.show(io::IO, t::TVScale) print(io, "TVScale(", t.scale, ")") From 0af740685dd7574c4f7164219eb19d836b7409f7 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Wed, 9 Apr 2025 16:56:01 +0200 Subject: [PATCH 50/61] Clearly better suggestions Co-authored-by: David Widmann --- src/scalar.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index e612fdf..b069696 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -303,7 +303,8 @@ as(::Type{Real}, ::Infinity{false}, right::Real) = TVShift(right) ∘ TVNeg() function as(::Type{Real}, left::Real, right::Real) @argcheck left < right "the interval ($(left), $(right)) is empty" - TVShift(left) ∘ TVScale(right - left) ∘ TVLogistic() + shift, scale = promote(left, right - left) + TVShift(shift) ∘ TVScale(scale) ∘ TVLogistic() end """ @@ -345,7 +346,7 @@ const as_real = asℝ # Fallback method: print all transforms in order Base.show(io::IO, ct::CompositeScalarTransform) = join(io, ct.transforms, " ∘ ") -function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVExp}}) where T +function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{<:TVShift, TVExp}}) print(io, "as(Real, ", ct[1].shift, ", ∞)") end # If equivalent to asℝ₊, print as such. From 0935932140c9a5571710b6c96a6f48c457f73d87 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 9 Apr 2025 17:18:47 +0200 Subject: [PATCH 51/61] Use CompositionsBase to provide non-Unicode compose --- Project.toml | 2 ++ src/scalar.jl | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c34627b..6a3de76 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.8.14" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -23,6 +24,7 @@ InverseFunctionsExt = "InverseFunctions" [compat] ArgCheck = "1, 2" ChangesOfVariables = "0.1" +CompositionsBase = "0.1.2" DocStringExtensions = "0.8, 0.9" ForwardDiff = "0.10" InverseFunctions = "0.1" diff --git a/src/scalar.jl b/src/scalar.jl index b069696..3eff776 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -175,7 +175,6 @@ Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTra Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) -compose(args...) = CompositeScalarTransform(args) #### From d7d23982c0a6b4a0464f0f9ce0c975e35a61b9a7 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Wed, 9 Apr 2025 17:24:06 +0200 Subject: [PATCH 52/61] Nice little elision Co-authored-by: David Widmann --- src/scalar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scalar.jl b/src/scalar.jl index b069696..8e78c43 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -154,7 +154,7 @@ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform end end -transform(t::CompositeScalarTransform, x) = foldr((t, y) -> transform(t, y), t.transforms, init=x) +transform(t::CompositeScalarTransform, x) = foldr(transform, t.transforms, init=x) function transform_and_logjac(ts::CompositeScalarTransform, x) foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac) nx, nlogjac = transform_and_logjac(t, x) From a2145833bbc826707b5afaedfb3ee0a8a4a702f1 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 9 Apr 2025 20:03:40 +0200 Subject: [PATCH 53/61] Get scalar show tests passing again, for now --- src/scalar.jl | 9 +++++---- test/runtests.jl | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 06b5fa7..988a61e 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -345,21 +345,22 @@ const as_real = asℝ # Fallback method: print all transforms in order Base.show(io::IO, ct::CompositeScalarTransform) = join(io, ct.transforms, " ∘ ") -function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{<:TVShift, TVExp}}) - print(io, "as(Real, ", ct[1].shift, ", ∞)") +function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVExp}}) where T + print(io, "as(Real, ", ct.transforms[1].shift, ", ∞)") end # If equivalent to asℝ₊, print as such. Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVExp}}) = print(io, "asℝ₊") # If equivalent to asℝ₋, print as such. function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVNeg, TVExp}}) where T - print(io, "as(Real, -∞, ", ct[1].shift, ")") + print(io, "as(Real, -∞, ", ct.transforms[1].shift, ")") end Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVNeg, TVExp}}) = print(io, "asℝ₋") # If equivalent to as𝕀, print as such. Two ways to achieve this function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T1}, TVScale{T2}, TVLogistic}}) where {T1, T2} - print(io, "as(Real, ", ct[1].shift, ", ", ct[1].shift + ct[2].scale, ")") + print(io, "as(Real, ", ct.transforms[1].shift, ", ", ct.transforms[1].shift + + ct.transforms[2].scale, ")") end Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVLogistic}}) = print(io, "as𝕀") diff --git a/test/runtests.jl b/test/runtests.jl index fce29e8..e5190f8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,9 +52,9 @@ end @test string(asℝ₊) == "asℝ₊" @test string(asℝ₋) == "asℝ₋" @test string(as𝕀) == "as𝕀" - @test string(TVShift(0) ∘ TVNeg() ∘ TVExp()) == "asℝ₋" - @test string(TVShift(0) ∘ TVExp()) == "asℝ₊" - @test string(TVShift(0) ∘ TVScale(1) ∘ TVLogistic()) == "as𝕀" + # @test string(TVShift(0) ∘ TVNeg() ∘ TVExp()) == "asℝ₋" + # @test string(TVShift(0) ∘ TVExp()) == "asℝ₊" + # @test string(TVShift(0) ∘ TVScale(1) ∘ TVLogistic()) == "as𝕀" @test string(TVNeg() ∘ TVExp()) == "asℝ₋" @test string(∘(TVExp())) == "asℝ₊" @test string(∘(TVLogistic())) == "as𝕀" From 7996d6f6f621d6ceec842d9208512cf94359eceb Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 9 Apr 2025 20:14:35 +0200 Subject: [PATCH 54/61] Actually import compose, redo test --- src/TransformVariables.jl | 1 + test/runtests.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/TransformVariables.jl b/src/TransformVariables.jl index bf6bf72..f43b3f1 100644 --- a/src/TransformVariables.jl +++ b/src/TransformVariables.jl @@ -7,6 +7,7 @@ using LogExpFunctions using LinearAlgebra: UpperTriangular, logabsdet using Random: AbstractRNG, GLOBAL_RNG using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst +using CompositionsBase include("utilities.jl") include("generic.jl") diff --git a/test/runtests.jl b/test/runtests.jl index e5190f8..9999359 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -150,8 +150,8 @@ end @test_throws DomainError inverse_and_logjac(t, 0.5) @test_throws DomainError inverse_and_logjac(t, 11.0) - @test_throws DomainError TransformVariables.compose(TVExp(), 5) - @test_throws DomainError TransformVariables.compose() + t = TVExp ∘ 5 + @test_throws MethodError transform(t, 0.5) end @testset "scalar alternatives" begin From f65fe0544897cd6af45c64d64b267f38c1fb7d24 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 9 Apr 2025 20:15:31 +0200 Subject: [PATCH 55/61] Remove ShiftedExp,etc. from new tests --- src/scalar.jl | 1 - test/runtests.jl | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 988a61e..7fa35f2 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -176,7 +176,6 @@ Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTra Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) - #### #### shifted exponential #### diff --git a/test/runtests.jl b/test/runtests.jl index 9999359..1197452 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -109,8 +109,7 @@ end end @testset "composite scalar transformations" begin - all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg(), - ScaledShiftedLogistic(2.0, 1.0), ShiftedExp(true, -5), ShiftedExp(false, 4.5)] + all_transforms = [TVShift(3.0), TVScale(2.0), TVExp(), TVLogistic(), TVNeg()] for t1 in all_transforms, t2 in all_transforms, t3 in all_transforms t = t1 ∘ t2 ∘ t3 # Basic functionality @@ -125,7 +124,7 @@ end @testset "semiarbitrary compositions" begin same_domain_transforms = [TVShift(3.0), TVScale(2.0), TVNeg()] - new_domain_transforms = [TVExp(), TVLogistic(), ScaledShiftedLogistic(2.0, 1.0), ShiftedExp(true, -5), ShiftedExp(false, 4.5)] + new_domain_transforms = [TVExp(), TVLogistic()] for s1 in same_domain_transforms, s2 in same_domain_transforms, n in new_domain_transforms for s3 in same_domain_transforms, s4 in same_domain_transforms # don't worry about valid output here, let inverse check that From 4172d022ee1ea98a823cf3f4377b11e20786057f Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 16 Apr 2025 10:16:19 +0200 Subject: [PATCH 56/61] Amend asR+ an asI to be individual scalar transforms, not single-element composite --- src/scalar.jl | 54 ++++++++++++++---------------------------------- test/runtests.jl | 13 ++++-------- 2 files changed, 20 insertions(+), 47 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 7fa35f2..96edaf7 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -310,7 +310,7 @@ Transform to a positive real number. See [`as`](@ref). `asℝ₊` and `as_positive_real` are equivalent alternatives. """ -const asℝ₊ = ∘(TVExp()) +const asℝ₊ = TVExp() const as_positive_real = asℝ₊ @@ -328,7 +328,7 @@ Transform to the unit interval `(0, 1)`. See [`as`](@ref). `as𝕀` and `as_unit_interval` are equivalent alternatives. """ -const as𝕀 = ∘(TVLogistic()) +const as𝕀 = TVLogistic() const as_unit_interval = as𝕀 @@ -341,53 +341,31 @@ const asℝ = as(Real, -∞, ∞) const as_real = asℝ +# Single scalar transforms +Base.show(io::IO, ::Identity) = print(io, "asℝ") +Base.show(io::IO, ::TVExp) = print(io, "asℝ₊") +Base.show(io::IO, ::TVLogistic) = print(io, "as𝕀") +function Base.show(io::IO, t::TVScale) + print(io, "TVScale(", t.scale, ")") +end +function Base.show(io::IO, t::TVShift) + print(io, "TVShift(", t.shift, ")") +end + # Fallback method: print all transforms in order Base.show(io::IO, ct::CompositeScalarTransform) = join(io, ct.transforms, " ∘ ") +# Special cases which are constructed by as(Real, ...) function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVExp}}) where T print(io, "as(Real, ", ct.transforms[1].shift, ", ∞)") end -# If equivalent to asℝ₊, print as such. -Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVExp}}) = print(io, "asℝ₊") - -# If equivalent to asℝ₋, print as such. function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVNeg, TVExp}}) where T print(io, "as(Real, -∞, ", ct.transforms[1].shift, ")") end -Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVNeg, TVExp}}) = print(io, "asℝ₋") - -# If equivalent to as𝕀, print as such. Two ways to achieve this function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T1}, TVScale{T2}, TVLogistic}}) where {T1, T2} print(io, "as(Real, ", ct.transforms[1].shift, ", ", ct.transforms[1].shift + ct.transforms[2].scale, ")") end -Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVLogistic}}) = print(io, "as𝕀") - -function Base.show(io::IO, t::TVScale) - print(io, "TVScale(", t.scale, ")") -end -function Base.show(io::IO, t::TVShift) - print(io, "TVShift(", t.shift, ")") -end -function Base.show(io::IO, t::ShiftedExp) - if t === asℝ₊ - print(io, "asℝ₊") - elseif t === asℝ₋ - print(io, "asℝ₋") - elseif t isa ShiftedExp{true} - print(io, "as(Real, ", t.shift, ", ∞)") - else - print(io, "as(Real, -∞, ", t.shift, ")") - end -end - -function Base.show(io::IO, t::ScaledShiftedLogistic) - if t === as𝕀 - print(io, "as𝕀") - else - print(io, "as(Real, ", t.shift, ", ", t.shift + t.scale, ")") - end -end - -Base.show(io::IO, t::Identity) = print(io, "asℝ") +# Special case for asR- +Base.show(io::IO, ::CompositeScalarTransform{Tuple{TVNeg, TVExp}}) = print(io, "asℝ₋") diff --git a/test/runtests.jl b/test/runtests.jl index 1197452..d151e76 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,24 +52,19 @@ end @test string(asℝ₊) == "asℝ₊" @test string(asℝ₋) == "asℝ₋" @test string(as𝕀) == "as𝕀" - # @test string(TVShift(0) ∘ TVNeg() ∘ TVExp()) == "asℝ₋" - # @test string(TVShift(0) ∘ TVExp()) == "asℝ₊" - # @test string(TVShift(0) ∘ TVScale(1) ∘ TVLogistic()) == "as𝕀" @test string(TVNeg() ∘ TVExp()) == "asℝ₋" - @test string(∘(TVExp())) == "asℝ₊" - @test string(∘(TVLogistic())) == "as𝕀" @test string(as(Real, 0.0, 2.0)) == "as(Real, 0.0, 2.0)" @test string(as(Real, 1.0, ∞)) == "as(Real, 1.0, ∞)" @test string(as(Real, -∞, 1.0)) == "as(Real, -∞, 1.0)" + @test string(TVExp()) == "asℝ₊" + @test string(TVLogistic()) == "as𝕀" @test string(TVShift(4.0)) == "TVShift(4.0)" @test string(TVScale(4.0)) == "TVScale(4.0)" - @test string(TVExp()) == "TVExp()" - @test string(TVLogistic()) == "TVLogistic()" @test string(TVNeg()) == "TVNeg()" - @test string(TVScale(2.0) ∘ TVNeg() ∘ TVExp()) == "TVScale(2.0) ∘ TVNeg() ∘ TVExp()" - @test string(TVScale(5.0u"m") ∘ TVExp()) == "TVScale(5.0 m) ∘ TVExp()" + @test string(TVScale(2.0) ∘ TVNeg() ∘ TVExp()) == "TVScale(2.0) ∘ TVNeg() ∘ asℝ₊" + @test string(TVScale(5.0u"m") ∘ TVLogistic()) == "TVScale(5.0 m) ∘ as𝕀" end @testset "scalar transformations consistency" begin From 13034d4c7d287240097bcd15966779093c186a78 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 16 Apr 2025 10:36:09 +0200 Subject: [PATCH 57/61] Fully remove ShiftedExp and ScaledShiftedLogistic --- docs/src/internals.md | 2 -- src/scalar.jl | 80 ------------------------------------------- test/runtests.jl | 13 ++----- 3 files changed, 3 insertions(+), 92 deletions(-) diff --git a/docs/src/internals.md b/docs/src/internals.md index 6d0fc13..8960f20 100644 --- a/docs/src/internals.md +++ b/docs/src/internals.md @@ -8,8 +8,6 @@ These are not part of the API, use the `as` constructor or one of the predefined ```@docs TransformVariables.Identity -TransformVariables.ScaledShiftedLogistic -TransformVariables.ShiftedExp ``` ### Aggregating transformations diff --git a/src/scalar.jl b/src/scalar.jl index 96edaf7..65abd57 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -176,86 +176,6 @@ Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTra Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) -#### -#### shifted exponential -#### - -""" -$(TYPEDEF) - -Shifted exponential. When `D::Bool == true`, maps to `(shift, ∞)` using `x ↦ -shift + eˣ`, otherwise to `(-∞, shift)` using `x ↦ shift - eˣ`. -""" -struct ShiftedExp{D, T <: Real} <: ScalarTransform - shift::T - function ShiftedExp{D,T}(shift::T) where {D, T <: Real} - @argcheck D isa Bool - new(shift) - end -end - -ShiftedExp(D::Bool, shift::T) where {T <: Real} = ShiftedExp{D,T}(shift) - -transform(t::ShiftedExp{D}, x::Real) where D = - D ? t.shift + exp(x) : t.shift - exp(x) - -transform_and_logjac(t::ShiftedExp, x::Real) = transform(t, x), x - -function inverse(t::ShiftedExp{D}, x::Real) where D - (; shift) = t - if D - @argcheck x > shift DomainError - log(x - shift) - else - @argcheck x < shift DomainError - log(shift - x) - end -end - -#### -#### scaled and shifted logistic -#### - -""" -$(TYPEDEF) - -Maps to `(scale, shift + scale)` using `logistic(x) * scale + shift`. -""" -struct ScaledShiftedLogistic{T <: Real} <: ScalarTransform - scale::T - shift::T - function ScaledShiftedLogistic{T}(scale::T, shift::T) where {T <: Real} - @argcheck scale > 0 - new(scale, shift) - end -end - -ScaledShiftedLogistic(scale::T, shift::T) where {T <: Real} = - ScaledShiftedLogistic{T}(scale, shift) - -ScaledShiftedLogistic(scale::Real, shift::Real) = - ScaledShiftedLogistic(promote(scale, shift)...) - -# Switch to muladd and now it does have a DiffRule defined -transform(t::ScaledShiftedLogistic, x::Real) = muladd(logistic(x), t.scale, t.shift) - -transform_and_logjac(t::ScaledShiftedLogistic, x) = - transform(t, x), log(t.scale) + logistic_logjac(x) - -function inverse(t::ScaledShiftedLogistic, y) - @argcheck y > t.shift DomainError - @argcheck y < t.scale + t.shift DomainError - logit((y - t.shift)/t.scale) -end - -# NOTE: inverse_and_logjac interface experimental and sporadically implemented for now -function inverse_and_logjac(t::ScaledShiftedLogistic, y) - @argcheck y > t.shift DomainError - @argcheck y < t.scale + t.shift DomainError - z = (y - t.shift) / t.scale - logit(z), logit_logjac(z) - log(t.scale) -end - #### #### to_interval interface #### diff --git a/test/runtests.jl b/test/runtests.jl index d151e76..a454d38 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,7 @@ using LogDensityProblemsAD using TransformVariables: AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation, unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, - NOLOGJAC, transform_with, ShiftedExp, ScaledShiftedLogistic + NOLOGJAC, transform_with import ChangesOfVariables, InverseFunctions using Enzyme: autodiff, ReverseWithPrimal, Active, Const using Unitful: @u_str, ustrip, uconvert @@ -74,13 +74,6 @@ end test_transformation(as(Real, a, ∞), y -> y > a) b = a + 0.5 + rand(Float64) + exp(randn() * 10) test_transformation(as(Real, a, b), y -> a < y < b) - - posexp = TVShift(a) ∘ TVExp() - negexp = TVShift(a) ∘ TVNeg() ∘ TVExp() - finite = TVShift(a) ∘ TVScale(b-a) ∘ TVLogistic() - test_transformation(posexp, y -> y > a) - test_transformation(negexp, y -> y < a) - test_transformation(finite, y -> a < y < b) end test_transformation(as(Real, -∞, ∞), _ -> true) end @@ -634,14 +627,14 @@ end #### @testset "sum dimensions allocations" begin - shifted = TransformVariables.ShiftedExp{true,Float64}(0.0) + shifted = TVShift(0.0) ∘ TVExp() tr = (a = shifted, b = TransformVariables.Identity(), c = shifted, d = shifted, e = shifted, f = shifted) @test iszero(@allocated TransformVariables._sum_dimensions(tr)) end if VERSION >= v"1.7" @testset "inverse_eltype allocations" begin - trf = as((x0 = TransformVariables.ShiftedExp{true, Float32}(0f0), x1 = TransformVariables.Identity(), x2 = UnitSimplex(7), x3 = TransformVariables.CorrCholeskyFactor(5), x4 = as(Real, -∞, 1), x5 = as(Array, 10, 2), x6 = as(Array, as𝕀, 10), x7 = as((a = asℝ₊, b = as𝕀)), x8 = TransformVariables.UnitVector(10), x9 = TransformVariables.ShiftedExp{true, Float32}(0f0), x10 = TransformVariables.ShiftedExp{true, Float32}(0f0), x11 = TransformVariables.ShiftedExp{true, Float32}(0f0), x12 = TransformVariables.ShiftedExp{true, Float32}(0f0), x13 = TransformVariables.Identity(), x14 = TransformVariables.ShiftedExp{true, Float32}(0f0), x15 = TransformVariables.ShiftedExp{true, Float32}(0f0), x16 = TransformVariables.ShiftedExp{true, Float32}(0f0), x17 = TransformVariables.ShiftedExp{true, Float64}(0.0))); + trf = as((x0 = TVShift(0f0) ∘ TVExp(), x1 = TransformVariables.Identity(), x2 = UnitSimplex(7), x3 = TransformVariables.CorrCholeskyFactor(5), x4 = as(Real, -∞, 1), x5 = as(Array, 10, 2), x6 = as(Array, as𝕀, 10), x7 = as((a = asℝ₊, b = as𝕀)), x8 = TransformVariables.UnitVector(10), x9 = TVShift(0f0) ∘ TVExp(), x10 = TVShift(0f0) ∘ TVExp(), x11 = TVShift(0f0) ∘ TVExp(), x12 = TVShift(0f0) ∘ TVExp(), x13 = TransformVariables.Identity(), x14 = TVShift(0f0) ∘ TVExp(), x15 = TVShift(0f0) ∘ TVExp(), x16 = TVShift(0f0) ∘ TVExp(), x17 = TVShift(0.0) ∘ TVExp())); vx = randn(@inferred(TransformVariables.dimension(trf))); x = TransformVariables.transform(trf, vx); From 090b08f20c1efd812216f371afe578e7c9dab20f Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Wed, 16 Apr 2025 10:49:19 +0200 Subject: [PATCH 58/61] Add another note to docs about extending scaling for custom number types --- docs/src/index.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index 271a888..d0d7cec 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -125,6 +125,11 @@ produces positive quantities with the dimension of length. !!! note Because the log-Jacobian of a transform that adds units is not defined, `transform_and_logjac` and `inverse_and_logjac` only have methods defined for `TVScale{T} where {T<:Real}`. +!!! note + The inverse transform of `TVScale(scale)` divides by `scale`, which is the correct inverse for adding units to a number, but may be inappropriate for other custom number types. A transform that doesn't just multiply or an inverse that extracts a float from an exotic number type could be defined by adding methods to `transform` and `inverse` like the following: + ``` + transform(t::TVScale{T}, x) where T<:MyCustomNumberType = MyCustomNumberType(x) + inverse(t::TVScale{T}, x) where T<:MyCustomNumberType = get_the_float_part(x)``` ## Special arrays From f0e3d9080e041d0c10eeb8581f1af0dce43b3e00 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Wed, 23 Apr 2025 19:10:20 +0200 Subject: [PATCH 59/61] Update docs to clarify use of non-Real numbers Co-authored-by: Tamas K. Papp --- docs/src/index.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index d0d7cec..e5e355e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -114,7 +114,9 @@ If you are working in an editor where typing Unicode is difficult, `TransformVar This composition works with any scalar transform in any order, so `TVScale(4) ∘ as(Real, 2, ∞) ∘ TVShift(1e3)` is a valid transform. This is useful especially for making sure that values near 0, when transformed, yield usefully-scaled values for a given variable. -In addition, the `TVScale` transform accepts arbitrary types. It can be used as the outermost transform (so leftmost in the composition) to add Unitful units to a number (or to create other exotic number types which can be constructed by multiplying, such as a `ForwardDiff.Dual`). +In addition, the `TVScale` transform accepts arbitrary types. It can be used as the outermost transform (so leftmost in the composition) to add, for example, `Unitful` units to a number (or to create other exotic number types which can be constructed by multiplying, such as a `ForwardDiff.Dual`). + +However, note that calculating log Jacobian determinants may error for types that are not real numbers. For example, ```julia From 7daeb1eb2fb41d7de0dbe6e7476b76743002eb33 Mon Sep 17 00:00:00 2001 From: Isaac Wheeler <47340776+Ickaser@users.noreply.github.com> Date: Thu, 24 Apr 2025 09:50:30 +0200 Subject: [PATCH 60/61] Trim a whitespace Co-authored-by: David Widmann --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index a454d38..2cd5859 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,7 @@ using LogDensityProblemsAD using TransformVariables: AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation, unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, - NOLOGJAC, transform_with + NOLOGJAC, transform_with import ChangesOfVariables, InverseFunctions using Enzyme: autodiff, ReverseWithPrimal, Active, Const using Unitful: @u_str, ustrip, uconvert From 9018fb0957152f1b1d8b346673386df53da8cdcd Mon Sep 17 00:00:00 2001 From: Isaac Wheeler Date: Thu, 24 Apr 2025 10:16:33 +0200 Subject: [PATCH 61/61] Catch CompositeScalarTransforms inside the Vararg composite method --- src/scalar.jl | 2 +- test/runtests.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 65abd57..4ab3370 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -174,7 +174,7 @@ Base.:∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...)) Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) -Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = CompositeScalarTransform((t, tt...)) +Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = foldl(∘, tt; init=t) #### #### to_interval interface diff --git a/test/runtests.jl b/test/runtests.jl index a454d38..438fbdd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -106,7 +106,7 @@ end @test t ∘ t isa TransformVariables.CompositeScalarTransform{Tuple{typeof(t1), typeof(t2), typeof(t3), typeof(t1), typeof(t2), typeof(t3)}} @test t == t1 ∘ (t2 ∘ t3) @test t == ∘(t1, t2, t3) - @test t == TransformVariables.compose(t1, t2, t3) + @test t == TransformVariables.compose( ∘(t1), ∘(t2), ∘(t3)) end end @@ -117,6 +117,7 @@ end for s3 in same_domain_transforms, s4 in same_domain_transforms # don't worry about valid output here, let inverse check that test_transformation(s1 ∘ s2 ∘ n ∘ s3 ∘ s4, _ -> true; N=100) + @test TransformVariables.compose(s1, ∘(s2), n, ∘(s3), s4) == s1 ∘ s2 ∘ n ∘ s3 ∘ s4 end end end