Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP Training with Sentence Transformer #3023

Open
ShengYun-Peng opened this issue Oct 27, 2024 · 9 comments
Open

FSDP Training with Sentence Transformer #3023

ShengYun-Peng opened this issue Oct 27, 2024 · 9 comments

Comments

@ShengYun-Peng
Copy link

Given there are so many LLM-based models on top of MTEB benchmark nowadays, is there a canonical way to train with FSDP now? I'm trying to explore along this direction, but I just want to ask if there already exists some examples before I rebuild the wheel.

@ShengYun-Peng ShengYun-Peng changed the title FSDP Training FSDP Training with Sentence Transformer Oct 27, 2024
@ShengYun-Peng
Copy link
Author

I took a stab on training with FSDP, and encountered quite a few issues: huggingface/accelerate#3201

@tomaarsen
Copy link
Collaborator

Hello!

There are some details here for me to get it running originally: https://sbert.net/docs/sentence_transformer/training/distributed.html#fsdp
But I stopped trying to get a neat and convenient integration once I realised that DDP outperformed FSDP for most small models. I'm definitely open to improving on it though.

  • Tom Aarsen

@ShengYun-Peng
Copy link
Author

Thanks! I am working on this direction now and would like to hear your input! While you subclass the transformer trainer class and create the sentence transformer trainer, is there a guideline that you follow to write the customized trainer? I notice that you overwrite the compute loss, prepare inputs, and other class methods. Are you following some template or guidelines or just check the trainer source code line by line to make changes?

@tomaarsen
Copy link
Collaborator

I don't really check it line-by-line, but I'm somewhat familiar with the overall structure of the transformers Trainer. It's set up in quite a modular way, which means that it's rather feasible to subclass some "high level" methods like compute_loss and get_train_dataloader while leaving lower level methods like training_step, _inner_training_loop, etc. intact.

That's why the Sentence Transformers trainer file is only ~900 lines long, compared to ~5k for the base Trainer.

@ShengYun-Peng
Copy link
Author

Hi @tomaarsen, if a model is wrapped, can we directly update the model in the loss function with loss_fn.model = self.model here before calling the loss_fn.forward? Basically, I'm curious about the purpose of override_model_in_loss method in the sentence transformer trainer

@ShengYun-Peng
Copy link
Author

Based on my experiments, the evaluator cannot work out-of-the-box with fsdp, and it keeps throwing RuntimeError: 'weight' must be 2-D. I also recalled the doc said evaluator didn't work with fsdp. I'm curious why that is the case.

@ShengYun-Peng
Copy link
Author

I have successfully finetuned llama3 for text embedding with FSDP and sentence-transformer with some modifications.

@ShengYun-Peng
Copy link
Author

The core issue is to make the model in the loss function be aware of the FSDP setting. I may create a new PR if necessary.

@tomaarsen
Copy link
Collaborator

Hi @tomaarsen, if a model is wrapped, can we directly update the model in the loss function with loss_fn.model = self.model here before calling the loss_fn.forward? Basically, I'm curious about the purpose of override_model_in_loss method in the sentence transformer trainer

Apologies for the delay! The override_model_in_loss method is necessary because the losses in Sentence Transformers are a bit unusual: they are torch.nn.Module subclasses that are provided the model as an attribute. So, when trainer.model is wrapped/compiled, the loss.model isn't. As a result, when we call loss(features, labels), the loss just calls the original model.

This is why we have to override the original loss.model if the Trainer wrapped the model, so that the actual inference happens with the wrapped model.

As for the evaluator: FSDP splits the model up into pieces and separates it across the various devices. This is happening inside of the transformers Trainer code, which is fairly advanced. The evaluator on the other hand is much simpler, and lives in the Sentence Transformers project exclusively. It's only calculated on the first process to avoid running the same evaluations multiple times:

with nullcontext() if self.is_local_process_zero() else disable_logging(logging.INFO):
evaluator_metrics = self.evaluator(self.model)

Ideally, we'd have e.g. DDP support here where we can split the computations for the evaluator across devices, but that hasn't been implemented. FSDP would indeed be even nicer, but a lot more complex as well. Either way, the evaluator breaks with FSDP because we only run it on the first process.

Nice work on getting it to work! I'm open to PRs.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants