Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiHeadAttention implementation #2146

Merged
merged 21 commits into from
Mar 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
Dropout, AlphaDropout,
LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
MultiHeadAttention,
Upsample, PixelShuffle,
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
testmode!, trainmode!
Expand Down Expand Up @@ -59,6 +61,7 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/attention.jl")
include("layers/show.jl")

include("loading.jl")
Expand Down
133 changes: 133 additions & 0 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@

const A3{T} = AbstractArray{T, 3}
const IntOrDims{N} = Union{Int, Dims{N}}

"""
MultiHeadAttention(dims; [nheads, bias, init, dropout_prob])

The multi-head dot-product attention layer used in Transformer architectures [1].

Returns the transformed input sequnce and the attention scores.

[1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017.

# Arguments

- `dims`: The embedding dimensions of inputs, intermediate tensors and outputs.
In the most general case, it is given as
a) `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`.
Can take also simpler forms as
b) `dims::Int`;
c) `in_dim::Int => (qk_dim, v_dim) => out_dim`;
d) `in_dim::Int => qkv_dim => out_dim`.
- `nheads`: number of heads. Default `8`.
- `init`: weight initializer for the Dense layers. Default `glorot_uniform`.
- `bias` : whether pointwise QKVO dense transforms use bias. Default `false`.
- `dropout_prob`: dropout probability for the attention scores. Default `0.0`.

# Forward

(mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask])

The arguments of the forward pass are:

- `q_in`: Input query array of size `(q_in_dim, q_len, batch_size)`.
- `k_in`: Input key array of size `(k_in_dim, kv_len, batch_size)`.
- `v_in`: Input value array of size `(v_in_dim, kv_len, batch_size)`.
- `bias`: Bias array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
It will be added to the attention scores before the softmax.
Default `nothing`.
- `mask`: Input array broadcastable to size
`(kv_len, q_len, nheads, batch_size)`.
The mask is applied to the attention scores just before the softmax.
See [`NNlib.make_causal_mask`](@ref) for creating causal masks.
Default `nothing`.

Alternative calling signatures are `mha(q_in)`, equivalent to `mha(q_in, q_in, q_in)` (self-attention),
and `mha(q_in, k_in)`, equivalent to `mha(q_in, k_in, k_in)` (key and value are the same).

See also [`NNlib.dot_product_attention`](@ref).

# Examples

```julia
mha = MultiHeadAttention(64, nheads = 8)
q = rand(Float32, (64, 10, 32))
k = rand(Float32, (64, 20, 32))
v = rand(Float32, (64, 20, 32))
y, α = mha(q, k, v)
# [y] = [64, 10, 32]
# [α] = [20, 10, 8, 32]

mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8)
y, α = mha(q) # self-attention
# [y] = [1024, 10, 32]
# [α] = [10, 10, 8, 32]
```
"""
struct MultiHeadAttention{P1, D, P2}
nheads::Int
q_proj::P1
k_proj::P1
v_proj::P1
attn_drop::D
out_proj::P2
end

@functor MultiHeadAttention

function MultiHeadAttention(dims;
nheads::Int = 8,
bias::Bool = false,
init = glorot_uniform,
dropout_prob = 0.0)

dims = normalize_mha_dims(dims)
@assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads"
@assert dims.v % nheads == 0 "v_dim should be divisible by nheads"
q_proj = Dense(dims.q_in => dims.qk; bias, init)
k_proj = Dense(dims.k_in => dims.qk; bias, init)
v_proj = Dense(dims.v_in => dims.v; bias, init)
attn_drop = Dropout(dropout_prob)
out_proj = Dense(dims.v => dims.out; bias, init)
return MultiHeadAttention(nheads, q_proj, k_proj, v_proj, attn_drop, out_proj)
end

# turns the dims argument into a named tuple
normalize_mha_dims(dims::Int) =
(; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims)

