|
| 1 | + |
| 2 | +m1 = collect(1:3.0) |
| 3 | +m2 = (collect(1:3.0), collect(4:6.0)) |
| 4 | +m3 = (x = m1, y = sin, z = collect(4:6.0)) |
| 5 | +m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied |
| 6 | +m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) |
| 7 | +m6 = (a = m1, b = [4.0 + im], c = m1) |
| 8 | +m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) |
| 9 | +m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] |
| 10 | + |
| 11 | +@testset "flatten & rebuild" begin |
| 12 | + @test destructure(m1)[1] isa Vector{Float64} |
| 13 | + @test destructure(m1)[1] == 1:3 |
| 14 | + @test destructure(m2)[1] == 1:6 |
| 15 | + @test destructure(m3)[1] == 1:6 |
| 16 | + @test destructure(m4)[1] == 1:6 |
| 17 | + @test destructure(m5)[1] == vcat(1:6, 4:6) |
| 18 | + @test destructure(m6)[1] == vcat(1:3, 4 + im) |
| 19 | + |
| 20 | + @test destructure(m1)[2](7:9) == [7,8,9] |
| 21 | + @test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9]) |
| 22 | + @test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9]) |
| 23 | + m4′ = destructure(m4)[2](4:9) |
| 24 | + @test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9]) |
| 25 | + @test m4′.x === m4′.y |
| 26 | + m5′ = destructure(m5)[2](reverse(1:9)) |
| 27 | + @test m5′.a[1].x === m5′.b[1] |
| 28 | + @test m5′.b[2] === false |
| 29 | + m6′ = destructure(m6)[2]((4:7) .+ (1:4) .* im) |
| 30 | + @test m6′.a isa Vector{Float64} |
| 31 | + @test m6′.a == 4:6 |
| 32 | + @test m6′.a === m6′.c |
| 33 | + @test m6′.b == [7 + 4im] |
| 34 | + |
| 35 | + # struct, trainable |
| 36 | + @test destructure(m7)[1] == 1:3 |
| 37 | + m7′ = destructure(m7)[2]([10,20,30]) |
| 38 | + @test m7′.a == (sin, [10,20,30]) |
| 39 | + @test m7′.b == (cos, [4,5,6]) |
| 40 | + @test m7′.c == (tan, [7,8,9]) |
| 41 | + |
| 42 | + @test destructure(m8)[1] == 1:5 |
| 43 | + m8′ = destructure(m8)[2](1:5) |
| 44 | + @test m8′[1].x === m8′[1].y |
| 45 | + @test m8′[2].b.y === false |
| 46 | + @test m8′[3][1] == [5.0] |
| 47 | + |
| 48 | + # errors |
| 49 | + @test_throws Exception destructure(m7)[2]([10,20]) |
| 50 | + @test_throws Exception destructure(m7)[2]([10,20,30,40]) |
| 51 | +end |
| 52 | + |
| 53 | +@testset "gradient of flatten" begin |
| 54 | + @test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] |
| 55 | + @test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) |
| 56 | + @test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) |
| 57 | + @test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) |
| 58 | + @test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) |
| 59 | + |
| 60 | + g5 = gradient(m -> destructure(m)[1][3], m5)[1] |
| 61 | + @test g5.a[1].x == [0,0,1] |
| 62 | + @test g5.a[2] === nothing |
| 63 | + |
| 64 | + g6 = gradient(m -> imag(destructure(m)[1][4]), m6)[1] |
| 65 | + @test g6.a == [0,0,0] |
| 66 | + @test g6.a isa Vector{Float64} |
| 67 | + @test g6.b == [0+im] |
| 68 | + |
| 69 | + g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1] |
| 70 | + @test g8[1].x == [2,4,6] |
| 71 | + @test g8[2].b.x == [8] |
| 72 | + @test g8[3] == [[10.0]] |
| 73 | + |
| 74 | + @testset "second derivative" begin |
| 75 | + @test gradient([1,2,3.0]) do v |
| 76 | + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) |
| 77 | + end[1] ≈ [8,16,24] |
| 78 | + # With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx: |
| 79 | + # off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ... |
| 80 | + # until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing |
| 81 | + # With Zygote, instead: |
| 82 | + # dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),) |
| 83 | + |
| 84 | + @test gradient([1,2,3.0]) do v |
| 85 | + sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1]) |
| 86 | + end[1] == [378, 378, 378] |
| 87 | + |
| 88 | + @test_broken gradient([1,2,3.0]) do v |
| 89 | + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1]) |
| 90 | + end[1] ≈ [8,16,24] |
| 91 | + # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) |
| 92 | + # Diffractor error in perform_optic_transform |
| 93 | + end |
| 94 | +end |
| 95 | + |
| 96 | +@testset "gradient of rebuild" begin |
| 97 | + re1 = destructure(m1)[2] |
| 98 | + @test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] |
| 99 | + re2 = destructure(m2)[2] |
| 100 | + @test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] |
| 101 | + re3 = destructure(m3)[2] |
| 102 | + @test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] |
| 103 | + @test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] |
| 104 | + |
| 105 | + re4 = destructure(m4)[2] |
| 106 | + @test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] |
| 107 | + @test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] |
| 108 | + @test gradient(rand(6)) do x |
| 109 | + m = re4(x) |
| 110 | + m.x[1] + 2*m.y[2] + 3*m.z[3] |
| 111 | + end[1] == [1,2,0, 0,0,3] |
| 112 | + |
| 113 | + re7 = destructure(m7)[2] |
| 114 | + @test gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] |
| 115 | + @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] |
| 116 | + @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] |
| 117 | + |
| 118 | + v8, re8 = destructure(m8) |
| 119 | + @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] |
| 120 | + @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] |
| 121 | + |
| 122 | + @testset "second derivative" begin |
| 123 | + @test_broken gradient(collect(1:6.0)) do y |
| 124 | + sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) |
| 125 | + end[1] ≈ [8,16,24,0,0,0] |
| 126 | + # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} |
| 127 | + # with Zygote, which can be fixed by: |
| 128 | + # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) |
| 129 | + |
| 130 | + @test_broken gradient(collect(1:6.0)) do y |
| 131 | + sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) |
| 132 | + end[1] ≈ [0,0,0,32,40,48] |
| 133 | + # Not fixed by this: |
| 134 | + # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) |
| 135 | + end |
| 136 | +end |
| 137 | + |
| 138 | +@testset "Flux issue 1826" begin |
| 139 | + v, re = destructure((x=[1,2.0], y=[3,4,5.0])) |
| 140 | + @test gradient(zero(v)) do w |
| 141 | + m = re(w) |
| 142 | + 5 * sum(m.x) + 7 * sum(m[2]) # uses both x and y |
| 143 | + end == ([5.0, 5.0, 7.0, 7.0, 7.0],) |
| 144 | + # This, using only x, was broken on Flux: |
| 145 | + @test gradient(w -> sum(re(w).x), zero(v)) == ([1.0, 1.0, 0.0, 0.0, 0.0],) |
| 146 | + |
| 147 | + sh = [7,7.0]; |
| 148 | + v, re = destructure((x=sh, y=[3.0,4.0], z=sh)) # shared array in the model |
| 149 | + @test v == [7, 7, 3, 4] |
| 150 | + @test re([1,10,100,1000]) == (x = [1, 10], y = [100, 1000], z = [1, 10]) |
| 151 | + |
| 152 | + @test gradient(zero(v)) do w |
| 153 | + m = re(w) |
| 154 | + 3 * sum(m.x) + 13 * sum(m.z) # no dependence on y, but two distinct gradient arrays |
| 155 | + end == ([16, 16, 0, 0],) # Flux gave ([3.0, 3.0, 13.0, 13.0],) |
| 156 | + |
| 157 | + @test gradient(zero(v)) do w |
| 158 | + m = re(w) |
| 159 | + 4(sum(m.x) + sum(m.z)) # now two gradients are ===, so it eliminates one |
| 160 | + end == ([8,8,0,0],) |
| 161 | + |
| 162 | + @test gradient(zero(v)) do w |
| 163 | + m = re(w) |
| 164 | + 4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one |
| 165 | + end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],) |
| 166 | +end |
0 commit comments