Skip to content

Commit

Permalink
Merge pull request #23 from bhavnicksm/main
Browse files Browse the repository at this point in the history
Add BaseOptimizer, RMSProp, Adagrad, Adamax, Adadelta
  • Loading branch information
bhavnicksm authored Mar 4, 2023
2 parents 5ac061c + fbf3ff0 commit 3a9719c
Show file tree
Hide file tree
Showing 13 changed files with 466 additions and 184 deletions.
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ PyTorch is a popular machine learning framework that provides a flexible and eff

## Supported Optimisers

| Optimiser | Paper |
|:---------: |:-----: |
| **SGD** | |
| **Adam** | |
| Optimiser | Paper |
|:---------: |:-----: |
| **SGD** | https://paperswithcode.com/method/sgd |
| **Momentum** | https://paperswithcode.com/method/sgd-with-momentum |
| **Adagrad** | https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf |
| **RMSProp** | https://paperswithcode.com/method/rmsprop |
| **Adam** | https://arxiv.org/abs/1412.6980v9 |
| **Adamax** | https://arxiv.org/abs/1412.6980v9 |
| **Adadelta** | https://arxiv.org/abs/1212.5701v1 |




## Installation
Expand Down
31 changes: 17 additions & 14 deletions examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torch'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/ec3dev/personal/github/nadir/examples/example.ipynb Cell 1\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/ec3dev/personal/github/nadir/examples/example.ipynb#W0sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/ec3dev/personal/github/nadir/examples/example.ipynb#W0sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39margparse\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/ec3dev/personal/github/nadir/examples/example.ipynb#W0sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mwandb\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ec3dev/personal/github/pyenv/nadir/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
Expand All @@ -34,15 +31,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from nadir import tests\n",
"from nadir import nadir as optim\n",
"from nadir.tests import mnist"
"import nadir as nd\n",
"from nadir import SGD, SGDConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -138,9 +141,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "nadir",
"language": "python",
"name": "python3"
"name": "nadir"
},
"language_info": {
"codemirror_mode": {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "nadir"
version = "0.0.1.dev1"
version = "0.0.1"
authors = [
{ name="Bhavnick Minhas", email="bhavnicksm@gmail.com" },
]
Expand Down
83 changes: 0 additions & 83 deletions src/nadir/SGD.py

This file was deleted.

30 changes: 26 additions & 4 deletions src/nadir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,33 @@

from typing import Dict, List, Type
from torch.optim.optimizer import Optimizer
from .BaseOptimiser import BaseOptimizer, BaseConfig

from .SGD import SGD, SGDConfig

from .adadelta import Adadelta, AdadeltaConfig
from .adagrad import Adagrad, AdagradConfig
from .adam import Adam, AdamConfig
from .adamax import Adamax, AdamaxConfig
from .base import BaseOptimizer, BaseConfig
from .momentum import Momentum, MomentumConfig
from .rmsprop import RMSProp, RMSPropConfig
from .sgd import SGD, SGDConfig


__all__ = ('SGD', 'SGDConfig', 'Adam', 'AdamConfig')
__version__ = "0.0.1"

__version__ = "0.0.1.dev2"
__all__ = ('Adadelta',
'AdadeltaConfig',
'Adagrad',
'AdagradConfig',
'Adam',
'AdamConfig',
'Adamax',
'AdamaxConfig',
'BaseOptimizer',
'BaseConfig',
'Momentum',
'MomentumConfig',
'RMSProp',
'RMSPropConfig',
'SGD',
'SGDConfig')
60 changes: 60 additions & 0 deletions src/nadir/adadelta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
### Copyright 2023 [Dawn Of Eve]

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Any, Optional
from dataclasses import dataclass

import torch

from .base import BaseOptimizer
from .base import BaseConfig


__all__ = ['AdadeltaConfig', 'Adadelta']

@dataclass
class AdadeltaConfig(BaseConfig):
lr : float = 1
adaptive : bool = True
rho : float = 0.90
beta_2 : float = 0.90
eps : float = 1E-6


class Adadelta(BaseOptimizer):
def __init__ (self, params, config : AdadeltaConfig = AdadeltaConfig()):
super().__init__(params, config)

self.config = config
if self.config.rho != self.config.beta_2:
self.config.beta_2 = self.config.rho

def init_state(self, state, group, param):
state['adaptive_step'] = 0
state['adaptivity'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state['acc_delta'] = torch.zeros_like(param, memory_format=torch.preserve_format)

def update(self, state, group, grad, param):
eps = self.config.eps
rho = self.config.rho
lr = group['lr']
m = state['acc_delta']

denom = self.adaptivity(state, grad)

delta = m.add(eps).sqrt_().div_(denom).mul_(grad)

param.data.add_(delta, alpha = -1 * lr)

m.mul_(rho).addcmul_(delta, delta, value=(1 - rho))
state['acc_delta'] = m
48 changes: 48 additions & 0 deletions src/nadir/adagrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
### Copyright 2023 [Dawn Of Eve]

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Any, Optional
from dataclasses import dataclass

import torch

from .base import BaseOptimizer
from .base import BaseConfig


__all__ = ['AdagradConfig', 'Adagrad']

@dataclass
class AdagradConfig(BaseConfig):
lr : float = 1E-3
adaptive : bool = True
eps : float = 1E-8

class Adagrad(BaseOptimizer):

def __init__ (self, params, config : AdagradConfig = AdagradConfig()):
if not config.adaptive:
raise ValueError(f"Invalid value for adaptive in config: {config.adaptive} ",
"Value must be True")
super().__init__(params, config)
self.config = config

def adaptivity(self,
state,
grad):

v = state['adaptivity']
v.add_(torch.pow(grad, 2))
state['adaptivity'] = v

return torch.sqrt(v + self.config.eps)
Loading

0 comments on commit 3a9719c

Please sign in to comment.