@@ -64,6 +64,14 @@ There is also [`Optimisers.update!`](@ref) which similarly returns a new model a
64
64
but is free to mutate arrays within the old one for efficiency.
65
65
The method of ` apply! ` for each rule is likewise free to mutate arrays within its state;
66
66
they are defensively copied when this rule is used with ` update ` .
67
+ (The method of ` apply! ` above is likewise free to mutate arrays within its state;
68
+ they are defensively copied when this rule is used with ` update ` .)
69
+ For ` Adam() ` , there are two momenta per parameter, thus ` state ` is about twice the size of ` model ` :
70
+
71
+ ``` julia
72
+ Base. summarysize (model) / 1024 ^ 2 # about 45MB
73
+ Base. summarysize (state) / 1024 ^ 2 # about 90MB
74
+ ```
67
75
68
76
Optimisers.jl does not depend on any one automatic differentiation package,
69
77
but for now the most likely source of gradients is [ Zygote.jl] ( https://fluxml.ai/Zygote.jl ) .
@@ -72,6 +80,7 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
72
80
Zygote's "implicit" mode ` gradient(() -> loss(...), Flux.params(model)) ` -- see
73
81
[ Zygote's documentation] ( https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1 ) for more about this difference.
74
82
83
+
75
84
## Usage with [ Yota.jl] ( https://github.com/dfdx/Yota.jl )
76
85
77
86
Yota is another modern automatic differentiation package, an alternative to Zygote.
@@ -89,40 +98,6 @@ loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
89
98
end ;
90
99
```
91
100
92
- Unfortunately this example doesn't actually run right now. This is the error:
93
- ```
94
- julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
95
- sum(m(x))
96
- end;
97
- ┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
98
- │ (f, args) = Yota.RRULE_VIA_AD_STATE[]
99
- └ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
100
- ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
101
-
102
- ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
103
-
104
- Stacktrace:
105
- [1] error(s::String)
106
- @ Base ./error.jl:35
107
- [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
108
- @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
109
- [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
110
- @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
111
- [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
112
- @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
113
- [5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
114
- @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
115
- [6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
116
- @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
117
- [7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
118
- @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
119
- ...
120
-
121
- (jl_GWa2lX) pkg> st
122
- Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
123
- ⌃ [587475ba] Flux v0.13.4
124
- [cd998857] Yota v0.7.4
125
- ```
126
101
127
102
## Usage with [ Lux.jl] ( https://github.com/avik-pal/Lux.jl )
128
103
@@ -163,6 +138,14 @@ y, lux_state = Lux.apply(lux_model, images, params, lux_state);
163
138
Besides the parameters stored in ` params ` and gradually optimised, any other model state
164
139
is stored in ` lux_state ` , and updated by ` Lux.apply ` . (In this example, BatchNorm has state.)
165
140
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
141
+
142
+ ``` julia
143
+ Base. summarysize (lux_model) / 1024 # just 2KB
144
+ Base. summarysize (params) / 1024 ^ 2 # about 45MB, same as Flux model
145
+ Base. summarysize (lux_state) / 1024 # 40KB
146
+ Base. summarysize (opt_state) / 1024 ^ 2 # about 90MB, with Adam
147
+ ```
148
+
166
149
If you are certain there is no model state, then the gradient calculation can
167
150
be simplified to use ` Zygote.gradient ` instead of ` Zygote.pullback ` :
168
151
0 commit comments