Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are you sure .train() puts batch norm in eval mode? I'd assume that means train... #23

Open
brando90 opened this issue Sep 10, 2024 · 2 comments

Comments

@brando90
Copy link

# Fix batch norm running statistics (i.e., put batch_norm layers in eval mode)

        # Fix batch norm running statistics (i.e., put batch_norm layers in eval mode)
        self.model.train()

is this truly correct?

.train() usually puts layers in training mode. So for batch norm what it means is that it start collecting running statistics but uses mini-batch stats if I remember correctly, while in .eval() it uses the saved running stats. Right?

@brando90
Copy link
Author

    # Since we are fine-tuning the model during for T2V/FIM computation, .train() is the right choice in general.
    self.model.train()

@brando90
Copy link
Author

brando90 commented Sep 10, 2024

For batch norm comment:

# Since we are fine-tuning the model during T2V/FIM computation, .train() is the right choice as it ensures batch norm uses mini-batch statistics and properly adapts the model to the new task.
self.model.train()

but LLMs don't really use batch norm so doesn't matter...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant