Skip to content

Commit

Permalink
Update README.md and setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-Y authored May 6, 2024
1 parent dd2e08e commit 74dd8eb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
34 changes: 27 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ This library contains PyTorch implementations of the warmup schedules described

## Installation

Make sure you have Python 3.6+ and PyTorch 1.1+. Then, run the following command:
Make sure you have Python 3.7+ and PyTorch 1.1+. Then, run the following command in the project directory:

```
python setup.py install
python -m pip install .
```

or
or install the latest version from the Python Package Index:

```
pip install -U pytorch_warmup
Expand Down Expand Up @@ -92,12 +92,12 @@ When the learning rate schedule uses the epoch number, the warmup schedule can b
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[num_epochs//3], gamma=0.1)
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
for epoch in range(1,num_epochs+1):
for iter, batch in enumerate(dataloader):
for i, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
if iter < len(dataloader)-1:
if i < len(dataloader)-1:
with warmup_scheduler.dampening():
pass
with warmup_scheduler.dampening():
Expand All @@ -108,16 +108,36 @@ This code can be rewritten more compactly:

```python
for epoch in range(1,num_epochs+1):
for iter, batch in enumerate(dataloader):
for i, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
with warmup_scheduler.dampening():
if iter + 1 == len(dataloader):
if i + 1 == len(dataloader):
lr_scheduler.step()
```

#### Approach 3
When you use `CosineAnnealingWarmRestarts`, the warmup schedule can be used as follows:

```python
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
warmup_period = 2000
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period)
iters = len(dataloader)
warmup_epochs = ... # for example, (warmup_period + iters - 1) // iters
for epoch in range(epochs+warmup_epochs):
for i, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = ...
loss.backward()
optimizer.step()
with warmup_scheduler.dampening():
if epoch >= warmup_epochs:
lr_scheduler.step(epoch-warmup_epochs + i / iters)
```

### Warmup Schedules

#### Manual Warmup
Expand Down
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,22 @@
url="https://github.com/Tony-Y/pytorch_warmup",
packages=['pytorch_warmup'],
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Operating System :: OS Independent",
"License :: OSI Approved :: MIT License",
],
Expand Down

0 comments on commit 74dd8eb

Please sign in to comment.