Skip to content

An attempt to combine dense, depthwise and groupwise conv through DenseConvDims #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 13 additions & 31 deletions src/conv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, depthwiseconv,
depthwiseconv!, ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter,
∇depthwiseconv_filter!
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!

## Convolution API
#
Expand All @@ -19,7 +17,7 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
#
# All methods require a `ConvDims` object to define the dimensions and optional
# elements of the convolution (padding, stride, dilation, kernel-flipping, etc...),
# which is easily constructable through something like `DenseConvDims(x, w)`. All
# which is easily constructable through something like `ConvDims(x, w)`. All
# methods take in the `ConvDims` of the associated normal, forward-pass convolution,
# that is, the following is legal:
#
Expand All @@ -36,9 +34,6 @@ for (front_name, backend) in (
:conv => :im2col,
:∇conv_data => :im2col,
:∇conv_filter => :im2col,
:depthwiseconv => :im2col,
:∇depthwiseconv_data => :im2col,
:∇depthwiseconv_filter => :im2col,
)

# These are the GEMM types we will accelerate with `im2col`
Expand All @@ -58,8 +53,7 @@ end
# Our strategy for 1d and 2d convolution is to reshape to 3d convolutions, which
# makes things MUCH EASIER for us on the backend side, and is in general pretty fast,
# since we can specialize on sizes.
for front_name in (:conv, :∇conv_data, :∇conv_filter,
:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
for front_name in (:conv, :∇conv_data, :∇conv_filter)
for backend in (Symbol(), :_direct, :_im2col)
for N in (3, 4)
@eval begin
Expand Down Expand Up @@ -87,8 +81,7 @@ end
# We always support a fallback, non-accelerated path, where we use the direct, but
# slow, implementations. These should not typically be used, hence the `@debug`,
# but let's ggo ahead and define them first:
for front_name in (:conv, :∇conv_data, :∇conv_filter,
:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
for front_name in (:conv, :∇conv_data, :∇conv_filter)
@eval begin
function $(Symbol("$(front_name)!"))(
y::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
Expand All @@ -106,7 +99,7 @@ end
# allocation. :P
for backend in (Symbol(), :_direct, :_im2col)
# First make auto-allocating versions of the conv()-like calls:
for name in (:conv, :depthwiseconv)
for name in (:conv,)
@eval begin
function $(Symbol("$(name)$(backend)"))(
x::AbstractArray{xT,N}, w::AbstractArray{wT,N},
Expand All @@ -118,7 +111,7 @@ for backend in (Symbol(), :_direct, :_im2col)
end
end

for name in (:∇conv_data, :∇depthwiseconv_data)
for name in (:∇conv_data,)
@eval begin
function $(Symbol("$(name)$(backend)"))(
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
Expand All @@ -130,35 +123,24 @@ for backend in (Symbol(), :_direct, :_im2col)
end
end

# We do the conv/depthwiseconv filter backprops separately, as the shape calculation
# for `w` is slightly different for depthwise than for normal dense convolution.
# This filter back prop covers dense/depthwise/groupwise conv filter backprops, as groupcount alone
# is a deciding factor from cudnn's perspective. For backends im2col and direct needs to be handled.
@eval begin
function $(Symbol("∇conv_filter$(backend)"))(
x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},
cdims::ConvDims; kwargs...) where {xT, yT, N}
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims),
dw = similar(dy, kernel_size(cdims)..., div(channels_in(cdims),group_count(cdims)),
Copy link
Author

@arhik arhik Dec 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to have a succinct description of what to do when groupcount is not equal to 1 or input_channels(). E.g. what does groupcount == 7 look like? If you can succinctly describe what should happen, I can help you adjust the implementations of the direct and im2col implementations.

@staticfloat. This is only change. Only weights dimensions shrink as in here by groupcount value in third axis. if groupcount == 1 Nothing changes. When groupcount == 2 (lets say), then one group of weights operate only on the half of the input channels. and produce only one output channel. These weights groups have to be occupied across all the channels and we will have to use new group of weights (occupying all input channels in blocks) until their output matches output number of channels.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When groupcount == 7; we should make sure input channels can be divided exactly into 7 groups. Then we should check if output channels are multiple of groupcount(Since one group can only produce one output).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should then distribute these 7 groups to operate on different input channels in blocks of div(input_channels(), 7)

channels_out(cdims))
return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...)
end
end

@eval begin
function $(Symbol("∇depthwiseconv_filter$(backend)"))(
x::AbstractArray{xT,N}, dy::AbstractArray{yT,N},
cdims::ConvDims; kwargs...) where {xT, yT, N}
dw = similar(dy, kernel_size(cdims)..., channel_multiplier(cdims),
channels_in(cdims))
return $(Symbol("∇depthwiseconv_filter$(backend)!"))(dw, x, dy, cdims;
kwargs...)
end
end
end


# Use NNPACK if it is available and the operation is supported
if is_nnpack_available()
function conv(x::Array{xT, 4}, w::Array{wT, 4},
cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
cdims::ConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
kwargs...) where {xT, wT, K, C_in, C_out, S, P, F}
return conv_nnpack(x, w, cdims; kwargs...)
end
Expand All @@ -168,14 +150,14 @@ function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flip
stride = expand(Val(N-2), stride)
pad = expand(Val(N-2), pad)
dilation = expand(Val(N-2), dilation)
cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
cdims = ConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
return conv(x, w, cdims)
end

function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false) where {T, N}
function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groupcount) where {T, N}
stride = expand(Val(N-2), stride)
pad = expand(Val(N-2), pad)
dilation = expand(Val(N-2), dilation)
cdims = DepthwiseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
cdims = ConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped, groupcount=groupcount)
return depthwiseconv(x, w, cdims)
end
5 changes: 2 additions & 3 deletions src/dim_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Various helper functions to calculate dimensions for operations
include("dim_helpers/AbstractDims.jl")
include("dim_helpers/ConvDims.jl")
include("dim_helpers/DenseConvDims.jl")
include("dim_helpers/DepthwiseConvDims.jl")
include("dim_helpers/PoolDims.jl")


