Skip to content

Commit

Permalink
Adapting ConvDims to accomadate Groupwise and Depthwise Convolutions …
Browse files Browse the repository at this point in the history
…and removing seperate implementations of Depthwise and Groupwise.
  • Loading branch information
arhik committed Dec 8, 2019
1 parent 0afd23a commit 14da5b5
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 239 deletions.
10 changes: 5 additions & 5 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
#
# All methods require a `ConvDims` object to define the dimensions and optional
# elements of the convolution (padding, stride, dilation, kernel-flipping, etc...),
# which is easily constructable through something like `DenseConvDims(x, w)`. All
# which is easily constructable through something like `ConvDims(x, w)`. All
# methods take in the `ConvDims` of the associated normal, forward-pass convolution,
# that is, the following is legal:
#
Expand Down Expand Up @@ -123,7 +123,7 @@ for backend in (Symbol(), :_direct, :_im2col)
end
end

# This filter back prop covers dense/depthwise/groupwise conv filter backprops, as groupcount alone
# This filter back prop covers dense/depthwise/groupwise conv filter backprops, as groupcount alone
# is a deciding factor from cudnn's perspective. For backends im2col and direct needs to be handled.
@eval begin
function $(Symbol("∇conv_filter$(backend)"))(
Expand All @@ -140,7 +140,7 @@ end
# Use NNPACK if it is available and the operation is supported
if is_nnpack_available()
function conv(x::Array{xT, 4}, w::Array{wT, 4},
cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
cdims::ConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
kwargs...) where {xT, wT, K, C_in, C_out, S, P, F}
return conv_nnpack(x, w, cdims; kwargs...)
end
Expand All @@ -150,14 +150,14 @@ function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flip
stride = expand(Val(N-2), stride)
pad = expand(Val(N-2), pad)
dilation = expand(Val(N-2), dilation)
cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
cdims = ConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped)
return conv(x, w, cdims)
end

function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groupcount) where {T, N}
stride = expand(Val(N-2), stride)
pad = expand(Val(N-2), pad)
dilation = expand(Val(N-2), dilation)
cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped, groupcount=groupcount)
cdims = ConvDims(x, w; stride = stride, padding = pad, dilation = dilation, flipkernel = flipped, groupcount=groupcount)
return depthwiseconv(x, w, cdims)
end
5 changes: 2 additions & 3 deletions src/dim_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Various helper functions to calculate dimensions for operations
include("dim_helpers/AbstractDims.jl")
include("dim_helpers/ConvDims.jl")
include("dim_helpers/DenseConvDims.jl")
include("dim_helpers/DepthwiseConvDims.jl")
include("dim_helpers/PoolDims.jl")


Expand Down Expand Up @@ -45,7 +44,7 @@ function transpose_pad(cdims::ConvDims)
end

