-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cpu: aarch64: Enable matmul bf16f32 format desc #343
base: ideep_pytorch
Are you sure you want to change the base?
Conversation
cc: @yanbing-j can you please help with the review ? |
7a10d7a
to
4a026ff
Compare
Request @jgong5 for review. |
dst_data_type = src.get_data_type() == data_type::bf16 | ||
? ((dst.get_data_type() == data_type::f32) ? data_type::f32 | ||
: data_type::bf16) | ||
: ((src.get_data_type() == data_type::f16) ? data_type::f16 | ||
: data_type::f32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The concern is that how this would impact the original callers since it is a BC breaking changes. @yanbing-j can you double check the existing callers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current implementation, the existing callers are with same dtypes (src and dst), for example, https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LinearAlgebra.cpp#L1531-L1543 by using reproducer
mat1 = torch.randn(20, 30).to(torch.bfloat16)
mat2 = torch.randn(30, 30).to(torch.bfloat16)
torch.matmul(mat1, mat2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a broader context, this will allow to pick up the right kernel from oneDNN for matmul where src: bf16
, wei:bf16
and dst:f32
.
https://github.com/pytorch/pytorch/blob/1185975c6e7f23f0cb318e47166fceebd2995c98/aten/src/ATen/native/CPUBlas.cpp#L400
For the current Operator that will use this will be: scaled_dot_product_attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aditew01 Sorry, I cannot find the calling of mkldnn_bf16f32_gemm
in the main branch https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/CPUBlas.cpp#L405. Could you try to update your code base?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aditew01 Sorry, I cannot find the calling of
mkldnn_bf16f32_gemm
in the main branch https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/CPUBlas.cpp#L405. Could you try to update your code base?
Hi @yanbing-j , the change which I referred to is in a PR (draft) which I've raised in PyTorch. It's draft and waits on specific onednn and ideep changes.
include/ideep/operators/matmul.hpp
Outdated
src_desc = (src.get_data_type() == data_type::bf16 && | ||
dst_data_type == data_type::f32) | ||
? src.get_desc() | ||
: src.get_desc().to_type(dst_data_type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the logic is the same from both paths, suggest to avoid dup here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a question here, why this logic will be applied outside __aarch64
macro? Although it will get the same result when src and dst are bf16, it is unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the logic is the same from both paths, suggest to avoid dup here.
I reverted this, it should only be for aarch64. Pushed a revert which will preserve the previous state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this incur different semantics for the API when it works on different CPU archs? Normally, we shouldn't design an API like that way. Is it possible to add an extra arg to specify the behavior you added (dst type does not follow src)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case I think this logic was generally used before. I added a diff condition very specific for aarch64
Would this incur different semantics for the API when it works on different CPU archs?
ideep/include/ideep/operators/matmul.hpp
Line 825 in 3215f5f
dst_data_type = src.get_data_type() == data_type::bf16 ? |
I shifted the logic in
ideep_pytorch
which was generically for all arch to when cpu_arch != aarch64.Please let me know you thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That means the semantics are different between aarch64 and x86, right? I found that the arg dst_type
is not used. Can we use it to override the dst_data_type
to make sure the consistent behavior between aarch64 and x86? Of course, we have to redefine the default value for dst_type
(it is now fp32 which doesn't seem right for the default behavior).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I do understand.
I have updated the dst_type default value to undef
. The logic which uses dst_type
in this context here:
ideep/include/ideep/operators/matmul.hpp
Line 884 in c54a3ed
dst_data_type = dst_type == data_type::undef ? dst_data_type : dst_type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, can we specify the dst_type
from the caller side? You don't have to use different logic to initialize dst_data_type
on aarch64 and x64?
4a026ff
to
c576e78
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to make the semantics consistent between aarch64 and x86.
c576e78
to
c54a3ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we override things from the caller side?
Maybe we can, but if it's for this specific case, that'll imply there will be two sets of logic for |
I guess I have mentioned my major concern of inconsistency of the semantics with the change between aarch64 and x86. I'm list the differences with the tables below:
x86
How can we make them aligned? |
Apology if this is repetitive. ideep/include/ideep/operators/matmul.hpp Line 289 in c54a3ed
the logic will be similar right or do you think that's a better place to update design-wise? cc: @jgong5 |
Frankly speaking, I don't quite understand why the datatype semantics has to be different across different CPU archs. It is not just about the questions of runtime efficiency but also about the precisions and accuracy that are visible to users. Can you explain why we have to make things different here? |
The scaled-dot-product-attention op implemented in PyTorch calls For x86, it's calling the underlying MKL kernel: The logic implemented here enables oneDNN to pick the ACL kernel. I hope this makes sense. |
Yes, that makes sense. It means the precisions do keep the same between aarch64 and x86 for SDPA. And I also understand that we need the ideep API to handle two scenarios: one that returns the same data type as the input and one that returns fp32. We can extend the semantics of the API to support fp32 but I want the ideep API designed in a way to behave the same between aarch64 and x86, i.e., same data type mapping between input and output. I see two options (please comment if you have other ideas):
|
Thanks for the inputs. Alternatively, would it not be better to enable this for for both aarch64 and x86. I believe that'll be the best way handle this. This will align to the already existing The above suggested mechanism works, but in this case the caller is the descriptors which ideep generates right? And while setting the |
That sounds good to me. We'd better not break existing code though. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aditew01 Please let me know if you have any questions/concerns. Please tag me for review after you make the corresponding changes.
This will allow generating desc for matmul of format
src:bf16; wei:bf16; dst:f32
instead of reorderingdst
tobf16
and back tof32
.These kernels are directly used by
cpu_flash_attention
(sdpa).