From 74ae95052781f3a5e8349ffeabc24c2b84b1b158 Mon Sep 17 00:00:00 2001 From: Dongxu Li Date: Wed, 21 Jun 2023 02:09:16 +0000 Subject: [PATCH] add constant lr scheduler. --- lavis/common/optims.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/lavis/common/optims.py b/lavis/common/optims.py index 28b68645d..fb7d56647 100644 --- a/lavis/common/optims.py +++ b/lavis/common/optims.py @@ -94,6 +94,28 @@ def step(self, cur_epoch, cur_step): ) +@registry.register_lr_scheduler("constant_lr") +class ConstantLRScheduler: + def __init__(self, optimizer, init_lr, warmup_start_lr=-1, warmup_steps=0, **kwargs): + self.optimizer = optimizer + self.lr = init_lr + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + self.warmup_steps = warmup_steps + + def step(self, cur_epoch, cur_step): + if cur_epoch == 0: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.lr, + ) + else: + for param_group in self.optimizer.param_groups: + param_group["lr"] = self.lr + + def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): """Decay the learning rate""" lr = (init_lr - min_lr) * 0.5 * (