diff --git a/src/utils.jl b/src/utils.jl index c1888829d4..e2b5380fd5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -382,8 +382,9 @@ to the constructor's keyword `bias=bias`. * `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted. """ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) - bias ? fill!(similar(weights, dims...), 0) : Zeros() + bias ? fill!(similar(weights, dims...), 0) : Zeros(dims...) end + function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) bias diff --git a/src/zeros.jl b/src/zeros.jl index 1281f4c87a..1c4733e9a8 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -14,39 +14,183 @@ Useful to turn bias off for a forward pass of a layer. julia> bias_less_conv = Conv((2,2), 1=>3; bias = false) Conv((2, 2), 1=>3) -julia> params(bias_less_conv) |> length -1 - julia> bias_less_conv.bias Flux.Zeros() ``` """ -struct Zeros end -# To allow for things like Dense(10, 2, initb = Zeros) -Zeros(args...) = Zeros() +mutable struct Zeros{T,N} <: AbstractArray{T,N} + dims::NTuple{N,Int} +end + +Zeros(::Type{T}, dims...) where T = Zeros{T,length(dims)}(dims) +Zeros(dims...) = Zeros(Bool, dims...) + +Base.reshape(x::Zeros{T}, dims::Union{Colon,Int}...) where T = Zeros(T, Base._reshape_uncolon(x, dims)...) + +function Base.getindex(z::Zeros{T}, args...) where T + Base.checkbounds(z, args...) + zero(T) +end + +Base.collect(x::Zeros{T}) where T = zeros(T, x.dims...) + +Base.size(xs::Zeros) = xs.dims +Base.copyto!(a::Zeros, b::Zeros) = b + +# Base.print_array(io::IO, z::Zeros{T}) where T = print(io, "Zeros object with size $(z.dims)") + +Flux.CUDA.Adapt.adapt(to, x::Zeros) = x + +@adjoint reshape(xs::Zeros{T}, dims...) where T = + reshape(xs, dims...), _ -> nothing + +# Define basic ops +for f in (:+, :-) + @eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) + @assert size(a) == size(b) throw(DimensionMismatch("dimensions must match")) + a + end +end + ++(a::Zeros, b::AbstractArray) = b + a +-(a::Zeros, b::AbstractArray) = -b + a + +Base.copy(xs::Zeros{T,N}) where {T,N} = xs + +for op in (:+, :-) + @eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros) + bs = Broadcast.broadcast_shape(size(a), size(b)) + size(a) == bs && return a + sz = similar(a, bs) + sz .= a + end +end +broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a) +broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a) -Base.reshape(x::Zeros, dims...) = x +# a * b +function *(a::Flux.Zeros, b::AbstractMatrix{<: Number}) + sa = size(a) + sb = size(b) + @assert sa[2] == sb[1] throw(DimensionMismatch("dimensions must match")) + zero(b) +end +*(a::AbstractMatrix{<: Number}, b::Zeros) = b * a -+(::Zeros, b::AbstractArray) = b -+(a::AbstractArray, ::Zeros) = a -+(a::Zeros, ::Zeros) = a +function broadcasted(::typeof(*), a::AbstractArray{T}, b::Zeros) where {T} + bs = Broadcast.broadcast_shape(size(a), size(b)) + fill!(similar(a, bs), zero(T)) +end +broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = b .* a --(::Zeros, b::AbstractArray) = -b --(a::AbstractArray, ::Zeros) = a --(a::Zeros, ::Zeros) = a +# Adjoints -# Some opportunities to avoid scalar indexing, intermediaries +# grad( a $op b ) +for op in (:+, :-, :*) + @eval @adjoint function $op(a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} + $op(a, b), Δ -> begin + (Δ, nothing) + end + end + + if op === :- + continue + end + @eval @adjoint function $op(a::Zeros, b::AbstractArray) + $op(a, b), Δ -> begin + @show size(a), size(b) + (nothing, Δ) + end + end +end +@adjoint function -(a::Zeros, b::AbstractArray) + a - b, Δ -> (nothing, -Δ) +end + + +# grad( broadcast($op, a, b) ) +for op in (:+, :-, :*) + @eval @adjoint function Base.broadcasted(::typeof($op), a::AbstractArray{T,N}, b::Zeros{S,M}) where {T <: Number, S <: Number, N,M} + Base.broadcasted($op, a, b), Δ -> begin + dims = M > N ? tuple(setdiff(1:M, 1:N)...) : tuple(setdiff(1:N, 1:M)...) + da = dims == Tuple(1:N) ? Δ : M == N ? broadcast($op, Δ, b) : dropdims(sum(Δ, dims = dims), dims = dims) + (nothing, da, nothing) + end + end + + if op === :- + continue + end + @eval @adjoint function Base.broadcasted(::typeof($op), a::Zeros{T, N}, b::AbstractArray{S, M}) where {T <: Number, S <: Number, M, N} + Base.broadcasted($op, a, b), Δ -> begin + dims = N > M ? tuple(setdiff(1:N, 1:M)...) : tuple(setdiff(1:M, 1:N)...) + db = dims == Tuple(1:M) ? (Δ .* a) : M == N ? broadcast($op, a, Δ) : dropdims(sum(Δ .* a, dims = dims), dims = dims) + (nothing, nothing, db) + end + end +end + + +@adjoint function Base.broadcasted(::typeof(-), a::Zeros{T, N}, b::AbstractArray{S, M}) where {T <: Number, S <: Number, M, N} + a .- b, Δ -> begin + dims = N > M ? tuple(setdiff(1:N, 1:M)...) : tuple(setdiff(1:M, 1:N)...) + da = dims == Tuple(1:M) ? Δ : M == N ? Δ .- a : dropdims(sum(Δ, dims = dims), dims = dims) + (nothing, nothing, -da) + end +end + +# / +_div(a, sa, f) = fill!(similar(a, sa), f) +function /(a::Zeros, b::AbstractMatrix{T}) where T + sa = size(a) + sb = size(b) + @assert sa[end] == sb[end] throw(DimensionMismatch()) + _div(b, size(b), zero(T)) +end +function /(a::AbstractMatrix, b::Zeros) + sa = size(a) + sb = size(b) + @assert sa[end] == sb[end] throw(DimensionMismatch()) + _div(a, size(a), Inf) +end +function broadcasted(::typeof(/), a::Zeros, b::AbstractArray{T}) where T + bs = Broadcast.broadcast_shape(size(a), size(b)) + _div(b, bs, zero(T)) +end +function broadcasted(::typeof(/), a::AbstractArray, b::Zeros) + bs = Broadcast.broadcast_shape(size(a), size(b)) + _div(a, bs, Inf) +end + +# grad( a / b) +@adjoint function /(a::Zeros, b::AbstractArray) + a / b, Δ -> (nothing, zero(Δ)) +end +@adjoint function /(a::AbstractArray, b::Zeros) + a / b, Δ -> (zero(Δ), nothing) +end + +# grad( broadcast(/, a, b) ) +@adjoint function broadcasted(::typeof(/), a::Zeros{<: Number}, b::AbstractArray{<: Number}) # where T <: Number + sa, sb = size(a), size(b) + T = eltype(b) + a ./ b, Δ -> (nothing, nothing, _div(b, Broadcast.broadcast_shape(sa, sb), zero(T))) +end +@adjoint function broadcasted(::typeof(/), a::AbstractArray{<: Number}, b::Zeros{<: Number}) + sa, sb = size(a), size(b) + T = eltype(b) + a ./ b, Δ -> (nothing, _div(a, sa, Inf), nothing) +end + +Base.sum(z::Zeros{T}) where T = zero(T) + +# Some opportunities to avoid scalar indexing/ intermediaries # Since it replicates a little of what we expect Base to do, # it should be possible to remove in the future, but for now, # these help with performance. -broadcasted(::typeof(+), a::AbstractArray, b::Zeros) = a -broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = b -broadcasted(::typeof(-), a::AbstractArray, b::Zeros) = a -broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = -b -# Need adjoints for these or else the gradient w.r.t to the non-Zeros arg will be nothing as well -@adjoint broadcasted(::typeof(*), a::AbstractArray, b::Zeros) = zero(a), _ -> (nothing, zero(a), nothing) -@adjoint broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) -@adjoint broadcasted(::typeof(/), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) - -# Pass-through for layer constructors -create_bias(weights::AbstractArray, bias::Flux.Zeros, dims::Integer...) = bias +broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b +broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b + +broadcasted(::typeof(conj), z::Zeros) = z diff --git a/test/cuda/runtests.jl b/test/cuda/runtests.jl index 8ed3d66eb4..129f8dde6e 100644 --- a/test/cuda/runtests.jl +++ b/test/cuda/runtests.jl @@ -5,6 +5,27 @@ using Zygote: pullback @info "Testing GPU Support" CUDA.allowscalar(false) +function gpu_gradtest(f, args...) + args_gpu = gpu.(args) + + l_cpu, back_cpu = pullback((x...) -> f(x...), args...) + g_cpu = back_cpu(1f0)[1] + + l_gpu, back_gpu = pullback((x...) -> f(x...), args_gpu...) + g_gpu = back_gpu(1f0)[1] + + @test l_cpu ≈ l_gpu rtol=1e-4 atol=1e-4 + @test g_gpu isa CuArray + @test g_cpu ≈ collect(g_gpu) rtol=1e-4 atol=1e-4 +end + +@testset "Moving Zeros to GPU" begin + z = Flux.Zeros() + z2 = Flux.Zeros(3,3) + @test z === gpu(z) + @test z2 === gpu(z2) +end + include("test_utils.jl") include("cuda.jl") include("losses.jl") diff --git a/test/utils.jl b/test/utils.jl index 6b487e7854..952b5127a4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -247,21 +247,58 @@ end end @testset "Zeros" begin - m = Dense(3,2; bias=false) + + m = Dense(3, 2; bias = false) @test f64(m).bias === m.bias === Zeros() @test f32(m).bias === m.bias === Zeros() - @testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3)) + @testset "Gradients for broadcasted $op with sizes $s" for op in (+, -, *), s in ((1,), (2,3)) o = ones(s) z = zeros(s) - Z = Zeros() + Z = Zeros(s...) + a = ones(3,3) + b = zeros(3,3) + b′ = Zeros(3,3) + + + @testset "Basic operations" begin + a = rand(3,3) + b = zeros(3,3) + bz = Zeros(3,3) + + for op in (+, -) + @test op(a, b) == op(a, bz) + end + + for op in (+, -) + gs = gradient((a, b) -> sum(op(a, b)), a, b) + gsz = gradient((a, b) -> sum(op(a, b)), a, bz) + @test gs[1] == gsz[1] + @test gsz[2] === nothing + end + + # Check with broadcasting + b = zeros(3,3,3) + bz = Zeros(3,3,3) + + for op in (+, -) + @test broadcast(op, a, b) == broadcast(op, a, bz) + end + + for op in (+, -) + gs = gradient((a,b) -> sum(broadcast(op, a, b)), a, b) + gsz = gradient((a,b) -> sum(broadcast(op, a, b)), a, bz) + @test gs[1] == gsz[1] + @test gsz[2] === nothing + end + end @testset "Explicit" begin - gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) + gfun(args...) = gradient((x, y) -> sum(op.(x, y)), args...) g = gfun(o, z) @test gfun(o, Z) == (g[1], nothing) - g = gfun(z, o) + g = gfun(z, o) @test gfun(Z, o) == (nothing, g[2]) end @@ -271,19 +308,19 @@ end gres = gfun(o, Z) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params g = gfun(z, o) gres = gfun(Z, o) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params end end @testset "Gradients for broadcasted / with sizes $s" for s in ((1,), (2,3)) o = ones(s) z = zeros(s) - Z = Zeros() # Only defined for 0-dim + Z = Zeros(s...) # Only defined for 0-dim @testset "Explicit" begin gfun(args...) = gradient((x, y) -> sum(x ./ y), args...) @@ -297,14 +334,14 @@ end g = gfun(z, o) gres = gfun(Z, o) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params end end - @testset "Gradients for $op with sizes $s" for op in (+,-), s in (tuple(), (1,), (2,3)) + @testset "Gradients for $op with sizes $s" for op in (+, -), s in (tuple(), (1,), (2,3)) o = ones(s) z = zeros(s) - Z = Zeros() + Z = Zeros(s...) @testset "Explicit" begin @@ -322,12 +359,12 @@ end g = gfun(o, z) gres = gfun(o, Z) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params g = gfun(z, o) gres = gfun(Z, o) @test gres[o] == g[o] - @test Z ∉ gres.params + # @test Z ∉ gres.params end end end @@ -368,7 +405,7 @@ end dl(4, 3, bias) ) - nobias(n) = Zeros() + nobias(n) = Zeros(n) testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt))) @test l1.weight == l2.weight @test l1.bias == l2.bias