Skip to content

Commit

Permalink
cleanup and simplenet
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Nov 30, 2024
1 parent 7e0d74d commit 15bd59a
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 13 deletions.
25 changes: 18 additions & 7 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(self, base_name, params):
self.save_freq = config.get('save_frequency', 0)
self.save_best_after = config.get('save_best_after', 100)
self.print_stats = config.get('print_stats', True)
self.epochs_between_resets = config.get('epochs_between_resets', 0)
self.rnn_states = None
self.name = base_name

Expand Down Expand Up @@ -382,6 +383,12 @@ def set_eval(self):
self.model.eval()
if self.normalize_rms_advantage:
self.advantage_mean_std.eval()
if self.epochs_between_resets > 0:
if self.epoch_num % self.epochs_between_resets == 0:
self.reset_envs()
self.init_current_rewards()
print(f"Forcing env reset after {self.epoch_num} epochs")


def set_train(self):
self.model.train()
Expand Down Expand Up @@ -466,10 +473,7 @@ def init_tensors(self):

val_shape = (self.horizon_length, batch_size, self.value_size)
current_rewards_shape = (batch_size, self.value_size)
self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_shaped_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device)
self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device)
self.init_current_rewards(batch_size, current_rewards_shape)

if self.is_rnn:
self.rnn_states = self.model.get_default_rnn_state()
Expand All @@ -480,6 +484,12 @@ def init_tensors(self):
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

def init_current_rewards(self, batch_size, current_rewards_shape):
self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_shaped_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device)
self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device)

def init_rnn_from_model(self, model):
self.is_rnn = self.model.is_rnn()

Expand Down Expand Up @@ -571,12 +581,12 @@ def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_ext
mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * masks_t
return mb_advs

def clear_stats(self):
batch_size = self.num_agents * self.num_actors
def clear_stats(self, clean_rewards= True):
self.game_rewards.clear()
self.game_shaped_rewards.clear()
self.game_lengths.clear()
self.mean_rewards = self.last_mean_rewards = -100500
if clean_rewards:
self.mean_rewards = self.last_mean_rewards = -100500
self.algo_observer.after_clear_stats()

def update_epoch(self):
Expand Down Expand Up @@ -772,6 +782,7 @@ def play_steps(self):
self.current_rewards += rewards
self.current_shaped_rewards += shaped_rewards
self.current_lengths += 1

all_done_indices = self.dones.nonzero(as_tuple=False)
env_done_indices = all_done_indices[::self.num_agents]

Expand Down
7 changes: 3 additions & 4 deletions rl_games/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_
self.length = self.batch_size // self.minibatch_size
self.is_discrete = is_discrete
self.is_continuous = not is_discrete
total_games = self.batch_size // self.seq_length
self.num_games_batch = self.minibatch_size // self.seq_length
self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device)
self.flat_indexes = torch.arange(total_games * self.seq_length, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length)

self.special_names = ['rnn_states']
self.permute = permute
self.permutation_indices = torch.arange(self.batch_size, dtype=torch.long, device=self.device)
if self.permute:
self.permutation_indices = torch.arange(self.batch_size, dtype=torch.long, device=self.device)

def update_values_dict(self, values_dict):
"""Update the internal values dictionary."""
Expand Down
1 change: 0 additions & 1 deletion rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ def run(self):
done_indices = all_done_indices[::self.num_agents]
done_count = len(done_indices)
games_played += done_count
print(games_played)
if done_count > 0:
if self.is_rnn:
for s in self.states:
Expand Down
3 changes: 2 additions & 1 deletion rl_games/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@


from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder
from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder, SimpleNetBuilder
from rl_games.algos_torch import model_builder

model_builder.register_network('testnet', TestNetBuilder)
model_builder.register_network('simplenet', SimpleNetBuilder)
model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder)
47 changes: 47 additions & 0 deletions rl_games/envs/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,52 @@ def load(self, params):
def build(self, name, **kwargs):
return TestNetWithAuxLoss(self.params, **kwargs)

def __call__(self, name, **kwargs):
return self.build(name, **kwargs)



class SimpleNet(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
nn.Module.__init__(self)
actions_num = kwargs.pop('actions_num')
input_shape = kwargs.pop('input_shape')
num_inputs =input_shape[0]
self.actions_num = actions_num
self.central_value = params.get('central_value', False)
self.value_size = kwargs.pop('value_size', 1)
self.linear = torch.nn.Sequential(
nn.Linear(num_inputs, 512),
nn.SiLU(),
nn.Linear(512, 256),
nn.SiLU(),
nn.Linear(256, 128),
nn.SiLU(),
nn.Linear(128, actions_num + 1),
)
self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)

def is_rnn(self):
return False
@torch.compile
def forward(self, obs):
obs = obs['obs']
x = self.linear(obs)
mu, value = torch.split(x, [self.actions_num, 1], dim=-1)
return mu, self.sigma.unsqueeze(0).expand(mu.size()[0], self.actions_num), value, None




class SimpleNetBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)

def load(self, params):
self.params = params

def build(self, name, **kwargs):
return SimpleNet(self.params, **kwargs)

def __call__(self, name, **kwargs):
return self.build(name, **kwargs)

0 comments on commit 15bd59a

Please sign in to comment.