Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce for kqmax_new_j is unnecessary #1032

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

mahorozte
Copy link

using this patch,the performance will increase about 1%-2% ,testing in A800

test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf

i am do some trick to letting nb=1,2,3,7 will using flash_attn_vec_ext_f16(because A800 is capable for wmma) just for the eval performance

origin:
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.51 us/run - 4.19 MFLOP/run - 364.41 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 14.38 us/run - 8.39 MFLOP/run - 583.51 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.18 us/run - 12.58 MFLOP/run - 594.02 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 30654 runs - 34.17 us/run - 29.36 MFLOP/run - 859.30 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 19.03 us/run - 8.39 MFLOP/run - 440.74 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 41727 runs - 25.04 us/run - 16.78 MFLOP/run - 670.07 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27818 runs - 36.84 us/run - 25.17 MFLOP/run - 683.06 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 67.79 us/run - 58.72 MFLOP/run - 866.21 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 12.68 us/run - 4.19 MFLOP/run - 330.79 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.16 us/run - 8.39 MFLOP/run - 519.10 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 22.11 us/run - 12.58 MFLOP/run - 569.02 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 37.08 us/run - 29.36 MFLOP/run - 791.89 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.47 us/run - 8.39 MFLOP/run - 509.47 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.93 us/run - 16.78 MFLOP/run - 765.03 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 35.48 us/run - 25.17 MFLOP/run - 709.22 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 17030 runs - 60.34 us/run - 58.72 MFLOP/run - 973.20 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.23 us/run - 4.19 MFLOP/run - 373.53 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 13.85 us/run - 8.39 MFLOP/run - 605.60 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 20.40 us/run - 12.58 MFLOP/run - 616.89 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 40.79 us/run - 29.36 MFLOP/run - 719.80 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.64 us/run - 8.39 MFLOP/run - 450.01 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 23.73 us/run - 16.78 MFLOP/run - 707.06 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 34.49 us/run - 25.17 MFLOP/run - 729.75 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 66.55 us/run - 58.72 MFLOP/run - 882.32 GFLOPS

apply this patch:
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.14 us/run - 4.19 MFLOP/run - 376.67 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 14.02 us/run - 8.39 MFLOP/run - 598.41 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 20.66 us/run - 12.58 MFLOP/run - 609.01 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 30654 runs - 33.68 us/run - 29.36 MFLOP/run - 871.69 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.82 us/run - 8.39 MFLOP/run - 445.67 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 41727 runs - 24.57 us/run - 16.78 MFLOP/run - 682.88 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27818 runs - 36.41 us/run - 25.17 MFLOP/run - 691.19 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 66.20 us/run - 58.72 MFLOP/run - 887.06 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 12.70 us/run - 4.19 MFLOP/run - 330.27 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 15.85 us/run - 8.39 MFLOP/run - 529.23 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.69 us/run - 12.58 MFLOP/run - 580.05 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 36.73 us/run - 29.36 MFLOP/run - 799.45 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.19 us/run - 8.39 MFLOP/run - 518.10 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.47 us/run - 16.78 MFLOP/run - 781.60 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 34.47 us/run - 25.17 MFLOP/run - 730.00 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 17030 runs - 59.55 us/run - 58.72 MFLOP/run - 986.15 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 98304 runs - 10.93 us/run - 4.19 MFLOP/run - 383.64 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 13.52 us/run - 8.39 MFLOP/run - 620.46 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 19.85 us/run - 12.58 MFLOP/run - 633.98 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 40.10 us/run - 29.36 MFLOP/run - 732.12 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.36 us/run - 8.39 MFLOP/run - 456.96 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 23.40 us/run - 16.78 MFLOP/run - 716.87 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 33.66 us/run - 25.17 MFLOP/run - 747.69 GFLOPS
FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 65.55 us/run - 58.72 MFLOP/run - 895.82 GFLOPS

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I left this in when I at some point refactored the code; I would suggest you just remove the call to warp_reduce_max without a comment why it's not there since it shouldn't be there in the first place. Also fattn-vec-f32.cuh has the same issue, please remove the corresponding line there too while you're at it.

@@ -220,7 +220,8 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];

kqmax_new_j = warp_reduce_max(kqmax_new_j);
/* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking it is not exactly the same since the order in which you sum up floating point values makes a difference for the result but this is negligible. And there is no reason why the maximum value in particular should be preferred.

@mahorozte
Copy link
Author

fattn-vec-f32.cuh and fattn-vec-f16.cuh is updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants