From 57a3fa26cbf3a0478142e93ec16ebb25a62b5847 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 25 Mar 2022 21:18:26 -0400 Subject: [PATCH 1/2] attempt 62 --- src/destructure.jl | 6 +++--- test/destructure.jl | 47 ++++++++++++++++++++++++++++++++++++++++++--- test/runtests.jl | 17 ++++++++++++++++ 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index 2b91983d..15a4bb64 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -91,7 +91,7 @@ _getat(y::AbstractArray, o::Int, flat::AbstractVector) = function _trainable_biwalk(f, x, aux) ch, re = functor(typeof(x), x) - au, _ = functor(typeof(x), aux) + au, _ = functor(aux) _trainmap(f, ch, _trainable(x), au) |> re end @@ -103,7 +103,7 @@ end function _Tangent_biwalk(f, x, aux) # use with prune = NoT ch, re = functor(typeof(x), x) - au, _ = functor(typeof(x), aux) + au, _ = functor(aux) y = _trainmap(f, ch, _trainable(x), au) y isa Tuple{} && return NoT p = ProjectTo(x) @@ -126,7 +126,7 @@ ChainRulesCore.@non_differentiable _zero(x) function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) dx′, _ = functor(typeof(x), base(dx)) - off′, _ = functor(typeof(x), off) + off′, _ = functor(off) foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) flat end diff --git a/test/destructure.jl b/test/destructure.jl index 043315b3..fe5699cb 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -49,7 +49,7 @@ m9 = (a = m1, b = mat, c = [mat, m1]) m8′ = destructure(m8)[2](1:5) @test m8′[1].x === m8′[1].y @test m8′[2].b.y === false - @test m8′[3][1] == [5.0] + @test m8′[3][1] == [5.0] # broken m9′ = destructure(m9)[2](10:10:70) @test m9′.b === m9′.c[1] @@ -79,7 +79,7 @@ end g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1] @test g8[1].x == [2,4,6] @test g8[2].b.x == [8] - @test g8[3] == [[10.0]] + @test g8[3] == [[10.0]] # fails g9 = gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1] @test g9.c === nothing @@ -130,7 +130,7 @@ end v8, re8 = destructure(m8) @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] - @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] + @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # fails re9 = destructure(m9)[2] @test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] @@ -180,3 +180,44 @@ end 4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],) end + +@testset "issue 62" begin + # Flux.Chain used to have children which aren't its own fields, which Skip immitates. + + sk = Skip([1.0, 2.0], (x=3, y=[4.0, 5.0])) + @test fmap(identity, sk) == sk + + gk = gradient(x -> sum(x[2].y), sk)[1] + @test fmap(Zygote.accum, sk, gk) isa Skip # this relies on functor(typeof(x), dx) + + st = fmapstructure(identity, sk) + @test st isa Tuple{Vector, NamedTuple} + @test_throws Exception fmap(+, sk, st) # this fails because of functor(typeof(x), dx) + + v, re = destructure(sk) + @test v == [1,2,4,5] + @test re(10v) isa Skip + @test re(10v)[1] == [10, 20] + + @test gradient(zero(v)) do w + re(w)[2].y[1] + end == ([0,0,1,0],) + + # gradient(sk) do x + # w, _ = destructure(x) + # w[1] + # end +#= + +ERROR: ArgumentError: Tangent for the primal Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}} should be backed by a NamedTuple type, not by Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}. +Stacktrace: + [1] _backing_error(P::Type, G::Type, E::Type) + @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:62 + [2] ChainRulesCore.Tangent{Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}}(backing::Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}) + @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:36 + [3] _Tangent_biwalk(f::Function, x::Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, aux::Tuple{Int64, NamedTuple{(:x, :y), Tuple{Tuple{}, Int64}}}) + @ Optimisers ~/.julia/dev/Optimisers/src/destructure.jl:116 + +=# + +end diff --git a/test/runtests.jl b/test/runtests.jl index d47bce08..1a54c5e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,13 @@ struct TwoThirds a; b; c; end Functors.@functor TwoThirds (a, c) Optimisers.trainable(x::TwoThirds) = (a = x.a,) +struct Skip{T} # like Flux 0.12's Chain + layers::T + Skip(ls...) = new{typeof(ls)}(ls) +end +Base.getindex(x::Skip, i::Integer) = x.layers[i] +Functors.functor(::Type{<:Skip}, x) = x.layers, ls -> Skip(ls...) + @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin @@ -165,6 +172,16 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) @test_throws ArgumentError Optimisers.setup(ADAMW(), m2) end + @testset "issue 62" begin + m62 = (s = Skip([1.0, 2.0], Foo([3.0], false)), t = [4.0, 5.0]) + s62 = Optimisers.setup(Descent(), m62) + g62 = gradient(m -> m.s[2].x[1] + 3 * m.t[2], m62) + s, m = Optimisers.update(s62, m62, g62...) + @test m.s isa Skip + @test m.s[2].x ≈ [2.9] + @test m.t ≈ [4, 4.7] + end + end @testset verbose=true "Destructure" begin include("destructure.jl") From 81e41fe69eb550e618cf4e5b9639962d5d048d9f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 25 Mar 2022 22:29:14 -0400 Subject: [PATCH 2/2] next idea --- src/destructure.jl | 9 ++++++--- test/destructure.jl | 12 ++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index 15a4bb64..36928134 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -91,10 +91,13 @@ _getat(y::AbstractArray, o::Int, flat::AbstractVector) = function _trainable_biwalk(f, x, aux) ch, re = functor(typeof(x), x) - au, _ = functor(aux) + au = _aux_children(aux) _trainmap(f, ch, _trainable(x), au) |> re end +_aux_children(off) = functor(off)[1] +_aux_children(off::AbstractArray) = off # leaflike according to Functors, but we need to see each offset + function _trainmap(f, ch, tr, aux) map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c) isnothing(t) ? c : f(t, a) @@ -103,7 +106,7 @@ end function _Tangent_biwalk(f, x, aux) # use with prune = NoT ch, re = functor(typeof(x), x) - au, _ = functor(aux) + au = _aux_children(aux) y = _trainmap(f, ch, _trainable(x), au) y isa Tuple{} && return NoT p = ProjectTo(x) @@ -126,7 +129,7 @@ ChainRulesCore.@non_differentiable _zero(x) function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) dx′, _ = functor(typeof(x), base(dx)) - off′, _ = functor(off) + off′ = _aux_children(off) foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) flat end diff --git a/test/destructure.jl b/test/destructure.jl index fe5699cb..df1ecffb 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -49,7 +49,7 @@ m9 = (a = m1, b = mat, c = [mat, m1]) m8′ = destructure(m8)[2](1:5) @test m8′[1].x === m8′[1].y @test m8′[2].b.y === false - @test m8′[3][1] == [5.0] # broken + @test m8′[3][1] == [5.0] m9′ = destructure(m9)[2](10:10:70) @test m9′.b === m9′.c[1] @@ -130,7 +130,7 @@ end v8, re8 = destructure(m8) @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] - @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # fails + @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] re9 = destructure(m9)[2] @test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] @@ -203,10 +203,10 @@ end re(w)[2].y[1] end == ([0,0,1,0],) - # gradient(sk) do x - # w, _ = destructure(x) - # w[1] - # end + gradient(sk) do x + w, _ = destructure(x) + w[1] + end #= ERROR: ArgumentError: Tangent for the primal Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}} should be backed by a NamedTuple type, not by Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}.