-
Notifications
You must be signed in to change notification settings - Fork 9
/
policies.py
executable file
·61 lines (51 loc) · 2.13 KB
/
policies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import random
import torch
from torchvision import transforms
import models
class DQNPolicy:
def __init__(self, cfg, action_space, train=False, random_seed=None):
self.cfg = cfg
self.action_space = action_space
self.train = train
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.policy_net = self.build_network()
self.transform = transforms.ToTensor()
# Resume from checkpoint if applicable
if self.cfg.checkpoint_path is not None:
model_checkpoint = torch.load(self.cfg.model_path, map_location=self.device)
self.policy_net.load_state_dict(model_checkpoint['state_dict'])
if self.train:
self.policy_net.train()
else:
self.policy_net.eval()
print("=> loaded model '{}'".format(self.cfg.model_path))
if random_seed is not None:
random.seed(random_seed)
def build_network(self):
raise NotImplementedError
def apply_transform(self, s):
return self.transform(s).unsqueeze(0)
def step(self, state, exploration_eps=None, debug=False):
if exploration_eps is None:
exploration_eps = self.cfg.final_exploration
state = self.apply_transform(state).to(self.device)
with torch.no_grad():
output = self.policy_net(state).squeeze(0)
if random.random() < exploration_eps:
action = random.randrange(self.action_space)
else:
action = output.view(1, -1).max(1)[1].item()
info = {}
if debug:
info['output'] = output.squeeze(0)
return action, info
class SteeringCommandsPolicy(DQNPolicy):
def build_network(self):
return torch.nn.DataParallel(
models.SteeringCommandsDQN(num_input_channels=self.cfg.num_input_channels, num_output_channels=self.action_space)
).to(self.device)
class DenseActionSpacePolicy(DQNPolicy):
def build_network(self):
return torch.nn.DataParallel(
models.DenseActionSpaceDQN(num_input_channels=self.cfg.num_input_channels)
).to(self.device)