Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vote on new features in Discussions #694

Open
tianyu-l opened this issue Nov 23, 2024 · 6 comments
Open

Vote on new features in Discussions #694

tianyu-l opened this issue Nov 23, 2024 · 6 comments

Comments

@tianyu-l
Copy link
Contributor

Hi torchtitanists,

Thank you for your interests in torchtitan!

We created #693 for the community to add feature requests and vote on them. We'll try to prioritize on the most requested features. Please share what you'd like to see next!

@zigzagcai
Copy link

zigzagcai commented Nov 28, 2024

@tianyu-l

Hi developers,

Firstly, thanks for the great work that can demonstrate the power of PyTorch newly released features!

I just have one confusion about the usage of FSDP2 fully_shard.
Does FSDP2 support mixed precision within one warpping module, such like torch.float32 and torch.bfloat16 within a FSDPParamGroup?

To put it more clear, in most use cases of training LLM such like Lllama2, the precision of RMSNorm is usually torch.float32, but other components within the DecoderLayer is usually torch.bfloat16. When we want to train the model with the help of FSDP2, we have to wrap RMSNorm seperately since it has a seperate dtype, which will introduce additional all-gather and reduce-scatter.

From the profiling results, we found this approach (warpping RMSNorm seperately) will lead to bad computation-communication overlapping, especially in the backward pass.

Apart from that, there are also some other use cases: dtype of MoE gating layers is required to be torch.float32, but other components in the DecoderLayer is torch.bfloat16. We can also found that seperately warpping MoE.GateLayer would cause bad overlapping of computation-communication.

So, does mixed precision within a FSDPParamGroup supported? or could this be a new feature in the future?

Thanks!

@mayank31398
Copy link

@zigzagcai RMSNorm only has activations in fp32, the weights are still bf16.
Also, FSDP2 is quite flexible in having different dtypes for different tensors I believe

@tianyu-l
Copy link
Contributor Author

tianyu-l commented Dec 2, 2024

cc: @awgu

@aniltrkkn
Copy link

it should be simple but

Gradient Accumulation

it is very useful for sfting big models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants