Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu: aarch64: Enable matmul bf16f32 format desc #343

Open
wants to merge 1 commit into
base: ideep_pytorch
Choose a base branch
from

Conversation

aditew01
Copy link

@aditew01 aditew01 commented Nov 8, 2024

This will allow generating desc for matmul of format src:bf16; wei:bf16; dst:f32 instead of reordering dst to bf16 and back to f32.
These kernels are directly used by cpu_flash_attention (sdpa).

@aditew01
Copy link
Author

aditew01 commented Nov 8, 2024

cc: @yanbing-j can you please help with the review ?

@yanbing-j
Copy link

Request @jgong5 for review.

Comment on lines +830 to +834
dst_data_type = src.get_data_type() == data_type::bf16
? ((dst.get_data_type() == data_type::f32) ? data_type::f32
: data_type::bf16)
: ((src.get_data_type() == data_type::f16) ? data_type::f16
: data_type::f32);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concern is that how this would impact the original callers since it is a BC breaking changes. @yanbing-j can you double check the existing callers?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current implementation, the existing callers are with same dtypes (src and dst), for example, https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LinearAlgebra.cpp#L1531-L1543 by using reproducer

mat1 = torch.randn(20, 30).to(torch.bfloat16)
mat2 = torch.randn(30, 30).to(torch.bfloat16)
torch.matmul(mat1, mat2)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a broader context, this will allow to pick up the right kernel from oneDNN for matmul where src: bf16, wei:bf16and dst:f32.

https://github.com/pytorch/pytorch/blob/1185975c6e7f23f0cb318e47166fceebd2995c98/aten/src/ATen/native/CPUBlas.cpp#L400
For the current Operator that will use this will be: scaled_dot_product_attention

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aditew01 Sorry, I cannot find the calling of mkldnn_bf16f32_gemm in the main branch https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/CPUBlas.cpp#L405. Could you try to update your code base?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aditew01 Sorry, I cannot find the calling of mkldnn_bf16f32_gemm in the main branch https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/CPUBlas.cpp#L405. Could you try to update your code base?

Hi @yanbing-j , the change which I referred to is in a PR (draft) which I've raised in PyTorch. It's draft and waits on specific onednn and ideep changes.

pytorch/pytorch#140159

src_desc = (src.get_data_type() == data_type::bf16 &&
dst_data_type == data_type::f32)
? src.get_desc()
: src.get_desc().to_type(dst_data_type);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the logic is the same from both paths, suggest to avoid dup here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question here, why this logic will be applied outside __aarch64 macro? Although it will get the same result when src and dst are bf16, it is unnecessary.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the logic is the same from both paths, suggest to avoid dup here.

I reverted this, it should only be for aarch64. Pushed a revert which will preserve the previous state.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this incur different semantics for the API when it works on different CPU archs? Normally, we shouldn't design an API like that way. Is it possible to add an extra arg to specify the behavior you added (dst type does not follow src)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case I think this logic was generally used before. I added a diff condition very specific for aarch64

Would this incur different semantics for the API when it works on different CPU archs?

dst_data_type = src.get_data_type() == data_type::bf16 ?

I shifted the logic in ideep_pytorch which was generically for all arch to when cpu_arch != aarch64.
Please let me know you thoughts.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means the semantics are different between aarch64 and x86, right? I found that the arg dst_type is not used. Can we use it to override the dst_data_type to make sure the consistent behavior between aarch64 and x86? Of course, we have to redefine the default value for dst_type (it is now fp32 which doesn't seem right for the default behavior).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I do understand.
I have updated the dst_type default value to undef. The logic which uses dst_type in this context here:

dst_data_type = dst_type == data_type::undef ? dst_data_type : dst_type;
should make sense now. Please let me know your thoughts.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, can we specify the dst_type from the caller side? You don't have to use different logic to initialize dst_data_type on aarch64 and x64?

Copy link

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to make the semantics consistent between aarch64 and x86.

Copy link

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we override things from the caller side?

@aditew01
Copy link
Author

Can we override things from the caller side?

Maybe we can, but if it's for this specific case, that'll imply there will be two sets of logic for bf16-f32 and f16 (or the cases which are still there) .
Unless we want to refactor the code (we can open a issue and look at the problem broadly) , it's better to have specific logic in do_prepare call ?

@jgong5
Copy link

jgong5 commented Nov 18, 2024

Can we override things from the caller side?

Maybe we can, but if it's for this specific case, that'll imply there will be two sets of logic for bf16-f32 and f16 (or the cases which are still there) . Unless we want to refactor the code (we can open a issue and look at the problem broadly) , it's better to have specific logic in do_prepare call ?

I guess I have mentioned my major concern of inconsistency of the semantics with the change between aarch64 and x86. I'm list the differences with the tables below:
aarch64

Source Type (src) Destination Type (dst) Destination Data Type (dst_datatype)
bf16 fp32 fp32
bf16 not fp32 bf16
fp16 any fp16
other any fp32

x86

Source Type (src) Destination Type (dst) Destination Data Type (dst_data_type)
bf16 any bf16
fp16 any fp16
other any fp32

How can we make them aligned?

@aditew01
Copy link
Author

aditew01 commented Nov 18, 2024

Apology if this is repetitive.
That will be tricky to test as well, right? This is a very specific change and if we were to align x86 in a similar fashion, we may loose on perf if there's no specific kernel of the format available.
If it's for the semantics to align, I'm not sure where we can do that. Even if we push this code higher (change to caller) , eg:

do_prepare</*with_bias=*/false>(param, src, weights, bias, dst, dst_coeff, sum_coeff,

the logic will be similar right or do you think that's a better place to update design-wise?
cc: @jgong5

@jgong5
Copy link

jgong5 commented Nov 19, 2024

Apology if this is repetitive. That will be tricky to test as well, right? This is a very specific change and if we were to align x86 in a similar fashion, we may loose on perf if there's no specific kernel of the format available. If it's for the semantics to align, I'm not sure where we can do that. Even if we push this code higher (change to caller) , eg:

do_prepare</*with_bias=*/false>(param, src, weights, bias, dst, dst_coeff, sum_coeff,

the logic will be similar right or do you think that's a better place to update design-wise?
cc: @jgong5

Frankly speaking, I don't quite understand why the datatype semantics has to be different across different CPU archs. It is not just about the questions of runtime efficiency but also about the precisions and accuracy that are visible to users. Can you explain why we have to make things different here?

@aditew01
Copy link
Author

aditew01 commented Nov 19, 2024

Apology if this is repetitive. That will be tricky to test as well, right? This is a very specific change and if we were to align x86 in a similar fashion, we may loose on perf if there's no specific kernel of the format available. If it's for the semantics to align, I'm not sure where we can do that. Even if we push this code higher (change to caller) , eg:

do_prepare</*with_bias=*/false>(param, src, weights, bias, dst, dst_coeff, sum_coeff,

the logic will be similar right or do you think that's a better place to update design-wise?
cc: @jgong5

Frankly speaking, I don't quite understand why the datatype semantics has to be different across different CPU archs. It is not just about the questions of runtime efficiency but also about the precisions and accuracy that are visible to users. Can you explain why we have to make things different here?

The scaled-dot-product-attention op implemented in PyTorch calls cpublas::gemm. For dtype::bf16 , the gemm operator takes input mat in bf16 and returns a fp32 ref. Pytorch code: https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L750

For x86, it's calling the underlying MKL kernel: mkl_gemm_bf16bf16f32 https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/CPUBlas.cpp#L420

The logic implemented here enables oneDNN to pick the ACL kernel. I hope this makes sense.

@jgong5
Copy link

jgong5 commented Nov 20, 2024

The scaled-dot-product-attention op implemented in PyTorch calls cpublas::gemm. For dtype::bf16 , the gemm operator takes input mat in bf16 and returns a fp32 ref. Pytorch code: https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L750

For x86, it's calling the underlying MKL kernel: mkl_gemm_bf16bf16f32 https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/CPUBlas.cpp#L420

The logic implemented here enables oneDNN to pick the ACL kernel. I hope this makes sense.

Yes, that makes sense. It means the precisions do keep the same between aarch64 and x86 for SDPA. And I also understand that we need the ideep API to handle two scenarios: one that returns the same data type as the input and one that returns fp32. We can extend the semantics of the API to support fp32 but I want the ideep API designed in a way to behave the same between aarch64 and x86, i.e., same data type mapping between input and output. I see two options (please comment if you have other ideas):

  1. Let the caller to specify that it wants fp32 output instead of following the input data type. It can keep the original caller unchanged and not BC breaking. To me, it is a good choice. How are you going to invoke the onednn from SDPA? Can you specify that you want fp32 output from the caller?
  2. Change the behavior of the API to return fp32 by default on some reasonable conditions. It is BC breaking. We need to double check the existing caller to see if it would break changes.

@aditew01
Copy link
Author

The scaled-dot-product-attention op implemented in PyTorch calls cpublas::gemm. For dtype::bf16 , the gemm operator takes input mat in bf16 and returns a fp32 ref. Pytorch code: https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L750
For x86, it's calling the underlying MKL kernel: mkl_gemm_bf16bf16f32 https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/CPUBlas.cpp#L420
The logic implemented here enables oneDNN to pick the ACL kernel. I hope this makes sense.

Yes, that makes sense. It means the precisions do keep the same between aarch64 and x86 for SDPA. And I also understand that we need the ideep API to handle two scenarios: one that returns the same data type as the input and one that returns fp32. We can extend the semantics of the API to support fp32 but I want the ideep API designed in a way to behave the same between aarch64 and x86, i.e., same data type mapping between input and output. I see two options (please comment if you have other ideas):

  1. Let the caller to specify that it wants fp32 output instead of following the input data type. It can keep the original caller unchanged and not BC breaking. To me, it is a good choice. How are you going to invoke the onednn from SDPA? Can you specify that you want fp32 output from the caller?
  2. Change the behavior of the API to return fp32 by default on some reasonable conditions. It is BC breaking. We need to double check the existing caller to see if it would break changes.

Thanks for the inputs. Alternatively, would it not be better to enable this for for both aarch64 and x86. I believe that'll be the best way handle this. This will align to the already existing MKL kernel which is already being called. It'll ensure the same precision for specific operators like SDPA ?
https://github.com/intel/ideep/pull/343#issuecomment-2482833306

The above suggested mechanism works, but in this case the caller is the descriptors which ideep generates right? And while setting the src_ and dst_ data_type here, we enable the respective onednn kernels to make dispatch decisions based on this ? Reference:
https://github.com/oneapi-src/oneDNN/blob/d94cc8d4fbed06867ed3bebb04ac91573175ebfa/src/cpu/aarch64/matmul/acl_matmul.cpp#L79

@jgong5
Copy link

jgong5 commented Nov 21, 2024

Thanks for the inputs. Alternatively, would it not be better to enable this for for both aarch64 and x86. I believe that'll be the best way handle this.

That sounds good to me. We'd better not break existing code though.

Copy link

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aditew01 Please let me know if you have any questions/concerns. Please tag me for review after you make the corresponding changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants