Skip to content

Commit

Permalink
do all reduce on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 11, 2024
1 parent 94bc0ae commit 5b701d7
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

# Function to initialize the distributed process group
def ddp_setup():
init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES))
init_process_group(timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES))

torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


Expand Down Expand Up @@ -109,7 +110,7 @@ def get_model(config: Config) -> LlamaForCausalLM:


def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]:
return [param.data.detach().clone().to("cuda") for param in model.parameters()]
return [param.data.detach().clone().to("cpu") for param in model.parameters()]


def train(config: Config):
Expand All @@ -135,17 +136,20 @@ def train(config: Config):

local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
nnodes = world_size // local_world_size

# right now device mesh does not support two backend so we just create two identicaly mesh expect the backend
device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local"))
device_mesh_cpu = init_device_mesh("gloo", (nnodes, local_world_size), mesh_dim_names=("global", "local"))

global_pg = device_mesh.get_group("global")
global_pg = device_mesh_cpu.get_group("global")
local_pg = device_mesh.get_group("local")
log(f"global pg world : {global_pg.size()}, local pg: {local_pg.size()}")

model = FSDP(
model,
sharding_strategy=sharding_strategy,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
use_orig_params=config.torch_compile,
use_orig_params=True,
process_group=local_pg,
)
if config.torch_compile:
Expand All @@ -159,6 +163,9 @@ def train(config: Config):
) # todo: in case of sharded grap op we need to offload the cpu model only once per nodes

Check failure on line 163 in open_diloco/train_pure_fsdp.py

View workflow job for this annotation

GitHub Actions / codespell

grap ==> grep, grape
outer_optimizer = torch.optim.SGD(cpu_model, lr=config.diloco.outer_lr, momentum=0.9, nesterov=True)

# for param in outer_optimizer.param_groups[0]["params"]:
# log(param.device)

scheduler = get_cosine_schedule_with_warmup(
inner_optimizer,
num_warmup_steps=config.warmup_steps,
Expand All @@ -179,6 +186,11 @@ def train(config: Config):
while True:
if rank == 0:
log(f"outer_step step: {outer_step}")
# if "momentum_buffer" in outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]:
# momentum_buffer = outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]['momentum_buffer']
# log(f"momentum buffer device: {momentum_buffer.device}, shape: {momentum_buffer.shape}")
# else:
# log("no momentum buffer")
for inner_step in range(config.diloco.local_steps):
for grad_acc_step in range(gradient_accumulation_steps):
is_accumulating = grad_acc_step < gradient_accumulation_steps - 1
Expand Down Expand Up @@ -218,16 +230,22 @@ def train(config: Config):
## do the all reduce on cpu or on gpu
## do the outer optimizer step on cpu or on gpu

for param_offloaded, param in zip(
cpu_model, model.parameters()
): # There is only one big fat tensor in the param because of fsdp 1 bucket stuff
for param_offloaded, param in zip(cpu_model, model.parameters()):
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg)

if param_offloaded.grad.device != torch.device("cpu"):
# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg)
else:
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg)

outer_optimizer.step()
outer_optimizer.zero_grad()

# todo for the SHARD_GRAD_OP strategy we need to do one cpu -> gpu 0 copy and then do
# gpu 0 -> gpu 1,2.. copy as it would be faster
for param_offloaded, param in zip(cpu_model, model.parameters()):
param.data = param_offloaded.data.to("cuda")

Expand Down

0 comments on commit 5b701d7

Please sign in to comment.