Skip to content

Commit 1e34fa2

Browse files
authored
Add destructure, take II (#54)
* destructure, take II * add a test * tidy * replace append! with reduce(vcat, ...) * testset names * rename everything * tweak * two broken tests * make len positional, fix a bug * second derivatives * arrays of arrays * more... the dimensionmismatch bug is not here * warnings
1 parent 4155bcd commit 1e34fa2

File tree

8 files changed

+341
-9
lines changed

8 files changed

+341
-9
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
1414
ChainRulesCore = "1"
15-
Functors = "0.2.7"
15+
Functors = "0.2.8"
1616
julia = "1.6"
1717

1818
[extras]

docs/src/api.md

+7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ optimiser to act on all suitable fields. To restrict this, define `trainable`:
4242
Optimisers.trainable
4343
```
4444

45+
Such restrictions are also obeyed by this function for flattening a model:
46+
47+
```@docs
48+
Optimisers.destructure
49+
Optimisers.Restructure
50+
```
51+
4552
## Rule Definition
4653

4754
```@docs

src/Optimisers.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ using Functors: functor, fmap, isleaf
44
using LinearAlgebra
55

66
include("interface.jl")
7-
include("rules.jl")
87

8+
include("destructure.jl")
9+
export destructure, total, total2
10+
11+
include("rules.jl")
912
export Descent, ADAM, Momentum, Nesterov, RMSProp,
1013
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
1114
WeightDecay, ClipGrad, ClipNorm, OptimiserChain

src/destructure.jl

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
2+
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
3+
const NoT = NoTangent()
4+
5+
"""
6+
destructure(model) -> vector, reconstructor
7+
8+
Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
9+
to a vector, and returns also a function which reverses this transformation.
10+
Differentiable.
11+
12+
# Example
13+
```jldoctest
14+
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im])))
15+
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))
16+
17+
julia> re([3, 5-im, 7+11im])
18+
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))
19+
```
20+
"""
21+
function destructure(x)
22+
flat, off, len = _flatten(x)
23+
flat, Restructure(x, off, len)
24+
end
25+
26+
"""
27+
Restructure(Model, ..., length)
28+
29+
This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
30+
new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`.
31+
32+
# Example
33+
```julia
34+
julia> using Flux, Optimisers
35+
36+
julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))
37+
([1, 3, 2, 4, 0, 0], Restructure(Dense, ..., 6))
38+
39+
julia> m = re(-4:1)
40+
Dense(2, 2, σ) # 6 parameters
41+
42+
julia> m([0.2, 0.3]) ≈ re([0.2, 0.3], -4:1)
43+
true
44+
```
45+
"""
46+
struct Restructure{T,S}
47+
model::T
48+
offsets::S
49+
length::Int
50+
end
51+
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length)
52+
(re::Restructure)(x, flat::AbstractVector) = re(flat)(x)
53+
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
54+
Base.length(re::Restructure) = re.length
55+
56+
# This flattens a model, and returns a web of offsets for later use:
57+
function _flatten(x)
58+
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
59+
arrays = AbstractVector[]
60+
len = Ref(0)
61+
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
62+
push!(arrays, _vec(y))
63+
o = len[]
64+
len[] = o + length(y)
65+
o
66+
end
67+
reduce(vcat, arrays), off, len[]
68+
end
69+
70+
_vec(x::Number) = LinRange(x,x,1)
71+
_vec(x::AbstractArray) = vec(x)
72+
73+
function ChainRulesCore.rrule(::typeof(_flatten), x)
74+
flat, off, len = _flatten(x)
75+
_maybewarn()
76+
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT))
77+
(flat, off, len), _flatten_back
78+
end
79+
80+
# This reconstructs either a model like x, or a gradient for it:
81+
function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trainable_biwalk, kw...)
82+
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
83+
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
84+
_getat(y, o, flat)
85+
end
86+
end
87+
88+
_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
89+
_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
90+
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
91+
92+
function _trainable_biwalk(f, x, aux)
93+
ch, re = functor(typeof(x), x)
94+
au, _ = functor(typeof(x), aux)
95+
_trainmap(f, ch, _trainable(x), au) |> re
96+
end
97+
98+
function _trainmap(f, ch, tr, aux)
99+
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
100+
isnothing(t) ? c : f(t, a)
101+
end
102+
end
103+
104+
function _Tangent_biwalk(f, x, aux) # use with prune = NoT
105+
ch, re = functor(typeof(x), x)
106+
au, _ = functor(typeof(x), aux)
107+
y = _trainmap(f, ch, _trainable(x), au)
108+
y isa Tuple{} && return NoT
109+
p = ProjectTo(x)
110+
if p isa ProjectTo # e.g. Array, NamedTuple
111+
p(y)
112+
else # p === identity for unknown structs
113+
Tangent{typeof(x), typeof(y)}(y)
114+
end
115+
end
116+
117+
function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
118+
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT)
119+
_rebuild(x, off, flat, len; kw...), _rebuild_back
120+
end
121+
122+
_zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad!
123+
ChainRulesCore.@non_differentiable _zero(x)
124+
125+
# This is the gradient of model reconstruction, accumulating duplicates:
126+
function _grad!(x, dx, off, flat::AbstractVector)
127+
x′, _ = functor(typeof(x), x)
128+
dx′, _ = functor(typeof(x), base(dx))
129+
off′, _ = functor(typeof(x), off)
130+
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
131+
flat
132+
end
133+
function _grad!(x, dx, off::Integer, flat::AbstractVector)
134+
@views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes
135+
flat
136+
end
137+
_grad!(x, dx::Zero, off, flat::AbstractVector) = dx
138+
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity
139+
140+
# These are only needed for 2nd derivatives:
141+
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
142+
@warn "second derivatives of Restructure may not work yet, sorry!" maxlog=3
143+
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
144+
_grad!(x, dx, off, flat), _grad_back
145+
end
146+
base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure))
147+
base(dx::Tangent{Any, <:NamedTuple{(:backing,)}}) = base(backing(dx).backing) # Zygote version
148+
_maybewarn() = nothing
149+
function ChainRulesCore.rrule(::typeof(_maybewarn))
150+
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
151+
nothing, _ -> (NoT,)
152+
end

src/interface.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ trainable(x) = functor(x)[1]
7070

7171
_trainable(x) = _trainable(functor(x)[1], trainable(x))
7272
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
73-
_trainable(ch::Tuple, tr::Tuple) = tr
73+
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
74+
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
7475
function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
75-
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple"
76+
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3
7677
map(c -> c in tr ? c : nothing, ch)
7778
end
7879

test/destructure.jl

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
2+
m1 = collect(1:3.0)
3+
m2 = (collect(1:3.0), collect(4:6.0))
4+
m3 = (x = m1, y = sin, z = collect(4:6.0))
5+
m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
6+
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
7+
m6 = (a = m1, b = [4.0 + im], c = m1)
8+
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
9+
m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
10+
11+
@testset "flatten & rebuild" begin
12+
@test destructure(m1)[1] isa Vector{Float64}
13+
@test destructure(m1)[1] == 1:3
14+
@test destructure(m2)[1] == 1:6
15+
@test destructure(m3)[1] == 1:6
16+
@test destructure(m4)[1] == 1:6
17+
@test destructure(m5)[1] == vcat(1:6, 4:6)
18+
@test destructure(m6)[1] == vcat(1:3, 4 + im)
19+
20+
@test destructure(m1)[2](7:9) == [7,8,9]
21+
@test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9])
22+
@test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9])
23+
m4′ = destructure(m4)[2](4:9)
24+
@test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9])
25+
@test m4′.x === m4′.y
26+
m5′ = destructure(m5)[2](reverse(1:9))
27+
@test m5′.a[1].x === m5′.b[1]
28+
@test m5′.b[2] === false
29+
m6′ = destructure(m6)[2]((4:7) .+ (1:4) .* im)
30+
@test m6′.a isa Vector{Float64}
31+
@test m6′.a == 4:6
32+
@test m6′.a === m6′.c
33+
@test m6′.b == [7 + 4im]
34+
35+
# struct, trainable
36+
@test destructure(m7)[1] == 1:3
37+
m7′ = destructure(m7)[2]([10,20,30])
38+
@test m7′.a == (sin, [10,20,30])
39+
@test m7′.b == (cos, [4,5,6])
40+
@test m7′.c == (tan, [7,8,9])
41+
42+
@test destructure(m8)[1] == 1:5
43+
m8′ = destructure(m8)[2](1:5)
44+
@test m8′[1].x === m8′[1].y
45+
@test m8′[2].b.y === false
46+
@test m8′[3][1] == [5.0]
47+
48+
# errors
49+
@test_throws Exception destructure(m7)[2]([10,20])
50+
@test_throws Exception destructure(m7)[2]([10,20,30,40])
51+
end
52+
53+
@testset "gradient of flatten" begin
54+
@test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
55+
@test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
56+
@test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
57+
@test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
58+
@test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])
59+
60+
g5 = gradient(m -> destructure(m)[1][3], m5)[1]
61+
@test g5.a[1].x == [0,0,1]
62+
@test g5.a[2] === nothing
63+
64+
g6 = gradient(m -> imag(destructure(m)[1][4]), m6)[1]
65+
@test g6.a == [0,0,0]
66+
@test g6.a isa Vector{Float64}
67+
@test g6.b == [0+im]
68+
69+
g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
70+
@test g8[1].x == [2,4,6]
71+
@test g8[2].b.x == [8]
72+
@test g8[3] == [[10.0]]
73+
74+
@testset "second derivative" begin
75+
@test gradient([1,2,3.0]) do v
76+
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
77+
end[1] [8,16,24]
78+
# With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx:
79+
# off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ...
80+
# until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing
81+
# With Zygote, instead:
82+
# dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),)
83+
84+
@test gradient([1,2,3.0]) do v
85+
sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1])
86+
end[1] == [378, 378, 378]
87+
88+
@test_broken gradient([1,2,3.0]) do v
89+
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1])
90+
end[1] [8,16,24]
91+
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
92+
# Diffractor error in perform_optic_transform
93+
end
94+
end
95+
96+
@testset "gradient of rebuild" begin
97+
re1 = destructure(m1)[2]
98+
@test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
99+
re2 = destructure(m2)[2]
100+
@test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0]
101+
re3 = destructure(m3)[2]
102+
@test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0]
103+
@test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0]
104+
105+
re4 = destructure(m4)[2]
106+
@test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0]
107+
@test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0]
108+
@test gradient(rand(6)) do x
109+
m = re4(x)
110+
m.x[1] + 2*m.y[2] + 3*m.z[3]
111+
end[1] == [1,2,0, 0,0,3]
112+
113+
re7 = destructure(m7)[2]
114+
@test gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1]
115+
@test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
116+
@test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]
117+
118+
v8, re8 = destructure(m8)
119+
@test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
120+
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
121+
122+
@testset "second derivative" begin
123+
@test_broken gradient(collect(1:6.0)) do y
124+
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
125+
end[1] [8,16,24,0,0,0]
126+
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
127+
# with Zygote, which can be fixed by:
128+
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,)
129+
130+
@test_broken gradient(collect(1:6.0)) do y
131+
sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1])
132+
end[1] [0,0,0,32,40,48]
133+
# Not fixed by this:
134+
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
135+
end
136+
end
137+
138+
@testset "Flux issue 1826" begin
139+
v, re = destructure((x=[1,2.0], y=[3,4,5.0]))
140+
@test gradient(zero(v)) do w
141+
m = re(w)
142+
5 * sum(m.x) + 7 * sum(m[2]) # uses both x and y
143+
end == ([5.0, 5.0, 7.0, 7.0, 7.0],)
144+
# This, using only x, was broken on Flux:
145+
@test gradient(w -> sum(re(w).x), zero(v)) == ([1.0, 1.0, 0.0, 0.0, 0.0],)
146+
147+
sh = [7,7.0];
148+
v, re = destructure((x=sh, y=[3.0,4.0], z=sh)) # shared array in the model
149+
@test v == [7, 7, 3, 4]
150+
@test re([1,10,100,1000]) == (x = [1, 10], y = [100, 1000], z = [1, 10])
151+
152+
@test gradient(zero(v)) do w
153+
m = re(w)
154+
3 * sum(m.x) + 13 * sum(m.z) # no dependence on y, but two distinct gradient arrays
155+
end == ([16, 16, 0, 0],) # Flux gave ([3.0, 3.0, 13.0, 13.0],)
156+
157+
@test gradient(zero(v)) do w
158+
m = re(w)
159+
4(sum(m.x) + sum(m.z)) # now two gradients are ===, so it eliminates one
160+
end == ([8,8,0,0],)
161+
162+
@test gradient(zero(v)) do w
163+
m = re(w)
164+
4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one
165+
end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
166+
end

0 commit comments

Comments
 (0)