From 5b701d7182d3e4d70d85a1327088ecba963885cb Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 11 Sep 2024 11:44:25 +0000 Subject: [PATCH] do all reduce on cpu --- open_diloco/train_pure_fsdp.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index c1cff51..02ada3e 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -37,7 +37,8 @@ # Function to initialize the distributed process group def ddp_setup(): - init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES)) + init_process_group(timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES)) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) @@ -109,7 +110,7 @@ def get_model(config: Config) -> LlamaForCausalLM: def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]: - return [param.data.detach().clone().to("cuda") for param in model.parameters()] + return [param.data.detach().clone().to("cpu") for param in model.parameters()] def train(config: Config): @@ -135,9 +136,12 @@ def train(config: Config): local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) nnodes = world_size // local_world_size + + # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local")) + device_mesh_cpu = init_device_mesh("gloo", (nnodes, local_world_size), mesh_dim_names=("global", "local")) - global_pg = device_mesh.get_group("global") + global_pg = device_mesh_cpu.get_group("global") local_pg = device_mesh.get_group("local") log(f"global pg world : {global_pg.size()}, local pg: {local_pg.size()}") @@ -145,7 +149,7 @@ def train(config: Config): model, sharding_strategy=sharding_strategy, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), - use_orig_params=config.torch_compile, + use_orig_params=True, process_group=local_pg, ) if config.torch_compile: @@ -159,6 +163,9 @@ def train(config: Config): ) # todo: in case of sharded grap op we need to offload the cpu model only once per nodes outer_optimizer = torch.optim.SGD(cpu_model, lr=config.diloco.outer_lr, momentum=0.9, nesterov=True) + # for param in outer_optimizer.param_groups[0]["params"]: + # log(param.device) + scheduler = get_cosine_schedule_with_warmup( inner_optimizer, num_warmup_steps=config.warmup_steps, @@ -179,6 +186,11 @@ def train(config: Config): while True: if rank == 0: log(f"outer_step step: {outer_step}") + # if "momentum_buffer" in outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]: + # momentum_buffer = outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]['momentum_buffer'] + # log(f"momentum buffer device: {momentum_buffer.device}, shape: {momentum_buffer.shape}") + # else: + # log("no momentum buffer") for inner_step in range(config.diloco.local_steps): for grad_acc_step in range(gradient_accumulation_steps): is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 @@ -218,16 +230,22 @@ def train(config: Config): ## do the all reduce on cpu or on gpu ## do the outer optimizer step on cpu or on gpu - for param_offloaded, param in zip( - cpu_model, model.parameters() - ): # There is only one big fat tensor in the param because of fsdp 1 bucket stuff + for param_offloaded, param in zip(cpu_model, model.parameters()): # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) + + if param_offloaded.grad.device != torch.device("cpu"): + # gloo does not support AVG + param_offloaded.grad = param_offloaded.grad / global_pg.size() + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg) + else: + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) outer_optimizer.step() outer_optimizer.zero_grad() + # todo for the SHARD_GRAD_OP strategy we need to do one cpu -> gpu 0 copy and then do + # gpu 0 -> gpu 1,2.. copy as it would be faster for param_offloaded, param in zip(cpu_model, model.parameters()): param.data = param_offloaded.data.to("cuda")