1
1
using Flux, ChainRulesCore
2
2
using LinearAlgebra: mul!
3
- # using FastBroadcast: @..
3
+ using FastBroadcast: @. .
4
4
using Strided
5
5
6
6
const NoT = NoTangent ()
@@ -108,32 +108,62 @@ function ChainRulesCore.rrule(::typeof(scale!), y, (scale, ds), (x, dx), (bias,
108
108
end
109
109
110
110
# ####
111
- # #### softmax
111
+ # #### Conv
112
112
# ####
113
113
114
- function PreLayer (:: typeof (softmax))
115
- fwd, rev = zeros (Float32, 0 ), zeros (Float32, 0 ) # not ideal, demands `model |> pre |> gpu`
116
- PreLayer (softmax, nothing , fwd, rev)
114
+ function PreLayer (c:: Conv )
115
+ grad = _struct_sim (c)
116
+ fwd, rev = similar (c. weight, 0 ), similar (c. weight, 0 )
117
+ PreLayer (c, grad, fwd, rev)
117
118
end
118
119
119
- function (p:: PreLayer{typeof(softmax) } )(x:: AbstractArray{<:Real} )
120
- y, dx = _pre_setup (p, x) # generic version
121
- _softmaxcall ! (y, p, x, dx)
120
+ function (p:: PreLayer{<:Conv } )(x:: AbstractArray{<:Real} )
121
+ y, dx = _pre_setup (p, x)
122
+ _convcall ! (y, p, x, dx)
122
123
end
123
124
124
- _softmaxcall! (y, p, x, dx) = softmax! (y, x)
125
+ using Flux: conv_dims, conv_reshape_bias
126
+ using Flux. NNlib: fast_act, conv!, output_size, channels_out
125
127
126
- function ChainRulesCore. rrule (:: typeof (_softmaxcall!), y, p, x, dx)
127
- y = _softmaxcall! (y, p, x, dx)
128
- function back (dy)
129
- # TODO : CHECK THIS!
130
- dx .= dy .* y
131
- dx .= dx .- y .* sum (dx; dims= 1 ) # could sum! into the end of rev
132
- return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented?
128
+ function _pre_setup (p:: PreLayer{<:Conv} , x)
129
+ cdims = conv_dims (p. layer, x)
130
+ ysize = (output_size (cdims)... , channels_out (cdims), size (x)[end ])
131
+ if prod (ysize) != length (p. fwd)
132
+ resize! (p. fwd, prod (ysize))
133
+ resize! (p. rev, length (x))
133
134
end
134
- y, back
135
+ y = _pre_reshape (p. fwd, ysize)
136
+ dx = _pre_reshape (p. rev, size (x))
137
+ (; y, dx)
138
+ end
139
+
140
+ function _convcall! (y, p, x, dx)
141
+ cdims = conv_dims (p. layer, x)
142
+ conv! (y, x, p. layer. weight, cdims)
143
+ if p. layer. bias isa AbstractArray
144
+ y .+ = conv_reshape_bias (p. layer)
145
+ end
146
+ act! (y, fast_act (p. layer. σ, x))
135
147
end
136
148
149
+ # function ChainRulesCore.rrule(::typeof(_convcall!), y, p, x, dx)
150
+ # y = _densecall!(y, p, x, dx)
151
+ # function back(dy)
152
+ # dy = unthunk(dy)
153
+ # dy = ∇act!(y, dy, p.layer.σ)
154
+ # # layer
155
+ # weight = mul!(p.grad.weight, dy, x')
156
+ # bias = ∇bias!(p.grad.bias, dy)
157
+ # tang = Tangent{Dense}(; weight, bias)
158
+ # # input
159
+ # dx = mul!(dx, p.layer.weight', dy)
160
+ # return (NoT, NoT, Tangent{PreLayer}(; layer = tang), dx, NoT)
161
+ # end
162
+ # y, back
163
+ # end
164
+
165
+
166
+
137
167
# ####
138
168
# #### BatchNorm
139
169
# ####
@@ -201,6 +231,33 @@ function ChainRulesCore.rrule(::typeof(_norm_layer_forward!), y, x, dx, μ, σ²
201
231
y, back
202
232
end
203
233
234
+ # ####
235
+ # #### softmax
236
+ # ####
237
+
238
+ function PreLayer (:: typeof (softmax))
239
+ fwd, rev = zeros (Float32, 0 ), zeros (Float32, 0 ) # not ideal, demands `model |> pre |> gpu`
240
+ PreLayer (softmax, nothing , fwd, rev)
241
+ end
242
+
243
+ function (p:: PreLayer{typeof(softmax)} )(x:: AbstractArray{<:Real} )
244
+ y, dx = _pre_setup (p, x) # generic version
245
+ _softmaxcall! (y, p, x, dx)
246
+ end
247
+
248
+ _softmaxcall! (y, p, x, dx) = softmax! (y, x)
249
+
250
+ function ChainRulesCore. rrule (:: typeof (_softmaxcall!), y, p, x, dx)
251
+ y = _softmaxcall! (y, p, x, dx)
252
+ function back (dy)
253
+ # TODO : CHECK THIS!
254
+ dx .= dy .* y
255
+ dx .= dx .- y .* sum (dx; dims= 1 ) # could sum! into the end of rev
256
+ return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented?
257
+ end
258
+ y, back
259
+ end
260
+
204
261
205
262
# ####
206
263
# #### activation functions
@@ -212,8 +269,8 @@ function act!(y, act::F) where F
212
269
# y .= σ.(y)
213
270
# Unfortunately this hits https://github.com/JuliaLang/julia/issues/43153
214
271
# maybe you could patch Strided.jl to avoid it? Or use another package...
215
- @strided y .= σ .(y)
216
- # FastBroadcast. @.. y = σ(y)
272
+ # @strided y .= σ.(y)
273
+ @. . y = σ (y)
217
274
end
218
275
219
276
# Piracy, disable @strided on CuArrays:
@@ -223,10 +280,31 @@ Strided.maybestrided(x::Flux.CuArray) = x
223
280
ChainRulesCore. rrule (:: typeof (act!), y, f) = act! (y, f), dz -> (NoT, ∇act! (y, dy, f), NoT)
224
281
225
282
∇act! (y, dy, :: typeof (identity)) = dy
226
- ∇act! (y, dy, :: typeof (relu)) = @. y = ifelse (y> 0 , dy, 0f0 )
227
- ∇act! (y, dy, :: typeof (tanh)) = @. y = (1 - y^ 2 )
228
- ∇act! (y, dy, :: typeof (sigmoid)) = @. y = y * (1 - y)
283
+ ∇act! (y, dy, :: typeof (relu)) = @. . y = ifelse (y> 0 , dy, 0f0 )
284
+ ∇act! (y, dy, :: typeof (tanh)) = @. . y = (1 - y^ 2 )
285
+ ∇act! (y, dy, :: typeof (sigmoid)) = @. . y = y * (1 - y)
286
+
287
+
288
+ function PreLayer (:: typeof (relu))
289
+ fwd, rev = zeros (Float32, 0 ), zeros (Float32, 0 ) # not ideal
290
+ PreLayer (relu, nothing , fwd, rev)
291
+ end
292
+
293
+ function (p:: PreLayer{typeof(relu)} )(x:: AbstractArray{<:Real} )
294
+ y, dx = _pre_setup (p, x) # generic version
295
+ _relucall! (y, p, x, dx)
296
+ end
229
297
298
+ _relucall! (y, p, x, dx) = y .= relu .(x)
299
+
300
+ function ChainRulesCore. rrule (:: typeof (_relucall!), y, p, x, dx)
301
+ y = _relucall! (y, p, x, dx)
302
+ function back (dy)
303
+ @. dx = ifelse (y> 0 , dy, 0f0 )
304
+ return (NoT, NoT, NoT, dx, NoT)
305
+ end
306
+ y, back
307
+ end
230
308
231
309
# ####
232
310
# #### PreLayer utils
@@ -249,10 +327,14 @@ ChainRulesCore.@non_differentiable _pre_setup(::Any, ::Any)
249
327
250
328
# Cannot use reshape(::Array), as that prevents later resize!
251
329
_pre_reshape (x:: Array , size:: Tuple ) = Base. ReshapedArray (x, size, ())
330
+ # _pre_reshape(x::Array, size::Tuple) = Base.__reshape((x, Base.IndexStyle(x)), size) # what Base does, no better
252
331
# Must use reshape(::CuArray) as mul! rejects ReshapedArray
253
332
_pre_reshape (x:: Flux.CuArray , size:: Tuple ) = reshape (x, size)
254
333
_pre_reshape (x, size:: Tuple ) = reshape (x, size)
255
334
335
+ # Base piracy! to prevent ReshapedArray from going missing
336
+ Base. _reshape (R:: Base.ReshapedArray , dims:: Base.Dims ) = Base. ReshapedArray (R. parent, dims, ())
337
+
256
338
∇bias! (:: Bool , dx) = NoT
257
339
∇bias! (bias, dx) = sum! (bias, dx)
258
340
0 commit comments