Skip to content

Commit fffd297

Browse files
committed
add total
1 parent 1e34fa2 commit fffd297

File tree

4 files changed

+95
-7
lines changed

4 files changed

+95
-7
lines changed

docs/src/api.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@ optimiser to act on all suitable fields. To restrict this, define `trainable`:
4242
Optimisers.trainable
4343
```
4444

45-
Such restrictions are also obeyed by this function for flattening a model:
45+
Such restrictions are also obeyed by this function for flattening a model,
46+
and one for applying a function to every parameter:
4647

4748
```@docs
4849
Optimisers.destructure
4950
Optimisers.Restructure
51+
Optimisers.total
5052
```
5153

5254
## Rule Definition

src/Optimisers.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using LinearAlgebra
66
include("interface.jl")
77

88
include("destructure.jl")
9-
export destructure, total, total2
9+
export destructure, total
1010

1111
include("rules.jl")
1212
export Descent, ADAM, Momentum, Nesterov, RMSProp,

src/destructure.jl

+68-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
3-
const NoT = NoTangent()
2+
using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad
3+
const NoT = ChainRulesCore.NoTangent()
44

55
"""
66
destructure(model) -> vector, reconstructor
@@ -107,9 +107,11 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
107107
y = _trainmap(f, ch, _trainable(x), au)
108108
y isa Tuple{} && return NoT
109109
p = ProjectTo(x)
110-
if p isa ProjectTo # e.g. Array, NamedTuple
111-
p(y)
112-
else # p === identity for unknown structs
110+
# if p isa ProjectTo # e.g. Array, NamedTuple
111+
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538
112+
if x isa Union{Number, AbstractArray} # these don't use Tangent
113+
ProjectTo(x)(unthunk(y))
114+
else
113115
Tangent{typeof(x), typeof(y)}(y)
114116
end
115117
end
@@ -150,3 +152,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
150152
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
151153
nothing, _ -> (NoT,)
152154
end
155+
156+
"""
157+
total(f, model)
158+
159+
Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in
160+
the model, and returns the sum. Differentiable. Counts shared weights once.
161+
162+
# Examples
163+
```jldoctest
164+
julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));
165+
166+
julia> total(sum, m)
167+
12.0
168+
169+
julia> total(norm, m)
170+
10.0
171+
172+
julia> total(length, m) == length(destructure(m)[1])
173+
true
174+
```
175+
"""
176+
function total(f, x)
177+
values = []
178+
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z)))
179+
sum(values)
180+
end
181+
182+
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x)
183+
z, backs = _total_hobbit(config, f, x)
184+
total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...)
185+
z, total_back
186+
end
187+
188+
function _total_hobbit(config::RuleConfig, f, x)
189+
values = []
190+
backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
191+
val, back = rrule_via_ad(config, f, y)
192+
push!(values, val)
193+
back
194+
end
195+
sum(values), backs
196+
end
197+
198+
function _total_grad(dz, x, backs)
199+
dfs = []
200+
dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b
201+
df, dy = b(dz)
202+
push!(dfs, df)
203+
dy
204+
end
205+
sum(dfs), dx
206+
end
207+
208+
function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs)
209+
@warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3
210+
function grad_back((df, dx))
211+
df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!"
212+
(NoT, total(dx), NoT, NoT)
213+
end
214+
_total_grad(dz, x, backs), grad_back
215+
end

test/destructure.jl

+23
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,26 @@ end
164164
4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one
165165
end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
166166
end
167+
168+
@testset "total" begin
169+
@test total(sum, m1) == sum(1:3)
170+
@test total(prod, m2) == prod(1:3) + prod(4:6)
171+
@test total(sum, m3) == sum(1:6)
172+
@test total(sum, m4) == sum(1:6) # shared only counts once
173+
@test total(sum, m6) == 6 + 4 + im
174+
175+
@test gradient(m -> total(sum, m), m1) == ([1,1,1],)
176+
@test gradient(m -> total(sum, m), m3)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
177+
@test gradient(m -> total(sum, m), m4)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
178+
g6 = gradient(m -> abs2(total(sum, m)), m6)[1]
179+
@test g6.a isa Vector{Float64}
180+
181+
@test gradient-> total(x -> sum(x.*λ), m3), 1.0) == (21.0,)
182+
@test gradient-> total(x -> sum(x.*λ), m4), 1.0) == (21.0,)
183+
184+
@testset "second derivatives" begin
185+
f3 = v -> total(norm, (x=v, y=sin, z=[4,5,6.0]))
186+
@test_broken Zygote.hessian_reverse(f3, [1,2,3.0]) Zygote.hessian_dual(f3, [1,2,3.0])
187+
# typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple...
188+
end
189+
end

0 commit comments

Comments
 (0)