From 114cfe90755df8591488c7d71bd3109be5325fb9 Mon Sep 17 00:00:00 2001 From: Clement Guillot Date: Wed, 4 Dec 2024 22:24:04 +0100 Subject: [PATCH 1/3] Added explicit derivatives for analytic functions --- src/dual.jl | 45 +++++++++++++++++++++++++++++++++++++++++- test/DerivativeTest.jl | 7 +++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/dual.jl b/src/dual.jl index 7e8ec110..da12568b 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -238,13 +238,30 @@ function unary_dual_definition(M, f) val = $Mf(x) deriv = $(DiffRules.diffrule(M, f, :x)) end) - return quote + real_diff_exp = quote @inline function $M.$f(d::$FD.Dual{T}) where T x = $FD.value(d) $work return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d)) end end + if (M, f) in ((:Base, :abs), (:Base, :abs2)) + return real_diff_exp + else + complex_diff_expr = quote + @inline function $M.$f(d::Complex{$FD.Dual{T, V, N}}) where{T, V, N} + x = complex($FD.value(real(d)), $FD.value(imag(d))) + $work + re_deriv, im_deriv = reim(deriv) + re_partials = $FD.partials(real(d)) + im_partials = $FD.partials(imag(d)) + re_retval = $FD.dual_definition_retval(Val{T}(), real(val), re_deriv, re_partials, -im_deriv, im_partials) + im_retval = $FD.dual_definition_retval(Val{T}(), imag(val), im_deriv, re_partials, re_deriv, im_partials) + return complex(re_retval, im_retval) + end + end + return Expr(:block, real_diff_exp, complex_diff_expr) + end end function binary_dual_definition(M, f) @@ -727,6 +744,32 @@ end return (Dual{T}(sd, cd * partials(d)), Dual{T}(cd, -sd * partials(d))) end +function Base.sin(d::Complex{Dual{T, V, N}}) where{T, V, N} + FD = ForwardDiff + x = complex(FD.value(real(d)), FD.value(imag(d))) + val = sin(x) + deriv = cos(x) + re_deriv, im_deriv = reim(deriv) + re_partials = FD.partials(real(d)) + im_partials = FD.partials(imag(d)) + re_retval = FD.dual_definition_retval(Val{T}(), real(val), re_deriv, re_partials, -im_deriv, im_partials) + im_retval = FD.dual_definition_retval(Val{T}(), imag(val), im_deriv, re_partials, re_deriv, im_partials) + return complex(re_retval, im_retval) +end + +function Base.cos(d::Complex{Dual{T, V, N}}) where{T, V, N} + FD = ForwardDiff + x = complex(FD.value(real(d)), FD.value(imag(d))) + val = cos(x) + deriv = -sin(x) + re_deriv, im_deriv = reim(deriv) + re_partials = FD.partials(real(d)) + im_partials = FD.partials(imag(d)) + re_retval = FD.dual_definition_retval(Val{T}(), real(val), re_deriv, re_partials, -im_deriv, im_partials) + im_retval = FD.dual_definition_retval(Val{T}(), imag(val), im_deriv, re_partials, re_deriv, im_partials) + return complex(re_retval, im_retval) +end + # sincospi # #----------# diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index 4de1a6de..a105f084 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -113,4 +113,11 @@ end @test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im) end +@testset "analytic functions" begin + dexp(x) = ForwardDiff.derivative(y -> exp(complex(0, y)), x) + @test ForwardDiff.derivative(dexp, 0.0) ≈ -1 + @test ForwardDiff.derivative(x -> exp(1im*x), 0.7) ≈ im * cis(0.7) + @test ForwardDiff.derivative(x -> sqrt(im + (1+im) * x), 1.23) ≈ (1+im) / (2 * sqrt(im + (1+im)*1.23)) +end + end # module From b55fc60ae2aff8b94973012063a7095e75ee1801 Mon Sep 17 00:00:00 2001 From: Clement Guillot Date: Thu, 5 Dec 2024 10:34:59 +0100 Subject: [PATCH 2/3] Manual handling of inv --- src/dual.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/dual.jl b/src/dual.jl index da12568b..765e031e 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -726,6 +726,22 @@ end Dual{Tz}(muladd(x, y, value(z)), partials(z)) # z_body ) +# inv(Complex) # +#--------# + +function Base.inv(d::Complex{<:Dual{T}}) where{T} + FD = ForwardDiff + x = complex(FD.value(real(d)), FD.value(imag(d))) + val = inv(x) + deriv = - val * val + re_deriv, im_deriv = reim(deriv) + re_partials = FD.partials(real(d)) + im_partials = FD.partials(imag(d)) + re_retval = FD.dual_definition_retval(Val{T}(), real(val), re_deriv, re_partials, -im_deriv, im_partials) + im_retval = FD.dual_definition_retval(Val{T}(), imag(val), im_deriv, re_partials, re_deriv, im_partials) + return complex(re_retval, im_retval) +end + # sin/cos # #--------# From e4aa3cf5bca9ad7aa17c4c6b55f2b71d2bfb431c Mon Sep 17 00:00:00 2001 From: Clement Guillot Date: Thu, 5 Dec 2024 10:35:43 +0100 Subject: [PATCH 3/3] More subtypes handled ? --- src/dual.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index 765e031e..b29fe475 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -245,11 +245,11 @@ function unary_dual_definition(M, f) return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d)) end end - if (M, f) in ((:Base, :abs), (:Base, :abs2)) + if (M, f) in ((:Base, :abs), (:Base, :abs2), (:Base, :inv)) return real_diff_exp else complex_diff_expr = quote - @inline function $M.$f(d::Complex{$FD.Dual{T, V, N}}) where{T, V, N} + @inline function $M.$f(d::Complex{<:$FD.Dual{T}}) where{T} x = complex($FD.value(real(d)), $FD.value(imag(d))) $work re_deriv, im_deriv = reim(deriv) @@ -760,7 +760,7 @@ end return (Dual{T}(sd, cd * partials(d)), Dual{T}(cd, -sd * partials(d))) end -function Base.sin(d::Complex{Dual{T, V, N}}) where{T, V, N} +function Base.sin(d::Complex{<:Dual{T}}) where{T} FD = ForwardDiff x = complex(FD.value(real(d)), FD.value(imag(d))) val = sin(x) @@ -773,7 +773,7 @@ function Base.sin(d::Complex{Dual{T, V, N}}) where{T, V, N} return complex(re_retval, im_retval) end -function Base.cos(d::Complex{Dual{T, V, N}}) where{T, V, N} +function Base.cos(d::Complex{<:Dual{T}}) where{T} FD = ForwardDiff x = complex(FD.value(real(d)), FD.value(imag(d))) val = cos(x)