From 13d97b048cd9149694b98332c53330394c2fba3a Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 24 Sep 2024 01:38:20 +0000 Subject: [PATCH 1/3] add smth special maybe ? --- open_diloco/simulate_multi_node.sh | 0 open_diloco/train_pure_fsdp.py | 15 +++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) mode change 100644 => 100755 open_diloco/simulate_multi_node.sh diff --git a/open_diloco/simulate_multi_node.sh b/open_diloco/simulate_multi_node.sh old mode 100644 new mode 100755 diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index a9a6afc..d35b6b5 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -240,13 +240,16 @@ def train(config: Config): for param_offloaded, param in zip(cpu_model, model.parameters()): # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + + mask = torch.rand_like(param_offloaded.grad) > 0.95 + + data_to_all_reduce = param_offloaded.grad * mask - if param_offloaded.grad.device == torch.device("cpu"): - # gloo does not support AVG - param_offloaded.grad = param_offloaded.grad / global_pg.size() - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg) - else: - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) + # gloo does not support AVG + data_to_all_reduce = data_to_all_reduce / global_pg.size() + dist.all_reduce(data_to_all_reduce, op=dist.ReduceOp.SUM, group=global_pg) + + param_offloaded.grad += data_to_all_reduce outer_optimizer.step() outer_optimizer.zero_grad() From b3e2dd8854f6d3cbda9fd73f5f3b70080f853967 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 24 Sep 2024 02:24:33 +0000 Subject: [PATCH 2/3] fix it --- open_diloco/train_pure_fsdp.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index d35b6b5..cbbfa45 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -243,14 +243,15 @@ def train(config: Config): mask = torch.rand_like(param_offloaded.grad) > 0.95 - data_to_all_reduce = param_offloaded.grad * mask + data_to_send = param_offloaded.grad * mask + data_to_send_pre_reduce = data_to_send.clone() # gloo does not support AVG - data_to_all_reduce = data_to_all_reduce / global_pg.size() - dist.all_reduce(data_to_all_reduce, op=dist.ReduceOp.SUM, group=global_pg) - - param_offloaded.grad += data_to_all_reduce + data_to_send = data_to_send / global_pg.size() + dist.all_reduce(data_to_send, op=dist.ReduceOp.SUM, group=global_pg) + param_offloaded.grad += data_to_send - data_to_send_pre_reduce # removing the + outer_optimizer.step() outer_optimizer.zero_grad() From 1e029724deabbc93b66aaabffa7edf1a8fc729c5 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 24 Sep 2024 02:27:40 +0000 Subject: [PATCH 3/3] fix it --- open_diloco/simulate_multi_node.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/simulate_multi_node.sh b/open_diloco/simulate_multi_node.sh index c5def4a..fde4efa 100755 --- a/open_diloco/simulate_multi_node.sh +++ b/open_diloco/simulate_multi_node.sh @@ -57,7 +57,7 @@ mkdir -p logs for i in $(seq 0 $(($N - 1 ))) do > logs/log$i - CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & + CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & child_pids+=($!) done