diff --git a/Project.toml b/Project.toml index f044b713..dcfea4f2 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,8 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ChainRulesCore", "ChainRulesTestUtils", "ChangesOfVariables", "FiniteDifferences", "ForwardDiff", "InverseFunctions", "OffsetArrays", "Random", "Test"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "ChangesOfVariables", "FiniteDifferences", "ForwardDiff", "InverseFunctions", "OffsetArrays", "Random", "Statistics", "Test"] diff --git a/docs/src/index.md b/docs/src/index.md index bb0340bc..a2bde13b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -31,4 +31,7 @@ softmax! softmax cloglog cexpexp +logmeanexp +logvarexp +logstdexp ``` diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index 0b3385df..28a43aac 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -8,7 +8,8 @@ import LinearAlgebra export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax, - softmax!, logcosh, logabssinh, cloglog, cexpexp + softmax!, logcosh, logabssinh, cloglog, cexpexp, + logmeanexp, logvarexp, logstdexp # expm1(::Float16) is not defined in older Julia versions, # hence for better Float16 support we use an internal function instead @@ -22,6 +23,7 @@ end include("basicfuns.jl") include("logsumexp.jl") +include("logstatsexp.jl") if !isdefined(Base, :get_extension) include("../ext/LogExpFunctionsChainRulesCoreExt.jl") diff --git a/src/logstatsexp.jl b/src/logstatsexp.jl new file mode 100644 index 00000000..c686c1a4 --- /dev/null +++ b/src/logstatsexp.jl @@ -0,0 +1,38 @@ +""" + logmeanexp(A::AbstractArray; dims=:) + +Computes `log.(mean(exp.(A); dims))`, in a numerically stable way. +""" +function logmeanexp(A::AbstractArray; dims=:) + R = logsumexp(A; dims=dims) + N = convert(eltype(R), length(A) ÷ length(R)) + return R .- log(N) +end + +""" + logvarexp(A::AbstractArray; dims=:) + +Computes `log.(var(exp.(A); dims))`, in a numerically stable way. +""" +function logvarexp( + A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims=dims) +) + R = logsumexp(2logsubexp.(A, logmean); dims=dims) + N = convert(eltype(R), length(A) ÷ length(R)) + if corrected + return R .- log(N - 1) + else + return R .- log(N) + end +end + +""" + logstdexp(A::AbstractArray; dims=:) + +Computes `log.(std(exp.(A); dims))`, in a numerically stable way. +""" +function logstdexp( + A::AbstractArray; dims=:, corrected::Bool=true, logmean=logmeanexp(A; dims=dims) +) + return logvarexp(A; dims=dims, corrected=corrected, logmean=logmean) / 2 +end diff --git a/test/logstatsexp.jl b/test/logstatsexp.jl new file mode 100644 index 00000000..7f6aa075 --- /dev/null +++ b/test/logstatsexp.jl @@ -0,0 +1,30 @@ +using Test: @test, @testset, @inferred +using Statistics: mean, var, std +using LogExpFunctions: logmeanexp, logvarexp, logstdexp + +@testset "logmeanexp, logvarexp" begin + A = randn(5,3,2) + for dims in (2, (1,2), :) + @test logmeanexp(A; dims=dims) ≈ log.(mean(exp.(A); dims=dims)) + for corrected in (true, false) + @test logvarexp(A; dims=dims, corrected=corrected) ≈ log.(var(exp.(A); dims=dims, corrected=corrected)) + @test logstdexp(A; dims=dims, corrected=corrected) ≈ log.(std(exp.(A); dims=dims, corrected=corrected)) + end + end + @test logvarexp(A; dims=2) ≈ log.(var(exp.(A); dims=2)) + @test logstdexp(A; dims=2) ≈ log.(std(exp.(A); dims=2)) + @test logmeanexp(A) ≈ log.(mean(exp.(A))) + @test logvarexp(A) ≈ log.(var(exp.(A))) + @test logstdexp(A) ≈ log.(std(exp.(A))) +end + +@testset "logmeanexp properties" begin + X = randn(1000, 1000) + @test first(logmeanexp(logmeanexp(X; dims=1); dims=2)) ≈ logmeanexp(X) + @test first(logmeanexp(-logmeanexp(-X; dims=1); dims=2)) ≤ first(-logmeanexp(-logmeanexp(X; dims=1); dims=2)) + x = randn() + @test logmeanexp([x]) ≈ x + X = randn(1000, 1000, 1) + @test logmeanexp(X; dims=3) ≈ X + @test -logmeanexp(-X) ≤ logmeanexp(X) +end diff --git a/test/runtests.jl b/test/runtests.jl index b9665e71..bf9c3441 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,3 +16,4 @@ include("basicfuns.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") +include("logstatsexp.jl")