Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement backward pass #2

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

leloykun
Copy link

Description

This PR implements a minimal backward pass for flash attention.

I got these results on my RTX 2060

=== profiling manual attention (backward pass) ===
...
Self CPU time total: 11.139ms
Self CUDA time total: 1.721ms
=== profiling minimal flash attention (backward pass) === 
...
Self CPU time total: 31.466ms
Self CUDA time total: 629.000us

2x speedup

Tho my GPU can only handle size 16 blocks (vs. size 32 blocks for T4)

@hypertseng
Copy link

@leloykun hello Franz! I have some trouble with the code and flash attention. Firstly, why the attn values sanity check return False when the seq_len is lower than 32. It lead to collapse in inference which seq_len is usually 1, I guess the block size may cause this result? Then, how to choose a appropriate block size? Looking forward to your reply!
image

@leloykun
Copy link
Author

leloykun commented Apr 17, 2024

Hi @hypertseng!

I believe it was because we weren't exiting the loops after going past the seq length. The forward pass should be fixed in my repo here: https://github.com/leloykun/flash-hyperbolic-attention-minimal

@hypertseng
Copy link

@leloykun Recently, I found the flash_attn_bwd implementation in your repo is lower than the manual implementation, this is totally because the implicitly function call of cudaDeviceSynchronize which Increases the CPU time a lot. Do you have any idea to solve this problem?
image
By the way, I found that change the AtomicAdd to normal add will decrease the cudaDeviceSynchronize occupancy, but I don't know why, I am a beginner of cuda hhhhh.

@2440020096
Copy link

@hypertseng Most likely, cudaDeviceSynchronize time includes the kernel execution time. You can Use cuda events to time it instead.

torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
minimal_result = minimal_attn.forward(q, k, v)
end_event.record()
torch.cuda.synchronize()

elapsed_time_ms = start_event.elapsed_time(end_event)
max_vram_MB = torch.cuda.max_memory_allocated() / (1024*1024)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants