Skip to content

Commit 1de7b9c

Browse files
committed
make extract_gradient[_chunk]! GPU compatible
1 parent 6a19554 commit 1de7b9c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/gradient.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ function extract_gradient!(::Type{T}, result::DiffResult, dual::Dual) where {T}
8080
end
8181

8282
extract_gradient!(::Type{T}, result::AbstractArray, y::Real) where {T} = fill!(result, zero(y))
83-
extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}= copyto!(result, partials(T, dual))
83+
function extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}
84+
result[:] .= partials(T, dual)
85+
return result
86+
end
8487

8588
function extract_gradient_chunk!(::Type{T}, result, dual, index, chunksize) where {T}
86-
offset = index - 1
87-
for i in 1:chunksize
88-
result[i + offset] = partials(T, dual, i)
89-
end
89+
result[index:index+chunksize-1] .= partials.(T, dual, 1:chunksize)
9090
return result
9191
end
9292

0 commit comments

Comments
 (0)