-
Notifications
You must be signed in to change notification settings - Fork 2
/
example.py
53 lines (46 loc) · 1.74 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def main():
import warnings
import dreamerv3
from dreamerv3 import embodied
warnings.filterwarnings('ignore', '.*truncated to dtype int32.*')
# See configs.yaml for all options.
config = embodied.Config(dreamerv3.configs['defaults'])
config = config.update(dreamerv3.configs['medium'])
config = config.update({
'logdir': '~/logdir/run1',
'run.train_ratio': 64,
'run.log_every': 30, # Seconds
'batch_size': 16,
'jax.prealloc': False,
'encoder.mlp_keys': '$^',
'decoder.mlp_keys': '$^',
'encoder.cnn_keys': 'image',
'decoder.cnn_keys': 'image',
# 'jax.platform': 'cpu',
})
config = embodied.Flags(config).parse()
logdir = embodied.Path(config.logdir)
step = embodied.Counter()
logger = embodied.Logger(step, [
embodied.logger.TerminalOutput(),
embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
embodied.logger.TensorBoardOutput(logdir),
# embodied.logger.WandBOutput(logdir.name, config),
# embodied.logger.MLFlowOutput(logdir.name),
])
import crafter
from embodied.envs import from_gym
env = crafter.Env() # Replace this with your Gym env.
env = from_gym.FromGym(env, obs_key='image') # Or obs_key='vector'.
env = dreamerv3.wrap_env(env, config)
env = embodied.BatchEnv([env], parallel=False)
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
replay = embodied.replay.Uniform(
config.batch_length, config.replay_size, logdir / 'replay')
args = embodied.Config(
**config.run, logdir=config.logdir,
batch_steps=config.batch_size * config.batch_length)
embodied.run.train(agent, env, replay, logger, args)
# embodied.run.eval_only(agent, env, logger, args)
if __name__ == '__main__':
main()