Skip to content

Commit 4a62165

Browse files
KristofferCjrevels
authored andcommitted
specialize on input function on a few more wrapper functions
1 parent 04c751f commit 4a62165

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/apiutils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ end
2929

3030
@inline static_dual_eval(::Type{T}, f, x::SArray) where {T} = f(dualize(T, x))
3131

32-
function vector_mode_dual_eval(f, x, cfg::Union{JacobianConfig,GradientConfig})
32+
function vector_mode_dual_eval(f::F, x, cfg::Union{JacobianConfig,GradientConfig}) where {F}
3333
xdual = cfg.duals
3434
seed!(xdual, x, cfg.seeds)
3535
return f(xdual)
3636
end
3737

38-
function vector_mode_dual_eval(f!, y, x, cfg::JacobianConfig)
38+
function vector_mode_dual_eval(f!::F, y, x, cfg::JacobianConfig) where {F}
3939
ydual, xdual = cfg.duals
4040
seed!(xdual, x, cfg.seeds)
4141
seed!(ydual, y)

src/gradient.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Compute `∇f` evaluated at `x` and store the result(s) in `result`, assuming `f
2929
This method assumes that `isa(f(x), Real)`.
3030
3131
"""
32-
function gradient!(result::Union{AbstractArray,DiffResult}, f, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK}
32+
function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK, F}
3333
CHK && checktag(T, f, x)
3434
if chunksize(cfg) == length(x)
3535
vector_mode_gradient!(result, f, x, cfg)
@@ -92,13 +92,13 @@ end
9292
# vector mode #
9393
###############
9494

95-
function vector_mode_gradient(f, x, cfg::GradientConfig{T}) where {T}
95+
function vector_mode_gradient(f::F, x, cfg::GradientConfig{T}) where {T, F}
9696
ydual = vector_mode_dual_eval(f, x, cfg)
9797
result = similar(x, valtype(ydual))
9898
return extract_gradient!(T, result, ydual)
9999
end
100100

101-
function vector_mode_gradient!(result, f, x, cfg::GradientConfig{T}) where {T}
101+
function vector_mode_gradient!(result, f::F, x, cfg::GradientConfig{T}) where {T, F}
102102
ydual = vector_mode_dual_eval(f, x, cfg)
103103
result = extract_gradient!(T, result, ydual)
104104
return result

0 commit comments

Comments
 (0)