diff --git a/main_dino.py b/main_dino.py index cade9873d..bd99c1bf6 100644 --- a/main_dino.py +++ b/main_dino.py @@ -175,7 +175,10 @@ def train_dino(args): elif args.arch in torchvision_models.__dict__.keys(): student = torchvision_models.__dict__[args.arch]() teacher = torchvision_models.__dict__[args.arch]() - embed_dim = student.fc.weight.shape[1] + if args.arch.find("mobile") == -1: + embed_dim = student.fc.weight.shape[1] + else: + embed_dim = student.classifier[1].out_features else: print(f"Unknow architecture: {args.arch}")