Skip to content

add trainables_with_path #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
Manifest.toml
.vscode/
docs/build/
.DS_Store
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Functors = "0.4"
Functors = "0.4.9"
Statistics = "1"
Zygote = "0.6.40"
julia = "1.6"
Expand Down
Binary file added docs/.DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Documenter, Optimisers, Zygote, StaticArrays, Functors

DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers); recursive = true)

makedocs(modules = [Optimisers, Zygote, StaticArrays, Functors],
makedocs(modules = [Optimisers],
doctest = false,
sitename = "Optimisers.jl",
pages = ["Home" => "index.md",
Expand Down
14 changes: 14 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ To further restrict this by ignoring some fields of a layer type, define `traina
```@docs
Optimisers.trainable
Optimisers.isnumeric
Optimisers.maywrite
```

Such restrictions are also obeyed by this function for flattening a model:

```@docs
Optimisers.destructure
Optimisers.Restructure
Optimisers.trainables
```

## Rule Definition
Expand All @@ -68,4 +70,16 @@ Optimisers.init
Optimisers.@..
Optimisers.@lazy
Optimisers.adjust(::AbstractRule, ::Real)
Optimisers.@def
```

## KeyPath

A `KeyPath` is a sequence of keys that can be used to access a value within a nested structure.
It is defined in Functors.jl and re-exported by Optimisers.jl here for convenience.

```@docs
Functors.KeyPath
Functors.haskeypath
Functors.getkeypath
```
26 changes: 26 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,29 @@ flat, re = destructure(params)
end
```

## Collecting all trainable parameters

Sometimes it is useful to collect all trainable parameters in a model,
similarly to what [`destructure`](@ref Optimisers.destructure) does but without
concatenating the arrays into a flat vector.
This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays:

```julia
julia> using Flux, Optimisers

julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));

julia> trainables(model)
6-element Vector{AbstractArray}:
Float32[0.5756773 -0.1975264; 0.4723181 -0.7546912; -0.91631395 0.07392061]
Float32[0.0, 0.0, 0.0]
Float32[0.0, 0.0, 0.0]
Float32[1.0, 1.0, 1.0]
Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252]
Float32[0.0, 0.0]

julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]);

julia> g = gradient(l2reg, model)[1];
```
Notice that the `BatchNorm` layer has two trainable parameters, `γ` and `β`, which are included in the list, while the `μ ` and `σ²` buffers are not.
10 changes: 9 additions & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
module Optimisers

using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk
using Functors: functor, fmap, fmap_with_path,
KeyPath, haskeypath, getkeypath,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra

include("interface.jl")
export AbstractRule

include("utils.jl")

include("adjust.jl")

include("destructure.jl")
export destructure

include("trainables.jl")
export trainables
export KeyPath, haskeypath, getkeypath # from Functors.jl

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
Expand Down
9 changes: 5 additions & 4 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ function _flatten(x)
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
o
end
isempty(arrays) && return Bool[], off, 0
reduce(vcat, arrays), off, len[]
return reduce(vcat, arrays), off, len[]
end

struct _TrainableStructWalk <: AbstractWalk end
struct TrainableStructWalk <: AbstractWalk end

(::_TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x))

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)
Expand Down Expand Up @@ -174,3 +174,4 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
nothing, _ -> (NoT,)
end

17 changes: 5 additions & 12 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function _setup(rule, x; cache)
cache[x] = ℓ
end
else
valuemap(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
mapvalue(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
end
end

Expand Down Expand Up @@ -82,7 +82,7 @@ function _update!(tree, x; grads, params)
haskey(params, (tree,x)) && return params[(tree,x)]
isbits(tree) && return x # means () is not cached, and also (((),),)
x′, re = functor(x)
x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
x′′ = re(mapvalue((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
if ismutable(x′′)
params[(tree,x)] = x′′
else # no ties to preserve between immutable structs, right?
Expand Down Expand Up @@ -115,7 +115,7 @@ function _grads!(dict::IdDict, tree, x, x̄s...)
# functor(typeof(tree), base(x̄)), for things like Transpose
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, _ = functor(typeof(x), x)
valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
foreachvalue((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end

# default all rules to first order calls
Expand Down Expand Up @@ -167,26 +167,19 @@ and `trainable(x)` must contain a subset of these.
"""
trainable(x) = functor(x)[1]

# like trainable(x), but also tries to output non-trainable children giving value nothing
_trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
_trainable(ch::Dict, tr::Dict) = merge(valuemap(_ -> nothing, ch), tr)
_trainable(ch::Dict, tr::Dict) = merge(mapvalue(_ -> nothing, ch), tr)

function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3
map(c -> c in tr ? c : nothing, ch)
end


valuemap(f, x...) = map(f, x...)
valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
valueforeach(f, x...) = foreach(f, x...)
valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
end


###
### rule definition helpers
###
Expand Down
124 changes: 124 additions & 0 deletions src/trainables.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@

"""
trainables(x, path = false)

Return an iterable over all the trainable parameters in `x`, that is all the numerical
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).

Parameters appearing multiple times in the model (tied weights) will be present only once in the output.

If `path = false`, the output is a list of numerical arrays.

If `path = true`, the output is a list of `(KeyPath, AbstractArray)` pairs, where [`KeyPath`](@ref Functors.KeyPath) is a type
representing the path to the array in the original structure.

See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead.

# Examples

```jldoctest
julia> struct MyLayer
w
b
end

julia> Functors.@functor MyLayer

julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example

julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]);

julia> trainables(x)
1-element Vector{AbstractArray}:
[1.0, 2.0, 3.0]

julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);

julia> trainables(x) # collects nested parameters
2-element Vector{AbstractArray}:
[1.0, 2.0]
[3.0]
```

```jldoctest
julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0]));

julia> for (kp, y) in trainables(x, path = true)
println(kp, " => ", y)
end
KeyPath(:a,) => [1.0, 2.0]
KeyPath(:b, 1, "c") => [3.0, 4.0]
KeyPath(:b, 2) => [6.0, 7.0]

julia> getkeypath(x, KeyPath(:b, 1, "c"))
2-element Vector{Float64}:
3.0
4.0
```
"""
function trainables(x; path = false)
if path
return _trainables_with_path(x)
else
return _trainables(x)
end
end


function _trainables(x)
arrays = AbstractArray[]
fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
push!(arrays, y)
return y
end
return arrays
end

function ∇trainables(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return Δ[i+=1]
end
end

function ChainRulesCore.rrule(::typeof(_trainables), x)
y = trainables(x)
trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ)))
return y, trainables_back
end

function _trainables_with_path(x)
named_params = []
exclude(kp, x) = isnumeric(x)
fmap_with_path(x; exclude, walk = TrainableStructWalkWithPath()) do kp, y
push!(named_params, (kp, y))
return y
end
return named_params
end

struct TrainableStructWalkWithPath <: AbstractWalk end

function (::TrainableStructWalkWithPath)(recurse, kp::KeyPath, x)
x_children = trainable(x)
kps = mapkey(c -> KeyPath(kp, c), x_children)
return mapvalue(recurse, kps, x_children)
end

function ChainRulesCore.rrule(::typeof(_trainables_with_path), x)
y = _trainables_with_path(x)
trainables_with_path_back(Δ) = (NoTangent(), ∇trainables_with_path(x, unthunk(Δ)))
return y, trainables_with_path_back
end

function ∇trainables_with_path(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
Δi = Δ[i+=1]
if isnothing(Δi)
return nothing

Check warning on line 119 in src/trainables.jl

View check run for this annotation

Codecov / codecov/patch

src/trainables.jl#L119

Added line #L119 was not covered by tests
else
return Δi[2]
end
end
end
15 changes: 15 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

mapvalue(f, x...) = map(f, x...)
mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)

mapkey(f, x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(map(f, Ks))
mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x))
mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x))
mapkey(f, x::AbstractArray) = [f(i) for i=1:length(x)]

Check warning on line 8 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L6-L8

Added lines #L6 - L8 were not covered by tests

foreachvalue(f, x...) = foreach(f, x...)

foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
end

5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Random.seed!(1)

struct Foo; x; y; end
Functors.@functor Foo
Optimisers.trainable(x::Foo) = (x.y, x.x)
Optimisers.trainable(x::Foo) = (; x.y, x.x)

struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Expand Down Expand Up @@ -539,6 +539,9 @@ end
@testset verbose=true "Destructure" begin
include("destructure.jl")
end
@testset verbose=true "Trainables" begin
include("trainables.jl")
end
@testset verbose=true "Optimisation Rules" begin
include("rules.jl")
end
Expand Down
Loading
Loading