Expand Down Expand Up @@ -45,7 +44,7 @@ function transpose_pad(cdims::ConvDims)
end

"""
insert_singleton_spatial_dimension(cdims::DenseConvDims)
insert_singleton_spatial_dimension(cdims::ConvDims)

When converting a 1d convolution to a 2d, or a 2d to a 3d, we need to insert a singleton
spatial dimension at the end of the spatial dimensions. This does so for a ConvDims.
Expand Down
120 changes: 120 additions & 0 deletions src/dim_helpers/AbstractDims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
export AbstractDims

"""
AbstractDims

Type system-level information about convolution dimensions. Critical for things like
`im2col!()` to generate efficient code, and helpful to reduce the number of kwargs
getting passed around.

We don't want to specialize on things like image size/channel count, so we generally
store those as fields, just for convenience, and to allow for non-breaking changes when
we decide we _do_ want to specialize on those values. We always want to specialize on
things like stride, padding, dilation, and kernel flipping though.
"""
abstract type AbstractDims{N, S, P, D, F} end

# Hack to get rid of type parameters
function basetype(::Type{C}) where {C <: AbstractDims}
if C <: ConvDims
return ConvDims
elseif C <: PoolDims
return PoolDims
else
return nothing
end
end

# Obvious getter definitions for the type system-level definitions
spatial_dims(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = N
stride(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = S
padding(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = P
dilation(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = D
flipkernel(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = F

"""
im2col_dims(c::AbstractDims)

im2col calculates, for each output pixel, the "convolution" of N kernels where N is the
number of output channels, by doing a matrix multiply. The dimensions of that matrix
are given by this function.
"""
im2col_dims(c::AbstractDims) = (prod(output_size(c)), prod(kernel_size(c))*channels_in(c))

# Protect your skin, kids. Also do common validation of stride, padding, etc...
function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N}
# Number of spatial dimensions in `x` and `w`.
nd = N - 2

# Given a number, duplicate it out to have `nd` length. If it's already a collection,
# just splat it out into a tuple so it's always a tuple. We'll lint length later.
expand_size(p::Number) = ntuple(_ -> Int(p), nd)
expand_size(p) = tuple(p...)

# Convert stride, padding, dilation, etc.. to fully-specified tuples
pstride = expand_size(stride)
pdilation = expand_size(dilation)
ppadding = expand_size(padding)

if length(pstride) != nd
throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!"))
end
if length(pdilation) != nd
throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!"))
end

# padding is kind of a special case; we allow it to be either 2-length or 4-length,
# since we support asymmetrical padding
if length(ppadding) != 2*nd
if length(ppadding) == nd
# Do this repeat dance so that we get lo/hi symmetrical padding
ppadding = tuple(repeat(collect(ppadding), inner=2)...)
else
throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!"))
end
end

# Assert that kernel size * dilation is <= padded input size
for idx in 1:nd
Is = x_size[idx]
Pl = ppadding[(idx - 1)*2 + 1]
Ph = ppadding[(idx - 1)*2 + 2]
Ks = w_size[idx]
Ds = pdilation[idx]
if Is + Pl + Ph < (Ks - 1)*Ds + 1
throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!"))
end
end

return pstride, ppadding, pdilation
end

"""
output_size(c::AbstractDims)

Calculate the output (spatial) dimensions of the convolution. Get channel count via
`channels_out(c)`, and batch count is unknowable.
"""
function output_size(c::AbstractDims)
I = input_size(c)
K = kernel_size(c)
S = stride(c)
P = padding(c)
D = dilation(c)

return ntuple(spatial_dims(c)) do i
return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1
end
end

# Override show() for these beauties
function Base.show(io::IO, cdims::C) where {C <: AbstractDims}
I = (input_size(cdims)..., channels_in(cdims))
O = (output_size(cdims)..., channels_out(cdims))
K = kernel_size(cdims)
S = stride(cdims)
P = padding(cdims)
D = dilation(cdims)
F = flipkernel(cdims)
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F")
end
Loading