Skip to content

Commit

Permalink
Add progress bar argument (#107)
Browse files Browse the repository at this point in the history
* Add progress bar argument

* Sort imports
  • Loading branch information
araffin authored Oct 10, 2022
1 parent e9c9794 commit 52795a3
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 48 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages

Bug Fixes:
^^^^^^^^^^
Expand Down
12 changes: 11 additions & 1 deletion sb3_contrib/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def learn(
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
async_eval: Optional[AsyncEval] = None,
progress_bar: bool = False,
) -> ARSSelf:
"""
Return a trained model.
Expand All @@ -333,11 +334,20 @@ def learn(
:param eval_log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param async_eval: The object for asynchronous evaluation of candidates.
:param progress_bar: Display a progress bar using tqdm and rich.
:return: the trained model
"""

total_steps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
)

callback.on_training_start(locals(), globals())
Expand Down
14 changes: 12 additions & 2 deletions sb3_contrib/ppo_mask/ppo_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from gym import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
Expand Down Expand Up @@ -184,6 +184,7 @@ def _init_callback(
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
use_masking: bool = True,
progress_bar: bool = False,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
Expand All @@ -196,6 +197,7 @@ def _init_callback(
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param use_masking: Whether or not to use invalid action masks during evaluation
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
Expand All @@ -206,6 +208,10 @@ def _init_callback(
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)

# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])

# Create eval callback in charge of the evaluation
if eval_env is not None:
# Avoid circular import error
Expand Down Expand Up @@ -236,6 +242,7 @@ def _setup_learn(
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
use_masking: bool = True,
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
Expand All @@ -253,6 +260,7 @@ def _setup_learn(
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:param use_masking: Whether or not to use invalid action masks during training
:param progress_bar: Display a progress bar using tqdm and rich.
:return:
"""

Expand Down Expand Up @@ -299,7 +307,7 @@ def _setup_learn(
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)

# Create eval callback if needed
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking)
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar)

return total_timesteps, callback

Expand Down Expand Up @@ -563,6 +571,7 @@ def learn(
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
use_masking: bool = True,
progress_bar: bool = False,
) -> MaskablePPOSelf:
iteration = 0

Expand All @@ -576,6 +585,7 @@ def learn(
reset_num_timesteps,
tb_log_name,
use_masking,
progress_bar,
)

callback.on_training_start(locals(), globals())
Expand Down
54 changes: 11 additions & 43 deletions sb3_contrib/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import time
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Optional, Type, TypeVar, Union

import gym
import numpy as np
Expand Down Expand Up @@ -198,47 +198,6 @@ def _setup_model(self) -> None:

self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

def _setup_learn(
self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "RecurrentPPO",
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param eval_env: Environment to use for evaluation.
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param callback: Callback(s) called at every step with state of the algorithm.
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little).
Caution, this parameter is deprecated and will be removed in the future.
Please use `EvalCallback` or a custom Callback instead.
:param n_eval_episodes: How many episodes to play per evaluation
:param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:return:
"""

total_timesteps, callback = super()._setup_learn(
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
log_path,
reset_num_timesteps,
tb_log_name,
)
return total_timesteps, callback

def collect_rollouts(
self,
env: VecEnv,
Expand Down Expand Up @@ -500,11 +459,20 @@ def learn(
tb_log_name: str = "RecurrentPPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> RecurrentPPOSelf:
iteration = 0

total_timesteps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
total_timesteps,
eval_env,
callback,
eval_freq,
n_eval_episodes,
eval_log_path,
reset_num_timesteps,
tb_log_name,
progress_bar,
)

callback.on_training_start(locals(), globals())
Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def learn(
tb_log_name: str = "QRDQN",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> QRDQNSelf:

return super().learn(
Expand All @@ -274,6 +275,7 @@ def learn(
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

def _excluded_save_params(self) -> List[str]:
Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def learn(
tb_log_name: str = "TQC",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TQCSelf:

return super().learn(
Expand All @@ -311,6 +312,7 @@ def learn(
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

def _excluded_save_params(self) -> List[str]:
Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/trpo/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def learn(
tb_log_name: str = "TRPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> TRPOSelf:

return super().learn(
Expand All @@ -427,4 +428,5 @@ def learn(
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)
2 changes: 1 addition & 1 deletion tests/test_invalid_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_callback(tmp_path):
model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1)
model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path))

model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False))
model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False), progress_bar=True)


def test_child_callback():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_tqc(ent_coef):
create_eval_env=True,
ent_coef=ent_coef,
)
model.learn(total_timesteps=300, eval_freq=250)
model.learn(total_timesteps=300, eval_freq=250, progress_bar=True)


@pytest.mark.parametrize("n_critics", [1, 3])
Expand Down

0 comments on commit 52795a3

Please sign in to comment.