Skip to content

Commit bc8c543

Browse files
authored
Merge pull request #8 from tpapp/tp/fix-threaded-use
Don't preallocate GradientConfig in ForwardDiff backend by default
2 parents 1b828a3 + cba4985 commit bc8c543

File tree

4 files changed

+91
-21
lines changed

4 files changed

+91
-21
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ SimpleUnPack = "1"
3434

3535
[extras]
3636
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
37+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
3738
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3839
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3940
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -43,4 +44,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4344
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4445

4546
[targets]
46-
test = ["BenchmarkTools", "Enzyme", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]
47+
test = ["BenchmarkTools", "ComponentArrays", "Enzyme", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]

ext/LogDensityProblemsADForwardDiffExt.jl

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,49 +18,95 @@ end
1818
# Load DiffResults helpers
1919
include("DiffResults_helpers.jl")
2020

21-
struct ForwardDiffLogDensity{L, C} <: ADGradientWrapper
21+
struct ForwardDiffLogDensity{L, C <: ForwardDiff.Chunk, T <: Union{Nothing,ForwardDiff.Tag},
22+
G <: Union{Nothing,ForwardDiff.GradientConfig}} <: ADGradientWrapper
23+
"supports zero-order evaluation `logdensity(ℓ, x)`"
2224
::L
23-
gradientconfig::C
25+
"chunk size for ForwardDiff"
26+
chunk::C
27+
"tag, or `nothing` for the default"
28+
tag::T
29+
"gradient config, or `nothing` if created for each evaluation"
30+
gradient_config::G
2431
end
2532

2633
function Base.show(io::IO, ℓ::ForwardDiffLogDensity)
2734
print(io, "ForwardDiff AD wrapper for ", ℓ.ℓ,
28-
", w/ chunk size ", length(ℓ.gradientconfig.seeds))
35+
", w/ chunk size ", ForwardDiff.chunksize(ℓ.chunk))
2936
end
3037

3138
_chunk(chunk::ForwardDiff.Chunk) = chunk
3239
_chunk(chunk::Integer) = ForwardDiff.Chunk(chunk)
3340

3441
_default_chunk(ℓ) = _chunk(dimension(ℓ))
3542

36-
_default_gradientconfig(ℓ, chunk, ::Nothing) = _default_gradientconfig(ℓ, chunk, zeros(dimension(ℓ)))
37-
function _default_gradientconfig(ℓ, chunk, x::AbstractVector)
38-
return ForwardDiff.GradientConfig(Base.Fix1(logdensity, ℓ), x, _chunk(chunk))
43+
function Base.copy(fℓ::ForwardDiffLogDensity{L,C,T,<:ForwardDiff.GradientConfig}) where {L,C,T}
44+
@unpack ℓ, chunk, tag, gradient_config = fℓ
45+
ForwardDiffLogDensity(ℓ, chunk, tag, copy(gradient_config))
3946
end
4047

4148
"""
42-
ADgradient(:ForwardDiff, ℓ; x, chunk, gradientconfig)
43-
ADgradient(Val(:ForwardDiff), ℓ; x, chunk, gradientconfig)
49+
$(SIGNATURES)
50+
51+
Make a `ForwardDiff.GradientConfig` for function `f` and input `x`. `tag = nothing` generates the default tag.
52+
"""
53+
function _make_gradient_config(f::F, x, chunk, tag) where {F}
54+
c = _chunk(chunk)
55+
gradient_config = if tag nothing
56+
ForwardDiff.GradientConfig(f, x, c)
57+
else
58+
ForwardDiff.GradientConfig(f, x, c, tag)
59+
end
60+
gradient_config
61+
end
62+
63+
"""
64+
ADgradient(:ForwardDiff, ℓ; chunk, tag, x)
65+
ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x)
4466
4567
Wrap a log density that supports evaluation of `Value` to handle `ValueGradient`, using
4668
`ForwardDiff`.
4769
48-
Keywords are passed on to `ForwardDiff.GradientConfig` to customize the setup. In
49-
particular, chunk size can be set with a `chunk` keyword argument (accepting an integer or a
50-
`ForwardDiff.Chunk`), and the underlying vector used by `ForwardDiff` can be set with the
51-
`x` keyword argument (accepting an `AbstractVector`).
70+
Keyword arguments:
71+
72+
- `chunk` can be used to set the chunk size, an integer or a `ForwardDiff.Chunk`
73+
74+
- `tag` (default: `nothing`) can be used to set a tag for `ForwardDiff`
75+
76+
- `x` (default: `nothing`) will be used to preallocate a `ForwardDiff.GradientConfig` with
77+
the given vector. With the default, one is created for each evaluation.
78+
79+
Note that **pre-allocating a `ForwardDiff.GradientConfig` is not thread-safe**. You can
80+
[`copy`](@ref) the results for concurrent evaluation:
81+
```julia
82+
∇ℓ1 = ADgradient(:ForwardDiff, ℓ; x = zeros(dimension(ℓ)))
83+
∇ℓ2 = copy(∇ℓ1) # you can now use both, in different threads
84+
```
85+
86+
See also the ForwardDiff documentation regarding
87+
[`ForwardDiff.GradientConfig`](https://juliadiff.org/ForwardDiff.jl/stable/user/api/#Preallocating/Configuring-Work-Buffers)
88+
and [chunks and tags](https://juliadiff.org/ForwardDiff.jl/stable/user/advanced/).
5289
"""
5390
function ADgradient(::Val{:ForwardDiff}, ℓ;
54-
x::Union{Nothing,AbstractVector} = nothing,
5591
chunk::Union{Integer,ForwardDiff.Chunk} = _default_chunk(ℓ),
56-
gradientconfig::ForwardDiff.GradientConfig = _default_gradientconfig(ℓ, chunk, x))
57-
ForwardDiffLogDensity(ℓ, gradientconfig)
92+
tag::Union{Nothing,ForwardDiff.Tag} = nothing,
93+
x::Union{Nothing,AbstractVector} = nothing)
94+
gradient_config = if x nothing
95+
nothing
96+
else
97+
_make_gradient_config(Base.Fix1(logdensity, ℓ), x, chunk, tag)
98+
end
99+
ForwardDiffLogDensity(ℓ, chunk, tag, gradient_config)
58100
end
59101

