From 3a2d0f3c718f654046822bbc6a204f463fb55edc Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 17:04:33 +0100 Subject: [PATCH 1/9] logstatsexp --- Project.toml | 3 ++- src/LogExpFunctions.jl | 4 +++- src/logstatsexp.jl | 38 ++++++++++++++++++++++++++++++++++++++ test/logstatsexp.jl | 30 ++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 src/logstatsexp.jl create mode 100644 test/logstatsexp.jl diff --git a/Project.toml b/Project.toml index f044b713..dcfea4f2 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,8 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ChainRulesCore", "ChainRulesTestUtils", "ChangesOfVariables", "FiniteDifferences", "ForwardDiff", "InverseFunctions", "OffsetArrays", "Random", "Test"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "ChangesOfVariables", "FiniteDifferences", "ForwardDiff", "InverseFunctions", "OffsetArrays", "Random", "Statistics", "Test"] diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index 0b3385df..28a43aac 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -8,7 +8,8 @@ import LinearAlgebra export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax, - softmax!, logcosh, logabssinh, cloglog, cexpexp + softmax!, logcosh, logabssinh, cloglog, cexpexp, + logmeanexp, logvarexp, logstdexp # expm1(::Float16) is not defined in older Julia versions, # hence for better Float16 support we use an internal function instead @@ -22,6 +23,7 @@ end include("basicfuns.jl") include("logsumexp.jl") +include("logstatsexp.jl") if !isdefined(Base, :get_extension) include("../ext/LogExpFunctionsChainRulesCoreExt.jl") diff --git a/src/logstatsexp.jl b/src/logstatsexp.jl new file mode 100644 index 00000000..0131a9d4 --- /dev/null +++ b/src/logstatsexp.jl @@ -0,0 +1,38 @@ +""" + logmeanexp(A; dims=:) + +Computes `log.(mean(exp.(A); dims))`, in a numerically stable way. +""" +function logmeanexp(A::AbstractArray; dims=:) + R = logsumexp(A; dims) + N = length(A) ÷ length(R) + return R .- log(N) +end + +""" + logvarexp(A; dims=:) + +Computes `log.(var(exp.(A); dims))`, in a numerically stable way. +""" +function logvarexp( + A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims) +) + R = logsumexp(2logsubexp.(A, logmean); dims) + N = length(A) ÷ length(R) + if corrected + return R .- log(N - 1) + else + return R .- log(N) + end +end + +""" + logstdexp(A; dims=:) + +Computes `log.(std(exp.(A); dims))`, in a numerically stable way. +""" +function logstdexp( + A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims) +) + return logvarexp(A; dims, corrected, logmean) / 2 +end diff --git a/test/logstatsexp.jl b/test/logstatsexp.jl new file mode 100644 index 00000000..ba4f42a6 --- /dev/null +++ b/test/logstatsexp.jl @@ -0,0 +1,30 @@ +using Test: @test, @testset, @inferred +using Statistics: mean, var, std +using LogExpFunctions: logmeanexp, logvarexp, logstdexp + +@testset "logmeanexp, logvarexp" begin + A = randn(5,3,2) + for dims in (2, (1,2), :) + @test logmeanexp(A; dims) ≈ log.(mean(exp.(A); dims)) + for corrected in (true, false) + @test logvarexp(A; dims, corrected) ≈ log.(var(exp.(A); dims, corrected)) + @test logstdexp(A; dims, corrected) ≈ log.(std(exp.(A); dims, corrected)) + end + end + @test logvarexp(A; dims=2) ≈ log.(var(exp.(A); dims=2)) + @test logstdexp(A; dims=2) ≈ log.(std(exp.(A); dims=2)) + @test logmeanexp(A) ≈ log.(mean(exp.(A))) + @test logvarexp(A) ≈ log.(var(exp.(A))) + @test logstdexp(A) ≈ log.(std(exp.(A))) +end + +@testset "logmeanexp properties" begin + X = randn(1000, 1000) + @test only(logmeanexp(logmeanexp(X; dims=1); dims=2)) ≈ logmeanexp(X) + @test only(logmeanexp(-logmeanexp(-X; dims=1); dims=2)) ≤ only(-logmeanexp(-logmeanexp(X; dims=1); dims=2)) + x = randn() + @test logmeanexp([x]) ≈ x + X = randn(1000, 1000, 1) + @test logmeanexp(X; dims=3) ≈ X + @test -logmeanexp(-X) ≤ logmeanexp(X) +end diff --git a/test/runtests.jl b/test/runtests.jl index b9665e71..bf9c3441 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,3 +16,4 @@ include("basicfuns.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") +include("logstatsexp.jl") From 0a68325e25838ee8fc8745bad624806a9c09a500 Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 17:15:11 +0100 Subject: [PATCH 2/9] docs --- docs/src/index.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index bb0340bc..a2bde13b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -31,4 +31,7 @@ softmax! softmax cloglog cexpexp +logmeanexp +logvarexp +logstdexp ``` From 4376e9cf5dec81fce3d3d086302bef925131a7bc Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 19:30:54 +0100 Subject: [PATCH 3/9] 1.0 --- src/logstatsexp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logstatsexp.jl b/src/logstatsexp.jl index 0131a9d4..9e6f138f 100644 --- a/src/logstatsexp.jl +++ b/src/logstatsexp.jl @@ -15,7 +15,7 @@ end Computes `log.(var(exp.(A); dims))`, in a numerically stable way. """ function logvarexp( - A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims) + A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims=dims) ) R = logsumexp(2logsubexp.(A, logmean); dims) N = length(A) ÷ length(R) From 5aef521769309e29b51d07679bcab7bba207ac8c Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 19:32:17 +0100 Subject: [PATCH 4/9] kwargs --- src/logstatsexp.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/logstatsexp.jl b/src/logstatsexp.jl index 9e6f138f..46ace2a1 100644 --- a/src/logstatsexp.jl +++ b/src/logstatsexp.jl @@ -4,7 +4,7 @@ Computes `log.(mean(exp.(A); dims))`, in a numerically stable way. """ function logmeanexp(A::AbstractArray; dims=:) - R = logsumexp(A; dims) + R = logsumexp(A; dims=dims) N = length(A) ÷ length(R) return R .- log(N) end @@ -17,7 +17,7 @@ Computes `log.(var(exp.(A); dims))`, in a numerically stable way. function logvarexp( A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims=dims) ) - R = logsumexp(2logsubexp.(A, logmean); dims) + R = logsumexp(2logsubexp.(A, logmean); dims=dims) N = length(A) ÷ length(R) if corrected return R .- log(N - 1) @@ -32,7 +32,7 @@ end Computes `log.(std(exp.(A); dims))`, in a numerically stable way. """ function logstdexp( - A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims) + A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims=dims) ) - return logvarexp(A; dims, corrected, logmean) / 2 + return logvarexp(A; dims=dims, corrected=corrected, logmean=logmean) / 2 end From d28a5ce94889c6cd83990dc78fbe0369cf771e1a Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 19:35:09 +0100 Subject: [PATCH 5/9] types --- src/logstatsexp.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/logstatsexp.jl b/src/logstatsexp.jl index 46ace2a1..63159ea5 100644 --- a/src/logstatsexp.jl +++ b/src/logstatsexp.jl @@ -5,7 +5,7 @@ Computes `log.(mean(exp.(A); dims))`, in a numerically stable way. """ function logmeanexp(A::AbstractArray; dims=:) R = logsumexp(A; dims=dims) - N = length(A) ÷ length(R) + N = convert(eltype(R), length(A) ÷ length(R)) return R .- log(N) end @@ -18,7 +18,7 @@ function logvarexp( A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims=dims) ) R = logsumexp(2logsubexp.(A, logmean); dims=dims) - N = length(A) ÷ length(R) + N = convert(eltype(R), length(A) ÷ length(R)) if corrected return R .- log(N - 1) else From 3fdd3683eee47809c1b48a59e07f890150d33c57 Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 19:41:16 +0100 Subject: [PATCH 6/9] kwargs again --- test/logstatsexp.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/logstatsexp.jl b/test/logstatsexp.jl index ba4f42a6..f21dca7e 100644 --- a/test/logstatsexp.jl +++ b/test/logstatsexp.jl @@ -5,10 +5,10 @@ using LogExpFunctions: logmeanexp, logvarexp, logstdexp @testset "logmeanexp, logvarexp" begin A = randn(5,3,2) for dims in (2, (1,2), :) - @test logmeanexp(A; dims) ≈ log.(mean(exp.(A); dims)) + @test logmeanexp(A; dims=dims) ≈ log.(mean(exp.(A); dims=dims)) for corrected in (true, false) - @test logvarexp(A; dims, corrected) ≈ log.(var(exp.(A); dims, corrected)) - @test logstdexp(A; dims, corrected) ≈ log.(std(exp.(A); dims, corrected)) + @test logvarexp(A; dims=dims, corrected=corrected) ≈ log.(var(exp.(A); dims=dims, corrected=corrected)) + @test logstdexp(A; dims=dims, corrected=corrected) ≈ log.(std(exp.(A); dims=dims, corrected=corrected)) end end @test logvarexp(A; dims=2) ≈ log.(var(exp.(A); dims=2)) From 1154e2ed49897eef316a00b8dfbb5df17f6f5e57 Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 19:50:22 +0100 Subject: [PATCH 7/9] only --- test/logstatsexp.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/logstatsexp.jl b/test/logstatsexp.jl index f21dca7e..93ef117d 100644 --- a/test/logstatsexp.jl +++ b/test/logstatsexp.jl @@ -20,8 +20,8 @@ end @testset "logmeanexp properties" begin X = randn(1000, 1000) - @test only(logmeanexp(logmeanexp(X; dims=1); dims=2)) ≈ logmeanexp(X) - @test only(logmeanexp(-logmeanexp(-X; dims=1); dims=2)) ≤ only(-logmeanexp(-logmeanexp(X; dims=1); dims=2)) + @test first(logmeanexp(logmeanexp(X; dims=1); dims=2) ≈ logmeanexp(X) + @test first(logmeanexp(-logmeanexp(-X; dims=1); dims=2)) ≤ first(-logmeanexp(-logmeanexp(X; dims=1); dims=2)) x = randn() @test logmeanexp([x]) ≈ x X = randn(1000, 1000, 1) From 4f5c996853449846b64e07b0016af1dd949cad04 Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 19:53:33 +0100 Subject: [PATCH 8/9] ups --- test/logstatsexp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/logstatsexp.jl b/test/logstatsexp.jl index 93ef117d..7f6aa075 100644 --- a/test/logstatsexp.jl +++ b/test/logstatsexp.jl @@ -20,7 +20,7 @@ end @testset "logmeanexp properties" begin X = randn(1000, 1000) - @test first(logmeanexp(logmeanexp(X; dims=1); dims=2) ≈ logmeanexp(X) + @test first(logmeanexp(logmeanexp(X; dims=1); dims=2)) ≈ logmeanexp(X) @test first(logmeanexp(-logmeanexp(-X; dims=1); dims=2)) ≤ first(-logmeanexp(-logmeanexp(X; dims=1); dims=2)) x = randn() @test logmeanexp([x]) ≈ x From 689ea928535be7d2722dae086292d4b59ca1ec01 Mon Sep 17 00:00:00 2001 From: Jorge FdCD Date: Sat, 11 Nov 2023 19:56:42 +0100 Subject: [PATCH 9/9] docstring --- src/logstatsexp.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/logstatsexp.jl b/src/logstatsexp.jl index 63159ea5..c686c1a4 100644 --- a/src/logstatsexp.jl +++ b/src/logstatsexp.jl @@ -1,5 +1,5 @@ """ - logmeanexp(A; dims=:) + logmeanexp(A::AbstractArray; dims=:) Computes `log.(mean(exp.(A); dims))`, in a numerically stable way. """ @@ -10,7 +10,7 @@ function logmeanexp(A::AbstractArray; dims=:) end """ - logvarexp(A; dims=:) + logvarexp(A::AbstractArray; dims=:) Computes `log.(var(exp.(A); dims))`, in a numerically stable way. """ @@ -27,7 +27,7 @@ function logvarexp( end """ - logstdexp(A; dims=:) + logstdexp(A::AbstractArray; dims=:) Computes `log.(std(exp.(A); dims))`, in a numerically stable way. """