Skip to content

Commit e451d15

Browse files
committed
add again changes made on website which got lost in a local rebase without checking first because I forgot about this for ages
1 parent 2bb637a commit e451d15

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
ChainRulesCore = "1"
1515
Functors = "0.3"
16-
Yota = "0.7.3"
16+
Yota = "0.8.1"
1717
Zygote = "0.6.40"
1818
julia = "1.6"
1919

test/destructure.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,15 @@ end
106106
end
107107

108108
@testset "using Yota" begin
109-
@test_broken Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] # Unexpected expression: $(Expr(:static_parameter, 1))
110-
# These are all broken!
109+
@test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
111110
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
112-
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], ZeroTangent())
113-
@test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = ZeroTangent(), z = [0,0,0])
114-
@test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = ZeroTangent(), z = [0,0,0])
111+
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
112+
@test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
113+
@test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])
115114

116115
g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1]
117116
@test g5.a[1].x == [0,0,1]
118-
@test g5.a[2] === ZeroTangent()
117+
@test g5.a[2] === nothing
119118

120119
g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1]
121120
@test g6.a == [0,0,0]
@@ -128,7 +127,7 @@ end
128127
@test g8[3] == [[10.0]]
129128

130129
g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
131-
@test g9.c === ZeroTangent()
130+
@test g9.c === nothing
132131
end
133132
end
134133

@@ -199,11 +198,11 @@ end
199198
@test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]
200199

201200
v8, re8 = destructure(m8)
202-
@test_broken Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] # MethodError: no method matching zero(::Type{Any})
203-
@test_broken Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # MethodError: no method matching !(::Expr)
201+
@test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
202+
@test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
204203

205204
re9 = destructure(m9)[2]
206-
@test_broken Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] # MethodError: no method matching zero(::Type{Array})
205+
@test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]
207206
end
208207
end
209208

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ end
3939

4040
# Make Yota's output look like Zygote's:
4141

42-
Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2])
42+
Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
43+
y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
44+
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
45+
y2z(x) = x
4346

4447
@testset verbose=true "Optimisers.jl" begin
4548
@testset verbose=true "Features" begin

0 commit comments

Comments
 (0)