Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛[BUG]: CorrDiff loss is scaled by hyper-parameter #605

Closed
chychen opened this issue Jul 17, 2024 · 3 comments
Closed

🐛[BUG]: CorrDiff loss is scaled by hyper-parameter #605

chychen opened this issue Jul 17, 2024 · 3 comments
Assignees
Labels
? - Needs Triage Need team to review and classify bug Something isn't working

Comments

@chychen
Copy link
Contributor

chychen commented Jul 17, 2024

Version

latest

On which installation method(s) does this occur?

Source

Describe the issue

CorrDiff loss is scaled by hyper-parameter, therefore we could not make a hyper-parameter search, because each run cannot be compared to the others.

example:

  • if batch_gpu_total = 1, loss_accum = L, when batch_gpu_total = 2, loss_accum = L/2
  • if batch_size_gpu = 1, loss_accum = L, when batch_size_gpu = 2, loss_accum = 2*L

why not just normalize it by batch_size_global? such as below

Now Implementation

      for round_idx in range(num_accumulation_rounds):
            with ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
                ...
                loss = loss.sum().mul(loss_scaling / batch_gpu_total)
                loss_accum += loss / num_accumulation_rounds
                loss.backward()

        loss_sum = torch.tensor([loss_accum], device=device)
        if dist.world_size > 1:
            torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM)
        average_loss = loss_sum / dist.world_size
        if dist.rank == 0:
            wb.log({"training loss": average_loss}, step=cur_nimg)

Proposed Modification

      for round_idx in range(num_accumulation_rounds):
            with ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
                ...
                loss = loss.sum().mul(loss_scaling / batch_size_global) ### Modified
                loss_accum += loss ### Modified
                loss.backward()

        loss_sum = torch.tensor([loss_accum], device=device)
        if dist.world_size > 1:
            torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM)
        average_loss = loss_sum / dist.world_size
        if dist.rank == 0:
            wb.log({"training loss": average_loss}, step=cur_nimg)

Minimum reproducible example

see README

Relevant log output

example:
- if `batch_gpu_total` = 1, `loss_accum` = L, when `batch_gpu_total` = 2, `loss_accum` = L/2
- if `batch_size_gpu` = 1, `loss_accum` = L, when `batch_size_gpu` = 2, `loss_accum` = 2*L

Environment details

No response

@chychen chychen added ? - Needs Triage Need team to review and classify bug Something isn't working labels Jul 17, 2024
@mnabian
Copy link
Collaborator

mnabian commented Oct 10, 2024

Hi @chychen , thanks for reporting the issue. I agree with the proposed modification. Could you please open a PR?

@mnabian
Copy link
Collaborator

mnabian commented Oct 17, 2024

@chychen did you have a chance to make a PR for this modification?

@chychen
Copy link
Contributor Author

chychen commented Oct 23, 2024

this is an 3-month old issue, seems like the latest version has already solve this issue.

@chychen chychen closed this as completed Oct 23, 2024
@chychen chychen reopened this Oct 23, 2024
@chychen chychen closed this as completed Oct 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants