Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement backward pass #2

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,49 +1,78 @@
# flash-attention-minimal
A minimal re-implementation of Flash Attention with CUDA and PyTorch.

A minimal re-implementation of Flash Attention with CUDA and PyTorch.
The official [implementation](https://github.com/Dao-AILab/flash-attention) can be quite daunting for a CUDA beginner
(like myself), so this repo tries to be small and educational.

* The entire forward pass is written in ~100 lines in `flash.cu`.
* The variable names follow the notations from the original [paper](https://arxiv.org/abs/2205.14135).

## Usage

### Prerequisite

* PyTorch (with CUDA)
* `Ninja` for loading in C++

### Benchmark

Compare the wall-clock time between manual attention and minimal flash attention:
```

```bash
python bench.py
```

Sample output on a [T4](https://aws.amazon.com/ec2/instance-types/g4/):
Sample output on a [T4](https://aws.amazon.com/ec2/instance-types/g4/) for the forward pass (Br = Bc = 32):

```
=== profiling manual attention ===
=== profiling manual attention (forward pass) ===
...
Self CPU time total: 52.389ms
Self CUDA time total: 52.545ms

=== profiling minimal flash attention ===
...
=== profiling minimal flash attention (forward pass) ===
...
Self CPU time total: 11.452ms
Self CUDA time total: 3.908ms
```
Speed-up achieved!

That's a 13x speedup!

Sample output on an RTX 3060 for the backward pass (Br = Bc = 16):

```
=== profiling manual attention (backward pass) ===
...
Self CPU time total: 11.139ms
Self CUDA time total: 1.721ms

=== profiling minimal flash attention (backward pass) ===
...
Self CPU time total: 31.466ms
Self CUDA time total: 629.000us
```

That's a 2x speedup! Note though that we've only tested this on an RTX 3060 which has a smaller SRAM than the T4
(hence the reduction of block size from 32 to 16). The speedup might be different on a T4.

### I don't have a GPU

Try out this [online colab demo](https://colab.research.google.com/gist/tspeterkim/143bc7be7a845656817cf94c5228598e/demo-flash-attention-minimal.ipynb).

## Caveats
* No backward pass! To be honest, I found it a lot more complex than the forward pass, which was enough to show the
use of shared memory to avoid large N^2 read/writes.

* In the inner loop, I assign each thread to a row of the output matrix. This differs from the original implementation.
* This thread-per-row simplification makes the matrix multiplications very slow. This is probably why for longer
* This thread-per-row simplification makes the matrix multiplications very slow. This is probably why for longer
sequences and larger block sizes, this gets slower than the manual implementation.
* Q,K,Vs are in float32, unlike the original implementation which uses float16.
* The block size is [fixed](https://github.com/tspeterkim/flash-attention-minimal/blob/9b7ca8ef4e6afdbfeb149a9cd488c8dea9af9ad6/flash.cu#L85) at compile time to 32.

## Todos
- [ ] Add backward pass
- [ ] Speed up matmults
- [ ] Dynamically set block size

* [ ] Speed up matmults
* [ ] Dynamically set block size

## Contributors

* [Peter Kim](https://github.com/tspeterkim), Lead Contributor
* [Franz Cesista](https://github.com/leloykun), Implemented the backward pass
48 changes: 41 additions & 7 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
seq_len = 64
head_embd = 64

q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
q = torch.randn(batch_size, n_head, seq_len, head_embd, requires_grad=True).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd, requires_grad=True).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd, requires_grad=True).cuda()

print('=== profiling manual attention ===')
print('====== profiling forward pass ======')

print('=== profiling manual attention (forward pass) ===')

# Our minimal flash attention aims to be faster than this by avoiding HBM read/writes of N^2 matrices.
def manual_attn(q, k, v):
Expand All @@ -30,10 +32,42 @@ def manual_attn(q, k, v):
manual_result = manual_attn(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('=== profiling minimal flash attention === ')
print('=== profiling minimal flash attention (forward pass) === ')

with torch.autograd.profiler.profile(use_cuda=True) as prof:
minimal_result = minimal_attn.forward(q, k, v)
with (
torch.autograd.profiler.profile(use_cuda=True) as prof,
torch.no_grad(),
):
minimal_result, l, m = minimal_attn.forward(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))

print("\n\n\n")
print('====== profiling backward pass ======')

print('=== profiling manual attention (backward pass) ===')

y_grad = torch.ones_like(minimal_result)

def manual_attn_backward(q, k, v, y, y_grad):
return torch.autograd.grad([y], [q, k, v], grad_outputs=[y_grad])

with torch.autograd.profiler.profile(use_cuda=True) as prof:
manual_grad_q, manual_grad_k, manual_grad_v = manual_attn_backward(
q, k, v, manual_result, y_grad
)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('=== profiling minimal flash attention (backward pass) === ')

with (
torch.autograd.profiler.profile(use_cuda=True) as prof,
torch.no_grad(),
):
minimal_grad_q, minimal_grad_k, minimal_grad_v = minimal_attn.backward(q, k, v, minimal_result, y_grad, l, m)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('q grad sanity check:', torch.allclose(manual_grad_q, minimal_grad_q, rtol=0, atol=1e-02))
print('k grad sanity check:', torch.allclose(manual_grad_k, minimal_grad_k, rtol=0, atol=1e-02))
print('v grad sanity check:', torch.allclose(manual_grad_v, minimal_grad_v, rtol=0, atol=1e-02))
Loading