-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
137 lines (122 loc) · 4.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import cv2
import torch
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt
def get_driver_path(driver_path):
folder_path = driver_path.split('/')[0]
driver = [
x for x in os.listdir(folder_path) if 'driver' in x
][0]
return folder_path + "/" + driver
def get_model_weights(weights_path):
folder_path = weights_path.split('/')[0]
weights = [
x for x in os.listdir(folder_path) if 'model_weights' in x
][0]
return folder_path + "/" + weights
def check_device(device):
''' check if cuda is available for inference '''
device = 'cuda' if torch.cuda.is_available() else 'cpu'
return device
class NeuralNet(nn.Module):
def __init__(self, num_ftrs=512, num_inv=16):
super().__init__()
self.backbone = models.resnet18(pretrained=True)
self.backbone.fc = nn.Flatten()
self.head = nn.Sequential(
nn.Linear(
in_features=num_ftrs + num_inv,
out_features=8
)
)
def forward(self, x):
p2v = self.backbone(x[0])
z = torch.cat([p2v, x[1]], dim=1)
z = self.head(z)
return z
class BRAIN():
def __init__(self, agent, model, device='CPU', plot_state=True):
self.dir_aug = np.array(
[[0, 1, 2, 3, 4, 5, 6, 7, 8], # real image
[0, 3, 4, 5, 6, 7, 8, 1, 2], # rotate 90 o'clock
[0, 5, 6, 7, 8, 1, 2, 3, 4], # rotate 180
[0, 7, 8, 1, 2, 3, 4, 5, 6], # rotate 270
[0, 1, 8, 7, 6, 5, 4, 3, 2], # real image + gor flip
[0, 7, 6, 5, 4, 3, 2, 1, 8], # rotate 90 o'clock + gor flip
[0, 5, 4, 3, 2, 1, 8, 7, 6], # rotate 180 + gor flip
[0, 3, 2, 1, 8, 7, 6, 5, 4]]) # rotate 270 + gor flip
self.agent = agent
self.model = model
self.device = device
if self.device != 'CPU':
self.model = self.model.to(self.device)
self.plot_state = plot_state
def choose_action(self, scr):
p0 = cv2.resize(
scr, (84, 84), interpolation=cv2.INTER_AREA)[:, :, ::-1]
if self.plot_state:
self.see_plot(p0)
p_main = np.zeros((8, 84, 84, 3))
p_main[0, :, :, :] = p0.copy()
for aug in range(1, 8):
if aug == 1:
p = cv2.rotate(p0, cv2.ROTATE_90_CLOCKWISE)
elif aug == 2:
p = cv2.rotate(p0, cv2.ROTATE_180)
elif aug == 3:
p = cv2.rotate(p0, cv2.ROTATE_90_COUNTERCLOCKWISE)
elif aug == 4:
p = cv2.flip(p0, 1)
elif aug == 5:
p = cv2.rotate(p0, cv2.ROTATE_90_CLOCKWISE)
p = cv2.flip(p, 1)
elif aug == 6:
p = cv2.rotate(p0, cv2.ROTATE_180)
p = cv2.flip(p, 1)
elif aug == 7:
p = cv2.rotate(p0, cv2.ROTATE_90_COUNTERCLOCKWISE)
p = cv2.flip(p, 1)
p_main[aug, :, :, :] = p
p_main = torch.tensor(
p_main.astype(np.float32) / 127.5 - 1
).permute(0, 3, 1, 2) # 8x84x84x3-->8x3x84x84
inv = torch.tensor(
np.array([
self.agent.hp / 100, # hp
self.agent.sp / 100, # sp
self.agent.left_bullets > 0, # weapon_mag
self.agent.band / 30, # bandage
self.agent.medk / 4, # medicine
self.agent.cola / 15, # cola
self.agent.pill / 4, # pills
self.agent.helmet / 3, # helmet
self.agent.vest / 3, # vest
self.agent.backpack / 3, # backpack
self.agent.zoom / 15, # zoom
self.agent.use_band, # use_band
self.agent.use_medk, # use_medk
self.agent.use_cola, # use_cola
self.agent.use_pill, # use_pill
self.agent.reloading
]).astype(np.float32)) # reloading
inv = torch.cat([inv.view(1, -1)] * 8, dim=0)
if self.device != 'CPU':
p_main = p_main.to(self.device)
inv = inv.to(self.device)
with torch.no_grad():
predictions = self.model((p_main, inv))
for i in range(8):
temp = predictions[i, :]
predictions[i, :] = temp[self.dir_aug[i][1:] - 1]
predictions = predictions.sum(dim=0).view(1, 8)
_, action = torch.max(predictions, 1)
return action.item()
def see_plot(self, pict, size=(5, 5)):
plt.figure(figsize=size)
plt.imshow(pict)
plt.grid()
plt.show()