diff --git a/Project.toml b/Project.toml index 83b5223b8..b409c12a9 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ julia = "1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test"] +test = ["Test", "Zygote"] diff --git a/src/activation.jl b/src/activation.jl index 9dffa5b18..b40a42403 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -72,7 +72,8 @@ elu(x, α = one(x)) = ifelse(x ≥ 0, x/one(x), α * (exp(x) - one(x))) activation function. """ function gelu(x::Real) - λ = oftype(x/1, √(2/π)) + p = oftype(x/1, π) + λ = oftype(x/1, √(2/p)) α = oftype(x/1, 0.044715) h = oftype(x/1, 0.5) h * x * (one(x) + tanh(λ * (x + α * x^3))) @@ -126,12 +127,6 @@ Return `log(cosh(x))` which is computed in a numerically stable way. """ logcosh(x::T) where T = x + softplus(-2x) - log(convert(T, 2)) -# Provide an informative error message if activation functions are called with an array -for f in (:σ, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh) - @eval $(f)(x::AbstractArray, args...) = - error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.") -end - """ mish(x) = x * tanh(softplus(x)) @@ -140,3 +135,10 @@ Self Regularized Non-Monotonic Neural Activation Function See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). """ mish(x::Real) = x * tanh(softplus(x)) + + +# Provide an informative error message if activation functions are called with an array +for f in (:σ, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh, :mish) + @eval $(f)(x::AbstractArray, args...) = + error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.") +end \ No newline at end of file diff --git a/test/activation.jl b/test/activation.jl index 1c53e17e3..68da3cde7 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -1,6 +1,6 @@ -using NNlib, Test +using NNlib, Test, Zygote -ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh]; +ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh, mish]; function test_value_float_precision_preserving(a) @testset "$(a): " begin @@ -24,6 +24,17 @@ function test_value_int_input_forces_float64(a) end end +function test_gradient_float_precision_preserving(a) + @testset "$(a): " begin + for T in [Float32, Float64] + for val in [-10, -1, 0, 1, 10] + val = @inferred a'(T(val)) + @test typeof(val) == T + end + end + end +end + @testset "Activation Functions" begin @test σ(0.0) == 0.5 @test relu(0.0) == 0.0 @@ -83,6 +94,10 @@ end @test typeof(relu(Int32(1))) == Int32 end end + + @testset "Float gradient inference" begin + test_gradient_float_precision_preserving.(ACTIVATION_FUNCTIONS) + end @testset "softmax" begin xs = rand(5,5)