diff --git a/src/scalar.jl b/src/scalar.jl index 509bc5e..3f5c0c4 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -54,32 +54,33 @@ inverse(::Identity, x::Real) = x """ $(TYPEDEF) -Shifted exponential. When `D::Bool == true`, maps to `(shift, ∞)` using `x ↦ -shift + eˣ`, otherwise to `(-∞, shift)` using `x ↦ shift - eˣ`. +Shifted exponential. When `D::Bool == true`, maps to `(shift, scale, ∞)` using `x ↦ +shift + exp(x/scale)`, otherwise to `(-∞, shift)` using `x ↦ shift - exp(x/scale)`. """ -struct ShiftedExp{D, T <: Real} <: ScalarTransform +struct ShiftedExp{D, T <: Real, S <: Real} <: ScalarTransform shift::T - function ShiftedExp{D,T}(shift::T) where {D, T <: Real} + scale::S + function ShiftedExp{D,T,S}(shift::T, scale::S) where {D, T <: Real, S <: Real} @argcheck D isa Bool - new(shift) + new(shift, scale) end end - -ShiftedExp(D::Bool, shift::T) where {T <: Real} = ShiftedExp{D,T}(shift) +ShiftedExp(D::Bool, shift::T) where {T <: Real} = ShiftedExp{D,T,T}(shift, one(T)) +ShiftedExp(D::Bool, shift::T, scale::S) where {T <: Real, S <: Real} = ShiftedExp{D,T,S}(shift, scale) transform(t::ShiftedExp{D}, x::Real) where D = - D ? t.shift + exp(x) : t.shift - exp(x) + D ? t.shift + exp(x/t.scale) : t.shift - exp(x/t.scale) -transform_and_logjac(t::ShiftedExp, x::Real) = transform(t, x), x +transform_and_logjac(t::ShiftedExp, x::Real) = transform(t, x), x/t.scale - log(t.scale) function inverse(t::ShiftedExp{D}, x::Real) where D - @unpack shift = t + @unpack shift, scale = t if D @argcheck x > shift DomainError - log(x - shift) + scale*log(x - shift) else @argcheck x < shift DomainError - log(shift - x) + scale*log(shift - x) end end @@ -155,23 +156,32 @@ Return a transformation that transforms a single real number to the given (open) interval. `left < right` is required, but may be `-∞` or `∞`, respectively, in which case -the appropriate transformation is selected. See [`∞`](@ref). +the appropriate transformation is selected. See [`∞`](@ref). If `left` or `right` +are infinite, optionally the scale of the variable can be provied, e.g.: +``` +as(Real, left, ∞; scale=10) +```` Some common transformations are predefined as constants, see [`asℝ`](@ref), [`asℝ₋`](@ref), [`asℝ₊`](@ref), [`as𝕀`](@ref). !!! note - The finite arguments are promoted to a common type and affect promotion. Eg - `transform(as(0, ∞), 0f0) isa Float32`, but `transform(as(0.0, ∞), 0f0) isa Float64`. + The finite arguments are promoted to a common type and affect promotion. E.g. + `transform(as(0, ∞; scale=10f0), 0f0) isa Float32`, but + `transform(as(0.0, ∞), 0f0) isa Float64`. """ as(::Type{Real}, left, right) = throw(ArgumentError("($(left), $(right)) must be an interval")) as(::Type{Real}, ::Infinity{false}, ::Infinity{true}) = Identity() -as(::Type{Real}, left::Real, ::Infinity{true}) = ShiftedExp(true, left) +function as(::Type{Real}, left::T, ::Infinity{true}; scale=1) where T <: Real + ShiftedExp(true, left, scale) +end -as(::Type{Real}, ::Infinity{false}, right::Real) = ShiftedExp(false, right) +function as(::Type{Real}, ::Infinity{false}, right::T; scale=1) where T <: Real + ShiftedExp(false, right, scale) +end function as(::Type{Real}, left::Real, right::Real) @argcheck left < right "the interval ($(left), $(right)) is empty" @@ -220,9 +230,9 @@ Base.show(io::IO, t::ShiftedExp) = elseif t === asℝ₋ print(io, "asℝ₋") elseif t isa ShiftedExp{true} - print(io, "as(Real, ", t.shift, ", ∞)") + print(io, "as(Real, ", t.shift, ", ∞; scale = ", t.scale, ")") else - print(io, "as(Real, -∞, ", t.shift, ")") + print(io, "as(Real, -∞, ", t.shift, "; scale = ", t.scale, ")") end Base.show(io::IO, t::ScaledShiftedLogistic) = diff --git a/test/runtests.jl b/test/runtests.jl index df948a0..a1d2fbf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,15 @@ end a = randn() * 100 test_transformation(as(Real, -∞, a), y -> y < a) test_transformation(as(Real, a, ∞), y -> y > a) + + a = randn() * 1000 + test_transformation(as(Real, a, ∞, scale = 10.0), y -> y > a) + test_transformation(as(Real, -∞, a, scale = 10.0), y -> y < a) + + a = randn() * 10 + test_transformation(as(Real, a, ∞, scale = 0.1), y -> y > a) + test_transformation(as(Real, -∞, a, scale = 0.1), y -> y < a) + b = a + 0.5 + rand(Float64) + exp(randn() * 10) test_transformation(as(Real, a, b), y -> a < y < b) end @@ -61,9 +70,13 @@ end t = as(Real, 1.0, ∞) @test_throws DomainError inverse(t, 0.5) + t = as(Real, 1.0, ∞; scale = 0.1) + @test_throws DomainError inverse(t, 0.5) t = as(Real, -∞, 10.0) @test_throws DomainError inverse(t, 11.0) + t = as(Real, -∞, 10.0; scale=0.1) + @test_throws DomainError inverse(t, 11.0) t = as(Real, 1.0, 10.0) @test_throws DomainError inverse(t, 0.5) @@ -85,6 +98,8 @@ end @test transform(asℝ₊, a) isa Float32 @test transform(asℝ₋, a) isa Float32 @test transform(as𝕀, a) isa Float32 + @test transform(as(Real, 0, ∞), a) isa Float32 + @test transform(as(Real, -∞, 0), a) isa Float32 end #### @@ -529,24 +544,24 @@ end @test string(asℝ₋) == "asℝ₋" @test string(as𝕀) == "as𝕀" @test string(as(Real, 0.0, 2.0)) == "as(Real, 0.0, 2.0)" - @test string(as(Real, 1.0, ∞)) == "as(Real, 1.0, ∞)" - @test string(as(Real, -∞, 1.0)) == "as(Real, -∞, 1.0)" + @test string(as(Real, 1.0, ∞)) == "as(Real, 1.0, ∞; scale = 1)" + @test string(as(Real, -∞, 1.0)) == "as(Real, -∞, 1.0; scale = 1)" end @testset "sum dimensions allocations" begin - shifted = TransformVariables.ShiftedExp{true,Float64}(0.0) + shifted = TransformVariables.ShiftedExp{true,Float64,Float64}(0.0, 0.0) tr = (a = shifted, b = TransformVariables.Identity(), c = shifted, d = shifted, e = shifted, f = shifted) @test iszero(@allocated TransformVariables._sum_dimensions(tr)) end if VERSION >= v"1.7" @testset "inverse_eltype allocations" begin - trf = as((x0 = TransformVariables.ShiftedExp{true, Float32}(0f0), x1 = TransformVariables.Identity(), x2 = UnitSimplex(7), x3 = TransformVariables.CorrCholeskyFactor(5), x4 = as(Real, -∞, 1), x5 = as(Array, 10, 2), x6 = as(Array, as𝕀, 10), x7 = as((a = asℝ₊, b = as𝕀)), x8 = TransformVariables.UnitVector(10), x9 = TransformVariables.ShiftedExp{true, Float32}(0f0), x10 = TransformVariables.ShiftedExp{true, Float32}(0f0), x11 = TransformVariables.ShiftedExp{true, Float32}(0f0), x12 = TransformVariables.ShiftedExp{true, Float32}(0f0), x13 = TransformVariables.Identity(), x14 = TransformVariables.ShiftedExp{true, Float32}(0f0), x15 = TransformVariables.ShiftedExp{true, Float32}(0f0), x16 = TransformVariables.ShiftedExp{true, Float32}(0f0), x17 = TransformVariables.ShiftedExp{true, Float64}(0.0))); - + trf = as((x0 = TransformVariables.ShiftedExp{true, Float32, Float32}(0f0, 0f0), x1 = TransformVariables.Identity(), x2 = UnitSimplex(7), x3 = TransformVariables.CorrCholeskyFactor(5), x4 = as(Real, -∞, 1), x5 = as(Array, 10, 2), x6 = as(Array, as𝕀, 10), x7 = as((a = asℝ₊, b = as𝕀)), x8 = TransformVariables.UnitVector(10), x9 = TransformVariables.ShiftedExp{true, Float32, Float32}(0f0, 0f0), x10 = TransformVariables.ShiftedExp{true, Float32, Float32}(0f0, 0f0), x11 = TransformVariables.ShiftedExp{true, Float32, Float32}(0f0, 0f0), x12 = TransformVariables.ShiftedExp{true, Float32, Float32}(0f0, 0f0), x13 = TransformVariables.Identity(), x14 = TransformVariables.ShiftedExp{true, Float32, Float32}(0f0, 0f0), x15 = TransformVariables.ShiftedExp{true, Float32, Float32}(0f0, 0f0), x16 = TransformVariables.ShiftedExp{true, Float32, Float32}(0f0, 0f0), x17 = TransformVariables.ShiftedExp{true, Float64, Float64}(0.0, 0.0))); + vx = randn(@inferred(TransformVariables.dimension(trf))); x = TransformVariables.transform(trf, vx); @test @inferred(TransformVariables.inverse_eltype(trf, x)) === Float64 - + end end