"""
insert_singleton_spatial_dimension(cdims::DenseConvDims)
insert_singleton_spatial_dimension(cdims::ConvDims)
When converting a 1d convolution to a 2d, or a 2d to a 3d, we need to insert a singleton
spatial dimension at the end of the spatial dimensions. This does so for a ConvDims.
Expand Down
120 changes: 120 additions & 0 deletions src/dim_helpers/AbstractDims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
export AbstractDims

"""
AbstractDims
Type system-level information about convolution dimensions. Critical for things like
`im2col!()` to generate efficient code, and helpful to reduce the number of kwargs
getting passed around.
We don't want to specialize on things like image size/channel count, so we generally
store those as fields, just for convenience, and to allow for non-breaking changes when
we decide we _do_ want to specialize on those values. We always want to specialize on
things like stride, padding, dilation, and kernel flipping though.
"""
abstract type AbstractDims{N, S, P, D, F} end

# Hack to get rid of type parameters
function basetype(::Type{C}) where {C <: AbstractDims}
if C <: ConvDims
return ConvDims
elseif C <: PoolDims
return PoolDims
else
return nothing
end
end

# Obvious getter definitions for the type system-level definitions
spatial_dims(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = N
stride(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = S
padding(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = P
dilation(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = D
flipkernel(c::AbstractDims{N,S,P,D,F}) where {N, S, P, D, F} = F

"""
im2col_dims(c::AbstractDims)
im2col calculates, for each output pixel, the "convolution" of N kernels where N is the
number of output channels, by doing a matrix multiply. The dimensions of that matrix
are given by this function.
"""
im2col_dims(c::AbstractDims) = (prod(output_size(c)), prod(kernel_size(c))*channels_in(c))

# Protect your skin, kids. Also do common validation of stride, padding, etc...
function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N}
# Number of spatial dimensions in `x` and `w`.
nd = N - 2

# Given a number, duplicate it out to have `nd` length. If it's already a collection,
# just splat it out into a tuple so it's always a tuple. We'll lint length later.
expand_size(p::Number) = ntuple(_ -> Int(p), nd)
expand_size(p) = tuple(p...)

# Convert stride, padding, dilation, etc.. to fully-specified tuples
pstride = expand_size(stride)
pdilation = expand_size(dilation)
ppadding = expand_size(padding)

if length(pstride) != nd
throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!"))
end
if length(pdilation) != nd
throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!"))
end

# padding is kind of a special case; we allow it to be either 2-length or 4-length,
# since we support asymmetrical padding
if length(ppadding) != 2*nd
if length(ppadding) == nd
# Do this repeat dance so that we get lo/hi symmetrical padding
ppadding = tuple(repeat(collect(ppadding), inner=2)...)
else
throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!"))
end
end

# Assert that kernel size * dilation is <= padded input size
for idx in 1:nd
Is = x_size[idx]
Pl = ppadding[(idx - 1)*2 + 1]
Ph = ppadding[(idx - 1)*2 + 2]
Ks = w_size[idx]
Ds = pdilation[idx]
if Is + Pl + Ph < (Ks - 1)*Ds + 1
throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!"))
end
end

return pstride, ppadding, pdilation
end

"""
output_size(c::AbstractDims)
Calculate the output (spatial) dimensions of the convolution. Get channel count via
`channels_out(c)`, and batch count is unknowable.
"""
function output_size(c::AbstractDims)
I = input_size(c)
K = kernel_size(c)
S = stride(c)
P = padding(c)
D = dilation(c)

return ntuple(spatial_dims(c)) do i
return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1
end
end

# Override show() for these beauties
function Base.show(io::IO, cdims::C) where {C <: AbstractDims}
I = (input_size(cdims)..., channels_in(cdims))
O = (output_size(cdims)..., channels_out(cdims))
K = kernel_size(cdims)
S = stride(cdims)
P = padding(cdims)
D = dilation(cdims)
F = flipkernel(cdims)
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F")
end
152 changes: 60 additions & 92 deletions src/dim_helpers/ConvDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,109 +12,77 @@ store those as fields, just for convenience, and to allow for non-breaking chang
we decide we _do_ want to specialize on those values. We always want to specialize on
things like stride, padding, dilation, and kernel flipping though.
"""
abstract type ConvDims{N, S, P, D, F} end

# Hack to get rid of type parameters
function basetype(::Type{C}) where {C <: ConvDims}
if C <: DenseConvDims
return DenseConvDims
elseif C <: PoolDims
return PoolDims
else
return nothing
end
struct ConvDims{N,K,C_in,C_out,S,P,D,F,G} <: AbstractDims{N,S,P,D,F}
I::NTuple{N,Int}
end

