Skip to content

Stable-Baselines3 v2.4.0: New algorithm (CrossQ in SB3-Contrib) and Gymnasium v1.0 support

Latest
Compare
Choose a tag to compare
@araffin araffin released this 18 Nov 10:33
· 4 commits to master since this release
020ee42

Warning

Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
and PyTorch < 2.3.
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.3 (compatible with NumPy v2).

SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo
Stable-Baselines Jax (SBX): https://github.com/araffin/sbx

To upgrade:

pip install stable_baselines3 sb3_contrib rl_zoo3 --upgrade

Note

DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about truncation of optimizer state when loaded with SB3 >= 2.4.0.
To suppress the warning, simply save the model again.
You can find more info in PR #1963

Breaking Changes:

  • Increased minimum required version of Gymnasium to 0.29.1

New Features:

  • Added support for pre_linear_modules and post_linear_modules in create_mlp (useful for adding normalization layers, like in DroQ or CrossQ)
  • Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
  • Updated env checker to warn users when using multi-dim array to define MultiDiscrete spaces
  • Added support for Gymnasium v1.0

Bug Fixes:

  • Fixed memory leak when loading learner from storage, set_parameters() does not try to load the object data anymore
    and only loads the PyTorch parameters (@peteole)
  • Cast type in compute gae method to avoid error when using torch compile (@amjames)
  • CallbackList now sets the .parent attribute of child callbacks to its own .parent. (will-maclean)
  • Fixed error when loading a model that has net_arch manually set to None (@jak3122)
  • Set requirement numpy<2.0 until PyTorch is compatible (pytorch/pytorch#107302)
  • Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger)
  • Fixed test_buffers.py::test_device which was not actually checking the device of tensors (@rhaps0dy)

SB3-Contrib

  • Added CrossQ algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
  • Added BatchRenorm PyTorch layer used in CrossQ (@danielpalen)
  • Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
  • Fixed loading QRDQN changes target_update_interval (@jak3122)

RL Zoo

  • Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results)

SBX (SB3 + Jax)

  • Added CNN support for DQN
  • Bug fix for SAC and related algorithms, optimize log of ent coeff to be consistent with SB3

Others:

  • Fixed various typos (@cschindlbeck)
  • Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
  • Updated PyTorch version on CI to 2.3.1
  • Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and MlpPolicy
  • Switched to uv to download packages faster on GitHub CI
  • Updated dependencies for read the doc
  • Removed unnecessary copy_obs_dict method for SubprocVecEnv, remove the use of ordered dict and rename flatten_obs to stack_obs

Documentation:

  • Updated PPO doc to recommend using CPU with MlpPolicy
  • Clarified documentation about planned features and citing software
  • Added a note about the fact we are optimizing log of ent coeff for SAC

New Contributors

Full Changelog: v2.3.2...v2.4.0