Skip to content

Commit

Permalink
chore: more torch optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed May 26, 2022
1 parent 8ca5acf commit cdab9e3
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/backends/torch/torchsolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,28 @@ namespace dd
_params, torch::optim::AdamOptions(_base_lr)
.betas(std::make_tuple(_beta1, _beta2))
.weight_decay(_weight_decay)));
this->_logger->info("base_lr: {}", _base_lr);
}
else if (_solver_type == "ADAMW")
{
_optimizer
= std::unique_ptr<torch::optim::Optimizer>(new torch::optim::AdamW(
_params, torch::optim::AdamWOptions(_base_lr)
.betas(std::make_tuple(_beta1, _beta2))
.weight_decay(_weight_decay)));
}
else if (_solver_type == "RMSPROP")
{
_optimizer = std::unique_ptr<torch::optim::Optimizer>(
new torch::optim::RMSprop(
_params, torch::optim::RMSpropOptions(_base_lr).weight_decay(
_weight_decay)));
this->_logger->info("base_lr: {}", _base_lr);
}
else if (_solver_type == "ADAGRAD")
{
_optimizer = std::unique_ptr<torch::optim::Optimizer>(
new torch::optim::Adagrad(
_params, torch::optim::AdagradOptions(_base_lr).weight_decay(
_weight_decay)));
this->_logger->info("base_lr: {}", _base_lr);
}
else if (_solver_type == "RANGER" || _solver_type == "RANGER_PLUS")
{
Expand All @@ -131,7 +136,6 @@ namespace dd
.adamp(_adamp)
.lsteps(_lsteps)
.lalpha(_lalpha)));
this->_logger->info("base_lr: {}", _base_lr);
this->_logger->info("beta_1: {}", _beta1);
this->_logger->info("beta_2: {}", _beta2);
this->_logger->info("weight_decay: {}", _weight_decay);
Expand Down Expand Up @@ -162,7 +166,6 @@ namespace dd
.lookahead(_lookahead)
.lsteps(_lsteps)
.lalpha(_lalpha)));
this->_logger->info("base_lr: {}", _base_lr);
this->_logger->info("momentum: {}", _momentum);
this->_logger->info("weight_decay: {}", _weight_decay);
this->_logger->info("lookahead: {}", _lookahead);
Expand All @@ -180,7 +183,6 @@ namespace dd
_optimizer
= std::unique_ptr<torch::optim::Optimizer>(new torch::optim::SGD(
_params, torch::optim::SGDOptions(_base_lr)));
this->_logger->info("base_lr: {}", _base_lr);
}
this->_logger->info("clip: {}", _clip);
if (_clip)
Expand All @@ -199,6 +201,8 @@ namespace dd
}
if (_sam)
this->_logger->info("using Sharpness Aware Minimization (SAM)");
this->_logger->info("using optimizer " + _solver_type);
this->_logger->info("base_lr: {}", _base_lr);
}

void TorchSolver::sam_first_step()
Expand Down Expand Up @@ -417,6 +421,14 @@ namespace dd
options.betas(std::make_tuple(_beta1, _beta2));
options.weight_decay(_weight_decay);
}
else if (_solver_type == "ADAMW")
{
auto &options = static_cast<torch::optim::AdamWOptions &>(
param_group.options());
options.lr(_base_lr);
options.betas(std::make_tuple(_beta1, _beta2));
options.weight_decay(_weight_decay);
}
else if (_solver_type == "RMSPROP")
{
auto &options = static_cast<torch::optim::RMSpropOptions &>(
Expand Down

0 comments on commit cdab9e3

Please sign in to comment.