Skip to content

Commit 2bb637a

Browse files
committed
tidy, add summarysize
1 parent a6e85fc commit 2bb637a

File tree

1 file changed

+17
-34
lines changed

1 file changed

+17
-34
lines changed

docs/src/index.md

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ There is also [`Optimisers.update!`](@ref) which similarly returns a new model a
6464
but is free to mutate arrays within the old one for efficiency.
6565
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
6666
they are defensively copied when this rule is used with `update`.
67+
(The method of `apply!` above is likewise free to mutate arrays within its state;
68+
they are defensively copied when this rule is used with `update`.)
69+
For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`:
70+
71+
```julia
72+
Base.summarysize(model) / 1024^2 # about 45MB
73+
Base.summarysize(state) / 1024^2 # about 90MB
74+
```
6775

6876
Optimisers.jl does not depend on any one automatic differentiation package,
6977
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
@@ -72,6 +80,7 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
7280
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
7381
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
7482

83+
7584
## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)
7685

7786
Yota is another modern automatic differentiation package, an alternative to Zygote.
@@ -89,40 +98,6 @@ loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
8998
end;
9099
```
91100

92-
Unfortunately this example doesn't actually run right now. This is the error:
93-
```
94-
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
95-
sum(m(x))
96-
end;
97-
┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
98-
│ (f, args) = Yota.RRULE_VIA_AD_STATE[]
99-
└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
100-
ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
101-
102-
ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
103-
104-
Stacktrace:
105-
[1] error(s::String)
106-
@ Base ./error.jl:35
107-
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
108-
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
109-
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
110-
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
111-
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
112-
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
113-
[5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
114-
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
115-
[6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
116-
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
117-
[7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
118-
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
119-
...
120-
121-
(jl_GWa2lX) pkg> st
122-
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
123-
⌃ [587475ba] Flux v0.13.4
124-
[cd998857] Yota v0.7.4
125-
```
126101

127102
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
128103

@@ -163,6 +138,14 @@ y, lux_state = Lux.apply(lux_model, images, params, lux_state);
163138
Besides the parameters stored in `params` and gradually optimised, any other model state
164139
is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
165140
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
141+
142+
```julia
143+
Base.summarysize(lux_model) / 1024 # just 2KB
144+
Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model
145+
Base.summarysize(lux_state) / 1024 # 40KB
146+
Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam
147+
```
148+
166149
If you are certain there is no model state, then the gradient calculation can
167150
be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:
168151

0 commit comments

Comments
 (0)