function normalize_mha_dims((in, (qkv, out))::Pair{<:IntOrDims{3}, <:Pair{<:IntOrDims{2}, Int}})
if in isa Int
q_in = k_in = v_in = in
else
q_in, k_in, v_in = in
end
if qkv isa Int
qk = v = qkv
else
qk, v = qkv
end
return (; q_in, k_in, v_in, qk, v, out)
end

# self-attention
(mha::MultiHeadAttention)(qkv; kws...) = mha(qkv, qkv, qkv; kws...)

# key and value are the same
(mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...)

function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3,
bias=nothing; mask=nothing)
## [q_in] = [q_in_dim, q_len, batch_size]
## [k_in] = [k_in_dim, kv_len, batch_size]
## [v_in] = [v_in_dim, kv_len, batch_size]
q = mha.q_proj(q_in) # [q] = [qk_dim, q_len, batch_size]
k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size]
v = mha.v_proj(v_in) # [v] = [v_dim, kv_len, batch_size]
x, α = NNlib.dot_product_attention(q, k, v, bias; mha.nheads, mask, fdrop=mha.attn_drop)
x = mha.out_proj(x)
# [x] = [out_dim, q_len, batch_size]
# [α] = [kv_len, q_len, nheads, batch_size]
return x, α
end
26 changes: 26 additions & 0 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,29 @@ end
@test eltype(pool(reshape(gx,3,4,1))) == Float16
end
end

@testset "MultiHeadAttention" begin
dim = 4; nheads = 2; len = 3; batch_size = 5
mha_cpu = MultiHeadAttention(dim; nheads)
x_cpu = rand(Float32, (dim, len, batch_size))
y_cpu, α_cpu = mha_cpu(x_cpu)

mha_gpu = mha_cpu |> gpu
x_gpu = x_cpu |> gpu
y_gpu, α_gpu = mha_gpu(x_gpu)
@test y_gpu isa CuArray{Float32}
@test α_gpu isa CuArray{Float32}
@test Array(y_gpu) ≈ y_cpu atol=1e-4
@test Array(α_gpu) ≈ α_cpu atol=1e-4

gm_cpu, gx_cpu = gradient(mha_cpu, x_cpu) do mha, x
y, α = mha(x)
return sum(y.^2) + sum(α.^2)
end
gm_gpu, gx_gpu = gradient(mha_gpu, x_gpu) do mha, x
y, α = mha(x)
return sum(y.^2) + sum(α.^2)
end
check_grad(gm_gpu, gm_cpu)
check_grad(gx_gpu, gx_cpu)
end
65 changes: 65 additions & 0 deletions test/layers/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@


@testset "attention" begin
dim = 4; nheads = 2; len = 3; batch_size = 5
mha = MultiHeadAttention(dim; nheads)
q = rand(Float32, (dim, len, batch_size))
k = rand(Float32, (dim, len, batch_size))
v = rand(Float32, (dim, len, batch_size))

y, α = mha(q, k, v)
@test y isa Array{Float32, 3}
@test size(y) == (dim, len, batch_size)
@test α isa Array{Float32, 4}
@test size(α) == (len, len, nheads, batch_size)

@testset "self-attention" begin
y1, α1 = mha(q)
y2, α2 = mha(q, q, q)
@test y1 ≈ y2
@test α1 ≈ α2
end

@testset "key and value are the same" begin
y1, α1 = mha(q, k)
y2, α2 = mha(q, k, k)
@test y1 ≈ y2
@test α1 ≈ α2
end

@testset "change dims" begin
dims = 4 => 10 => 5
nhead = 5
mha2 = MultiHeadAttention(dims; nheads)
y2, _ = mha2(q, k, v)
@test size(y2) == (dims.second.second, len, batch_size)
end

@testset "mask" begin
mask = NNlib.make_causal_mask(q)
y, α = mha(q; mask)
@test all(α[2, 1, :, :] .== 0)
@test α[:, :, 1, 1] ≈ triu(α[:, :, 1, 1])
end

@testset "bias" begin
# use bias to produce a causal mask
b = zeros(Float32, (len, len))
for i in 1:len, j in i:len
b[i, j] = typemax(Float32)
end
y, α = mha(q, k, v, b)
@test all(α[2, 1, :, :] .== 0)
@test α[:, :, 1, 1] ≈ triu(α[:, :, 1, 1])
end

@testset "gradient" begin
gm, gq = gradient(mha, q) do mha, q
y, α = mha(q)
return sum(y.^2) + sum(α.^2)
end
check_grad_type(gm, mha)
check_grad_type(gq, q)
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Random.seed!(0)
end

@testset "Layers" begin
include("layers/attention.jl")
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
Expand Down
57 changes: 41 additions & 16 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
function check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing::Bool)
function check_grad(g_gpu, g_cpu;
rtol=1e-4, atol=1e-4,
allow_nothing::Bool=false)
allow_nothing && return
@show g_gpu g_cpu
@test false
end
check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, rtol; allow_nothing::Bool) =
check_grad(g_gpu[], g_cpu[], atol, rtol; allow_nothing)
check_grad(g_gpu::Nothing, g_cpu::Nothing, atol, rtol; allow_nothing::Bool) =

check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) =
check_grad(g_gpu[], g_cpu[]; rtol, atol, allow_nothing)

check_grad(g_gpu::Nothing, g_cpu::Nothing; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) =
@test true
check_grad(g_gpu::Float32, g_cpu::Float32, atol, rtol; allow_nothing::Bool) =

check_grad(g_gpu::Float32, g_cpu::Float32; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) =
@test g_cpu ≈ g_gpu rtol=rtol atol=atol
check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}, atol, rtol; allow_nothing::Bool) =

check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) =
@test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol

function check_grad(g_gpu::Tuple, g_cpu::Tuple, atol, rtol; allow_nothing::Bool)
function check_grad(g_gpu::Tuple, g_cpu::Tuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false)
for (v1, v2) in zip(g_gpu, g_cpu)
check_grad(v1, v2, atol, rtol; allow_nothing)
check_grad(v1, v2; rtol, atol, allow_nothing)
end
end

function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple, atol, rtol; allow_nothing::Bool)
function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false)
for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu))
@test k1 == k2
check_grad(v1, v2, atol, rtol; allow_nothing)
check_grad(v1, v2; rtol, atol, allow_nothing)
end
end

Expand All @@ -31,10 +37,14 @@ check_type(x::CuArray{Float32}) = true
check_type(x::Array{Float32}) = true

function gpu_autodiff_test(
f_cpu, xs_cpu::Array{Float32}...;
test_equal=true, rtol=1e-4, atol=1e-4,
checkgrad::Bool = true, allow_nothing::Bool = false,
)
f_cpu,
xs_cpu::Array{Float32}...;
test_equal=true,
rtol=1e-4, atol=1e-4,
checkgrad::Bool = true,
allow_nothing::Bool = false,
)

# Compare CPU & GPU function outputs.
f_gpu = f_cpu |> gpu
xs_gpu = gpu.(xs_cpu)
Expand All @@ -60,7 +70,7 @@ function gpu_autodiff_test(
if test_equal
@test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol
for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu)
check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing)
check_grad(g_gpu, g_cpu; atol, rtol, allow_nothing)
end
end

Expand All @@ -78,7 +88,22 @@ function gpu_autodiff_test(
@test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol
@assert length(ps_gpu) == length(ps_cpu)
for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu)
check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu], atol, rtol; allow_nothing)
check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu]; atol, rtol, allow_nothing)
end
end
end

# check_grad_type checks that the gradient type matches the primal type.

check_grad_type(g::Nothing, x) = nothing

function check_grad_type(g::AbstractArray{T1}, x::AbstractArray{T2}) where {T1, T2}
@test T1 == T2
@test size(g) == size(x)
end

function check_grad_type(g::NamedTuple, x::T) where T
for f in fieldnames(T)
check_grad_type(g[f], getfield(x, f))
end
end