Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in teacher weights calculation #21

Open
ShpihanVlad opened this issue Sep 3, 2023 · 4 comments
Open

Bug in teacher weights calculation #21

ShpihanVlad opened this issue Sep 3, 2023 · 4 comments

Comments

@ShpihanVlad
Copy link

Hi, your code implementation contains error in updating teacher weights.
Basically, current code implementation for WeightEMA at utils/torch_utils.py breaks statistics, which are saved in model for batch norm, if I recall correctly. Because of this, teacher model after about 5th epoch begins making invalid predictions, which further hurts training and it becomes a little worse, than without using teacher model at all.

To fix this, you can refer to original yolov5 EMA ModelEMA in the same file just above. I was able to rewrite the code that way and then reproduce results, which are close to the ones in the paper even at image size equal to 640.

My old issue #18 was from this bug, and currently not closed #6 faced that issue too.

I can make a pull request later in the next week, if you wish.

@hnuzhy
Copy link
Owner

hnuzhy commented Sep 3, 2023

OK. It is very nice of your to fix this bug. And I suggest you add environment details of your machine. Because I did not have this bug in my server. This bug may be closely related to PyTorch version or other libs.

@ShpihanVlad
Copy link
Author

ShpihanVlad commented Sep 3, 2023

@hnuzhy OK, I will. I worked with this project in May, so some of my logs are lost, so it may take some time. I remember that you can verify, that some of the model parameters are integers, using something like

for k, v in model.state_dict().items():
                if not v.dtype.is_floating_point:
                   print(k, v.dtype)

This is the reason why this check is present in ModelEMA. I'm not sure whether these values become floats or zeros, but that was the root of the issue, I'll try to make report later.

Also as far as I remember, there were some other issues, I'll try to find them and hopefully fix again

@nice3310
Copy link

@ShpihanVlad Hello, I'm currently facing the same issue as you, but after attempting to modify the calculation of the teacher model weight, there hasn't been a significant improvement.

Could you please share the source code with me?

Additionally, I'd like to ask the final mAP50 training with an image size of 640 in your case.

I greatly appreciate any help you can provide.

@ShpihanVlad
Copy link
Author

@nice3310 Hi, about source code, I'll need to check about some credentials in my repo, as far as I remember there may be some info from gpu cloud service I used. Also final mAP@0.5 is not saved, but that was realy close to results reported in paper, less than 1.0 difference, also as I remember a little smaller, but this may be just random initialization.

Here is some sample of code which I used:
`class WeightEMA (object):
"""
Exponential moving average weight optimizer for mean teacher model.
Based mainly on ModelEMA class by ultralitics team
"""

def __init__(self, teacher_model, alpha=0.99):
    self.model = teacher_model  # FP32 Teacher EMA
    self.alpha = alpha
    self.decay = lambda epoch: self.alpha
    for p in self.model.parameters():  # teacher
        p.requires_grad_(False)

def update(self, stud_model, epoch):
    # Update EMA parameters from student model
    with torch.no_grad():
        alpha = self.decay(epoch)
        msd = stud_model.module.state_dict() if is_parallel(stud_model) else stud_model.state_dict()
        for k, v in self.model.state_dict().items():
            if v.dtype.is_floating_point:  # weights, biases.
                v *= alpha
                v += (1. - alpha) * msd[k].detach()

def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
    # Update EMA attributes
    copy_attr(self.model, model, include, exclude)`

this was mainly copypasted from yolov5 original implementation, and as far as I remember I didn't even need update_attr. I needed to change alpha based on epoch, so formated my code like this, and could just change decay func self.decay inside child classes. As far as I remember I also changed some parts of training code, but can't provide on this for now. Just make sure teacher weight updates are called after each step of student optimization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants