-
Notifications
You must be signed in to change notification settings - Fork 55
/
bench.py
39 lines (29 loc) · 1.38 KB
/
bench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import math
import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load
# Load the CUDA kernel as a python module
minimal_attn = load(name='minimal_attn', sources=['main.cpp', 'flash.cu'], extra_cuda_cflags=['-O2'])
# Use small model params, otherwise slower than manual attention. See caveats in README.
batch_size = 16
n_head = 12
seq_len = 64
head_embd = 64
q = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd).cuda()
print('=== profiling manual attention ===')
# Our minimal flash attention aims to be faster than this by avoiding HBM read/writes of N^2 matrices.
def manual_attn(q, k, v):
att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
att = F.softmax(att, dim=-1)
y = att @ v
return y
with torch.autograd.profiler.profile(use_cuda=True) as prof:
manual_result = manual_attn(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
print('=== profiling minimal flash attention === ')
with torch.autograd.profiler.profile(use_cuda=True) as prof:
minimal_result = minimal_attn.forward(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))