60102
function logdensity_and_gradient(fℓ::ForwardDiffLogDensity, x::AbstractVector)
61-
@unpack ℓ, gradientconfig = fℓ
103+
@unpack ℓ, chunk, tag, gradient_config = fℓ
62104
buffer = _diffresults_buffer(x)
63-
result = ForwardDiff.gradient!(buffer, Base.Fix1(logdensity, ℓ), x, gradientconfig)
105+
ℓ′ = Base.Fix1(logdensity, ℓ)
106+
if gradient_config nothing
107+
gradient_config = _make_gradient_config(ℓ′, x, chunk, tag)
108+
end
109+
result = ForwardDiff.gradient!(buffer, ℓ′, x, gradient_config)
64110
_diffresults_extract(result)
65111
end
66112

src/LogDensityProblemsAD.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using LogDensityProblems: LogDensityOrder
1111

1212
import SimpleUnPack
1313

14-
1514
#####
1615
##### AD wrappers --- interface and generic code
1716
#####
@@ -34,6 +33,8 @@ dimension(ℓ::ADGradientWrapper) = dimension(ℓ.ℓ)
3433

3534
Base.parent(ℓ::ADGradientWrapper) =.
3635

36+
Base.copy(x::ADGradientWrapper) = x # no-op, except for ForwardDiff
37+
3738
"""
3839
$(SIGNATURES)
3940
@@ -57,6 +58,10 @@ ADgradient(:ForwardDiff, P)
5758
and should mostly be equivalent if the compiler manages to fold the constant.
5859
5960
The function `parent` can be used to retrieve the original argument.
61+
62+
!!! note
63+
With the default options, automatic differentiation preserves thread-safety. See
64+
exceptions and workarounds in the docstring for each backend.
6065
"""
6166
ADgradient(kind::Symbol, P; kwargs...) = ADgradient(Val{kind}(), P; kwargs...)
6267

test/runtests.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import LogDensityProblems: capabilities, dimension, logdensity
44
using LogDensityProblems: logdensity_and_gradient, LogDensityOrder
55
import ForwardDiff, Enzyme, Tracker, Zygote, ReverseDiff # backends
66
import BenchmarkTools # load the heuristic chunks code
7+
using ComponentArrays: ComponentVector # test with other vector types
78

89
struct EnzymeTestMode <: Enzyme.Mode end
910

@@ -115,13 +116,30 @@ end
115116
(test_logdensity(x), test_gradient(x))
116117
end
117118

118-
# Make sure that other types are supported.
119+
# preallocated gradient config
119120
x = randexp(Float32, 3)
120-
∇ℓ = ADgradient(:ForwardDiff, ℓ; x=x)
121+
∇ℓ = ADgradient(:ForwardDiff, ℓ; x = x)
121122
@test eltype(first(logdensity_and_gradient(∇ℓ, x))) === Float32
122123
@test @inferred(logdensity(∇ℓ, x)) test_logdensity(x)
123124
@test @inferred(logdensity_and_gradient(∇ℓ, x))
124125
(test_logdensity(x), test_gradient(x))
126+
@test @inferred(copy(∇ℓ)).gradient_config ∇ℓ.gradient_config
127+
end
128+
129+
@testset "component vectors" begin
130+
# test with something else than `Vector`
131+
# cf https://github.com/tpapp/LogDensityProblemsAD.jl/pull/3
132+
= TestLogDensity()
133+
∇ℓ = ADgradient(:ForwardDiff, ℓ)
134+
x = zeros(3)
135+
y = ComponentVector(x = x)
136+
@test @inferred(logdensity(∇ℓ, y)) test_logdensity(x)
137+
@test @inferred(logdensity_and_gradient(∇ℓ, y))
138+
(test_logdensity(x), test_gradient(x))
139+
∇ℓ2 = ADgradient(:ForwardDiff, ℓ; x = y) # preallocate GradientConfig
140+
@test @inferred(logdensity(∇ℓ2, y)) test_logdensity(x)
141+
@test @inferred(logdensity_and_gradient(∇ℓ2, y))
142+
(test_logdensity(x), test_gradient(x))
125143
end
126144

127145
@testset "chunk heuristics for ForwardDiff" begin

0 commit comments

Comments
 (0)