Figure 1 (a): SYMPOL (ours) |
Figure 1 (b): State-Action DT |
Figure 1 (c): Discretized Soft DT |
Information Loss in Tree-Based Reinforcement Learning on Pendulum. Existing methods for symbolic, tree-based RL (see Figure 1 (b) and (c)) suffer from severe information loss when converting the differentiable policy (high train reward) into the symbolic policy (low test reward). Using SYMPOL (Figure 1 (a)), we can directly optimize the symbolic policy with PPO and therefore have no information loss during the application (high train and test reward).
🔍 It’s a novel method to learn interpretable decision trees directly with policy gradients, allowing a seamless integration into existing RL frameworks.
Figure 2: SYMPOL. This is an overview of SYMPOL, with its two main components: (1) an interpretable DT policy and (2) a neural network as critic, which can be integrated into arbitrary actor-critic RL frameworks.
✅ Efficient learning of axis-aligned decision trees on-policy
✅ Enhanced training stability through a dynamic rollout buffer and batch size
✅ Weight decay to support dynamic adjustment of tree architecture
✅ Separate actor-critic network architecture, where the critic captures complexity, and the actor ensures interpretability
✔️ No need for pre-trained neural networks
✔️ No complex custom search procedures
✔️ No post-processing
✔️ No information loss
➡️ SYMPOL guarantees that the learned policy remains consistent from training to inference.
🚀 SYMPOL outperforms existing tree-based RL approaches and provides human-understandable explanations for every decision made.
📝 Read our detailed preprint on arXiv: https://arxiv.org/abs/2408.08761
💬 We’d love to hear your thoughts—feedback is always welcome!
You can contact me at sascha.marton@uni-mannheim.de
@article{marton2024sympol,
title={SYMPOL: Symbolic Tree-Based On-Policy Reinforcement Learning},
author={Marton, Sascha and Grams, Tim and Vogt, Florian and L{\"u}dtke, Stefan and Bartelt, Christian and Stuckenschmidt, Heiner},
journal={arXiv preprint arXiv:2408.08761},
year={2024}
}
Create a virtual environment with any tool of your choice, e.g. conda and activate it:
conda create -n sympol python=3.11.4
conda activate sympol
To install all necessary libraries, run
pip install -r requirements.txt
Occasionally it is necessary to install graphviz through anaconda, so please run
conda install graphviz
If you do not have matching cuda distributions for JAX, this might significantly slow down compilation. In this case, please consider updatind cuda, e.g. with
conda install cuda=12.6 -c nvidia
First, choose the environment you want to run using the --env_id argument, for example, if you want to run PPO using SYMPOL (SYMPOL is the default actor, you can choose between actors using the --actor flag, see args.py
for more info) on CartPole-v1. This already uses optimized hyperparameters specified in configs.py
.
python ppo_gymnax.py --env_id CartPole-v1 --gpu_number 0 --actor sympol --dynamic_buffer --use_best_config
For further possible arguments check the args.py
file.
We want to note that our code is based on ClearnRL (https://github.com/vwxyzjn/cleanrl) with adjustments to our method described in the paper as well as gymnax. We provide two scripts, ppo.py
for standard Gymnasium environments and ppo_gymnax.py
for vectorized environments with gymnax (https://github.com/RobertTLange/gymnax).
To reproduce the experiments conducted to obtain the results in the paper, we provide the file commands_to_run.txt
which contains all the exact commands used to obtain the results from the paper (with HPO as well as using the already documented optimized hyperparameters). We logged all results using wandb which can be activated using the --track
flag.
To avoid randomness, we used a predefined seed for all experiments. Unfortunately, due to the inherent randomness of some GPU operations we were not able to completely remove randomness. Therefore, running the experiments might result in slightly different results from those presented in the paper. However, this should only be minor deviations and the results and claims should stay consistent with those reported in the paper.
With JAX, we can actually optimize the hyperparameters quite efficiently, due to the significant speed up.
If you want to optimize a range of hyperparameters, first note that the range of each hyperparameter is specified in configs.py
in the body of the function suggest_config
. This function is called for every optuna trial. You can remove / add hyperparameters there, or change their range. To run one optuna study for example on LunarLander, you can use
python ppo --env_id CartPole-v1 --optimize_config --track
I recommend using the --track flag, so each run gets logged in wandb. This will allow you to find easily the best config after the optuna study in the wandb user interface. However, if you decide to track your experiments using optuna, keep in mind you can only run 1 job when creating the study (see line 430)
study.optimize(objective_fn, n_trials=args.n_trials, n_jobs=1)
If you use more than one job, wandb will try to initialize multiple times, and you will get a freeze, at least that is what happened to us. The best fix was to simply run the script multiple times. For example, if you want to run 3 jobs, you can do the following:
Open a tmux session and run:
python ppo --env_id CartPole-v1 --optimize_config --track
Hide the session, and open another one and run the same argument again as often as desired.
Now you are running jobs in parallel, and optuna also shares their trials' results between each job. This is also what the hpopt.db
file is good for in the directory (if it is not there yet, it will be created as soon as optuna is run the first time). This file is created by optuna to save information about individual trials.
Alternatively, you can also use the run_ppo.sh
script for parallelization of the HPO.