diff --git a/rl_games/__init__.py b/rl_games/__init__.py index e69de29b..7c443754 100644 --- a/rl_games/__init__.py +++ b/rl_games/__init__.py @@ -0,0 +1 @@ +from rl_games.networks import * \ No newline at end of file diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index e93ea362..786fda61 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -41,7 +41,7 @@ def __init__(self, base_name, params): self.init_rnn_from_model(self.model) self.last_lr = float(self.last_lr) self.bound_loss_type = self.config.get('bound_loss_type', 'bound') # 'regularisation' or 'bound' - self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay) + self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay, fused=True) if self.has_central_value: cv_config = { diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 289812dd..5ba5492d 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -5,12 +5,12 @@ import torch.nn as nn from rl_games.algos_torch.d2rl import D2RLNet +from rl_games.common.layers.switch_ffn import MoEBlock from rl_games.algos_torch.sac_helper import SquashedNormal from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax - def _create_initializer(func, **kwargs): return lambda v : func(v, **kwargs) @@ -68,6 +68,8 @@ def get_default_rnn_state(self): return None def get_aux_loss(self): + if self.moe_block: + return self.actor_mlp.get_aux_loss() return None def _calc_input_size(self, input_shape,cnn_layers=None): @@ -129,6 +131,9 @@ def _build_mlp(self, else: return self._build_sequential_mlp(input_size, units, activation, dense_func, norm_func_name = None,) + def _build_moe_block(self, input_size, expert_units, model_units, num_experts): + return MoEBlock(input_size, expert_units, model_units, num_experts) + def _build_conv(self, ctype, **kwargs): print('conv_name:', ctype) @@ -232,9 +237,8 @@ def __init__(self, params, **kwargs): cnn_output_size = self._calc_input_size(input_shape, self.actor_cnn) mlp_input_size = cnn_output_size - if len(self.units) == 0: - out_size = cnn_output_size - else: + out_size = mlp_input_size + if len(self.units) > 0: out_size = self.units[-1] if self.has_rnn: @@ -264,18 +268,23 @@ def __init__(self, params, **kwargs): if self.rnn_ln: self.layer_norm = torch.nn.LayerNorm(self.rnn_units) - mlp_args = { - 'input_size' : mlp_input_size, - 'units' : self.units, - 'activation' : self.activation, - 'norm_func_name' : self.normalization, - 'dense_func' : torch.nn.Linear, - 'd2rl' : self.is_d2rl, - 'norm_only_first_layer' : self.norm_only_first_layer - } - self.actor_mlp = self._build_mlp(**mlp_args) - if self.separate: - self.critic_mlp = self._build_mlp(**mlp_args) + + if self.moe_block: + self.actor_mlp = self._build_moe_block(mlp_input_size, self.expert_units, self.model_units, self.num_experts) + assert(not self.separate) + else: + mlp_args = { + 'input_size' : mlp_input_size, + 'units' : self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear, + 'd2rl' : self.is_d2rl, + 'norm_only_first_layer' : self.norm_only_first_layer + } + self.actor_mlp = self._build_mlp(**mlp_args) + if self.separate: + self.critic_mlp = self._build_mlp(**mlp_args) self.value = self._build_value_layer(out_size, self.value_size) self.value_act = self.activations_factory.create(self.value_activation) @@ -507,11 +516,22 @@ def get_default_rnn_state(self): def load(self, params): self.separate = params.get('separate', False) - self.units = params['mlp']['units'] - self.activation = params['mlp']['activation'] - self.initializer = params['mlp']['initializer'] - self.is_d2rl = params['mlp'].get('d2rl', False) - self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) + self.moe_block = params.get('moe', False) + + if self.moe_block: + assert(not params.get('mlp', False)) + self.num_experts = self.moe_block['num_experts'] + self.expert_units = self.moe_block['expert_units'] + self.model_units = self.moe_block['model_units'] + self.initializer = self.moe_block['initializer'] + self.units = self.expert_units + + else: + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_d2rl = params['mlp'].get('d2rl', False) + self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) self.value_activation = params.get('value_activation', 'None') self.normalization = params.get('normalization', None) self.has_rnn = 'rnn' in params diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index 08170847..d0ffd006 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -10,7 +10,6 @@ import math - class HCRewardEnv(gym.RewardWrapper): def __init__(self, env): gym.RewardWrapper.__init__(self, env) diff --git a/rl_games/common/layers/switch_ffn.py b/rl_games/common/layers/switch_ffn.py new file mode 100644 index 00000000..6101fa80 --- /dev/null +++ b/rl_games/common/layers/switch_ffn.py @@ -0,0 +1,210 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SwitchFeedForward(nn.Module): + + def __init__(self, + model_dim: int, + hidden_dim: int, + out_dim: int, + is_scale_prob: bool, + num_experts: int, + activation: nn.Module = nn.ReLU + + ): + super().__init__() + self.hidden_dim = hidden_dim + self.model_dim = model_dim + self.out_dim = out_dim + self.is_scale_prob = is_scale_prob + self.num_experts = num_experts + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(model_dim, out_dim), + activation(), + #nn.Linear(model_dim, hidden_dim), + #activation(), + #nn.Linear(hidden_dim, out_dim), + #activation(), + ) + for _ in range(num_experts) + ]) + # Routing layer and softmax + self.switch = nn.Linear(model_dim, num_experts) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x: torch.Tensor): + route_prob = self.softmax(self.switch(x)) + route_prob_max, routes = torch.max(route_prob, dim=-1) + indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.num_experts)] + + final_output = torch.zeros((x.size(0), self.out_dim), device=x.device) + counts = x.new_tensor([len(indexes_list[i]) for i in range(self.num_experts)]) + + # Get outputs of the expert FFNs + expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.num_experts)] + # Assign to final output + for i in range(self.num_experts): + final_output[indexes_list[i], :] = expert_output[i] + + if self.is_scale_prob: + # Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$ + final_output = final_output * route_prob_max.view(-1, 1) + else: + # not sure if this is correct + final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1) + + + return final_output, counts, route_prob.sum(0), route_prob_max + + + +class MoEFF(nn.Module): + def __init__(self, + model_dim: int, + hidden_dim: int, + out_dim: int, + num_experts: int, + activation: nn.Module = nn.ReLU, + **kwargs + ): + super().__init__() + + # Parameters from params + self.model_dim = model_dim + self.num_experts = num_experts + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.gating_hidden_size = kwargs.get('gating_hidden_size', 64) + self.use_sparse_gating = kwargs.get('use_sparse_gating', True) + self.use_entropy_loss = kwargs.get('use_entropy_loss', True) + self.use_diversity_loss = kwargs.get('use_diversity_loss', True) + self.top_k = kwargs.get('top_k', 2) + self.lambda_entropy = kwargs.get('lambda_entropy', 0.01) + self.lambda_diversity = kwargs.get('lambda_diversity', 0.00) + + + # Gating Network + self.gating_fc1 = nn.Linear(self.model_dim, self.gating_hidden_size) + self.gating_fc2 = nn.Linear(self.gating_hidden_size, num_experts) + + # Expert Networks + self.expert_networks = nn.ModuleList([ + nn.Sequential( + nn.Linear(self.model_dim, out_dim), + activation(), + ) for _ in range(num_experts) + ]) + + + # Auxiliary loss map + self.aux_loss_map = { + } + if self.use_diversity_loss: + self.aux_loss_map['moe_diversity_loss'] = 0.0 + if self.use_entropy_loss: + self.aux_loss_map['moe_entropy_loss'] = 0.0 + + def get_aux_loss(self): + return self.aux_loss_map + + def forward(self, x): + + # Gating Network Forward Pass + gating_x = F.relu(self.gating_fc1(x)) + gating_logits = self.gating_fc2(gating_x) # Shape: [batch_size, num_experts] + orig_gating_weights = F.softmax(gating_logits, dim=1) + gating_weights = orig_gating_weights + # Apply Sparse Gating if enabled + if self.use_sparse_gating: + topk_values, topk_indices = torch.topk(gating_weights, self.top_k, dim=1) + sparse_mask = torch.zeros_like(gating_weights) + sparse_mask.scatter_(1, topk_indices, topk_values) + # probably better go with masked softmax + gating_weights = sparse_mask / sparse_mask.sum(dim=1, keepdim=True) + + if self.use_entropy_loss: + # Compute Entropy Loss for Gating Weights + entropy = -torch.sum(gating_weights * torch.log(gating_weights + 1e-8), dim=1) + entropy_loss = torch.mean(entropy) + self.aux_loss_map['moe_entropy_loss'] = -self.lambda_entropy * entropy_loss + + # Expert Networks Forward Pass + expert_outputs = [] + for expert in self.expert_networks: + expert_outputs.append(expert(x)) # Each output shape: [batch_size, hidden_size] + expert_outputs = torch.stack(expert_outputs, dim=1) # Shape: [batch_size, num_experts, hidden_size] + + # Compute Diversity Loss + if self.use_diversity_loss: + diversity_loss = 0.0 + num_experts = len(self.expert_networks) + for i in range(num_experts): + for j in range(i + 1, num_experts): + similarity = F.cosine_similarity(expert_outputs[:, i, :], expert_outputs[:, j, :], dim=-1) + diversity_loss += torch.mean(similarity) + num_pairs = num_experts * (num_experts - 1) / 2 + diversity_loss = diversity_loss / num_pairs + self.aux_loss_map['moe_diversity_loss'] = self.lambda_diversity * diversity_loss + + # Aggregate Expert Outputs + gating_weights = gating_weights.unsqueeze(-1) # Shape: [batch_size, num_experts, 1] + aggregated_output = torch.sum(gating_weights * expert_outputs, dim=1) # Shape: [batch_size, hidden_size] + out = aggregated_output + return out + + +class MoEBlock(nn.Module): + def __init__(self, + input_size: int, + model_units: list[int], + expert_units: list[int], + num_experts: int, + ): + super().__init__() + self.num_experts = num_experts + in_size = input_size + layers =[] + for u, m in zip(expert_units, model_units): + layers.append(MoEFF(in_size, m, u, num_experts)) + in_size = u + self.layers = nn.ModuleList(layers) + self.load_balancing_loss = None + + def get_aux_loss(self): + return { + "moe_load_balancing_loss": self.load_balancing_loss + } + + def forward(self, x: torch.Tensor): + moe_diversity_loss, moe_entropy_loss = 0, 0 + for layer in self.layers: + x = layer(x) + moe_diversity_loss = moe_diversity_loss + layer.get_aux_loss()['moe_diversity_loss'] + moe_entropy_loss = moe_diversity_loss + layer.get_aux_loss()['moe_entropy_loss'] + + self.load_balancing_loss = moe_diversity_loss / len(self.layers) + moe_entropy_loss / len(self.layers) + return x + +''' + def forward(self, x: torch.Tensor): + counts, route_prob_sums, route_prob_maxs = [], [], [] + for layer in self.layers: + x, count, route_prob_sum, route_prob_max = layer(x) + counts.append(count) + route_prob_sums.append(route_prob_sum) + route_prob_maxs.append(route_prob_max) + + counts = torch.stack(counts) + route_prob_sums = torch.stack(route_prob_sums) + route_prob_maxs = torch.stack(route_prob_maxs) + + total = counts.sum(dim=-1, keepdims=True) + route_frac = counts / total + route_prob = route_prob_sums / total + + self.load_balancing_loss = self.num_experts * (route_frac * route_prob).sum() + return x +''' \ No newline at end of file diff --git a/rl_games/common/player.py b/rl_games/common/player.py index 98be6501..f2181603 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -12,6 +12,8 @@ from rl_games.common import env_configurations from rl_games.algos_torch import model_builder +import pandas as pd + class BasePlayer(object): @@ -271,6 +273,9 @@ def init_rnn(self): )[2]), dtype=torch.float32).to(self.device) for s in rnn_states] def run(self): + # create pandas dataframe with fields: game_index, observation, action, reward and done + df = pd.DataFrame(columns=['game_index', 'observation', 'action', 'reward', 'done']) + n_games = self.games_num render = self.render_env n_game_life = self.n_game_life @@ -313,6 +318,8 @@ def run(self): print_game_res = False + game_indices = torch.arange(0, batch_size).to(self.device) + cur_games = batch_size for n in range(self.max_steps): if self.evaluation and n % self.update_checkpoint_freq == 0: self.maybe_load_new_checkpoint() @@ -324,7 +331,11 @@ def run(self): else: action = self.get_action(obses, is_deterministic) + prev_obses = obses obses, r, done, info = self.env_step(self.env, action) + + for i in range(batch_size): + df.loc[len(df)] = [game_indices[i].cpu().numpy().item(), prev_obses[i].cpu().numpy(), action[i].cpu().numpy(), r[i].cpu().numpy().item(), done[i].cpu().numpy().item()] cr += r steps += 1 @@ -337,6 +348,9 @@ def run(self): done_count = len(done_indices) games_played += done_count + for bid in done_indices: + game_indices[bid] = cur_games + cur_games += 1 if done_count > 0: if self.is_rnn: for s in self.states: @@ -380,6 +394,9 @@ def run(self): print('av reward:', sum_rewards / games_played * n_game_life, 'av steps:', sum_steps / games_played * n_game_life) + # save game data to parquet file + df.to_parquet('game_data.parquet') + def get_batch_size(self, obses, batch_size): obs_shape = self.obs_shape if type(self.obs_shape) is dict: diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index c29fd4be..275f9e37 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -7,6 +7,7 @@ from time import sleep import torch + class RayWorker: """Wrapper around a third-party (gym for example) environment class that enables parallel training. @@ -47,7 +48,7 @@ def step(self, action): """ next_state, reward, is_done, info = self.env.step(action) - + if np.isscalar(is_done): episode_done = is_done else: @@ -64,7 +65,7 @@ def seed(self, seed): np.random.seed(seed) random.seed(seed) self.env.seed(seed) - + def render(self): self.env.render() @@ -95,7 +96,7 @@ def get_env_info(self): info = {} observation_space = self.env.observation_space - #if isinstance(observation_space, gym.spaces.dict.Dict): + # if isinstance(observation_space, gym.spaces.dict.Dict): # observation_space = observation_space['observations'] info['action_space'] = self.env.action_space @@ -115,12 +116,16 @@ def get_env_info(self): class RayVecEnv(IVecEnv): """Main env class that manages several `rl_games.common.vecenv.Rayworker` objects for parallel training - + The RayVecEnv class manages a set of individual environments and wraps around the methods from RayWorker. Each worker is executed asynchronously. """ - import ray + # To avoid import errors when Ray is not installed and this class is not used + try: + import ray + except ImportError: + pass def __init__(self, config_name, num_actors, **kwargs): """Initialise the class. Sets up the config for the environment and creates individual workers to manage. @@ -136,7 +141,6 @@ def __init__(self, config_name, num_actors, **kwargs): self.use_torch = False self.seed = kwargs.pop('seed', None) - self.remote_worker = self.ray.remote(RayWorker) self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)] @@ -162,7 +166,7 @@ def __init__(self, config_name, num_actors, **kwargs): self.concat_func = np.stack else: self.concat_func = np.concatenate - + def step(self, actions): """Step all individual environments (using the created workers). Returns a concatenated array of observations, rewards, done states, and infos if the env allows concatenation. @@ -201,7 +205,7 @@ def step(self, actions): if self.use_global_obs: newobsdict = {} newobsdict["obs"] = ret_obs - + if self.state_type_dict: newobsdict["states"] = dicts_to_dict_with_arrays(newstates, True) else: @@ -231,7 +235,7 @@ def get_action_masks(self): def reset(self): res_obs = [worker.reset.remote() for worker in self.workers] - newobs, newstates = [],[] + newobs, newstates = [], [] for res in res_obs: cobs = self.ray.get(res) if self.use_global_obs: @@ -248,7 +252,7 @@ def reset(self): if self.use_global_obs: newobsdict = {} newobsdict["obs"] = ret_obs - + if self.state_type_dict: newobsdict["states"] = dicts_to_dict_with_arrays(newstates, True) else: @@ -256,8 +260,10 @@ def reset(self): ret_obs = newobsdict return ret_obs + vecenv_config = {} + def register(config_name, func): """Add an environment type (for example RayVecEnv) to the list of available types `rl_games.common.vecenv.vecenv_config` Args: @@ -267,10 +273,12 @@ def register(config_name, func): """ vecenv_config[config_name] = func + def create_vec_env(config_name, num_actors, **kwargs): vec_env_name = configurations[config_name]['vecenv_type'] return vecenv_config[vec_env_name](config_name, num_actors, **kwargs) + register('RAY', lambda config_name, num_actors, **kwargs: RayVecEnv(config_name, num_actors, **kwargs)) from rl_games.envs.brax import BraxEnv diff --git a/rl_games/common/wrappers.py b/rl_games/common/wrappers.py index dab4a648..5c3b17c7 100644 --- a/rl_games/common/wrappers.py +++ b/rl_games/common/wrappers.py @@ -1,4 +1,3 @@ -import gymnasium import numpy as np from numpy.random import randint @@ -11,12 +10,12 @@ from copy import copy - class InfoWrapper(gym.Wrapper): def __init__(self, env): gym.RewardWrapper.__init__(self, env) - + self.reward = 0 + def reset(self, **kwargs): self.reward = 0 return self.env.reset(**kwargs) @@ -87,7 +86,7 @@ def __init__(self, env): """ gym.Wrapper.__init__(self, env) self.lives = 0 - self.was_real_done = True + self.was_real_done = True def step(self, action): obs, reward, done, info = self.env.step(action) @@ -122,7 +121,7 @@ def __init__(self, env): gym.Wrapper.__init__(self, env) self.max_stacked_steps = 1000 - self.current_steps=0 + self.current_steps = 0 def step(self, action): obs, reward, done, info = self.env.step(action) @@ -140,17 +139,17 @@ def step(self, action): class MaxAndSkipEnv(gym.Wrapper): - def __init__(self, env,skip=4, use_max = True): + def __init__(self, env, skip=4, use_max=True): """Return only every `skip`-th frame""" gym.Wrapper.__init__(self, env) self.use_max = use_max # most recent raw observations (for max pooling across time steps) if self.use_max: - self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) + self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8) else: - self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.float32) - self._skip = skip - + self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.float32) + self._skip = skip + def step(self, action): """Repeat action, sum reward, and max over last observations.""" total_reward = 0.0 @@ -211,8 +210,9 @@ def observation(self, frame): frame = np.expand_dims(frame, -1) return frame + class FrameStack(gym.Wrapper): - def __init__(self, env, k, flat = False): + def __init__(self, env, k, flat=False): """ Stack k last frames. Returns lazy array, which is much more memory efficient. @@ -262,7 +262,7 @@ def _get_ob(self): class BatchedFrameStack(gym.Wrapper): - def __init__(self, env, k, transpose = False, flatten = False): + def __init__(self, env, k, transpose=False, flatten=False): gym.Wrapper.__init__(self, env) self.k = k self.frames = deque([], maxlen=k) @@ -303,8 +303,9 @@ def _get_ob(self): frames = np.transpose(self.frames, (1, 0, 2)) return frames + class BatchedFrameStackWithStates(gym.Wrapper): - def __init__(self, env, k, transpose = False, flatten = False): + def __init__(self, env, k, transpose=False, flatten=False): gym.Wrapper.__init__(self, env) self.k = k self.obses = deque([], maxlen=k) @@ -363,6 +364,7 @@ def process_data(self, data): obses = np.transpose(data, (1, 0, 2)) return obses + class ProcgenStack(gym.Wrapper): def __init__(self, env, k = 2, greyscale=True): gym.Wrapper.__init__(self, env) @@ -370,7 +372,7 @@ def __init__(self, env, k = 2, greyscale=True): self.curr_frame = 0 self.frames = deque([], maxlen=k) - self.greyscale=greyscale + self.greyscale = greyscale self.prev_frame = None shp = env.observation_space.shape if greyscale: @@ -421,6 +423,7 @@ def observation(self, observation): # with smaller replay buffers only. return np.array(observation).astype(np.float32) / 255.0 + class LazyFrames(object): def __init__(self, frames): """This object ensures that common frames between the observations are only stored once. @@ -449,6 +452,7 @@ def __len__(self): def __getitem__(self, i): return self._force()[i] + class ReallyDoneWrapper(gym.Wrapper): def __init__(self, env): """ @@ -457,7 +461,7 @@ def __init__(self, env): self.old_env = env gym.Wrapper.__init__(self, env) self.lives = 0 - self.was_real_done = True + self.was_real_done = True def step(self, action): old_lives = self.env.unwrapped.ale.lives() @@ -471,6 +475,7 @@ def step(self, action): done = lives == 0 return obs, reward, done, info + class AllowBacktracking(gym.Wrapper): """ Use deltas in max(X) as the reward, rather than deltas @@ -506,6 +511,7 @@ def unwrap(env): else: return env + class StickyActionEnv(gym.Wrapper): def __init__(self, env, p=0.25): super(StickyActionEnv, self).__init__(env) @@ -591,7 +597,7 @@ def step(self, action): obs, reward, done, info = self.env.step(action) obs = { 'observation': obs, - 'reward':np.clip(reward, -1, 1), + 'reward': np.clip(reward, -1, 1), 'last_action': action } return obs, reward, done, info @@ -625,10 +631,13 @@ def __init__(self, env, name): raise NotImplementedError def observation(self, observation): - return observation * self.mask + return observation * self.mask + class OldGymWrapper(gym.Env): def __init__(self, env): + import gymnasium + self.env = env # Convert Gymnasium spaces to Gym spaces @@ -636,6 +645,8 @@ def __init__(self, env): self.action_space = self.convert_space(env.action_space) def convert_space(self, space): + import gymnasium + """Recursively convert Gymnasium spaces to Gym spaces.""" if isinstance(space, gymnasium.spaces.Box): return gym.spaces.Box( @@ -691,6 +702,7 @@ def render(self, mode='human'): def close(self): return self.env.close() + # Example usage: if __name__ == "__main__": # Create a MyoSuite environment @@ -718,19 +730,21 @@ def make_atari(env_id, timelimit=True, noop_max=0, skip=4, sticky=False, directo env = MontezumaInfoWrapper(env, room_address=3 if 'Montezuma' in env_id else 1) env = StickyActionEnv(env) env = InfoWrapper(env) - if directory != None: - env = gym.wrappers.Monitor(env,directory=directory,force=True) + + if directory is not None: + env = gym.wrappers.Monitor(env, directory=directory, force=True) if sticky: env = StickyActionEnv(env) if not timelimit: env = env.env - #assert 'NoFrameskip' in env.spec.id + # assert 'NoFrameskip' in env.spec.id if noop_max > 0: env = NoopResetEnv(env, noop_max=noop_max) env = MaxAndSkipEnv(env, skip=skip) - #env = EpisodeStackedEnv(env) + # env = EpisodeStackedEnv(env) return env + def wrap_deepmind(env, episode_life=False, clip_rewards=True, frame_stack=True, scale =False, wrap_impala=False): """Configure environment for DeepMind-style Atari. """ @@ -749,6 +763,7 @@ def wrap_deepmind(env, episode_life=False, clip_rewards=True, frame_stack=True, env = ImpalaEnvWrapper(env) return env + def wrap_carracing(env, clip_rewards=True, frame_stack=True, scale=False): """Configure environment for DeepMind-style Atari. """ @@ -761,11 +776,12 @@ def wrap_carracing(env, clip_rewards=True, frame_stack=True, scale=False): env = FrameStack(env, 4) return env + def make_car_racing(env_id, skip=4): env = make_atari(env_id, noop_max=0, skip=skip) return wrap_carracing(env, clip_rewards=False) + def make_atari_deepmind(env_id, noop_max=30, skip=4, sticky=False, episode_life=True, wrap_impala=False, **kwargs): env = make_atari(env_id, noop_max=noop_max, skip=skip, sticky=sticky, **kwargs) return wrap_deepmind(env, episode_life=episode_life, clip_rewards=False, wrap_impala=wrap_impala) - diff --git a/rl_games/configs/bark/ppo_merging.yaml b/rl_games/configs/bark/ppo_merging.yaml new file mode 100644 index 00000000..253c4273 --- /dev/null +++ b/rl_games/configs/bark/ppo_merging.yaml @@ -0,0 +1,64 @@ +params: + seed: 5 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: actor_critic + separate: False + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + mlp: + units: [256, 128, 64] + activation: elu + initializer: + name: default + + config: + name: Ant-v3_ray + env_name: openai_gym + score_to_win: 20000 + normalize_input: True + normalize_value: True + value_bootstrap: True + reward_shaper: + scale_value: 0.1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + max_epochs: 2000 + num_actors: 8 #64 + horizon_length: 256 #64 + minibatch_size: 2048 + mini_epochs: 4 + critic_coef: 2 + clip_value: True + use_smooth_clamp: True + bound_loss_type: regularisation + bounds_loss_coef: 0.0 + + env_config: + name: "merging-v0" + seed: 5 + + player: + render: True \ No newline at end of file diff --git a/rl_games/configs/mujoco/ant_envpool.yaml b/rl_games/configs/mujoco/ant_envpool.yaml index da769e45..54eb015f 100644 --- a/rl_games/configs/mujoco/ant_envpool.yaml +++ b/rl_games/configs/mujoco/ant_envpool.yaml @@ -62,4 +62,7 @@ params: #flat_observation: True player: - render: False \ No newline at end of file + render: False + num_actors: 64 + games_num: 1000 + use_vecenv: True \ No newline at end of file diff --git a/rl_games/configs/mujoco/ant_envpool_moe.yaml b/rl_games/configs/mujoco/ant_envpool_moe.yaml new file mode 100644 index 00000000..cdfbc1f2 --- /dev/null +++ b/rl_games/configs/mujoco/ant_envpool_moe.yaml @@ -0,0 +1,70 @@ +params: + seed: 5 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: actor_critic + separate: False + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + + moe: + num_experts: 4 + expert_units: [256, 128, 64] + model_units: [256, 128, 64] + #expert_activation: elu + initializer: + name: default + config: + name: Ant-v4_envpool_moe + env_name: envpool + score_to_win: 20000 + normalize_input: True + normalize_value: True + value_bootstrap: True + normalize_advantage: True + reward_shaper: + scale_value: 1 + + gamma: 0.99 + tau: 0.95 + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + clip_value: True + use_smooth_clamp: True + bound_loss_type: regularisation + bounds_loss_coef: 0.0 + max_epochs: 2000 + num_actors: 64 + horizon_length: 64 + minibatch_size: 2048 + mini_epochs: 4 + critic_coef: 2 + + env_config: + env_name: Ant-v4 + seed: 5 + #flat_observation: True + + player: + render: False + num_actors: 64 + games_num: 1000 + use_vecenv: True \ No newline at end of file diff --git a/rl_games/configs/mujoco/humanoid_envpool_moe.yaml b/rl_games/configs/mujoco/humanoid_envpool_moe.yaml new file mode 100644 index 00000000..1de67a92 --- /dev/null +++ b/rl_games/configs/mujoco/humanoid_envpool_moe.yaml @@ -0,0 +1,66 @@ +params: + seed: 5 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: actor_critic + separate: False + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + moe: + num_experts: 4 + expert_units: [512, 256, 128] + model_units: [512, 256, 128] + #expert_activation: elu + is_scale_prob: False + initializer: + name: default + + config: + name: Humanoid-v4_envpool + env_name: envpool + score_to_win: 20000 + normalize_input: True + normalize_value: True + value_bootstrap: True + reward_shaper: + scale_value: 0.1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + grad_norm: 1.0 + entropy_coef: 0.0 + truncate_grads: True + e_clip: 0.2 + clip_value: True + use_smooth_clamp: True + bound_loss_type: regularisation + bounds_loss_coef: 0.0005 + max_epochs: 2000 + num_actors: 64 + horizon_length: 128 + minibatch_size: 2048 + mini_epochs: 5 + critic_coef: 4 + + env_config: + env_name: Humanoid-v4 + + player: + render: True \ No newline at end of file