Skip to content

Commit

Permalink
Update plots for tutorial; larger font;
Browse files Browse the repository at this point in the history
  • Loading branch information
cr-xu committed Feb 2, 2024
1 parent 600de2c commit 0cbd070
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
Binary file modified img/random_policy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified img/trained_meta_policy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 6 additions & 3 deletions meta-rl/read_out_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@

import matplotlib.pyplot as plt
import numpy as np
from sympy import root

from maml_rl.utils.reinforcement_learning import get_returns
from sympy import root

# from maml_rl.utils.torch_utils import to_numpy

Expand Down Expand Up @@ -175,7 +174,7 @@ def plot_progress(
ax.set_title(title)
ax.set_xlabel("Batches")
ax.set_ylabel("Returns")
ax.legend()
ax.legend(loc="lower right")
ax.grid(True)

if save_folder:
Expand Down Expand Up @@ -285,14 +284,18 @@ def setup_and_plot(base_folder, experiment_name, experiment_type, ax, label_pref
returns_mean_valid,
nr_total_interactions,
) = read_train_data(my_dir=progress_folder)
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)
plot_progress(
returns_train,
returns_valid,
returns_mean_train,
returns_mean_valid,
title=f"Statistics for exp: {args.experiment_type}, "
+ f"total {nr_total_interactions} steps",
ax=ax,
)
ax.set_ylim(-120, 0) # For tutorial purposes

data_train_individual, data_valid_individual = read_train_data_individual(
my_dir=progress_folder
Expand Down
4 changes: 2 additions & 2 deletions tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@
"\n",
"Run the following code to train the task policy $\\varphi_0^0$ for 500 steps:\n",
"\n",
"`python test.py --experiment-name tutorial --experiment-type adapt_from_scratch --num-batches=500 --plot-interval=50 --task-ids 0`\n",
"`python test.py --experiment-name tutorial --experiment-type adapt_from_scratch --num-batches 500 --plot-interval 50 --task-ids 0`\n",
"\n",
"Once it has run, you can look at the adaptation progress by running:\n",
"\n",
Expand Down Expand Up @@ -655,7 +655,7 @@
"\n",
"We will now use a pre-trained policy located in `awake/pretrained_policy.th` and evalulate it against a certain number of fixed tasks.\n",
"\n",
"`python test.py --experiment-name tutorial --experiment-type test_meta --use-meta-policy --policy awake/pretrained_policy.th --num-batches=500 --plot-interval=50 --task-ids 0 1 2 3 4`\n",
"`python test.py --experiment-name tutorial --experiment-type test_meta --use-meta-policy --policy awake/pretrained_policy.th --num-batches 500 --plot-interval 50 --task-ids 0 1 2 3 4`\n",
"\n",
"- use `--task-ids 0 1 2 3 4` to run evaluation against all 5 tasks, or e.g. `--task-ids 0` to evaluate only for task 0.\n",
"- here we set the flag `--use-meta-policy` so that it uses the pre-trained policy.\n",
Expand Down

0 comments on commit 0cbd070

Please sign in to comment.