Skip to content

Commit 3da605f

Browse files
committed
more
1 parent 01d7386 commit 3da605f

File tree

2 files changed

+164
-24
lines changed

2 files changed

+164
-24
lines changed

src/preallocated.jl

+104-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Flux, ChainRulesCore
22
using LinearAlgebra: mul!
3-
# using FastBroadcast: @..
3+
using FastBroadcast: @..
44
using Strided
55

66
const NoT = NoTangent()
@@ -108,32 +108,62 @@ function ChainRulesCore.rrule(::typeof(scale!), y, (scale, ds), (x, dx), (bias,
108108
end
109109

110110
#####
111-
##### softmax
111+
##### Conv
112112
#####
113113

114-
function PreLayer(::typeof(softmax))
115-
fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal, demands `model |> pre |> gpu`
116-
PreLayer(softmax, nothing, fwd, rev)
114+
function PreLayer(c::Conv)
115+
grad = _struct_sim(c)
116+
fwd, rev = similar(c.weight, 0), similar(c.weight, 0)
117+
PreLayer(c, grad, fwd, rev)
117118
end
118119

119-
function (p::PreLayer{typeof(softmax)})(x::AbstractArray{<:Real})
120-
y, dx = _pre_setup(p, x) # generic version
121-
_softmaxcall!(y, p, x, dx)
120+
function (p::PreLayer{<:Conv})(x::AbstractArray{<:Real})
121+
y, dx = _pre_setup(p, x)
122+
_convcall!(y, p, x, dx)
122123
end
123124

124-
_softmaxcall!(y, p, x, dx) = softmax!(y, x)
125+
using Flux: conv_dims, conv_reshape_bias
126+
using Flux.NNlib: fast_act, conv!, output_size, channels_out
125127

126-
function ChainRulesCore.rrule(::typeof(_softmaxcall!), y, p, x, dx)
127-
y = _softmaxcall!(y, p, x, dx)
128-
function back(dy)
129-
# TODO: CHECK THIS!
130-
dx .= dy .* y
131-
dx .= dx .- y .* sum(dx; dims=1) # could sum! into the end of rev
132-
return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented?
128+
function _pre_setup(p::PreLayer{<:Conv}, x)
129+
cdims = conv_dims(p.layer, x)
130+
ysize = (output_size(cdims)..., channels_out(cdims), size(x)[end])
131+
if prod(ysize) != length(p.fwd)
132+
resize!(p.fwd, prod(ysize))
133+
resize!(p.rev, length(x))
133134
end
134-
y, back
135+
y = _pre_reshape(p.fwd, ysize)
136+
dx = _pre_reshape(p.rev, size(x))
137+
(; y, dx)
138+
end
139+
140+
function _convcall!(y, p, x, dx)
141+
cdims = conv_dims(p.layer, x)
142+
conv!(y, x, p.layer.weight, cdims)
143+
if p.layer.bias isa AbstractArray
144+
y .+= conv_reshape_bias(p.layer)
145+
end
146+
act!(y, fast_act(p.layer.σ, x))
135147
end
136148

149+
# function ChainRulesCore.rrule(::typeof(_convcall!), y, p, x, dx)
150+
# y = _densecall!(y, p, x, dx)
151+
# function back(dy)
152+
# dy = unthunk(dy)
153+
# dy = ∇act!(y, dy, p.layer.σ)
154+
# # layer
155+
# weight = mul!(p.grad.weight, dy, x')
156+
# bias = ∇bias!(p.grad.bias, dy)
157+
# tang = Tangent{Dense}(; weight, bias)
158+
# # input
159+
# dx = mul!(dx, p.layer.weight', dy)
160+
# return (NoT, NoT, Tangent{PreLayer}(; layer = tang), dx, NoT)
161+
# end
162+
# y, back
163+
# end
164+
165+
166+
137167
#####
138168
##### BatchNorm
139169
#####
@@ -201,6 +231,33 @@ function ChainRulesCore.rrule(::typeof(_norm_layer_forward!), y, x, dx, μ, σ²
201231
y, back
202232
end
203233

234+
#####
235+
##### softmax
236+
#####
237+
238+
function PreLayer(::typeof(softmax))
239+
fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal, demands `model |> pre |> gpu`
240+
PreLayer(softmax, nothing, fwd, rev)
241+
end
242+
243+
function (p::PreLayer{typeof(softmax)})(x::AbstractArray{<:Real})
244+
y, dx = _pre_setup(p, x) # generic version
245+
_softmaxcall!(y, p, x, dx)
246+
end
247+
248+
_softmaxcall!(y, p, x, dx) = softmax!(y, x)
249+
250+
function ChainRulesCore.rrule(::typeof(_softmaxcall!), y, p, x, dx)
251+
y = _softmaxcall!(y, p, x, dx)
252+
function back(dy)
253+
# TODO: CHECK THIS!
254+
dx .= dy .* y
255+
dx .= dx .- y .* sum(dx; dims=1) # could sum! into the end of rev
256+
return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented?
257+
end
258+
y, back
259+
end
260+
204261

205262
#####
206263
##### activation functions
@@ -212,8 +269,8 @@ function act!(y, act::F) where F
212269
# y .= σ.(y)
213270
# Unfortunately this hits https://github.com/JuliaLang/julia/issues/43153
214271
# maybe you could patch Strided.jl to avoid it? Or use another package...
215-
@strided y .= σ.(y)
216-
# FastBroadcast.@.. y = σ(y)
272+
# @strided y .= σ.(y)
273+
@.. y = σ(y)
217274
end
218275

219276
# Piracy, disable @strided on CuArrays:
@@ -223,10 +280,31 @@ Strided.maybestrided(x::Flux.CuArray) = x
223280
ChainRulesCore.rrule(::typeof(act!), y, f) = act!(y, f), dz -> (NoT, ∇act!(y, dy, f), NoT)
224281

225282
∇act!(y, dy, ::typeof(identity)) = dy
226-
∇act!(y, dy, ::typeof(relu)) = @. y = ifelse(y>0, dy, 0f0)
227-
∇act!(y, dy, ::typeof(tanh)) = @. y = (1 - y^2)
228-
∇act!(y, dy, ::typeof(sigmoid)) = @. y = y * (1 - y)
283+
∇act!(y, dy, ::typeof(relu)) = @.. y = ifelse(y>0, dy, 0f0)
284+
∇act!(y, dy, ::typeof(tanh)) = @.. y = (1 - y^2)
285+
∇act!(y, dy, ::typeof(sigmoid)) = @.. y = y * (1 - y)
286+
287+
288+
function PreLayer(::typeof(relu))
289+
fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal
290+
PreLayer(relu, nothing, fwd, rev)
291+
end
292+
293+
function (p::PreLayer{typeof(relu)})(x::AbstractArray{<:Real})
294+
y, dx = _pre_setup(p, x) # generic version
295+
_relucall!(y, p, x, dx)
296+
end
229297

298+
_relucall!(y, p, x, dx) = y .= relu.(x)
299+
300+
function ChainRulesCore.rrule(::typeof(_relucall!), y, p, x, dx)
301+
y = _relucall!(y, p, x, dx)
302+
function back(dy)
303+
@. dx = ifelse(y>0, dy, 0f0)
304+
return (NoT, NoT, NoT, dx, NoT)
305+
end
306+
y, back
307+
end
230308

231309
#####
232310
##### PreLayer utils
@@ -249,10 +327,14 @@ ChainRulesCore.@non_differentiable _pre_setup(::Any, ::Any)
249327

250328
# Cannot use reshape(::Array), as that prevents later resize!
251329
_pre_reshape(x::Array, size::Tuple) = Base.ReshapedArray(x, size, ())
330+
# _pre_reshape(x::Array, size::Tuple) = Base.__reshape((x, Base.IndexStyle(x)), size) # what Base does, no better
252331
# Must use reshape(::CuArray) as mul! rejects ReshapedArray
253332
_pre_reshape(x::Flux.CuArray, size::Tuple) = reshape(x, size)
254333
_pre_reshape(x, size::Tuple) = reshape(x, size)
255334

335+
# Base piracy! to prevent ReshapedArray from going missing
336+
Base._reshape(R::Base.ReshapedArray, dims::Base.Dims) = Base.ReshapedArray(R.parent, dims, ())
337+
256338
∇bias!(::Bool, dx) = NoT
257339
∇bias!(bias, dx) = sum!(bias, dx)
258340

test/preallocated.jl

+60-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ g2 = gradient((m,x) -> m(x)[1], m2, x)
1616
#=
1717
1818
julia> @btime gradient((m,x) -> m(x)[1], $m1, $x);
19-
min 52.167 μs, mean 2.519 ms (58 allocations, 355.41 KiB)
19+
min 50.167 μs, mean 88.796 μs (58 allocations, 355.41 KiB)
2020
2121
julia> @btime gradient((m,x) -> m(x)[1], $m2, $x);
22-
min 58.750 μs, mean 190.440 μs (109 allocations, 17.44 KiB)
22+
min 57.792 μs, mean 66.050 μs (115 allocations, 17.75 KiB)
2323
2424
2525
@@ -33,9 +33,14 @@ let data = [(x,) for _ in 1:1000]
3333
nothing
3434
end
3535
36+
# Yesterday:
3637
# min 1.799 s, mean 1.802 s (177001 allocations, 352.94 MiB)
3738
# min 146.713 ms, mean 251.041 ms (295001 allocations, 25.71 MiB)
3839
40+
# Today, wtf? Maybe threading changes have hurt.
41+
# min 244.235 ms, mean 251.582 ms (177001 allocations, 352.94 MiB)
42+
# min 224.760 ms, mean 227.594 ms (301001 allocations, 26.02 MiB)
43+
3944
4045
m1cu = m1 |> gpu
4146
m2cu = m2 |> gpu
@@ -78,3 +83,56 @@ julia> @btime $m4($x);
7883
7984
=#
8085

86+
x4 = randn(Float32, 28, 28, 1, 13);
87+
88+
m5 = @autosize (size(x4)...,) Chain(
89+
Conv((3,3), 1 => 7, relu, stride=2, pad=1),
90+
Conv((3,3), _ => 9, relu, stride=2),
91+
Conv((3,3), _ => 5, tanh, stride=2, bias=false),
92+
Flux.flatten,
93+
Dense(_ => 10),
94+
)
95+
m6 = m5 |> pre
96+
97+
@test m5(x4) m6(x4)
98+
99+
#=
100+
101+
julia> @btime $m5($x4);
102+
min 139.125 μs, mean 191.653 μs (179 allocations, 262.73 KiB)
103+
104+
julia> @btime $m6($x4);
105+
min 140.125 μs, mean 196.337 μs (160 allocations, 86.39 KiB)
106+
107+
=#
108+
109+
110+
using Metalhead
111+
m50 = Metalhead.ResNet(50) # 100MB
112+
m50pre = m50 |> pre # 200BM
113+
114+
115+
# First run
116+
117+
julia> @time m50(randn(Float32, 100,100,3,32)) |> size
118+
5.543590 seconds (6.11 M allocations: 1.963 GiB, 14.14% gc time, 96.22% compilation time)
119+
(1000, 32)
120+
121+
julia> @time m50pre(randn(Float32, 100,100,3,32)) |> size
122+
16.098089 seconds (15.84 M allocations: 2.576 GiB, 62.26% gc time, 69.06% compilation time)
123+
(1000, 32)
124+
125+
# Later
126+
127+
128+
julia> @time m50(randn(Float32, 100,100,3,32)) |> size
129+
11.541100 seconds (4.40 k allocations: 1.570 GiB, 85.73% gc time)
130+
(1000, 32)
131+
132+
julia> @time m50pre(randn(Float32, 100,100,3,32)) |> size
133+
4.664626 seconds (4.09 k allocations: 381.454 MiB, 61.15% gc time)
134+
(1000, 32)
135+
136+
137+
m50pre # now 1.340 GiB
138+

0 commit comments

Comments
 (0)