# Obvious getter definitions for the type system-level definitions
spatial_dims(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = N
stride(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = S
padding(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = P
dilation(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = D
flipkernel(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = F

"""
im2col_dims(c::ConvDims)
im2col calculates, for each output pixel, the "convolution" of N kernels where N is the
number of output channels, by doing a matrix multiply. The dimensions of that matrix
are given by this function.
"""
im2col_dims(c::ConvDims) = (prod(output_size(c)), prod(kernel_size(c))*channels_in(c))

# Protect your skin, kids. Also do common validation of stride, padding, etc...
function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilation) where {N}
# Number of spatial dimensions in `x` and `w`.
nd = N - 2

# Given a number, duplicate it out to have `nd` length. If it's already a collection,
# just splat it out into a tuple so it's always a tuple. We'll lint length later.
expand_size(p::Number) = ntuple(_ -> Int(p), nd)
expand_size(p) = tuple(p...)

# Convert stride, padding, dilation, etc.. to fully-specified tuples
pstride = expand_size(stride)
pdilation = expand_size(dilation)
ppadding = expand_size(padding)

if length(pstride) != nd
throw(DimensionMismatch("Stride $(length(stride))d, should be $(nd)d!"))
end
if length(pdilation) != nd
throw(DimensionMismatch("Dilation $(length(pdilation))d, should be $(nd)d!"))
# Getters for the fields
input_size(c::ConvDims) = c.I
kernel_size(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = K
channels_in(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_in
channels_out(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = C_out
group_count(c::ConvDims{N,K,C_in,C_out,S,P,D,F,G}) where {N,K,C_in,C_out,S,P,D,F,G} = G

# Convenience wrapper to create ConvDims objects
function ConvDims(x_size::NTuple{M}, w_size::NTuple{M};
stride=1, padding=0, dilation=1, flipkernel::Bool=false, groupcount=1) where M
# Do common parameter validation
stride, padding, dilation = check_spdf(x_size, w_size, stride, padding, dilation)

# Ensure channels are equal
if x_size[M-1] != w_size[M-1]*groupcount
xs = x_size[M-1]
ws = w_size[M-1]*groupcount
throw(DimensionMismatch("Input channels must match! ($xs vs. $ws)"))
end

# padding is kind of a special case; we allow it to be either 2-length or 4-length,
# since we support asymmetrical padding
if length(ppadding) != 2*nd
if length(ppadding) == nd
# Do this repeat dance so that we get lo/hi symmetrical padding
ppadding = tuple(repeat(collect(ppadding), inner=2)...)
else
throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!"))
end
end
# The type parameters are what
return ConvDims{
M - 2,
w_size[1:M-2],
x_size[M-1],
w_size[M],
stride,
padding,
dilation,
flipkernel,
groupcount
}(
# Input spatial size
x_size[1:M-2],
)
end

# Assert that kernel size * dilation is <= padded input size
for idx in 1:nd
Is = x_size[idx]
Pl = ppadding[(idx - 1)*2 + 1]
Ph = ppadding[(idx - 1)*2 + 2]
Ks = w_size[idx]
Ds = pdilation[idx]
if Is + Pl + Ph < (Ks - 1)*Ds + 1
throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!"))
end
# Auto-extract sizes and sub out to big brother above
function ConvDims(x::AbstractArray, w::AbstractArray; kwargs...)
if ndims(x) != ndims(w)
throw(DimensionMismatch("Rank of x and w must match! ($(ndims(x)) vs. $(ndims(w)))"))
end

return pstride, ppadding, pdilation
return ConvDims(size(x), size(w); kwargs...)
end

"""
output_size(c::ConvDims)
# Useful for constructing a new ConvDims that has only a few elements different
# from the original progenitor object that it inherits shapes from.
function ConvDims(c::AbstractDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
C_in=channels_in(c), C_out=channels_out(c), S=stride(c),
P=padding(c), D=dilation(c), F=flipkernel(c), G=group_count(c))
return ConvDims{N, K, C_in, C_out, S, P, D, F, G}(I)
end

Calculate the output (spatial) dimensions of the convolution. Get channel count via
`channels_out(c)`, and batch count is unknowable.
"""
function output_size(c::ConvDims)
I = input_size(c)
K = kernel_size(c)
S = stride(c)
P = padding(c)
D = dilation(c)
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::ConvDims) where {M}
# First, check that channel counts are all correct:
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
@assert w[M-1] == channels_in(cdims)/group_count(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)/group_count(cdims)))")
@assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")

return ntuple(spatial_dims(c)) do i
return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1
end
end
# Next, check that the spatial dimensions match up
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")

# Override show() for these beauties
function Base.show(io::IO, cdims::C) where {C <: ConvDims}
I = (input_size(cdims)..., channels_in(cdims))
O = (output_size(cdims)..., channels_out(cdims))
K = kernel_size(cdims)
S = stride(cdims)
P = padding(cdims)
D = dilation(cdims)
F = flipkernel(cdims)
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F")
# Finally, check that the batch size matches
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
end
Loading

0 comments on commit 14da5b5

Please sign in to comment.