Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom_kernel: fix shape mismatch by sharding segment_ids in flash attn. #8333

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

dudulightricks
Copy link
Contributor

@dudulightricks dudulightricks commented Oct 29, 2024

Description: This PR addresses an issue where segment_ids were not considered when adding sharding support in this module. The absence of segment_ids handling results in a shape mismatch failure when using them in sharded Flash Attention.

Edit:
During training with dummy data using this fix, the loss stalls at 0.2 and does not converge to 0 as expected. Further adjustments are needed to resolve this convergence issue.

Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Looks like your use case runs into computational inaccuracies which may suggest enable_manual_sharding API calls need correction. I suggest adding a test case to further verify / debug the issue in a small example.

Here is a reference for kernel tests you can refer to.

@dudulightricks dudulightricks force-pushed the bug-fix/shard-segment-ids-flash-attention branch 27 times, most recently from 4ba9067 to 0608900 Compare November 21, 2024 15:11
@dudulightricks
Copy link
Contributor Author

@miladm Hi! We have added a test that currently fails (16% of the values are correct, the others are not). I hope it will help you understand whats wrong.

@miladm
Copy link
Collaborator

miladm commented Nov 21, 2024

@dudulightricks Thanks for submitting the test code. We had a review of your code internally. It seems if you shard the KV and Q segment_id's the code won't attend the query to all kv elements in the matmul - hence the numerical inconsistency. Have you tried sharding the query segment_id only?

@dudulightricks
Copy link
Contributor Author

@miladm I just did and the test still fails, but why would something like this happen anyway? We are sharding the model and the data and expect consistency in the results in any sharding case. Can't we trust the result in any sharding case?

@miladm miladm added the pallas label Nov 21, 2024
when adding the sharding support in this module, seqment_ids weren't
take into count which causes a failure with shape mismatch when using
them in sharded flash attention.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants