Skip to content

Commit fce7f76

Browse files
authored
Do not seed structural zeros (#739)
* Fix gradient with `LowerTriangular` and `UpperTriangular` input * Only seed structurally non-zero entries * Base vector vs chunk mode on structural length * Fix tests * Check number of function evaluations and chunk sizes
1 parent fd2d4a9 commit fce7f76

File tree

5 files changed

+101
-22
lines changed

5 files changed

+101
-22
lines changed

src/apiutils.jl

+44-8
Original file line numberDiff line numberDiff line change
@@ -40,32 +40,68 @@ end
4040
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...)
4141
end
4242

43+
# Only seed indices that are structurally non-zero
44+
structural_eachindex(x::AbstractArray) = structural_eachindex(x, x)
45+
function structural_eachindex(x::AbstractArray, y::AbstractArray)
46+
require_one_based_indexing(x, y)
47+
eachindex(x, y)
48+
end
49+
function structural_eachindex(x::UpperTriangular, y::AbstractArray)
50+
require_one_based_indexing(x, y)
51+
if size(x) != size(y)
52+
throw(DimensionMismatch())
53+
end
54+
n = size(x, 1)
55+
return (CartesianIndex(i, j) for j in 1:n for i in 1:j)
56+
end
57+
function structural_eachindex(x::LowerTriangular, y::AbstractArray)
58+
require_one_based_indexing(x, y)
59+
if size(x) != size(y)
60+
throw(DimensionMismatch())
61+
end
62+
n = size(x, 1)
63+
return (CartesianIndex(i, j) for j in 1:n for i in j:n)
64+
end
65+
function structural_eachindex(x::Diagonal, y::AbstractArray)
66+
require_one_based_indexing(x, y)
67+
if size(x) != size(y)
68+
throw(DimensionMismatch())
69+
end
70+
return diagind(x)
71+
end
72+
4373
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
4474
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
45-
duals .= Dual{T,V,N}.(x, Ref(seed))
75+
for idx in structural_eachindex(duals, x)
76+
duals[idx] = Dual{T,V,N}(x[idx], seed)
77+
end
4678
return duals
4779
end
4880

4981
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
5082
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
51-
dual_inds = 1:N
52-
duals[dual_inds] .= Dual{T,V,N}.(view(x,dual_inds), seeds)
83+
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
84+
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
85+
end
5386
return duals
5487
end
5588

5689
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
5790
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
5891
offset = index - 1
59-
dual_inds = (1:N) .+ offset
60-
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), Ref(seed))
92+
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
93+
for idx in idxs
94+
duals[idx] = Dual{T,V,N}(x[idx], seed)
95+
end
6196
return duals
6297
end
6398

6499
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
65100
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
66101
offset = index - 1
67-
seed_inds = 1:chunksize
68-
dual_inds = seed_inds .+ offset
69-
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), getindex.(Ref(seeds), seed_inds))
102+
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
103+
for (i, idx) in zip(1:chunksize, idxs)
104+
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
105+
end
70106
return duals
71107
end

src/gradient.jl

+14-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
1616
function gradient(f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
1717
require_one_based_indexing(x)
1818
CHK && checktag(T, f, x)
19-
if chunksize(cfg) == length(x)
19+
if chunksize(cfg) == structural_length(x)
2020
return vector_mode_gradient(f, x, cfg)
2121
else
2222
return chunk_mode_gradient(f, x, cfg)
@@ -35,7 +35,7 @@ This method assumes that `isa(f(x), Real)`.
3535
function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK, F}
3636
result isa DiffResult ? require_one_based_indexing(x) : require_one_based_indexing(result, x)
3737
CHK && checktag(T, f, x)
38-
if chunksize(cfg) == length(x)
38+
if chunksize(cfg) == structural_length(x)
3939
vector_mode_gradient!(result, f, x, cfg)
4040
else
4141
chunk_mode_gradient!(result, f, x, cfg)
@@ -63,12 +63,19 @@ function extract_gradient!(::Type{T}, result::DiffResult, dual::Dual) where {T}
6363
end
6464

6565
extract_gradient!(::Type{T}, result::AbstractArray, y::Real) where {T} = fill!(result, zero(y))
66-
extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}= copyto!(result, partials(T, dual))
66+
function extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}
67+
idxs = structural_eachindex(result)
68+
for (i, idx) in zip(1:npartials(dual), idxs)
69+
result[idx] = partials(T, dual, i)
70+
end
71+
return result
72+
end
6773

6874
function extract_gradient_chunk!(::Type{T}, result, dual, index, chunksize) where {T}
6975
offset = index - 1
70-
for i in 1:chunksize
71-
result[i + offset] = partials(T, dual, i)
76+
idxs = Iterators.drop(structural_eachindex(result), offset)
77+
for (i, idx) in zip(1:chunksize, idxs)
78+
result[idx] = partials(T, dual, i)
7279
end
7380
return result
7481
end
@@ -106,10 +113,10 @@ end
106113

107114
function chunk_mode_gradient_expr(result_definition::Expr)
108115
return quote
109-
@assert length(x) >= N "chunk size cannot be greater than length(x) ($(N) > $(length(x)))"
116+
@assert structural_length(x) >= N "chunk size cannot be greater than ForwardDiff.structural_length(x) ($(N) > $(structural_length(x)))"
110117

111118
# precalculate loop bounds
112-
xlen = length(x)
119+
xlen = structural_length(x)
113120
remainder = xlen % N
114121
lastchunksize = ifelse(remainder == 0, N, remainder)
115122
lastchunkindex = xlen - lastchunksize + 1

src/jacobian.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
1818
function jacobian(f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
1919
require_one_based_indexing(x)
2020
CHK && checktag(T, f, x)
21-
if chunksize(cfg) == length(x)
21+
if chunksize(cfg) == structural_length(x)
2222
return vector_mode_jacobian(f, x, cfg)
2323
else
2424
return chunk_mode_jacobian(f, x, cfg)
@@ -36,7 +36,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
3636
function jacobian(f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
3737
require_one_based_indexing(y, x)
3838
CHK && checktag(T, f!, x)
39-
if chunksize(cfg) == length(x)
39+
if chunksize(cfg) == structural_length(x)
4040
return vector_mode_jacobian(f!, y, x, cfg)
4141
else
4242
return chunk_mode_jacobian(f!, y, x, cfg)
@@ -57,7 +57,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
5757
function jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
5858
result isa DiffResult ? require_one_based_indexing(x) : require_one_based_indexing(result, x)
5959
CHK && checktag(T, f, x)
60-
if chunksize(cfg) == length(x)
60+
if chunksize(cfg) == structural_length(x)
6161
vector_mode_jacobian!(result, f, x, cfg)
6262
else
6363
chunk_mode_jacobian!(result, f, x, cfg)
@@ -78,7 +78,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
7878
function jacobian!(result::Union{AbstractArray,DiffResult}, f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
7979
result isa DiffResult ? require_one_based_indexing(y, x) : require_one_based_indexing(result, y, x)
8080
CHK && checktag(T, f!, x)
81-
if chunksize(cfg) == length(x)
81+
if chunksize(cfg) == structural_length(x)
8282
vector_mode_jacobian!(result, f!, y, x, cfg)
8383
else
8484
chunk_mode_jacobian!(result, f!, y, x, cfg)
@@ -169,10 +169,10 @@ const JACOBIAN_ERROR = DimensionMismatch("jacobian(f, x) expects that f(x) is an
169169
function jacobian_chunk_mode_expr(work_array_definition::Expr, compute_ydual::Expr,
170170
result_definition::Expr, y_definition::Expr)
171171
return quote
172-
@assert length(x) >= N "chunk size cannot be greater than length(x) ($(N) > $(length(x)))"
172+
@assert structural_length(x) >= N "chunk size cannot be greater than ForwardDiff.structural_length(x) ($(N) > $(structural_length(x)))"
173173

174174
# precalculate loop bounds
175-
xlen = length(x)
175+
xlen = structural_length(x)
176176
remainder = xlen % N
177177
lastchunksize = ifelse(remainder == 0, N, remainder)
178178
lastchunkindex = xlen - lastchunksize + 1

src/prelude.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,15 @@ function Chunk(input_length::Integer, threshold::Integer = DEFAULT_CHUNK_THRESHO
1212
Base.@nif 12 d->(N == d) d->(Chunk{d}()) d->(Chunk{N}())
1313
end
1414

15+
structural_length(x::AbstractArray) = length(x)
16+
function structural_length(x::Union{LowerTriangular,UpperTriangular})
17+
n = size(x, 1)
18+
return (n * (n + 1)) >> 1
19+
end
20+
structural_length(x::Diagonal) = size(x, 1)
21+
1522
function Chunk(x::AbstractArray, threshold::Integer = DEFAULT_CHUNK_THRESHOLD)
16-
return Chunk(length(x), threshold)
23+
return Chunk(structural_length(x), threshold)
1724
end
1825

1926
# Constrained to `N <= threshold`, minimize (in order of priority):

test/GradientTest.jl

+29
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,33 @@ end
226226
@test dx sum(a * b)
227227
end
228228

229+
# issue #738
230+
@testset "LowerTriangular, UpperTriangular and Diagonal" begin
231+
for n in (3, 10, 20)
232+
M = rand(n, n)
233+
for T in (LowerTriangular, UpperTriangular, Diagonal)
234+
@test ForwardDiff.gradient(sum, T(randn(n, n))) == T(ones(n, n))
235+
@test ForwardDiff.gradient(x -> dot(M, x), T(randn(n, n))) == T(M)
236+
237+
# Check number of function evaluations and chunk sizes
238+
fevals = Ref(0)
239+
npartials = Ref(0)
240+
y = ForwardDiff.gradient(T(randn(n, n))) do x
241+
fevals[] += 1
242+
npartials[] += ForwardDiff.npartials(eltype(x))
243+
return sum(x)
244+
end
245+
if npartials[] <= ForwardDiff.DEFAULT_CHUNK_THRESHOLD
246+
# Vector mode (single evaluation)
247+
@test fevals[] == 1
248+
@test npartials[] == sum(y)
249+
else
250+
# Chunk mode (multiple evaluations)
251+
@test fevals[] > 1
252+
@test sum(y) <= npartials[] < sum(y) + fevals[]
253+
end
254+
end
255+
end
256+
end
257+
229258
end # module

0 commit comments

Comments
 (0)