diff --git a/src/conv.jl b/src/conv.jl index af26d5936..b7161c4e8 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -1,6 +1,4 @@ -export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, depthwiseconv, - depthwiseconv!, ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter, - ∇depthwiseconv_filter! +export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter! ## Convolution API # @@ -19,7 +17,7 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter! # # All methods require a `ConvDims` object to define the dimensions and optional # elements of the convolution (padding, stride, dilation, kernel-flipping, etc...), -# which is easily constructable through something like `DenseConvDims(x, w)`. All +# which is easily constructable through something like `ConvDims(x, w)`. All # methods take in the `ConvDims` of the associated normal, forward-pass convolution, # that is, the following is legal: # @@ -36,9 +34,6 @@ for (front_name, backend) in ( :conv => :im2col, :∇conv_data => :im2col, :∇conv_filter => :im2col, - :depthwiseconv => :im2col, - :∇depthwiseconv_data => :im2col, - :∇depthwiseconv_filter => :im2col, ) # These are the GEMM types we will accelerate with `im2col` @@ -58,8 +53,7 @@ end # Our strategy for 1d and 2d convolution is to reshape to 3d convolutions, which # makes things MUCH EASIER for us on the backend side, and is in general pretty fast, # since we can specialize on sizes. -for front_name in (:conv, :∇conv_data, :∇conv_filter, - :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter) +for front_name in (:conv, :∇conv_data, :∇conv_filter) for backend in (Symbol(), :_direct, :_im2col) for N in (3, 4) @eval begin @@ -87,8 +81,7 @@ end # We always support a fallback, non-accelerated path, where we use the direct, but # slow, implementations. These should not typically be used, hence the `@debug`, # but let's ggo ahead and define them first: -for front_name in (:conv, :∇conv_data, :∇conv_filter, - :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter) +for front_name in (:conv, :∇conv_data, :∇conv_filter) @eval begin function $(Symbol("$(front_name)!"))( y::AbstractArray{yT,N}, in1::AbstractArray{T1,N}, @@ -106,7 +99,7 @@ end # allocation. :P for backend in (Symbol(), :_direct, :_im2col) # First make auto-allocating versions of the conv()-like calls: - for name in (:conv, :depthwiseconv) + for name in (:conv,) @eval begin function $(Symbol("$(name)$(backend)"))( x::AbstractArray{xT,N}, w::AbstractArray{wT,N}, @@ -118,7 +111,7 @@ for backend in (Symbol(), :_direct, :_im2col) end end - for name in (:∇conv_data, :∇depthwiseconv_data) + for name in (:∇conv_data,) @eval begin function $(Symbol("$(name)$(backend)"))( dy::AbstractArray{yT,N}, w::AbstractArray{wT,N}, @@ -130,35 +123,24 @@ for backend in (Symbol(), :_direct, :_im2col) end end - # We do the conv/depthwiseconv filter backprops separately, as the shape calculation - # for `w` is slightly different for depthwise than for normal dense convolution. + # This filter back prop covers dense/depthwise/groupwise conv filter backprops, as groupcount alone + # is a deciding factor from cudnn's perspective. For backends im2col and direct needs to be handled. @eval begin function $(Symbol("∇conv_filter$(backend)"))( x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, cdims::ConvDims; kwargs...) where {xT, yT, N} - dw = similar(dy, kernel_size(cdims)..., channels_in(cdims), + dw = similar(dy, kernel_size(cdims)..., div(channels_in(cdims),group_count(cdims)), channels_out(cdims)) return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...) end end - - @eval begin - function $(Symbol("∇depthwiseconv_filter$(backend)"))( - x::AbstractArray{xT,N}, dy::AbstractArray{yT,N}, - cdims::ConvDims; kwargs...) where {xT, yT, N} - dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims), - channels_in(cdims)) - return $(Symbol("∇depthwiseconv_filter$(backend)!"))(dw, x, dy, cdims; - kwargs...) - end - end end # Use NNPACK if it is available and the operation is supported if is_nnpack_available() function conv(x::Array{xT, 4}, w::Array{wT, 4}, - cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F}; + cdims::ConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F}; kwargs...) where {xT, wT, K, C_in, C_out, S, P, F} return conv_nnpack(x, w, cdims; kwargs...) end @@ -168,14 +150,14 @@ function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flip stride = expand(Val(N-2), stride) pad = expand(Val(N-2), pad) dilation = expand(Val(N-2), dilation) - cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped) + cdims = ConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped) return conv(x, w, cdims) end -function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false) where {T, N} +function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groupcount) where {T, N} stride = expand(Val(N-2), stride) pad = expand(Val(N-2), pad) dilation = expand(Val(N-2), dilation) - cdims = DepthwiseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped) + cdims = ConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped, groupcount=groupcount) return depthwiseconv(x, w, cdims) end diff --git a/src/dim_helpers.jl b/src/dim_helpers.jl index 22d5636a7..7734ea3c9 100644 --- a/src/dim_helpers.jl +++ b/src/dim_helpers.jl @@ -1,7 +1,6 @@ # Various helper functions to calculate dimensions for operations +include("dim_helpers/AbstractDims.jl") include("dim_helpers/ConvDims.jl") -include("dim_helpers/DenseConvDims.jl") -include("dim_helpers/DepthwiseConvDims.jl") include("dim_helpers/PoolDims.jl") @@ -45,7 +44,7 @@ function transpose_pad(cdims::ConvDims) end """ - insert_singleton_spatial_dimension(cdims::DenseConvDims) + insert_singleton_spatial_dimension(cdims::ConvDims) When converting a 1d convolution to a 2d, or a 2d to a 3d, we need to insert a singleton spatial dimension at the end of the spatial dimensions. This does so for a ConvDims. diff --git a/src/dim_helpers/AbstractDims.jl b/src/dim_helpers/AbstractDims.jl new file mode 100644 index 000000000..b1532c32a --- /dev/null +++ b/src/dim_helpers/AbstractDims.jl @@ -0,0 +1,120 @@ +export AbstractDims + +""" + AbstractDims + +Type system-level information about convolution dimensions. Critical for things like +`im2col!()` to generate efficient code, and helpful to reduce the number of kwargs +getting passed around. + +We don't want to specialize on things like image size/channel count, so we generally +store those as fields, just for convenience, and to allow for non-breaking changes when +we decide we _do_ want to specialize on those values. We always want to specialize on +things like stride, padding, dilation, and kernel flipping though. +""" +abstract type AbstractDims{N, S, P, D, F} end + +# Hack to get rid of type parameters +function basetype(::Type{C}) where {C <: AbstractDims} + if C <: ConvDims + return ConvDims + elseif C <: PoolDims + return PoolDims + else + return nothing + end +end + +# Obvious getter definitions for the type system-level definitions +spatial_dims(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = N +stride(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = S +padding(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = P +dilation(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = D +flipkernel(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = F + +""" + im2col_dims(c::AbstractDims) + +im2col calculates, for each output pixel, the "convolution" of N kernels where N is the +number of output channels, by doing a matrix multiply. The dimensions of that matrix +are given by this function. +""" +im2col_dims(c::AbstractDims) = (prod(output_size(c)), prod(kernel_size(c))*channels_in(c)) + +# Protect your skin, kids. Also do common validation of stride, padding, etc... +function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N} + # Number of spatial dimensions in `x` and `w`. + nd = N - 2 + + # Given a number, duplicate it out to have `nd` length. If it's already a collection, + # just splat it out into a tuple so it's always a tuple. We'll lint length later. + expand_size(p::Number) = ntuple(_ -> Int(p), nd) + expand_size(p) = tuple(p...) + + # Convert stride, padding, dilation, etc.. to fully-specified tuples + pstride = expand_size(stride) + pdilation = expand_size(dilation) + ppadding = expand_size(padding) + + if length(pstride) != nd + throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!")) + end + if length(pdilation) != nd + throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!")) + end + + # padding is kind of a special case; we allow it to be either 2-length or 4-length, + # since we support asymmetrical padding + if length(ppadding) != 2*nd + if length(ppadding) == nd + # Do this repeat dance so that we get lo/hi symmetrical padding + ppadding = tuple(repeat(collect(ppadding), inner=2)...) + else + throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!")) + end + end + + # Assert that kernel size * dilation is <= padded input size + for idx in 1:nd + Is = x_size[idx] + Pl = ppadding[(idx - 1)*2 + 1] + Ph = ppadding[(idx - 1)*2 + 2] + Ks = w_size[idx] + Ds = pdilation[idx] + if Is + Pl + Ph < (Ks - 1)*Ds + 1 + throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!")) + end + end + + return pstride, ppadding, pdilation +end + +""" + output_size(c::AbstractDims) + +Calculate the output (spatial) dimensions of the convolution. Get channel count via +`channels_out(c)`, and batch count is unknowable. +""" +function output_size(c::AbstractDims) + I = input_size(c) + K = kernel_size(c) + S = stride(c) + P = padding(c) + D = dilation(c) + + return ntuple(spatial_dims(c)) do i + return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1 + end +end + +# Override show() for these beauties +function Base.show(io::IO, cdims::C) where {C <: AbstractDims} + I = (input_size(cdims)..., channels_in(cdims)) + O = (output_size(cdims)..., channels_out(cdims)) + K = kernel_size(cdims) + S = stride(cdims) + P = padding(cdims) + D = dilation(cdims) + F = flipkernel(cdims) + print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F") +end diff --git a/src/dim_helpers/ConvDims.jl b/src/dim_helpers/ConvDims.jl index 335cf4389..7202d0cf1 100644 --- a/src/dim_helpers/ConvDims.jl +++ b/src/dim_helpers/ConvDims.jl @@ -12,111 +12,77 @@ store those as fields, just for convenience, and to allow for non-breaking chang we decide we _do_ want to specialize on those values. We always want to specialize on things like stride, padding, dilation, and kernel flipping though. """ -abstract type ConvDims{N, S, P, D, F} end -# Hack to get rid of type parameters -function basetype(::Type{C}) where {C <: ConvDims} - if C <: DepthwiseConvDims - return DepthwiseConvDims - elseif C <: DenseConvDims - return DenseConvDims - elseif C <: PoolDims - return PoolDims - else - return nothing - end +struct ConvDims{N,K,C_in,C_out,S,P,D,F,G} <: AbstractDims{N,S,P,D,F} + I::NTuple{N,Int} end -# Obvious getter definitions for the type system-level definitions -spatial_dims(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = N -stride(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = S -padding(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = P -dilation(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = D -flipkernel(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = F - -""" - im2col_dims(c::ConvDims) - -im2col calculates, for each output pixel, the "convolution" of N kernels where N is the -number of output channels, by doing a matrix multiply. The dimensions of that matrix -are given by this function. -""" -im2col_dims(c::ConvDims) = (prod(output_size(c)), prod(kernel_size(c))*channels_in(c)) -# Protect your skin, kids. Also do common validation of stride, padding, etc... -function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N} - # Number of spatial dimensions in `x` and `w`. - nd = N - 2 - - # Given a number, duplicate it out to have `nd` length. If it's already a collection, - # just splat it out into a tuple so it's always a tuple. We'll lint length later. - expand_size(p::Number) = ntuple(_ -> Int(p), nd) - expand_size(p) = tuple(p...) - - # Convert stride, padding, dilation, etc.. to fully-specified tuples - pstride = expand_size(stride) - pdilation = expand_size(dilation) - ppadding = expand_size(padding) - - if length(pstride) != nd - throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!")) - end - if length(pdilation) != nd - throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!")) +# Getters for the fields +input_size(c::ConvDims) = c.I +kernel_size(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = K +channels_in(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_in +channels_out(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_out +group_count(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = G + +# Convenience wrapper to create ConvDims objects +function ConvDims(x_size::NTuple{M}, w_size::NTuple{M}; + stride=1, padding=0, dilation=1, flipkernel::Bool=false, groupcount=1) where M + # Do common parameter validation + stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation) + + # Ensure channels are equal + if x_size[M-1] != w_size[M-1]*groupcount + xs = x_size[M-1] + ws = w_size[M-1]*groupcount + throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) end - # padding is kind of a special case; we allow it to be either 2-length or 4-length, - # since we support asymmetrical padding - if length(ppadding) != 2*nd - if length(ppadding) == nd - # Do this repeat dance so that we get lo/hi symmetrical padding - ppadding = tuple(repeat(collect(ppadding), inner=2)...) - else - throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!")) - end - end + # The type parameters are what + return ConvDims{ + M - 2, + w_size[1:M-2], + x_size[M-1], + w_size[M], + stride, + padding, + dilation, + flipkernel, + groupcount + }( + # Input spatial size + x_size[1:M-2], + ) +end - # Assert that kernel size * dilation is <= padded input size - for idx in 1:nd - Is = x_size[idx] - Pl = ppadding[(idx - 1)*2 + 1] - Ph = ppadding[(idx - 1)*2 + 2] - Ks = w_size[idx] - Ds = pdilation[idx] - if Is + Pl + Ph < (Ks - 1)*Ds + 1 - throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!")) - end +# Auto-extract sizes and sub out to big brother above +function ConvDims(x::AbstractArray, w::AbstractArray; kwargs...) + if ndims(x) != ndims(w) + throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) end - - return pstride, ppadding, pdilation + return ConvDims(size(x), size(w); kwargs...) end -""" - output_size(c::ConvDims) +# Useful for constructing a new ConvDims that has only a few elements different +# from the original progenitor object that it inherits shapes from. +function ConvDims(c::AbstractDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), + C_in=channels_in(c), C_out=channels_out(c), S=stride(c), + P=padding(c), D=dilation(c), F=flipkernel(c), G=group_count(c)) + return ConvDims{N, K, C_in, C_out, S, P, D, F, G}(I) +end -Calculate the output (spatial) dimensions of the convolution. Get channel count via -`channels_out(c)`, and batch count is unknowable. -""" -function output_size(c::ConvDims) - I = input_size(c) - K = kernel_size(c) - S = stride(c) - P = padding(c) - D = dilation(c) +function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::ConvDims) where {M} + # First, check that channel counts are all correct: + @assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") + @assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))") + @assert w[M-1] == channels_in(cdims)/group_count(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)/group_count(cdims)))") + @assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))") - return ntuple(spatial_dims(c)) do i - return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1 - end -end + # Next, check that the spatial dimensions match up + @assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))") + @assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))") + @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))") -# Override show() for these beauties -function Base.show(io::IO, cdims::C) where {C <: ConvDims} - I = (input_size(cdims)..., channels_in(cdims)) - O = (output_size(cdims)..., channels_out(cdims)) - K = kernel_size(cdims) - S = stride(cdims) - P = padding(cdims) - D = dilation(cdims) - F = flipkernel(cdims) - print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F") + # Finally, check that the batch size matches + @assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))") end diff --git a/src/dim_helpers/DenseConvDims.jl b/src/dim_helpers/DenseConvDims.jl index 9509f5b42..8aee9d1d5 100644 --- a/src/dim_helpers/DenseConvDims.jl +++ b/src/dim_helpers/DenseConvDims.jl @@ -1,77 +1,79 @@ -export DenseConvDims - -""" - DenseConvDims - -Concrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d. -""" -struct DenseConvDims{N,K,C_in,C_out,S,P,D,F} <: ConvDims{N,S,P,D,F} - I::NTuple{N,Int} -end - -# Getters for the fields -input_size(c::DenseConvDims) = c.I -kernel_size(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = K -channels_in(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_in -channels_out(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_out - -# Convenience wrapper to create DenseConvDims objects -function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M}; - stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M - # Do common parameter validation - stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation) - - # Ensure channels are equal - if x_size[end-1] != w_size[end-1] - xs = x_size[end-1] - ws = w_size[end-1] - throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) - end - - # The type parameters are what - return DenseConvDims{ - M - 2, - w_size[1:end-2], - x_size[end-1], - w_size[end], - stride, - padding, - dilation, - flipkernel - }( - # Input spatial size - x_size[1:end-2], - ) -end - -# Auto-extract sizes and sub out to big brother above -function DenseConvDims(x::AbstractArray, w::AbstractArray; kwargs...) - if ndims(x) != ndims(w) - throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) - end - return DenseConvDims(size(x), size(w); kwargs...) -end - -# Useful for constructing a new DenseConvDims that has only a few elements different -# from the original progenitor object that it inherits shapes from. -function DenseConvDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), - C_in=channels_in(c), C_out=channels_out(c), S=stride(c), - P=padding(c), D=dilation(c), F=flipkernel(c)) - return DenseConvDims{N, K, C_in, C_out, S, P, D, F}(I) -end - -function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M} - # First, check that channel counts are all correct: - @assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") - @assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))") - @assert w[M-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))") - @assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))") - - # Next, check that the spatial dimensions match up - @assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))") - @assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))") - @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))") - - # Finally, check that the batch size matches - @assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))") -end +# export DenseConvDims +# +# """ +# DenseConvDims +# +# Concrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d. +# """ +# struct DenseConvDims{N,K,C_in,C_out,S,P,D,F,G} <: ConvDims{N,S,P,D,F} +# I::NTuple{N,Int} +# end +# +# # Getters for the fields +# input_size(c::DenseConvDims) = c.I +# kernel_size(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = K +# channels_in(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_in +# channels_out(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_out +# group_count(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = G +# +# # Convenience wrapper to create DenseConvDims objects +# function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M}; +# stride=1, padding=0, dilation=1, flipkernel::Bool=false, groupcount=1) where M +# # Do common parameter validation +# stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation) +# +# # Ensure channels are equal +# if x_size[end-1] != w_size[end-1]*groupcount +# xs = x_size[end-1] +# ws = w_size[end-1]*groupcount +# throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) +# end +# +# # The type parameters are what +# return DenseConvDims{ +# M - 2, +# w_size[1:end-2], +# x_size[end-1], +# w_size[end], +# stride, +# padding, +# dilation, +# flipkernel, +# groupcount +# }( +# # Input spatial size +# x_size[1:end-2], +# ) +# end +# +# # Auto-extract sizes and sub out to big brother above +# function DenseConvDims(x::AbstractArray, w::AbstractArray; kwargs...) +# if ndims(x) != ndims(w) +# throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) +# end +# return DenseConvDims(size(x), size(w); kwargs...) +# end +# +# # Useful for constructing a new DenseConvDims that has only a few elements different +# # from the original progenitor object that it inherits shapes from. +# function DenseConvDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), +# C_in=channels_in(c), C_out=channels_out(c), S=stride(c), +# P=padding(c), D=dilation(c), F=flipkernel(c), G=group_count(c)) +# return DenseConvDims{N, K, C_in, C_out, S, P, D, F, G}(I) +# end +# +# function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M} +# # First, check that channel counts are all correct: +# @assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") +# @assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))") +# @assert w[M-1] == channels_in(cdims)/group_count(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)/group_count(cdims)))") +# @assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))") +# +# # Next, check that the spatial dimensions match up +# @assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))") +# @assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))") +# @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))") +# +# # Finally, check that the batch size matches +# @assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))") +# end diff --git a/src/dim_helpers/DepthwiseConvDims.jl b/src/dim_helpers/DepthwiseConvDims.jl index 4c25eea6f..b778f5e73 100644 --- a/src/dim_helpers/DepthwiseConvDims.jl +++ b/src/dim_helpers/DepthwiseConvDims.jl @@ -1,84 +1,90 @@ -export DepthwiseConvDims - -""" - DepthwiseConvDims - -Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to -characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from -DenseConvDims primarily for channel calculation differences. -""" -struct DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F} <: ConvDims{N,S,P,D,F} - I::NTuple{N, Int} -end - -# Getters for the fields -input_size(c::DepthwiseConvDims) = c.I -kernel_size(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = K -channels_in(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in -channels_out(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in * C_mult -channel_multiplier(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_mult - - -# Convenience wrapper to create DepthwiseConvDims objects -function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M}; - stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M - # Do common parameter validation - stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation) - - # Ensure channels are equal - if x_size[end-1] != w_size[end] - xs = x_size[end-1] - ws = w_size[end] - throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) - end - - return DepthwiseConvDims{ - M - 2, - # Kernel spatial size - w_size[1:end-2], - # Input channels - x_size[end-1], - # Channel multiplier - w_size[end-1], - stride, - padding, - dilation, - flipkernel - }( - # Image spatial size - x_size[1:end-2], - ) -end - -# Auto-extract sizes and just pass those directly in -function DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...) - if ndims(x) != ndims(w) - throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) - end - return DepthwiseConvDims(size(x), size(w); kwargs...) -end - -# Useful for constructing a new DepthwiseConvDims that has only a few elements different -# from the original progenitor object. -function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), - C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c), - P=padding(c), D=dilation(c), F=flipkernel(c)) - return DepthwiseConvDims{N, K, C_in, C_m, S, P, D, F}(I) -end - -# This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count -function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M} - # First, check that channel counts are all correct: - @assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") - @assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))") - @assert w[M-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[M-1]) vs. $(channel_multiplier(cdims))") - @assert w[M] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M]) vs. $(channels_in(cdims)))") - - # Next, check that the spatial dimensions match up - @assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))") - @assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))") - @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))") - - # Finally, check that the batch size matches - @assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))") -end \ No newline at end of file +# export DepthwiseConvDims +# +# """ +# DepthwiseConvDims +# +# Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily due to +# characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from +# DenseConvDims primarily for channel calculation differences. +# """ +# struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F} +# I::NTuple{N, Int} +# K::NTuple{N, Int} +# C_in::Int +# C_mult::Int +# end +# +# # Getters for the fields +# input_size(c::DepthwiseConvDims) = c.I +# kernel_size(c::DepthwiseConvDims) = c.K +# channels_in(c::DepthwiseConvDims) = c.C_in +# channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c) +# channel_multiplier(c::DepthwiseConvDims) = c.C_mult +# +# +# # Convenience wrapper to create DepthwiseConvDims objects +# function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M}; +# stride=1, padding=0, dilation=1, flipkernel::Bool=false) where M +# # Do common parameter validation +# stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation) +# +# # Ensure channels are equal +# if x_size[end-1] != w_size[end] +# xs = x_size[end-1] +# ws = w_size[end] +# throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)")) +# end +# +# return DepthwiseConvDims{ +# M - 2, +# stride, +# padding, +# dilation, +# flipkernel +# }( +# # Image spatial size +# x_size[1:end-2], +# +# # Kernel spatial size +# w_size[1:end-2], +# +# # Input channels +# x_size[end-1], +# +# # Channel multiplier +# w_size[end-1], +# ) +# end +# +# # Auto-extract sizes and just pass those directly in +# function DepthwiseConvDims(x::AbstractArray, w::AbstractArray; kwargs...) +# if ndims(x) != ndims(w) +# throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))")) +# end +# return DepthwiseConvDims(size(x), size(w); kwargs...) +# end +# +# # Useful for constructing a new DepthwiseConvDims that has only a few elements different +# # from the original progenitor object. +# function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), +# C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c), +# P=padding(c), D=dilation(c), F=flipkernel(c)) +# return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m) +# end +# +# # This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count +# function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M} +# # First, check that channel counts are all correct: +# @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))") +# @assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))") +# @assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))") +# @assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))") +# +# # Next, check that the spatial dimensions match up +# @assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))") +# @assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))") +# @assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))") +# +# # Finally, check that the batch size matches +# @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))") +# end diff --git a/src/dim_helpers/PoolDims.jl b/src/dim_helpers/PoolDims.jl index 97144968e..220cdfbf4 100644 --- a/src/dim_helpers/PoolDims.jl +++ b/src/dim_helpers/PoolDims.jl @@ -7,7 +7,7 @@ Dimensions for a "pooling" operation that can have an arbitrary input size, kern stride, dilation, and channel count. Used to dispatch onto efficient implementations at compile-time. """ -struct PoolDims{N,K,S,P,D} <: ConvDims{N, S, P, D, false} +struct PoolDims{N,K,S,P,D} <: AbstractDims{N, S, P, D, false} I::NTuple{N,Int} C_in::Int end @@ -19,7 +19,7 @@ channels_in(c::PoolDims) = c.C_in channels_out(c::PoolDims) = c.C_in -# Convenience wrapper to create DenseConvDims objects +# Convenience wrapper to create ConvDims objects function PoolDims(x_size::NTuple{M}, k::Union{NTuple{L, Int}, Int}; stride=k, padding=0, dilation=1) where {M, L} # Expand `k` up to a tuple @@ -53,7 +53,7 @@ end # Useful for constructing a new PoolDims that has only a few elements different # from the original progenitor object that it inherits shapes from. -function PoolDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), +function PoolDims(c::AbstractDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), C_in=channels_in(c), S=stride(c), P=padding(c), D=dilation(c)) return PoolDims{N, K, S, P, D}(I, C_in) end @@ -62,7 +62,7 @@ function check_dims(x::NTuple{M}, y::NTuple{M}, pdims::PoolDims) where {M} # First, check that channel counts are all correct: @assert x[end-1] == channels_in(pdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(pdims)))") @assert y[end-1] == channels_out(pdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(pdims)))") - + # Next, check that the spatial dimensions match up @assert x[1:end-2] == input_size(pdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(pdims)))") @assert y[1:end-2] == output_size(pdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(pdims)))") diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 5f5b7c4c3..f771ebb4d 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -46,7 +46,7 @@ wrapper methods are available. conv_direct! function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, - w::AbstractArray{wT,5}, cdims::DenseConvDims; + w::AbstractArray{wT,5}, cdims::ConvDims; alpha::yT = yT(1), beta = false) where {yT, xT, wT} check_dims(size(x), size(w), size(y), cdims) @@ -57,14 +57,14 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, dil_w, dil_h, dil_d = dilation(cdims) stride_w, stride_h, stride_d = stride(cdims) out_width, out_height, out_depth = output_size(cdims) - + # Create a method that, at compile-time, determines how we're going to index into `w` kproj(k, M, cdims::ConvDims{N,S,P,D,true}) where {N, S, P, D} = k kproj(k, M, cdims::ConvDims{N,S,P,D,false}) where {N, S, P, D} = M - k + 1 - + # A helper function to project from output (w, h) to input (input_w, input_h) project(idx, stride, pad) = (idx - 1)*stride - pad + 1 - + # Use `calc_padding_regions` to determine where we do or don't need to worry about padding padded_regions, central_region = calc_padding_regions(cdims) @@ -153,11 +153,11 @@ Calculate the gradient imposed upon `x` in the convolution `y = x * w`. ∇conv_data_direct! function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, - w::AbstractArray{wT,5}, cdims::DenseConvDims; + w::AbstractArray{wT,5}, cdims::ConvDims; alpha::xT=xT(1), beta=false) where {xT, yT, wT} w = transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :]) dy = predilate(dy, stride(cdims)) - ctdims = DenseConvDims(dy, w; padding=transpose_pad(cdims), + ctdims = ConvDims(dy, w; padding=transpose_pad(cdims), dilation=dilation(cdims), flipkernel=flipkernel(cdims)) dx = conv_direct!(dx, dy, w, ctdims; alpha=alpha, beta=beta) @@ -172,11 +172,11 @@ Calculate the gradient imposed upon `w` in the convolution `y = x * w`. ∇conv_filter_direct! function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, - dy::AbstractArray{yT,5}, cdims::DenseConvDims; + dy::AbstractArray{yT,5}, cdims::ConvDims; alpha::wT=wT(1), beta=false) where {xT, yT, wT} x = transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :]) dy = transpose_swapbatch(predilate(dy, stride(cdims))) - ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims), + ctdims = ConvDims(dy, x; padding=transpose_pad(cdims), stride=dilation(cdims)) conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta) if flipkernel(cdims) diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl index 17e6aaa61..ef9cb0731 100644 --- a/src/impl/conv_im2col.jl +++ b/src/impl/conv_im2col.jl @@ -24,7 +24,7 @@ which should eliminate any need for large allocations within this method. """ function conv_im2col!( y::AbstractArray{T,5}, x::AbstractArray{T,5}, - w::AbstractArray{T,5}, cdims::DenseConvDims; + w::AbstractArray{T,5}, cdims::ConvDims; col::AbstractArray{T,2}=similar(x, im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(x), size(w), size(y), cdims) @@ -45,7 +45,7 @@ function conv_im2col!( M = prod(output_size(cdims)) N = channels_out(cdims) K = prod(kernel_size(cdims))*channels_in(cdims) - + @threads for batch_idx in 1:size(x,5) # We invoke `@timeit_debug` on the outside of `im2col!()` because inference # doesn't like us putting it on the inside. @@ -68,7 +68,7 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter """ function ∇conv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, - dy::AbstractArray{T,5}, cdims::DenseConvDims; + dy::AbstractArray{T,5}, cdims::ConvDims; col::AbstractArray{T,2} = similar(dw, im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(x), size(dw), size(dy), cdims) @@ -88,12 +88,12 @@ function ∇conv_filter_im2col!( # input pixel that touched a particular element of the kernel. # # This is identical to a convolution between x and a dimension-permuted dY, - # where we - + # where we + M = prod(kernel_size(cdims))*channels_in(cdims) N = channels_out(cdims) K = prod(output_size(cdims)) - + @threads for batch_idx in 1:size(x,5) im2col!(col, view(x, :, :, :, :, batch_idx), cdims) GC.@preserve col, dw, dy, begin @@ -118,7 +118,7 @@ See the documentation for `conv_im2col!()` for explanation of other parameters. """ function ∇conv_data_im2col!( dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, - w::AbstractArray{T,5}, cdims::DenseConvDims; + w::AbstractArray{T,5}, cdims::ConvDims; col::AbstractArray{T,2} = similar(dx, im2col_dims(cdims)), alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(dx), size(w), size(dy), cdims) @@ -188,7 +188,7 @@ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, out_width, out_height, out_depth, - + # By input patch size kernel_w, kernel_h, @@ -201,7 +201,7 @@ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, # A helper function to project from output (w, h) to input (input_w, input_h) @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 - + # We begin by copying the central region of the image which requires no padding at all. # Eliminating the branches of the fully generalized version below gives us a nice # speedup on the majority of the data. @@ -215,7 +215,7 @@ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, d in d_region, h in h_region, w in w_region - + input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w @@ -225,8 +225,8 @@ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, col_reshaped[w, h, d, kidxs..., c] = xval end end - - + + # For each "padded region", we run the fully general version @inbounds for (w_region, h_region, d_region) in padded_regions for c in 1:C_in, @@ -272,7 +272,7 @@ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, xval::T = x[input_kw, input_kh, input_kd, c] col_reshaped[w, h, d, kidxs..., c] = xval end - end + end end @@ -302,7 +302,7 @@ function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, dil_w, dil_h, dil_d = dilation(cdims) stride_w, stride_h, stride_d = stride(cdims) out_width, out_height, out_depth = output_size(cdims) - + # TODO: Rewrite this method so we don't have this fill!() at the beginning! # Calculate each output pixel once rather than accumulating into it? fill!(x, T(0)) @@ -313,7 +313,7 @@ function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, out_width, out_height, out_depth, - + # By input patch size kernel_w, kernel_h, @@ -331,7 +331,7 @@ function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, for d in 1:out_depth input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d - + # If this d is off the edge, then deal with the entire plane # in one fell swoop, like a ravenous flock of crows. CAW CAW. if input_kd <= 0 || input_kd > depth @@ -340,21 +340,21 @@ function col2im!(x::AbstractArray{T,4}, col::AbstractArray{T,2}, for h in 1:out_height input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h - + # Same for `h`, but in this case it's only a line, not a plane. # This results in slightly less caw'ing. if input_kh <= 0 || input_kh > height continue end - + for w in 1:out_width input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w - + # If this `w` is off the edge, only it gets cleared out. if input_kw <= 0 || input_kw > width continue end - + # Copy the data over kidxs = kernel_index(kw, kh, kd, cdims) cval::T = col_reshaped[w, h, d, kidxs..., c] diff --git a/src/impl/depthwiseconv_direct.jl b/src/impl/depthwiseconv_direct.jl index 7e2e02bd5..67b92b88e 100644 --- a/src/impl/depthwiseconv_direct.jl +++ b/src/impl/depthwiseconv_direct.jl @@ -1,195 +1,158 @@ -## This file contains direct Julia implementations of depwthwise convolutions - -""" - depthwiseconv_direct!(y, x, w, cdims; alpha=1, beta=0) - -Direct depthwise convolution implementation; used for debugging, tests, and mixing/ -matching of strange datatypes within a single convolution. Uses naive nested for loop -implementation and does not attempt to optimize performance. Rather, this implementation -is intended to be maximally understandable and debuggable, to aid in testing other, more -performant implementations. We also explicitly support mixing and matching of strange -datatypes, so that if the user really wants to convolve an image of `UInt8`'s with a -`Float16` kernel, storing the result in a `Float32` output, there is at least a function -call for that madness. - -One subtlety about depthwise convolutions; the shape of a depthwise convolutional kernel -is `(spatial_dims..., C_mult, C_in)`, so the axis that must match with the number of -channels in `x` is the last, not the second-to-last, as in a normal dense convolution. - -See the docstring for `conv_direct!()` for more on the optional parameters. -""" -function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, - w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; - alpha::yT = yT(1), beta = false) where {yT, xT, wT} - check_dims(size(x), size(w), size(y), cdims) - - width, height, depth = input_size(cdims) - kernel_w, kernel_h, kernel_d = kernel_size(cdims) - out_c = channels_out(cdims) - pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) - dil_w, dil_h, dil_d = dilation(cdims) - stride_w, stride_h, stride_d = stride(cdims) - out_width, out_height, out_depth = output_size(cdims) - - # Create a method that, at compile-time, determines how we're going to index into `w` - kproj(k, M, cdims::DepthwiseConvDims{N,K,C_mult,C_in,S,P,D,true}) where {N, K, C_mult, C_in, S, P, D} = k - kproj(k, M, cdims::DepthwiseConvDims{N,K,C_mult,C_in,S,P,D,false}) where {N, K, C_mult, C_in, S, P, D} = M - k + 1 - - # A helper function to project from output (w, h) to input (input_w, input_h) - project(idx, stride, pad) = (idx - 1)*stride - pad + 1 - - # Use `calc_padding_regions` to determine where we do or don't need to worry about padding - padded_regions, central_region = calc_padding_regions(cdims) - - # Start with the central region - w_region, h_region, d_region = central_region - @inbounds for batch in 1:size(x)[end], - c_mult in 1:channel_multiplier(cdims), - c_in in 1:channels_in(cdims), - d_idx in d_region, - h_idx in h_region, - w_idx in w_region - - # Since we're in the central region, we don't need to worry about clamping - dotprod = yT(0) - c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult - for kd in 1:kernel_d, - kh in 1:kernel_h, - kw in 1:kernel_w - - # Hoist me, you coward. - x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d - x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h - x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w - - x_val = x[x_w, x_h, x_d, c_in, batch] - w_val = w[kproj(kw, kernel_w, cdims), - kproj(kh, kernel_h, cdims), - kproj(kd, kernel_d, cdims), - c_mult, c_in] - dotprod = muladd(x_val, w_val, dotprod) - end - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] - end - - # Next, do potentially-padded regions: - @inbounds for (w_region, h_region, d_region) in padded_regions, - batch in 1:size(x)[end], - c_mult in 1:channel_multiplier(cdims), - c_in in 1:channels_in(cdims), - d_idx in d_region, - h_idx in h_region, - w_idx in w_region - - # Probe for out-of-bounds accesses on `x` and `continue` if we hit one - dotprod = yT(0) - c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult - for c_in in 1:channels_in(cdims), - kd in 1:kernel_d - - x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d - if x_d <= 0 || x_d > depth - continue - end - - for kh in 1:kernel_h - x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h - if x_h <= 0 || x_h > height - continue - end - - for kw in 1:kernel_w - x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w - if x_w <= 0 || x_w > width - continue - end - - x_val = x[x_w, x_h, x_d, c_in, batch] - w_val = w[kproj(kw, kernel_w, cdims), - kproj(kh, kernel_h, cdims), - kproj(kd, kernel_d, cdims), - c_mult, c_in] - dotprod = muladd(x_val, w_val, dotprod) - end - end - end - - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] - end - - return y -end - -""" - ∇depthwiseconv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) - -Calculate the gradient imposed upon `x` in the depthwise convolution `y = x * w`. -We make use of the fact that a depthwise convolution is equivalent to `C_in` separate -normal convolutions between that channel of `x` and the `C_mult` different kernels that -get applied to it. The output of such a convolution is the gradient imposed upon that -particular channel of `x`, and so we simply walk through `x`, calculating the gradient -for each batch and channel independently. -""" -∇depthwiseconv_data_direct! - -function ∇depthwiseconv_data_direct!( - dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, - w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; - alpha::xT=xT(1), beta::xT=xT(0)) where {xT, yT, wT} - # We do a separate convolution for each channel in x - @inbounds for cidx in 1:channels_in(cdims) - # For this batch and in-channel, we have a normal transposed convolution - # between this slice of `x` and the corresponding slices of `w` and `dy`: - dx_slice = view(dx, :, :, :, cidx:cidx, :) - C_mult = channel_multiplier(cdims) - dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :) - w_slice = permutedims(view(w, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4)) - - # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out - # channels appropriately for this one convolution. - cdims_slice = DenseConvDims(cdims; - C_in=1, - C_out=channel_multiplier(cdims), - ) - - ∇conv_data_direct!(dx_slice, dy_slice, w_slice, cdims_slice; - alpha=alpha, beta=beta) - end - return dx -end - -""" - ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) - -Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`. -""" -∇depthwiseconv_filter_direct! - -function ∇depthwiseconv_filter_direct!( - dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, - dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims; - alpha::wT=wT(1),beta::wT=wT(0)) where {xT, yT, wT} - # We do a separate convolution for each channel in x - @inbounds for cidx in 1:channels_in(cdims) - # For this batch and in-channel, we have a normal transposed convolution - # between this slice of `x` and the corresponding slices of `w` and `dy`: - x_slice = view(x, :, :, :, cidx:cidx, :) - C_mult = channel_multiplier(cdims) - dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :) - dw_slice = permutedims(view(dw, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4)) - - # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out - # channels appropriately for this one convolution. - cdims_slice = DenseConvDims(cdims; - C_in=1, - C_out=channel_multiplier(cdims), - ) - - ∇conv_filter_direct!(dw_slice, x_slice, dy_slice, cdims_slice; - alpha=alpha, beta=beta) - dw[:, :, :, :, cidx:cidx] .= permutedims(dw_slice, (1, 2, 3, 5, 4)) - end - return dw -end - - +# ## This file contains direct Julia implementations of depwthwise convolutions +# +# """ +# depthwiseconv_direct!(y, x, w, cdims; alpha=1, beta=0) +# +# Direct depthwise convolution implementation; used for debugging, tests, and mixing/ +# matching of strange datatypes within a single convolution. Uses naive nested for loop +# implementation and does not attempt to optimize performance. Rather, this implementation +# is intended to be maximally understandable and debuggable, to aid in testing other, more +# performant implementations. We also explicitly support mixing and matching of strange +# datatypes, so that if the user really wants to convolve an image of `UInt8`'s with a +# `Float16` kernel, storing the result in a `Float32` output, there is at least a function +# call for that madness. +# +# One subtlety about depthwise convolutions; the shape of a depthwise convolutional kernel +# is `(spatial_dims..., C_mult, C_in)`, so the axis that must match with the number of +# channels in `x` is the last, not the second-to-last, as in a normal dense convolution. +# +# See the docstring for `conv_direct!()` for more on the optional parameters. +# """ +# function depthwiseconv_direct!( +# y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, +# w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; +# alpha::yT = yT(1), beta::yT = yT(0)) where {yT, xT, wT} +# check_dims(size(x), size(w), size(y), cdims) +# +# width, height, depth = input_size(cdims) +# kernel_w, kernel_h, kernel_d = kernel_size(cdims) +# out_c = channels_out(cdims) +# pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims) +# dil_w, dil_h, dil_d = dilation(cdims) +# stride_w, stride_h, stride_d = stride(cdims) +# out_width, out_height, out_depth = output_size(cdims) +# +# # If we're doing crosscorr instead of conv, then don't bother to flip `w` +# if !flipkernel(cdims) +# w = w[end:-1:1, end:-1:1, end:-1:1, :, :] +# end +# +# # A helper function to project from output (w, h) to input (input_w, input_h) +# @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 +# +# # explicit formulation of convolution. Oh hoisting gods, hear my plea. +# @inbounds for batch in 1:size(x)[end], +# c_mult in 1:channel_multiplier(cdims), +# c_in in 1:channels_in(cdims), +# h_idx in 1:out_height, +# w_idx in 1:out_width, +# d_idx in 1:out_depth +# +# # Starting points of the window of x we're going to grab +# x_w = project(w_idx, stride_w, pad_w_lo) +# x_h = project(h_idx, stride_h, pad_h_lo) +# x_d = project(d_idx, stride_d, pad_d_lo) +# +# # Grow that starting point into ranges +# x_widxs = x_w .+ (0:dil_w:(dil_w*kernel_w-1)) +# x_hidxs = x_h .+ (0:dil_h:(dil_h*kernel_h-1)) +# x_didxs = x_d .+ (0:dil_d:(dil_d*kernel_d-1)) +# w_widxs = 1:kernel_w +# w_hidxs = 1:kernel_h +# w_didxs = 1:kernel_d +# +# # Clamp the ranges to simulate padding +# x_widxs, w_widxs = clamp_lo(x_widxs, w_widxs) +# x_widxs, w_widxs = clamp_hi(x_widxs, w_widxs, width) +# x_hidxs, w_hidxs = clamp_lo(x_hidxs, w_hidxs) +# x_hidxs, w_hidxs = clamp_hi(x_hidxs, w_hidxs, height) +# x_didxs, w_didxs = clamp_lo(x_didxs, w_didxs) +# x_didxs, w_didxs = clamp_hi(x_didxs, w_didxs, depth) +# +# # Grab our slices (for a single channel pairing, as this is depthwise) +# c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult +# x_slice = view(x, x_widxs, x_hidxs, x_didxs, c_in, batch) +# w_slice = view(w, w_widxs, w_hidxs, w_didxs, c_mult, c_in) +# +# # Do the dotproduct dance, then weight by alpha/beta and git 'er done +# dotprod = sum(x_slice .* w_slice) +# prev_yval::yT = beta*y[w_idx, h_idx, d_idx, c_out, batch] +# y[w_idx, h_idx, d_idx, c_out, batch] = alpha*convert(yT, dotprod) + prev_yval +# end +# +# return y +# end +# +# """ +# ∇depthwiseconv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0) +# +# Calculate the gradient imposed upon `x` in the depthwise convolution `y = x * w`. +# We make use of the fact that a depthwise convolution is equivalent to `C_in` separate +# normal convolutions between that channel of `x` and the `C_mult` different kernels that +# get applied to it. The output of such a convolution is the gradient imposed upon that +# particular channel of `x`, and so we simply walk through `x`, calculating the gradient +# for each batch and channel independently. +# """ +# ∇depthwiseconv_data_direct! +# +# function ∇depthwiseconv_data_direct!( +# dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5}, +# w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; +# alpha::xT=xT(1), beta::xT=xT(0)) where {xT, yT, wT} +# # We do a separate convolution for each channel in x +# @inbounds for cidx in 1:channels_in(cdims) +# # For this batch and in-channel, we have a normal transposed convolution +# # between this slice of `x` and the corresponding slices of `w` and `dy`: +# dx_slice = view(dx, :, :, :, cidx:cidx, :) +# C_mult = channel_multiplier(cdims) +# dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :) +# w_slice = permutedims(view(w, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4)) +# +# # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out +# # channels appropriately for this one convolution. +# cdims_slice = DenseConvDims(cdims; +# C_in=1, +# C_out=channel_multiplier(cdims), +# ) +# +# ∇conv_data_direct!(dx_slice, dy_slice, w_slice, cdims_slice; +# alpha=alpha, beta=beta) +# end +# return dx +# end +# +# """ +# ∇depthwiseconv_filter_direct!(dw, x, dy, cdims; alpha=1, beta=0) +# +# Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`. +# """ +# ∇depthwiseconv_filter_direct! +# +# function ∇depthwiseconv_filter_direct!( +# dw::AbstractArray{wT,5}, x::AbstractArray{xT,5}, +# dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims; +# alpha::wT=wT(1),beta::wT=wT(0)) where {xT, yT, wT} +# # We do a separate convolution for each channel in x +# @inbounds for cidx in 1:channels_in(cdims) +# # For this batch and in-channel, we have a normal transposed convolution +# # between this slice of `x` and the corresponding slices of `w` and `dy`: +# x_slice = view(x, :, :, :, cidx:cidx, :) +# C_mult = channel_multiplier(cdims) +# dy_slice = view(dy, :, :, :, ((cidx-1)*C_mult + 1):cidx*C_mult, :) +# dw_slice = permutedims(view(dw, :, :, :, :, cidx:cidx), (1, 2, 3, 5, 4)) +# +# # Adapt a DenseConvDims out of this DepthwiseConvDims, setting the in/out +# # channels appropriately for this one convolution. +# cdims_slice = DenseConvDims(cdims; +# C_in=1, +# C_out=channel_multiplier(cdims), +# ) +# +# ∇conv_filter_direct!(dw_slice, x_slice, dy_slice, cdims_slice; +# alpha=alpha, beta=beta) +# dw[:, :, :, :, cidx:cidx] .= permutedims(dw_slice, (1, 2, 3, 5, 4)) +# end +# return dw +# end +# +# diff --git a/src/impl/depthwiseconv_im2col.jl b/src/impl/depthwiseconv_im2col.jl index 387efdb29..684980ac3 100644 --- a/src/impl/depthwiseconv_im2col.jl +++ b/src/impl/depthwiseconv_im2col.jl @@ -1,120 +1,120 @@ -## This file contains adapter code for doing depthwise convolutions with im2col. - - -""" - depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) - -Perform a depthwise convolution using im2col and GEMM, store the result in `y`. - -See `conv_im2col!()` for an explanation of optional parameters. -""" -depthwiseconv_im2col! - -function depthwiseconv_im2col!( - y::AbstractArray{T,5}, x::AbstractArray{T,5}, - w::AbstractArray{T,5}, cdims::DepthwiseConvDims; - col::AbstractArray{T,2} = similar(x, im2col_dims(cdims)), - alpha=T(1), beta=T(0)) where T - check_dims(size(x), size(w), size(y), cdims) - - # This functions exactly the same as conv_im2col!(), except that we shard the - # incoming data into slices of single channels. This means that we need to walk - # each pointer forward individually, as done below, taking a single input channel - # and combining it with each kernel individually, before walking forward and doing - # the next input channel. - M = prod(output_size(cdims)) - N = channel_multiplier(cdims) - K = prod(kernel_size(cdims)) - - dcdims = DenseConvDims(cdims) - @inbounds for batch_idx in 1:size(x)[end] - im2col!(col, view(x, :, :, :, :, batch_idx), dcdims) - - # We do a separate convolution for each channel in x, as we must - for c_in in 1:channels_in(cdims) - # Walk each pointer forward as we process each input channel - GC.@preserve col, w, y, begin - col_ptr = pointer(col, (c_in-1)*M*K+1) - w_ptr = pointer(w, (c_in-1)*K*N+1) - y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1) - gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) - end - end - end - return y -end - -""" - ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta) - -Depthwise conv2d backward pass onto the weights using im2col and GEMM. -See the documentation for `conv_im2col!()` for explanation of optional parameters. -""" -∇depthwiseconv_filter_im2col! - -function ∇depthwiseconv_filter_im2col!( - dw::AbstractArray{T,5}, x::AbstractArray{T,5}, - dy::AbstractArray{T,5}, cdims::DepthwiseConvDims; - col::AbstractArray{T,2} = similar(dw, im2col_dims(cdims)), - alpha=T(1), beta=T(0)) where T - check_dims(size(x), size(dw), size(dy), cdims) - - M = prod(kernel_size(cdims)) - N = channel_multiplier(cdims) - K = prod(output_size(cdims)) - - @inbounds for batch_idx in 1:size(x)[end] - im2col!(col, view(x, :, :, :, :, batch_idx), cdims) - - # We do a separate convolution for each channel in x, as we must - for c_in in 1:channels_in(cdims) - # Walk each pointer forward as we process each input channel - GC.@preserve col, dw, dy, begin - col_ptr = pointer(col, (c_in - 1)*M*K + 1) - dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1) - dw_ptr = pointer(dw, (c_in - 1)*M*N + 1) - gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) - end - end - - # Because we accumulate over batches in this loop, we must set `beta` equal - # to `1.0` from this point on. - beta = T(1) - end - return dw -end - -""" - depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) - -Depwthwise conv2d backward pass onto the input using im2col and GEMM. -See the documentation for `conv_im2col!()` for explanation of optional parameters. -""" -∇depthwiseconv_data_im2col! - -function ∇depthwiseconv_data_im2col!( - dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, - w::AbstractArray{T,5}, cdims::DepthwiseConvDims; - col::AbstractArray{T,2} = similar(dx, im2col_dims(cdims)), - alpha=T(1), beta=T(0)) where T - check_dims(size(dx), size(w), size(dy), cdims) - - M = prod(output_size(cdims)) - N = prod(kernel_size(cdims)) - K = channel_multiplier(cdims) - - @inbounds for batch_idx in 1:size(dx)[end] - # We do a separate convolution for each channel in x, as we must - for cidx in 1:channels_in(cdims) - GC.@preserve col, w, dy, begin - # Walk each pointer forward as we process each input channel - dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1) - w_ptr = pointer(w, (cidx - 1)*K*N + 1) - col_ptr = pointer(col, (cidx - 1)*M*N + 1) - gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) - end - end - col2im!(view(dx, :, :, :, :, batch_idx), col, cdims) - end - return dx -end +# ## This file contains adapter code for doing depthwise convolutions with im2col. +# +# +# """ +# depthwiseconv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) +# +# Perform a depthwise convolution using im2col and GEMM, store the result in `y`. +# +# See `conv_im2col!()` for an explanation of optional parameters. +# """ +# depthwiseconv_im2col! +# +# function depthwiseconv_im2col!( +# y::AbstractArray{T,5}, x::AbstractArray{T,5}, +# w::AbstractArray{T,5}, cdims::DepthwiseConvDims; +# col::AbstractArray{T,2} = similar(x, im2col_dims(cdims)), +# alpha=T(1), beta=T(0)) where T +# check_dims(size(x), size(w), size(y), cdims) +# +# # This functions exactly the same as conv_im2col!(), except that we shard the +# # incoming data into slices of single channels. This means that we need to walk +# # each pointer forward individually, as done below, taking a single input channel +# # and combining it with each kernel individually, before walking forward and doing +# # the next input channel. +# M = prod(output_size(cdims)) +# N = channel_multiplier(cdims) +# K = prod(kernel_size(cdims)) +# +# dcdims = DenseConvDims(cdims) +# @inbounds for batch_idx in 1:size(x)[end] +# im2col!(col, view(x, :, :, :, :, batch_idx), dcdims) +# +# # We do a separate convolution for each channel in x, as we must +# for c_in in 1:channels_in(cdims) +# # Walk each pointer forward as we process each input channel +# GC.@preserve col, w, y, begin +# col_ptr = pointer(col, (c_in-1)*M*K+1) +# w_ptr = pointer(w, (c_in-1)*K*N+1) +# y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1) +# gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) +# end +# end +# end +# return y +# end +# +# """ +# ∇depthwiseconv_filter_im2col!(dw, w, dy, cdims, col=similar(dw); alpha=1, beta) +# +# Depthwise conv2d backward pass onto the weights using im2col and GEMM. +# See the documentation for `conv_im2col!()` for explanation of optional parameters. +# """ +# ∇depthwiseconv_filter_im2col! +# +# function ∇depthwiseconv_filter_im2col!( +# dw::AbstractArray{T,5}, x::AbstractArray{T,5}, +# dy::AbstractArray{T,5}, cdims::DepthwiseConvDims; +# col::AbstractArray{T,2} = similar(dw, im2col_dims(cdims)), +# alpha=T(1), beta=T(0)) where T +# check_dims(size(x), size(dw), size(dy), cdims) +# +# M = prod(kernel_size(cdims)) +# N = channel_multiplier(cdims) +# K = prod(output_size(cdims)) +# +# @inbounds for batch_idx in 1:size(x)[end] +# im2col!(col, view(x, :, :, :, :, batch_idx), cdims) +# +# # We do a separate convolution for each channel in x, as we must +# for c_in in 1:channels_in(cdims) +# # Walk each pointer forward as we process each input channel +# GC.@preserve col, dw, dy, begin +# col_ptr = pointer(col, (c_in - 1)*M*K + 1) +# dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1) +# dw_ptr = pointer(dw, (c_in - 1)*M*N + 1) +# gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr) +# end +# end +# +# # Because we accumulate over batches in this loop, we must set `beta` equal +# # to `1.0` from this point on. +# beta = T(1) +# end +# return dw +# end +# +# """ +# depthwiseconv2d_Δx_im2col!(dx, w, dy, cdims, col=similar(dx); alpha=1, beta=0) +# +# Depwthwise conv2d backward pass onto the input using im2col and GEMM. +# See the documentation for `conv_im2col!()` for explanation of optional parameters. +# """ +# ∇depthwiseconv_data_im2col! +# +# function ∇depthwiseconv_data_im2col!( +# dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, +# w::AbstractArray{T,5}, cdims::DepthwiseConvDims; +# col::AbstractArray{T,2} = similar(dx, im2col_dims(cdims)), +# alpha=T(1), beta=T(0)) where T +# check_dims(size(dx), size(w), size(dy), cdims) +# +# M = prod(output_size(cdims)) +# N = prod(kernel_size(cdims)) +# K = channel_multiplier(cdims) +# +# @inbounds for batch_idx in 1:size(dx)[end] +# # We do a separate convolution for each channel in x, as we must +# for cidx in 1:channels_in(cdims) +# GC.@preserve col, w, dy, begin +# # Walk each pointer forward as we process each input channel +# dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1) +# w_ptr = pointer(w, (cidx - 1)*K*N + 1) +# col_ptr = pointer(col, (cidx - 1)*M*N + 1) +# gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr) +# end +# end +# col2im!(view(dx, :, :, :, :, batch_idx), col, cdims) +# end +# return dx +# end diff --git a/src/nnpack/performance.jl b/src/nnpack/performance.jl index 24abdb411..01c039b74 100644 --- a/src/nnpack/performance.jl +++ b/src/nnpack/performance.jl @@ -1,5 +1,5 @@ -function select_threadpool(cdims::DenseConvDims, batch_size::Int) - inp_size = input_size(cdims)[1] +function select_threadpool(cdims::ConvDims, batch_size::Int) + inp_size = input_size(cdims)[1] if batch_size >= 32 return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] elseif batch_size >= 16 && inp_size >= 64 @@ -10,12 +10,12 @@ function select_threadpool(cdims::DenseConvDims, batch_size::Int) return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] elseif inp_size * batch_size >= 256 return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - end + end return C_NULL end function select_threadpool(pdims::PoolDims, batch_size::Int) - inp_size = input_size(pdims)[1] + inp_size = input_size(pdims)[1] if batch_size >= 32 return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] elseif batch_size >= 16 && inp_size >= 64 @@ -26,6 +26,6 @@ function select_threadpool(pdims::PoolDims, batch_size::Int) return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] elseif inp_size * batch_size >= 256 return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - end + end return C_NULL end diff --git a/test/conv.jl b/test/conv.jl index 2df60cbec..24c0b54d6 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -3,13 +3,13 @@ using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multipl stride, padding, dilation, flipkernel, output_size @testset "ConvDims" begin - for T in (DenseConvDims, DepthwiseConvDims) + for T in (ConvDims) @testset "$(T)" begin x = randn(5,4,3,2) - if T == DenseConvDims + if T == ConvDims w = randn(1,2,3,4) - elseif T == DepthwiseConvDims + elseif T == ConvDims w = randn(1,2,4,3) end @@ -25,7 +25,7 @@ using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multipl @test output_size(cdims) == (5,3) # Special-case channel output tests - if T == DenseConvDims + if T == ConvDims @test channels_out(cdims) == size(w, 4) elseif T == DepthwiseConvDims @test channel_multiplier(cdims) == size(w, 3) @@ -63,7 +63,7 @@ using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multipl # Dilation will cause us to reach beyond the end of input + padding here: @test_throws DimensionMismatch T(x, w; dilation=(1, 5)) # Channel mismatch: - if T == DenseConvDims + if T == ConvDims @test_throws DimensionMismatch T(x, w[:,:,1:1,:]) elseif T == DepthwiseConvDims @test_throws DimensionMismatch T(x, w[:,:,:,1:1]) @@ -277,7 +277,7 @@ conv_answer_dict = Dict( for conv in (NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct) @testset "$(conv)" begin # First, your basic convolution with no parameters - cdims = DenseConvDims(x, w) + cdims = ConvDims(x, w) @test isapprox(ddims(conv(x, w, cdims)), y_plain, rtol = 1.0e-7) # Next, test convolution on views and alternate datatypes: @@ -285,19 +285,19 @@ conv_answer_dict = Dict( @test isapprox(ddims(conv(Float32.(x), Float32.(w), cdims)), Float32.(y_plain), rtol = 1.0e-7) # Next, introduce stride: - cdims = DenseConvDims(x, w; stride=2) + cdims = ConvDims(x, w; stride=2) @test isapprox(ddims(conv(x, w, cdims)), y_stride, rtol = 1.0e-7) # Next, introduce dilation: - cdims = DenseConvDims(x, w; dilation=2) + cdims = ConvDims(x, w; dilation=2) @test isapprox(ddims(conv(x, w, cdims)), y_dil, rtol = 1.0e-7) # Next, introduce padding: - cdims = DenseConvDims(x, w; padding=1) + cdims = ConvDims(x, w; padding=1) @test isapprox(ddims(conv(x, w, cdims)), y_pad, rtol = 1.0e-7) # Next, test crosscor/conv with a flipped kernel - cdims = DenseConvDims(x, w; flipkernel=true) + cdims = ConvDims(x, w; flipkernel=true) @test isapprox(ddims(conv(x, w, cdims)), y_flip, rtol = 1.0e-7) end end @@ -310,7 +310,7 @@ conv_answer_dict = Dict( ) @testset "$(∇conv_filter)/$(∇conv_data)" begin # First, your basic convolution with no parameters - cdims = DenseConvDims(x, w) + cdims = ConvDims(x, w) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx, rtol = 1.0e-7) @@ -323,25 +323,25 @@ conv_answer_dict = Dict( @test isapprox(ddims(∇conv_data(Float32.(dy), Float32.(w), cdims)), dx, rtol = 1.0e-7) # Next, introduce stride: - cdims = DenseConvDims(x, w; stride=2) + cdims = ConvDims(x, w; stride=2) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_stride, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_stride, rtol = 1.0e-7) # Next, introduce dilation: - cdims = DenseConvDims(x, w; dilation=2) + cdims = ConvDims(x, w; dilation=2) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_dil, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_dil, rtol = 1.0e-7) # Next, introduce padding: - cdims = DenseConvDims(x, w; padding=1) + cdims = ConvDims(x, w; padding=1) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_pad, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_pad, rtol = 1.0e-7) # Next, test crosscor/conv with a flipped kernel - cdims = DenseConvDims(x, w; flipkernel=true) + cdims = ConvDims(x, w; flipkernel=true) dy = NNlib.conv(x, w, cdims) @test isapprox(ddims(∇conv_filter(x, dy, cdims)), dw_flip, rtol = 1.0e-7) @test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_flip, rtol = 1.0e-7) @@ -393,7 +393,7 @@ conv_answer_dict = Dict( # Skip tests that are impossible due to mismatched sizes try - DenseConvDims(x, w; + ConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size, ) catch e @@ -404,7 +404,7 @@ conv_answer_dict = Dict( end # Do the actual convolution, comparing convolution implementations - cdims = DenseConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size) + cdims = ConvDims(x, w; stride=S_size, padding=P_size, dilation=D_size) # We use mutating calls with explicitly different initial values, so as # to be sure to catch when we're leaving pieces of the output untouched. diff --git a/test/inference.jl b/test/inference.jl index 896ff22a1..7ba3bac52 100644 --- a/test/inference.jl +++ b/test/inference.jl @@ -9,6 +9,6 @@ using NNlib: conv_direct, conv_im2col NNlib.is_nnpack_available() && push!(impl, NNlib.conv_nnpack) for T in impl - @inferred T(x, w, DenseConvDims(x, w)) + @inferred T(x, w, ConvDims(x, w)) end end diff --git a/test/perf/perf_report.jl b/test/perf/perf_report.jl index 943d487f3..20fb33b45 100644 --- a/test/perf/perf_report.jl +++ b/test/perf/perf_report.jl @@ -29,20 +29,20 @@ for rank in (2,), padding in (0, 2) benchmark_items = [ - (NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, DenseConvDims, "direct"), - (NNlib.conv_im2col!, NNlib.∇conv_data_im2col!, NNlib.∇conv_filter_im2col!, DenseConvDims, "im2col"), + (NNlib.conv_direct!, NNlib.∇conv_data_direct!, NNlib.∇conv_filter_direct!, ConvDims, "direct"), + (NNlib.conv_im2col!, NNlib.∇conv_data_im2col!, NNlib.∇conv_filter_im2col!, ConvDims, "im2col"), (NNlib.depthwiseconv_direct!, NNlib.∇depthwiseconv_data_direct!, NNlib.∇depthwiseconv_filter_direct!, DepthwiseConvDims, "direct"), (NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, "im2col"), ] if NNlib.is_nnpack_available() - push!(benchmark_items, (NNlib.conv_nnpack!, NNlib.∇conv_data_nnpack!, NNlib.∇conv_filter_nnpack!, DenseConvDims, "nnpack")) + push!(benchmark_items, (NNlib.conv_nnpack!, NNlib.∇conv_data_nnpack!, NNlib.∇conv_filter_nnpack!, ConvDims, "nnpack")) end for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in benchmark_items x = zeros(Float32, repeat([N], rank)..., C_in, 1) - if cT == DenseConvDims + if cT == ConvDims w = zeros(Float32, repeat([K], rank)..., C_in, C_out) else w = zeros(Float32, repeat([K], rank)..., C_out, C_in) @@ -53,7 +53,7 @@ for rank in (2,), continue end - if cT == DenseConvDims + if cT == ConvDims y = zeros(Float32, NNlib.output_size(cdims)..., C_out, 1) else y = zeros(Float32, NNlib.output_size(cdims)..., C_out*C_in, 1)