Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cosine LR scheduler for warmup #2312

Closed
wants to merge 2 commits into from

Conversation

sinahmr
Copy link
Contributor

@sinahmr sinahmr commented Oct 24, 2024

I noticed that when using cosine scheduler with warmup (and warmup_prefix = True), the LR will not reach lr_min, which can be problematic the larger warmup_t is. For example, for epochs, initial_lr, lr_min, warmup_lr, warmup_t = 500, 1e-4, 1e-7, 1e-7, 100, we will have the following progression for LR:
default

I propose to change the code in a way to generate the following:
proposed

Hope I have changed the correct lines in the code.

@rwightman
Copy link
Collaborator

@sinahmr this does look like a valid thing to fix, just pondering how to approach re backwards compatibility of old hparam sets... similar to warmup prefix.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@rwightman
Copy link
Collaborator

@sinahmr so I looked at this more closely, it's a bit messy. Your fix improved the behaviour you were looking for w/ warmup_prefix=True, but it made warmup_prefix=False worse. It also appears to make the cycles unworkable.

The current state of things actually works fine IF you run for warmup_epochs + num_epochs. So, I was thinking the most sensible fix is to adjust get_cycle_length() to add warmup epochs/steps if warmup_prefix=True.

@sinahmr
Copy link
Contributor Author

sinahmr commented Nov 7, 2024

@rwightman Thanks for taking the time. You're right, sorry about those mistakes. I updated the code and I think it should resolve those problems. Can you please have a look?
I also tried to fix the cycles, but since I don't have experience in using that, can you please have a closer look to make sure I didn't ruin them?

Below, I provided the plots to compare:
epochs, initial_lr, lr_min, warmup_lr, warmup_t, cycle_mul, cycle_decay, cycle_limit = 500, 1e-4, 1e-7, 1e-7, 100, 0.9, 0.5, 3
plot_2024-11-07 12-50-57_0
plot_2024-11-07 12-50-57_1
plot_2024-11-07 12-50-57_2
plot_2024-11-07 12-50-57_3

@rwightman
Copy link
Collaborator

rwightman commented Nov 7, 2024

@sinahmr I have an alternative PR that I feel addresses the issue adequately, as long as the # of epochs/steps the schedule is run for is extended by the warmup when warmup_prefix=True, the schedule will complete correctly without any additional alterations. See #2325

EDIT: I also feel that extending the schedule to finish vs squishing the first cycle is a less significant change for backwards compat of hparams (does not alter early training), but allows more time to pick off good checkpoints at the end (if train hasn't petered out by then). Hence, no old result would be worse only potentially better.

@sinahmr
Copy link
Contributor Author

sinahmr commented Nov 8, 2024

@rwightman I agree that your proposal is more backward friendly. The only concern is that it might confuse users that the model runs for 330 epochs if they set --epochs 300 --warmup-epochs 30. Maybe it should be documented that if --warmup-prefix is set, the user should adjust the --epochs value manually.
Thanks for taking the time to fix the issue, feel free to close this PR if #2325 is merged.

@rwightman
Copy link
Collaborator

Updated the the log in train, will merge other pr shortly so closing this, thanks!

Scheduled epochs: 28 (warmup_epochs + epochs + cooldown_epochs). Warmup added to total when warmup_prefix=True. LR stepped per update.

Scheduled epochs: 25 (epochs + cooldown_epochs). Warmup within epochs when warmup_prefix=False. LR stepped per update.

@rwightman rwightman closed this Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants