Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wrh): add adaptive batch size for transition #256

Open
wants to merge 4 commits into
base: dev-unizero-multitask-v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions lzero/entry/train_unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.entry.utils import log_buffer_memory_usage, clamp, softmax_with_temperature
from lzero.policy import visit_count_temperature
from lzero.worker import MuZeroCollector as Collector, MuZeroEvaluator as Evaluator
from lzero.mcts import UniZeroGameBuffer as GameBuffer
Expand Down Expand Up @@ -192,10 +192,16 @@ def train_unizero_multitask(
for i in range(update_per_collect):
train_data_multi_task = []
envstep_multi_task = 0
if cfg.policy.adaptive_batch_size_for_transition:
buffer_num_of_transitions_list = [buffer.get_num_of_transitions() for buffer in game_buffers]
buffer_ratio_list = [x / sum(buffer_num_of_transitions_list) for x in buffer_num_of_transitions_list]
softmax_buffer_ratio_list = softmax_with_temperature(buffer_ratio_list, temperature=cfg.policy.temperature_for_softmax_list)
adaptive_batch_size_list = [int(x * cfg.policy.adaptive_total_batch_size) for x in softmax_buffer_ratio_list]
for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)):
envstep_multi_task += collector.envstep
if replay_buffer.get_num_of_transitions() > batch_size:
batch_size = cfg.policy.batch_size[task_id]
batch_size = adaptive_batch_size_list[task_id] \
if cfg.policy.adaptive_batch_size_for_transition else cfg.policy.batch_size[task_id]
train_data = replay_buffer.sample(batch_size, policy)
if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0:
policy.recompute_pos_emb_diff_and_clear_cache()
Expand All @@ -210,7 +216,8 @@ def train_unizero_multitask(
break

if train_data_multi_task:
log_vars = learner.train(train_data_multi_task, envstep_multi_task)
batch_size_list = adaptive_batch_size_list if cfg.policy.adaptive_batch_size_for_transition else cfg.policy.batch_size
log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs={"batch_size_list": batch_size_list})

if cfg.policy.use_priority:
for task_id, replay_buffer in enumerate(game_buffers):
Expand Down
15 changes: 14 additions & 1 deletion lzero/entry/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import os
from typing import Optional, Callable
from typing import Optional, Callable, Union, List

import psutil
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter
from typing import Optional, Callable
import torch
import torch.nn.functional as F

def clamp(x: Union[int, float], min: Optional[Union[int, float]]=None, max: Optional[Union[int, float]]=None) -> Union[int, float]:
if min is not None and x < min:
return min
elif max is not None and x > max:
return max
else:
return x

def softmax_with_temperature(input_list: List[Union[int, float]], temperature: float=1) -> List[Union[int, float]]:
list_2_tensor = torch.tensor(input_list, dtype=torch.float32) / temperature
softmax_tensor = F.softmax(list_2_tensor, dim=0)
return [round(x, 2) for x in softmax_tensor.tolist()]

def initialize_zeros_batch(observation_shape, batch_size, device):
"""
Expand Down
33 changes: 22 additions & 11 deletions lzero/policy/unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from .utils import configure_optimizers_nanogpt
from line_profiler import line_profiler

sys.path.append('/Users/puyuan/code/LibMTL/')
from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect
# sys.path.append('/Users/puyuan/code/LibMTL/')
# from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect
# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect
# from LibMTL.weighting.FAMO_unizero import FAMO as GradCorrect # NOTE: FAMO have bugs now

Expand Down Expand Up @@ -339,6 +339,15 @@ class UniZeroMTPolicy(UniZeroPolicy):
# (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048.
use_ture_chance_label_in_chance_encoder=False,

# ****** Adaptive for transition length ******
# (bool) Whether to adapt batch size according to the transition lengths of different tasks.
adaptive_batch_size_for_transition=False,
# (int) The total batch size for all tasks
adaptive_total_batch_size=1500,
# (float) The temperature for softmax when allocating different weights
temperature_for_softmax_list=1.0,


# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
use_priority=False,
Expand Down Expand Up @@ -467,17 +476,17 @@ def _init_learn(self) -> None:
# 将 wrapped_model 作为 share_model 传递给 GradCorrect
# ========= 初始化 MoCo CAGrad 参数 =========
self.task_num = self._cfg.task_num
self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device)
self.grad_correct.init_param()
self.grad_correct.rep_grad = False
# self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device)
# self.grad_correct.init_param()
# self.grad_correct.rep_grad = False

# =========only for FAMO =========
# self.grad_correct.set_min_losses(torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device))
# self.curr_min_loss = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device)
# self.grad_correct.prev_loss = self.curr_min_loss

#@profile
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
def _forward_learn(self, data: Tuple[torch.Tensor], **kwargs) -> Dict[str, Union[float, int]]:
"""
Overview:
The forward function for learning policy in learn mode, which is the core of the learning process.
Expand All @@ -490,6 +499,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
- info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \
current learning loss and learning statistics.
"""
batch_size_list = kwargs.get("batch_size_list", self._cfg.batch_size)
print(f"batch size list is {batch_size_list}")
self._learn_model.train()
self._target_model.train()

Expand Down Expand Up @@ -540,10 +551,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list,
self._cfg.device)

target_reward = target_reward.view(self._cfg.batch_size[task_id], -1)
target_value = target_value.view(self._cfg.batch_size[task_id], -1)
target_reward = target_reward.view(batch_size_list[task_id], -1)
target_value = target_value.view(batch_size_list[task_id], -1)

# assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0)
# assert obs_batch.size(0) == batch_size_list == target_reward.size(0)

# Transform rewards and values to their scaled forms
transformed_target_reward = scalar_transform(target_reward)
Expand All @@ -557,10 +568,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
batch_for_gpt = {}
if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1:
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(
self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape)
batch_size_list[task_id], -1, self._cfg.model.observation_shape)
elif len(self._cfg.model.observation_shape) == 3:
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(
self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape)
batch_size_list[task_id], -1, *self._cfg.model.observation_shape)

batch_for_gpt['actions'] = action_batch.squeeze(-1)
batch_for_gpt['rewards'] = target_reward_categorical[:, :-1]
Expand Down
18 changes: 16 additions & 2 deletions zoo/atari/config/atari_unizero_multitask_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
# collector_env_num=collector_env_num,
# evaluator_env_num=evaluator_env_num,
task_num=len(env_id_list),

# ==== for soft modulizastion ====
use_soft_modulization_head=False,
num_modules_per_layer=4,
num_layers_for_sm=3,
gating_embed_mlp_num=2,

use_normal_head=True,
# use_normal_head=False,
use_softmoe_head=False,
Expand Down Expand Up @@ -91,6 +98,13 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
replay_buffer_size=int(1e6),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,

adaptive_batch_size_for_transition=False,
# adaptive_total_batch_size=1500,
adaptive_total_batch_size=512, # for debug
# min_clamp_ratio_for_adaptive_bs=0.06,
# max_clamp_ratio_for_adaptive_bs=0.84,
temperature_for_softmax_list=0.5
),
))

Expand All @@ -103,7 +117,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
# exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/'
# exp_name_prefix = f'data_unizero_mt_0722_debug/{len(env_id_list)}games_1-encoder-{norm_type}_trans-ffw-moeV2-expert4_4-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/'
# exp_name_prefix = f'data_unizero_mt_0722_profile/lineprofile_{len(env_id_list)}games_1-encoder-{norm_type}_4-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/'
exp_name_prefix = f'data_unizero_mt_0722/{len(env_id_list)}games_1-encoder-{norm_type}_4-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/'
exp_name_prefix = f'data_unizero_mt_0801/{len(env_id_list)}games_1-encoder-{norm_type}_4-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/'

for task_id, env_id in enumerate(env_id_list):
config = create_config(
Expand Down Expand Up @@ -172,7 +186,7 @@ def create_env_manager():
max_env_step = int(1e6)
reanalyze_ratio = 0.
# batch_size = [32, 32, 32, 32]
max_batch_size = 2000
max_batch_size = 512
batch_size = [int(max_batch_size/len(env_id_list)) for i in range(len(env_id_list))]
num_unroll_steps = 10
infer_context_length = 4
Expand Down