Skip to content

Commit 7daeb7c

Browse files
committed
add tests for AbstractGPUArray
1 parent 1de7b9c commit 7daeb7c

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ julia = "1.6"
3232
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
3333
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
3434
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
35+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
3536
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3637
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3738

3839
[targets]
39-
test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils"]
40+
test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils", "JLArrays"]

test/GradientTest.jl

+15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ForwardDiff
88
using ForwardDiff: Dual, Tag
99
using StaticArrays
1010
using DiffTests
11+
using JLArrays
1112

1213
include(joinpath(dirname(@__FILE__), "utils.jl"))
1314

@@ -149,6 +150,20 @@ end
149150
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0])
150151
end
151152

153+
154+
##############################################
155+
# test GPUArray compatibility (via JLArrays) #
156+
##############################################
157+
158+
println(" ...testing GPUArray compatibility (via JLArrays)")
159+
160+
@testset "size = $(size(x))" for x in JLArray.([rand(3), rand(13), rand(3,3), rand(13,13), rand(10,10,10)])
161+
162+
@test ForwardDiff.gradient(prod, x) isa typeof(x)
163+
164+
end
165+
166+
152167
#############
153168
# bug fixes #
154169
#############

0 commit comments

Comments
 (0)