-
Notifications
You must be signed in to change notification settings - Fork 0
/
PERL_ML1.py
152 lines (136 loc) · 6.3 KB
/
PERL_ML1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import click
import metaworld
from garage import wrap_experiment
from garage.envs import MetaWorldSetTaskEnv, normalize
from garage.experiment.deterministic import set_seed
from garage.experiment.task_sampler import SetTaskSampler
from garage.sampler import LocalSampler
from garage.torch import set_gpu_mode
from garage.torch.algos import PEARL
from garage.torch.algos.pearl import PEARLWorker
from garage.torch.embeddings import MLPEncoder
from garage.torch.policies import (ContextConditionedPolicy,
TanhGaussianMLPPolicy)
from garage.torch.q_functions import ContinuousMLPQFunction
from garage.trainer import Trainer
@click.command()
@click.option('--num_epochs', default=1000)
@click.option('--num_train_tasks', default=50)
@click.option('--encoder_hidden_size', default=200)
@click.option('--net_size', default=300)
@click.option('--num_steps_per_epoch', default=4000)
@click.option('--num_initial_steps', default=4000)
@click.option('--num_steps_prior', default=750)
@click.option('--num_extra_rl_steps_posterior', default=750)
@click.option('--batch_size', default=256)
@click.option('--embedding_batch_size', default=64)
@click.option('--embedding_mini_batch_size', default=64)
@wrap_experiment
def pearl_metaworld_ml1_push(ctxt=None,
seed=1,
num_epochs=1000,
num_train_tasks=50,
latent_size=7,
encoder_hidden_size=200,
net_size=300,
meta_batch_size=16,
num_steps_per_epoch=4000,
num_initial_steps=4000,
num_tasks_sample=15,
num_steps_prior=750,
num_extra_rl_steps_posterior=750,
batch_size=256,
embedding_batch_size=64,
embedding_mini_batch_size=64,
reward_scale=10.,
use_gpu=False):
"""Train PEARL with ML1 environments.
Args:
ctxt (garage.experiment.ExperimentContext): The experiment
configuration used by Trainer to create the snapshotter.
seed (int): Used to seed the random number generator to produce
determinism.
num_epochs (int): Number of training epochs.
num_train_tasks (int): Number of tasks for training.
latent_size (int): Size of latent context vector.
encoder_hidden_size (int): Output dimension of dense layer of the
context encoder.
net_size (int): Output dimension of a dense layer of Q-function and
value function.
meta_batch_size (int): Meta batch size.
num_steps_per_epoch (int): Number of iterations per epoch.
num_initial_steps (int): Number of transitions obtained per task before
training.
num_tasks_sample (int): Number of random tasks to obtain data for each
iteration.
num_steps_prior (int): Number of transitions to obtain per task with
z ~ prior.
num_extra_rl_steps_posterior (int): Number of additional transitions
to obtain per task with z ~ posterior that are only used to train
the policy and NOT the encoder.
batch_size (int): Number of transitions in RL batch.
embedding_batch_size (int): Number of transitions in context batch.
embedding_mini_batch_size (int): Number of transitions in mini context
batch; should be same as embedding_batch_size for non-recurrent
encoder.
reward_scale (int): Reward scale.
use_gpu (bool): Whether or not to use GPU for training.
"""
set_seed(seed)
encoder_hidden_sizes = (encoder_hidden_size, encoder_hidden_size,
encoder_hidden_size)
# create multi-task environment and sample tasks
ml1 = metaworld.ML1('button-press-v2')
train_env = MetaWorldSetTaskEnv(ml1, 'train')
env_sampler = SetTaskSampler(MetaWorldSetTaskEnv,
env=train_env,
wrapper=lambda env, _: normalize(env))
env = env_sampler.sample(num_train_tasks)
test_env = MetaWorldSetTaskEnv(ml1, 'test')
test_env_sampler = SetTaskSampler(MetaWorldSetTaskEnv,
env=test_env,
wrapper=lambda env, _: normalize(env))
trainer = Trainer(ctxt)
# instantiate networks
augmented_env = PEARL.augment_env_spec(env[0](), latent_size)
qf = ContinuousMLPQFunction(env_spec=augmented_env,
hidden_sizes=[net_size, net_size, net_size])
vf_env = PEARL.get_env_spec(env[0](), latent_size, 'vf')
vf = ContinuousMLPQFunction(env_spec=vf_env,
hidden_sizes=[net_size, net_size, net_size])
inner_policy = TanhGaussianMLPPolicy(
env_spec=augmented_env, hidden_sizes=[net_size, net_size, net_size])
sampler = LocalSampler(agents=None,
envs=env[0](),
max_episode_length=env[0]().spec.max_episode_length,
n_workers=1,
worker_class=PEARLWorker)
pearl = PEARL(
env=env,
policy_class=ContextConditionedPolicy,
encoder_class=MLPEncoder,
inner_policy=inner_policy,
qf=qf,
vf=vf,
sampler=sampler,
num_train_tasks=num_train_tasks,
latent_dim=latent_size,
encoder_hidden_sizes=encoder_hidden_sizes,
test_env_sampler=test_env_sampler,
meta_batch_size=meta_batch_size,
num_steps_per_epoch=num_steps_per_epoch,
num_initial_steps=num_initial_steps,
num_tasks_sample=num_tasks_sample,
num_steps_prior=num_steps_prior,
num_extra_rl_steps_posterior=num_extra_rl_steps_posterior,
batch_size=batch_size,
embedding_batch_size=embedding_batch_size,
embedding_mini_batch_size=embedding_mini_batch_size,
reward_scale=reward_scale,
)
set_gpu_mode(use_gpu, gpu_id=0)
if use_gpu:
pearl.to()
trainer.setup(algo=pearl, env=env[0]())
trainer.train(n_epochs=num_epochs, batch_size=batch_size)
pearl_metaworld_ml1_push()