Skip to content

Commit

Permalink
Merge pull request #2092 from huggingface/mesa_ema
Browse files Browse the repository at this point in the history
ModelEMAV3 + MESA experiments
  • Loading branch information
rwightman authored Feb 11, 2024
2 parents 88889de + 47c9bc4 commit 1b50b15
Show file tree
Hide file tree
Showing 13 changed files with 1,065 additions and 119 deletions.
8 changes: 2 additions & 6 deletions timm/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,12 @@ def __call__(self, pil_img):


class ToTensor:

""" ToTensor with no rescaling of values"""
def __init__(self, dtype=torch.float32):
self.dtype = dtype

def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)


# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
Expand Down
8 changes: 4 additions & 4 deletions timm/layers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def __init__(
self.drop = nn.Dropout(drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
def reset(self, num_classes, pool_type=None):
if pool_type is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
Expand Down
2 changes: 1 addition & 1 deletion timm/layers/create_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
return _ACT_LAYER_DEFAULT[name]


def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
act_layer = get_act_layer(name)
if act_layer is None:
return None
Expand Down
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .mvitv2 import *
from .nasnet import *
from .nest import *
from .nextvit import *
from .nfnet import *
from .pit import *
from .pnasnet import *
Expand Down
13 changes: 12 additions & 1 deletion timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,17 @@ def _init_weights(self, m):
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)

@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm_pre', (99999,)),
]
)

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
Expand All @@ -558,7 +569,7 @@ def get_classifier(self):
return self.head.fc

def reset_classifier(self, num_classes, global_pool=None):
self.head.reset(num_classes, global_pool=global_pool)
self.head.reset(num_classes, global_pool)

def forward_features(self, x):
x = self.stem(x)
Expand Down
Loading

0 comments on commit 1b50b15

Please sign in to comment.