-
Notifications
You must be signed in to change notification settings - Fork 39
/
models.py
318 lines (249 loc) · 18.8 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import copy
from math import ceil, exp, sqrt
from typing import Dict, Optional, Tuple
import numpy as np
from omegaconf import DictConfig
import torch
from torch import Tensor, nn
from torch.distributions import Distribution, Independent, Normal, TransformedDistribution
from torch.distributions.transforms import TanhTransform
from torch.nn import Parameter, functional as F
from torch.nn.utils import parametrizations
from memory import ReplayMemory
ACTIVATION_FUNCTIONS = {'relu': nn.ReLU, 'sigmoid': nn.Sigmoid, 'tanh': nn.Tanh}
# Concatenates the state and action
def _join_state_action(state: Tensor, action: Tensor) -> Tensor:
return torch.cat([state, action], dim=1)
# Computes the squared distance between two sets of vectors
def _squared_distance(x: Tensor, y: Tensor) -> Tensor:
n_1, n_2, d = x.size(0), y.size(0), x.size(1)
tiled_x, tiled_y = x.view(n_1, 1, d).expand(n_1, n_2, d), y.view(1, n_2, d).expand(n_1, n_2, d)
return (tiled_x - tiled_y).pow(2).mean(dim=2)
# Gaussian/radial basis function/exponentiated quadratic kernel
def _gaussian_kernel(x: Tensor, gamma: float=1) -> Tensor:
return torch.exp(-gamma * x)
def _weighted_similarity(XY: Tensor, w_x: Tensor, w_y: Tensor, gamma: float=1) -> Tensor:
return torch.einsum('i,ij,j->i', [w_x, _gaussian_kernel(XY, gamma=gamma), w_y])
def _weighted_median(x: Tensor, weights: Tensor) -> Tensor:
x_sorted, indices = torch.sort(x.flatten())
weights_norm_sorted = (weights.flatten() / weights.sum())[indices] # Normalise and rearrange weights according to sorting
median_index = torch.min((torch.cumsum(weights_norm_sorted, dim=0) >= 0.5).nonzero())
return x_sorted[median_index]
# Creates a sequential fully-connected network
def _create_fcnn(input_size: int, hidden_size: int, depth: int, output_size: int, activation_function: str, input_dropout: float=0, dropout: float=0, final_gain: float=1, spectral_norm: bool=False) -> nn.Module:
assert activation_function in ACTIVATION_FUNCTIONS.keys()
network_dims, layers = (input_size, *[hidden_size] * depth), []
if input_dropout > 0:
layers.append(nn.Dropout(p=input_dropout))
for l in range(len(network_dims) - 1):
layer = nn.Linear(network_dims[l], network_dims[l + 1])
nn.init.orthogonal_(layer.weight, gain=nn.init.calculate_gain(activation_function))
nn.init.constant_(layer.bias, 0)
if spectral_norm: layer = parametrizations.spectral_norm(layer)
layers.append(layer)
if dropout > 0: layers.append(nn.Dropout(p=dropout))
layers.append(ACTIVATION_FUNCTIONS[activation_function]())
final_layer = nn.Linear(network_dims[-1], output_size)
nn.init.orthogonal_(final_layer.weight, gain=final_gain)
nn.init.constant_(final_layer.bias, 0)
if spectral_norm: final_layer = parametrizations.spectral_norm(final_layer)
layers.append(final_layer)
return nn.Sequential(*layers)
def create_target_network(network: nn.Module) -> nn.Module:
target_network = copy.deepcopy(network)
for param in target_network.parameters():
param.requires_grad = False
return target_network
def update_target_network(network: nn.Module, target_network: nn.Module, polyak_factor: float):
for param, target_param in zip(network.parameters(), target_network.parameters()):
target_param.data.mul_(polyak_factor).add_((1 - polyak_factor) * param.data)
class SoftActor(nn.Module):
def __init__(self, state_size: int, action_size: int, model_cfg: DictConfig):
super().__init__()
self.log_std_dev_min, self.log_std_dev_max = -20, 2 # Constrain range of standard deviations to prevent very deterministic/stochastic policies
self.actor = _create_fcnn(state_size, model_cfg.hidden_size, model_cfg.depth, output_size=2 * action_size, activation_function=model_cfg.activation, input_dropout=model_cfg.get('input_dropout', 0), dropout=model_cfg.get('dropout', 0))
def forward(self, state: Tensor) -> Distribution:
mean, log_std_dev = self.actor(state).chunk(2, dim=1)
log_std_dev = torch.clamp(log_std_dev, min=self.log_std_dev_min, max=self.log_std_dev_max)
policy = TransformedDistribution(Independent(Normal(mean, log_std_dev.exp()), 1), TanhTransform(cache_size=1)) # Restrict action range to (-1, 1)
return policy
# Calculates the log probability of an action a with the policy π(·|s) given state s
def log_prob(self, state: Tensor, action: Tensor) -> Tensor:
action = action.clamp(-1 + 1e-6, 1 - 1e-6) # Clamp actions to (-1, 1) to prevent NaNs in log likelihood calculation of TanhGaussian policy; predominantly an issue with DRIL uncertainty calculation
return self.forward(state).log_prob(action)
def get_greedy_action(self, state: Tensor) -> Tensor:
return torch.tanh(self.forward(state).base_dist.mean)
def _get_action_uncertainty(self, state: Tensor, action: Tensor) -> Tensor:
state, action = torch.repeat_interleave(state, 5, dim=0), torch.repeat_interleave(action, 5, dim=0) # Repeat state and actions x ensemble size
prob = self.log_prob(state, action).exp() # Perform Monte-Carlo dropout for an implicit ensemble; PyTorch implementation does not share masks across a batch (all independent)
return prob.view(-1, 5).var(dim=1) # Resized tensor is batch size x ensemble size
# Set uncertainty threshold at the 98th quantile of uncertainty costs calculated over the expert data
def set_uncertainty_threshold(self, expert_state: Tensor, expert_action: Tensor, quantile_cutoff: float):
self.q = torch.quantile(self._get_action_uncertainty(expert_state, expert_action), quantile_cutoff).item()
def predict_reward(self, state: Tensor, action: Tensor) -> Tensor:
# Calculate (raw) uncertainty cost
uncertainty_cost = self._get_action_uncertainty(state, action)
# Calculate clipped uncertainty cost
neg_idxs = uncertainty_cost.less_equal(self.q)
uncertainty_cost[neg_idxs] = -1
uncertainty_cost[~neg_idxs] = 1
return -uncertainty_cost
class Critic(nn.Module):
def __init__(self, state_size: int, action_size: int, model_cfg: DictConfig):
super().__init__()
self.critic = _create_fcnn(state_size + action_size, model_cfg.hidden_size, model_cfg.depth, output_size=1, activation_function=model_cfg.activation)
def forward(self, state: Tensor, action: Tensor) -> Tensor:
value = self.critic(_join_state_action(state, action)).squeeze(dim=1)
return value
class TwinCritic(nn.Module):
def __init__(self, state_size: int, action_size: int, model_cfg: DictConfig):
super().__init__()
self.critic_1 = Critic(state_size, action_size, model_cfg)
self.critic_2 = Critic(state_size, action_size, model_cfg)
def forward(self, state: Tensor, action: Tensor) -> Tuple[Tensor, Tensor]:
value_1, value_2 = self.critic_1(state, action), self.critic_2(state, action)
return value_1, value_2
# Constructs the input for the GAIL discriminator
def make_gail_input(state: Tensor, action: Tensor, next_state: Tensor, terminal: Tensor, actor: SoftActor, reward_shaping: bool, subtract_log_policy: bool) -> Dict[str, Tensor]:
input = {'state': state, 'action': action}
if reward_shaping: input.update({'next_state': next_state, 'terminal': terminal})
if subtract_log_policy: input.update({'log_policy': actor.log_prob(state, action)})
return input
class GAILDiscriminator(nn.Module):
def __init__(self, state_size: int, action_size: int, imitation_cfg: DictConfig, discount: float):
super().__init__()
model_cfg = imitation_cfg.discriminator
self.discount, self.state_only, self.reward_shaping, self.subtract_log_policy, self.reward_function = discount, imitation_cfg.state_only, model_cfg.reward_shaping, model_cfg.subtract_log_policy, model_cfg.reward_function
if self.reward_shaping:
self.g = nn.Linear(state_size if self.state_only else state_size + action_size, 1) # Reward function r
if imitation_cfg.spectral_norm: self.g = parametrizations.spectral_norm(self.g)
self.h = _create_fcnn(state_size, model_cfg.hidden_size, model_cfg.depth, 1, activation_function=model_cfg.activation, spectral_norm=imitation_cfg.spectral_norm) # Shaping function Φ
else:
self.g = _create_fcnn(state_size if self.state_only else state_size + action_size, model_cfg.hidden_size, model_cfg.depth, 1, activation_function=model_cfg.activation, spectral_norm=imitation_cfg.spectral_norm)
def _reward(self, state: Tensor, action: Tensor) -> Tensor:
if self.state_only:
return self.g(state).squeeze(dim=1)
else:
return self.g(_join_state_action(state, action)).squeeze(dim=1)
def _value(self, state: Tensor) -> Tensor:
return self.h(state).squeeze(dim=1)
def forward(self, state: Tensor, action: Tensor, next_state: Optional[Tensor]=None, terminal: Optional[Tensor]=None, log_policy: Optional[Tensor]=None) -> Tensor:
f = self._reward(state, action) + (1 - terminal) * (self.discount * self._value(next_state) - self._value(state)) if self.reward_shaping else self._reward(state, action) # Note that vanilla GAIL does not learn a "reward function", but this naming just makes the code simpler to read
return f - log_policy if self.subtract_log_policy else f # Note that the former is equivalent to sigmoid^-1(e^f / (e^f + π))
def predict_reward(self, state: Tensor, action: Tensor, next_state: Optional[Tensor]=None, terminal: Optional[Tensor]=None, log_policy: Optional[Tensor]=None) -> Tensor:
D = torch.sigmoid(self.forward(state, action, next_state=next_state, terminal=terminal, log_policy=log_policy))
h = -torch.log1p(-D + 1e-6) if self.reward_function == 'GAIL' else torch.log(D + 1e-6) - torch.log1p(-D + 1e-6) # Add epsilon to improve numerical stability given limited floating point precision
return torch.exp(h) * -h if self.reward_function == 'FAIRL' else h # FAIRL reward function is based on AIRL reward function
class GMMILDiscriminator(nn.Module):
def __init__(self, state_size: int, action_size: int, imitation_cfg: DictConfig):
super().__init__()
self.state_only = imitation_cfg.state_only
self.gamma_1, self.gamma_2 = None, None
def predict_reward(self, state: Tensor, action: Tensor, expert_state: Tensor, expert_action: Tensor, weight: Tensor, expert_weight: Tensor) -> Tensor:
state_action = state if self.state_only else _join_state_action(state, action)
expert_state_action = expert_state if self.state_only else _join_state_action(expert_state, expert_action)
# Use median heuristics to set data-dependent bandwidths
if self.gamma_1 is None:
self.gamma_1 = 1 / (_weighted_median(_squared_distance(state_action, expert_state_action), torch.outer(weight, expert_weight)).item() + 1e-8)
self.gamma_2 = 1 / (_weighted_median(_squared_distance(expert_state_action, expert_state_action), torch.outer(expert_weight, expert_weight)).item() + 1e-8) # Add epsilon for numerical stability (if distance is zero)
# Calculate negative of witness function (based on kernel mean embeddings)
weight_norm, exp_weight_norm = weight / weight.sum(), expert_weight / expert_weight.sum()
s_a_e_s_a_sq_dist, s_a_s_a_sq_dist = _squared_distance(state_action, expert_state_action), _squared_distance(state_action, state_action)
similarity = _weighted_similarity(s_a_e_s_a_sq_dist, weight_norm, exp_weight_norm, gamma=self.gamma_1) + _weighted_similarity(s_a_e_s_a_sq_dist, weight_norm, exp_weight_norm, gamma=self.gamma_2)
self_similarity = _weighted_similarity(s_a_s_a_sq_dist, weight_norm, weight_norm, gamma=self.gamma_1) + _weighted_similarity(s_a_s_a_sq_dist, weight_norm, weight_norm, gamma=self.gamma_2)
return similarity - self_similarity
# Returns the scale and offset to normalise data based on mean and standard deviation
def _calculate_normalisation_scale_offset(data: Tensor) -> Tuple[Tensor, Tensor]:
inv_scale, offset = data.std(dim=0, keepdims=True), -data.mean(dim=0, keepdims=True) # Calculate statistics over dataset
inv_scale[inv_scale == 0] = 1 # Set (inverse) scale to 1 if feature is constant (no variance)
return 1 / inv_scale, offset
# Returns a tensor with a "row" (dim 0) deleted
def _delete_row(data: Tensor, index: int) -> Tensor:
return torch.cat([data[:index], data[index + 1:]], dim=0)
class PWILDiscriminator(nn.Module):
def __init__(self, state_size: int, action_size: int, imitation_cfg: DictConfig, expert_memory: ReplayMemory, time_horizon: int):
super().__init__()
self.state_only = imitation_cfg.state_only
self.expert_memory, self.time_horizon = expert_memory, time_horizon
self.data_scale, self.data_offset = _calculate_normalisation_scale_offset(self._get_expert_atoms()) # Calculate normalisation parameters for the data
self.reward_scale, self.reward_bandwidth = imitation_cfg.reward_scale, imitation_cfg.reward_bandwidth_scale * self.time_horizon / sqrt(state_size if imitation_cfg.state_only else (state_size + action_size)) # Reward function hyperparameters (based on α and β)
self.reset()
def _get_expert_atoms(self) -> Tensor:
return self.expert_memory['states'] if self.state_only else torch.cat([self.expert_memory['states'], self.expert_memory['actions']], dim=1)
def reset(self):
self.expert_atoms = self.data_scale * (self._get_expert_atoms() + self.data_offset) # Get and normalise the expert atoms
self.expert_weights = torch.full((len(self.expert_memory), ), 1 / len(self.expert_memory))
def compute_reward(self, state: Tensor, action: Tensor) -> float:
agent_atom = state if self.state_only else torch.cat([state, action], dim=1)
agent_atom = self.data_scale * (agent_atom + self.data_offset) # Normalise the agent atom
weight, cost = 1 / self.time_horizon - 1e-6, 0 # Note: subtracting eps from initial agent atom weight for numerical stability
dists = torch.linalg.norm(self.expert_atoms - agent_atom, dim=1)
while weight > 0:
closest_expert_idx = dists.argmin().item() # Find closest expert atom
expert_weight = self.expert_weights[closest_expert_idx].item()
# Update costs and weights
if weight >= expert_weight:
cost += expert_weight * dists[closest_expert_idx].item()
weight -= expert_weight
self.expert_atoms, self.expert_weights, dists = _delete_row(self.expert_atoms, closest_expert_idx), _delete_row(self.expert_weights, closest_expert_idx), _delete_row(dists, closest_expert_idx) # Remove the expert atom
else:
cost += weight * dists[closest_expert_idx].item()
self.expert_weights[closest_expert_idx] -= weight
weight = 0
return self.reward_scale * exp(-self.reward_bandwidth * cost)
class EmbeddingNetwork(nn.Module):
def __init__(self, input_size: int, model_cfg: DictConfig, input_dropout=0, dropout=0): # Takes dropout as a separate argument as not be applied to target network
super().__init__()
self.embedding = _create_fcnn(input_size, model_cfg.hidden_size, model_cfg.depth, input_size, model_cfg.activation, input_dropout=input_dropout, dropout=dropout)
def forward(self, input: Tensor) -> Tensor:
return self.embedding(input)
class REDDiscriminator(nn.Module):
def __init__(self, state_size: int, action_size: int, imitation_cfg: DictConfig):
super().__init__()
self.state_only = imitation_cfg.state_only
self.predictor = EmbeddingNetwork(state_size if self.state_only else state_size + action_size, imitation_cfg.discriminator, input_dropout=imitation_cfg.discriminator.input_dropout, dropout=imitation_cfg.discriminator.dropout)
self.target = EmbeddingNetwork(state_size if self.state_only else state_size + action_size, imitation_cfg.discriminator)
for param in self.target.parameters():
param.requires_grad = False
self.sigma_1 = imitation_cfg.reward_bandwidth_scale
def forward(self, state: Tensor, action: Tensor) -> Tuple[Tensor, Tensor]:
state_action = state if self.state_only else _join_state_action(state, action)
prediction, target = self.predictor(state_action), self.target(state_action)
return prediction, target
# Originally, sets σ based such that r(s, a) from expert demonstrations ≈ 1; instead this uses kernel median heuristic (same as GMMIL)
def set_sigma(self, expert_state: Tensor, expert_action: Tensor):
if not self.sigma_1:
prediction, target = self.forward(expert_state, expert_action)
self.sigma_1 = 1 / _squared_distance(prediction, target).median().item()
def predict_reward(self, state: Tensor, action: Tensor) -> Tensor:
prediction, target = self.forward(state, action)
return torch.exp(-self.sigma_1 * (prediction - target).pow(2).mean(dim=1))
def mix_expert_agent_transitions(transitions: Dict[str, Tensor], expert_transitions: Dict[str, Tensor]):
batch_size = transitions['rewards'].size(0)
for key in transitions.keys():
transitions[key][:batch_size // 2] = expert_transitions[key][:batch_size // 2] # Replace first half of the batch with expert data
class RewardRelabeller():
def __init__(self, update_freq: int, balanced: bool):
self.update_freq, self.balanced, self.sample_expert = update_freq, balanced, True
def resample_and_relabel(self, transitions: Dict[str, Tensor], expert_transitions: Dict[str, Tensor], step: int, num_trajectories: int, num_expert_trajectories: int):
# Creates a batch of training data made from a mix of expert and policy data; rewrites transitions in-place TODO: Add sampling ratio option?
batch_size = transitions['rewards'].size(0)
if self.balanced: # Alternate between sampling expert and policy data
if self.sample_expert:
for key in transitions.keys():
transitions[key] = expert_transitions[key] # Replace all of the batch with expert data
expert_idxs, policy_idxs = range(batch_size), []
else:
expert_idxs, policy_idxs = [], range(batch_size)
self.sample_expert = not self.sample_expert # Sample from other data next time
else: # Replace first half of the batch with expert data
mix_expert_agent_transitions(transitions, expert_transitions)
expert_idxs, policy_idxs = range(batch_size // 2), range(batch_size // 2, batch_size)
# Label rewards according to the algorithm
if self.update_freq > 0: # AdRIL
transitions['rewards'][expert_idxs] = 1 / num_expert_trajectories # Set a constant +1 reward for expert data, normalised by |trajectories|
round_num = ceil(step / self.update_freq)
transitions['rewards'][policy_idxs] = -1 * (round_num > torch.ceil(transitions['step'][policy_idxs] / self.update_freq)).to(dtype=torch.float32) / max(num_trajectories, 1) # Set a constant 0 reward for current round of policy data, and -1 for old rounds, normalised by |trajectories|
else: # SQIL
transitions['rewards'][expert_idxs] = 1 # Set a constant +1 reward for expert data
transitions['rewards'][policy_idxs] = 0 # Set a constant 0 reward for policy data