Skip to content

Commit 73d430e

Browse files
authored
use Base.Fix1 instead of closures in ForwardDiffStaticArraysExt.jl (#735)
* use `Base.Fix1` instead of closures in `ForwardDiffStaticArraysExt.jl` * inference tets from issue #639 and pr #735
1 parent 7d63748 commit 73d430e

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

ext/ForwardDiffStaticArraysExt.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,16 @@ end
103103
T = typeof(Tag(f, eltype(x)))
104104
ydual = static_dual_eval(T, f, x)
105105
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
106-
result = DiffResults.value!(d -> value(T,d), result, ydual)
106+
result = DiffResults.value!(Base.Fix1(value, T), result, ydual)
107107
return result
108108
end
109109

110110
# Hessian
111-
ForwardDiff.hessian(f::F, x::StaticArray) where {F} = jacobian(y -> gradient(f, y), x)
111+
ForwardDiff.hessian(f::F, x::StaticArray) where {F} = jacobian(Base.Fix1(gradient, f), x)
112112
ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig) where {F} = hessian(f, x)
113113
ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian(f, x)
114114

115-
ForwardDiff.hessian!(result::AbstractArray, f::F, x::StaticArray) where {F} = jacobian!(result, y -> gradient(f, y), x)
115+
ForwardDiff.hessian!(result::AbstractArray, f::F, x::StaticArray) where {F} = jacobian!(result, Base.Fix1(gradient, f), x)
116116

117117
ForwardDiff.hessian!(result::MutableDiffResult, f::F, x::StaticArray) where {F} = hessian!(result, f, x, HessianConfig(f, result, x))
118118

test/JacobianTest.jl

+18
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,24 @@ end
259259
@testset "type stability" begin
260260
g!(dy, y) = dy[1] = y[1]
261261
@inferred ForwardDiff.jacobian(g!, [1.0], [0.0])
262+
263+
@testset "issue 639" begin
264+
f(x) = SA[x[1]^2+x[2]^2, x[2]^2+x[3]^2]
265+
x = SA[1.0, 2.0, 3.0]
266+
y = f(x)
267+
imdr = DiffResults.JacobianResult(y, x)
268+
@inferred ForwardDiff.jacobian!(imdr, f, x)
269+
end
270+
271+
@testset "pr 735" begin
272+
f(x) = x .^ 2 ./ 2
273+
function withjacobian(x)
274+
res = DiffResults.JacobianResult(x)
275+
res = ForwardDiff.jacobian!(res, f, x)
276+
return DiffResults.value(res), DiffResults.jacobian(res)
277+
end
278+
@inferred withjacobian(SA[1.0, 2.0])
279+
end
262280
end
263281

264282
end # module

0 commit comments

Comments
 (0)