Skip to content

Commit

Permalink
Add cautious mars, improve test reliability by skipping grad diff for…
Browse files Browse the repository at this point in the history
… first step
  • Loading branch information
rwightman committed Dec 2, 2024
1 parent 82e8677 commit 303f769
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 36 deletions.
2 changes: 2 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def test_optim_factory(optimizer):
lr = (1e-2,) * 4
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
lr = (1e-3,) * 4
elif optimizer in ('cmars',):
lr = (1e-4,) * 4

try:
if not opt_info.second_order: # basic tests don't support second order right now
Expand Down
7 changes: 7 additions & 0 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True,
defaults = {'caution': True}
),
OptimInfo(
name='cmars',
opt_class=Mars,
description='Cautious MARS',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='cnadamw',
opt_class=NAdamW,
Expand Down
96 changes: 60 additions & 36 deletions timm/optim/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,50 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Optional, Tuple

import torch
from torch.optim.optimizer import Optimizer


def mars_single_tensor(
p,
grad,
exp_avg,
exp_avg_sq,
lr,
weight_decay,
beta1,
beta2,
last_grad,
eps,
step,
gamma,
mars_type,
is_grad_2d,
optimize_1d,
lr_1d_factor,
betas_1d,
from ._types import ParamsT


def _mars_single_tensor_step(
p: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
exp_avg_sq: torch.Tensor,
lr: float,
weight_decay: float,
beta1: float,
beta2: float,
last_grad: torch.Tensor,
eps: float,
step: int,
gamma: float,
mars_type: str,
is_grad_2d: bool,
optimize_1d: bool,
lr_1d_factor: bool,
betas_1d: Tuple[float, float],
caution: bool,
):
# optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
# optimize_1d ==> use MARS for 1d param, else use AdamW
if optimize_1d or is_grad_2d:
one_minus_beta1 = 1. - beta1
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
c_t_norm = torch.norm(c_t)
if c_t_norm > 1.:
c_t = c_t / c_t_norm
if step == 1:
# this is a timm addition, making first step more consistent when no grad history, otherwise tests fail
c_t = grad
else:
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
c_t_norm = torch.norm(c_t)
if c_t_norm > 1.:
c_t = c_t / c_t_norm
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
if caution:
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
if mars_type == "adamw":
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
bias_correction1 = 1.0 - beta1 ** step
Expand All @@ -64,6 +76,10 @@ def mars_single_tensor(
bias_correction1 = 1.0 - beta1_1d ** step
bias_correction2 = 1.0 - beta2_1d ** step
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
if caution:
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
p.add_(update, alpha=-(lr * lr_1d_factor))
return exp_avg, exp_avg_sq
Expand All @@ -78,16 +94,17 @@ class Mars(Optimizer):
"""
def __init__(
self,
params,
lr=3e-3,
betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0.,
gamma=0.025,
mars_type="adamw",
optimize_1d=False,
lr_1d_factor=1.0,
betas_1d=None,
params: ParamsT,
lr: float = 3e-3,
betas: Tuple[float, float] = (0.9, 0.99),
eps: float = 1e-8,
weight_decay: float = 0.,
gamma: float = 0.025,
mars_type: str = "adamw",
optimize_1d: bool = False,
lr_1d_factor: float = 1.0,
betas_1d: Optional[Tuple[float, float]] = None,
caution: bool = False
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
Expand All @@ -109,9 +126,15 @@ def __init__(
optimize_1d=optimize_1d,
lr_1d_factor=lr_1d_factor,
betas_1d=betas_1d or betas,
caution=caution,
)
super(Mars, self).__init__(params, defaults)

def __setstate__(self, state):
super(Mars, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Expand All @@ -134,7 +157,6 @@ def step(self, closure=None):
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

state = self.state[p]
# ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
# State initialization
if len(state) <= 1:
state['step'] = 0
Expand All @@ -155,7 +177,8 @@ def step(self, closure=None):
beta1, beta2 = group['betas']
is_grad_2d = grad.ndim >= 2

mars_single_tensor(
# FIXME add multi-tensor (if usage warrants), make more standard
_mars_single_tensor_step(
p,
grad,
exp_avg,
Expand All @@ -173,6 +196,7 @@ def step(self, closure=None):
optimize_1d=group['optimize_1d'],
lr_1d_factor=group['lr_1d_factor'],
betas_1d=group['betas_1d'],
caution=group['caution'],
)

state['last_grad'] = grad
Expand Down

0 comments on commit 303f769

Please sign in to comment.