|
1 | 1 |
|
2 |
| -using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk |
3 |
| -const NoT = NoTangent() |
| 2 | +using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad |
| 3 | +const NoT = ChainRulesCore.NoTangent() |
4 | 4 |
|
5 | 5 | """
|
6 | 6 | destructure(model) -> vector, reconstructor
|
@@ -107,9 +107,11 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
|
107 | 107 | y = _trainmap(f, ch, _trainable(x), au)
|
108 | 108 | y isa Tuple{} && return NoT
|
109 | 109 | p = ProjectTo(x)
|
110 |
| - if p isa ProjectTo # e.g. Array, NamedTuple |
111 |
| - p(y) |
112 |
| - else # p === identity for unknown structs |
| 110 | + # if p isa ProjectTo # e.g. Array, NamedTuple |
| 111 | + # p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538 |
| 112 | + if x isa Union{Number, AbstractArray} # these don't use Tangent |
| 113 | + ProjectTo(x)(unthunk(y)) |
| 114 | + else |
113 | 115 | Tangent{typeof(x), typeof(y)}(y)
|
114 | 116 | end
|
115 | 117 | end
|
@@ -150,3 +152,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
|
150 | 152 | @warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
|
151 | 153 | nothing, _ -> (NoT,)
|
152 | 154 | end
|
| 155 | + |
| 156 | +""" |
| 157 | + total(f, model) |
| 158 | +
|
| 159 | +Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in |
| 160 | +the model, and returns the sum. Differentiable. Counts shared weights once. |
| 161 | +
|
| 162 | +# Examples |
| 163 | +```jldoctest |
| 164 | +julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7)); |
| 165 | +
|
| 166 | +julia> total(sum, m) |
| 167 | +12.0 |
| 168 | +
|
| 169 | +julia> total(norm, m) |
| 170 | +10.0 |
| 171 | +
|
| 172 | +julia> total(length, m) == length(destructure(m)[1]) |
| 173 | +true |
| 174 | +``` |
| 175 | +""" |
| 176 | +function total(f, x) |
| 177 | + values = [] |
| 178 | + fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) |
| 179 | + sum(values) |
| 180 | +end |
| 181 | + |
| 182 | +function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x) |
| 183 | + z, backs = _total_hobbit(config, f, x) |
| 184 | + total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...) |
| 185 | + z, total_back |
| 186 | +end |
| 187 | + |
| 188 | +function _total_hobbit(config::RuleConfig, f, x) |
| 189 | + values = [] |
| 190 | + backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y |
| 191 | + val, back = rrule_via_ad(config, f, y) |
| 192 | + push!(values, val) |
| 193 | + back |
| 194 | + end |
| 195 | + sum(values), backs |
| 196 | +end |
| 197 | + |
| 198 | +function _total_grad(dz, x, backs) |
| 199 | + dfs = [] |
| 200 | + dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b |
| 201 | + df, dy = b(dz) |
| 202 | + push!(dfs, df) |
| 203 | + dy |
| 204 | + end |
| 205 | + sum(dfs), dx |
| 206 | +end |
| 207 | + |
| 208 | +function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs) |
| 209 | + @warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3 |
| 210 | + function grad_back((df, dx)) |
| 211 | + df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!" |
| 212 | + (NoT, total(dx), NoT, NoT) |
| 213 | + end |
| 214 | + _total_grad(dz, x, backs), grad_back |
| 215 | +end |
0 commit comments