diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index c561ea1b22..9beb7b0162 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -434,6 +434,7 @@ def __init__( ), LayerNorm2d(neck_chans), ) + self.num_features = neck_chans else: self.neck = nn.Identity() neck_chans = embed_dim