diff --git a/tests/test_models.py b/tests/test_models.py index d8ac8d6438..e75b17f961 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -41,7 +41,7 @@ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*' ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index f308a580b8..56ff246c49 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -17,6 +17,8 @@ from .efficientformer import * from .efficientformer_v2 import * from .efficientnet import * +from .efficientvit_mit import * +from .efficientvit_msra import * from .eva import * from .focalnet import * from .gcvit import * diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py new file mode 100644 index 0000000000..6d123cd44d --- /dev/null +++ b/timm/models/efficientvit_mit.py @@ -0,0 +1,677 @@ +""" EfficientViT (by MIT Song Han's Lab) + +Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition` + - https://arxiv.org/abs/2205.14756 + +Adapted from official impl at https://github.com/mit-han-lab/efficientvit +""" + +__all__ = ['EfficientVit'] +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._registry import register_model, generate_default_cfgs +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from timm.layers import SelectAdaptivePool2d, create_conv2d + + +def val2list(x: list or tuple or any, repeat_time=1): + if isinstance(x, (list, tuple)): + return list(x) + return [x for _ in range(repeat_time)] + + +def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1): + # repeat elements if necessary + x = val2list(x) + if len(x) > 0: + x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] + + return tuple(x) + + +def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: + if isinstance(kernel_size, tuple): + return tuple([get_same_padding(ks) for ks in kernel_size]) + else: + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +class ConvNormAct(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=False, + dropout=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(ConvNormAct, self).__init__() + self.dropout = nn.Dropout(dropout, inplace=False) + self.conv = create_conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity() + self.act = act_layer(inplace=True) if act_layer else nn.Identity() + + def forward(self, x): + x = self.dropout(x) + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +class DSConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, None), + ): + super(DSConv, self).__init__() + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + + self.depth_conv = ConvNormAct( + in_channels, + in_channels, + kernel_size, + stride, + groups=in_channels, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.point_conv = ConvNormAct( + in_channels, + out_channels, + 1, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + + def forward(self, x): + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class MBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, nn.ReLU6, None), + ): + super(MBConv, self).__init__() + use_bias = val2tuple(use_bias, 3) + norm_layer = val2tuple(norm_layer, 3) + act_layer = val2tuple(act_layer, 3) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.inverted_conv = ConvNormAct( + in_channels, + mid_channels, + 1, + stride=1, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.depth_conv = ConvNormAct( + mid_channels, + mid_channels, + kernel_size, + stride=stride, + groups=mid_channels, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + self.point_conv = ConvNormAct( + mid_channels, + out_channels, + 1, + norm_layer=norm_layer[2], + act_layer=act_layer[2], + bias=use_bias[2], + ) + + def forward(self, x): + x = self.inverted_conv(x) + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class LiteMSA(nn.Module): + """Lightweight multi-scale attention""" + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int or None = None, + heads_ratio: float = 1.0, + dim=8, + use_bias=False, + norm_layer=(None, nn.BatchNorm2d), + act_layer=(None, None), + kernel_func=nn.ReLU, + scales=(5,), + eps=1e-5, + ): + super(LiteMSA, self).__init__() + self.eps = eps + heads = heads or int(in_channels // dim * heads_ratio) + total_dim = heads * dim + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + + self.dim = dim + self.qkv = ConvNormAct( + in_channels, + 3 * total_dim, + 1, + bias=use_bias[0], + norm_layer=norm_layer[0], + act_layer=act_layer[0], + ) + self.aggreg = nn.ModuleList([ + nn.Sequential( + nn.Conv2d( + 3 * total_dim, + 3 * total_dim, + scale, + padding=get_same_padding(scale), + groups=3 * total_dim, + bias=use_bias[0], + ), + nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), + ) + for scale in scales + ]) + self.kernel_func = kernel_func(inplace=False) + + self.proj = ConvNormAct( + total_dim * (1 + len(scales)), + out_channels, + 1, + bias=use_bias[1], + norm_layer=norm_layer[1], + act_layer=act_layer[1], + ) + + def forward(self, x): + B, _, H, W = x.shape + + # generate multi-scale q, k, v + qkv = self.qkv(x) + multi_scale_qkv = [qkv] + for op in self.aggreg: + multi_scale_qkv.append(op(qkv)) + multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) + multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2) + q, k, v = multi_scale_qkv.tensor_split(3, dim=-1) + + # lightweight global attention + q = self.kernel_func(q) + k = self.kernel_func(k) + v = F.pad(v, (0, 1), mode="constant", value=1.) + + kv = k.transpose(-1, -2) @ v + out = q @ kv + out = out[..., :-1] / (out[..., -1:] + self.eps) + + # final projection + out = out.transpose(-1, -2).reshape(B, -1, H, W) + out = self.proj(out) + return out + + +class EfficientVitBlock(nn.Module): + def __init__( + self, + in_channels, + heads_ratio=1.0, + head_dim=32, + expand_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + ): + super(EfficientVitBlock, self).__init__() + self.context_module = ResidualBlock( + LiteMSA( + in_channels=in_channels, + out_channels=in_channels, + heads_ratio=heads_ratio, + dim=head_dim, + norm_layer=(None, norm_layer), + ), + nn.Identity(), + ) + self.local_module = ResidualBlock( + MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + use_bias=(True, True, False), + norm_layer=(None, None, norm_layer), + act_layer=(act_layer, act_layer, None), + ), + nn.Identity(), + ) + + def forward(self, x): + x = self.context_module(x) + x = self.local_module(x) + return x + + +class ResidualBlock(nn.Module): + def __init__( + self, + main: Optional[nn.Module], + shortcut: Optional[nn.Module] = None, + pre_norm: Optional[nn.Module] = None, + ): + super(ResidualBlock, self).__init__() + self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() + self.main = main + self.shortcut = shortcut + + def forward(self, x): + res = self.main(self.pre_norm(x)) + if self.shortcut is not None: + res = res + self.shortcut(x) + return res + + +def build_local_block( + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: float, + norm_layer: str, + act_layer: str, + fewer_norm: bool = False, +): + if expand_ratio == 1: + block = DSConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + use_bias=(True, False) if fewer_norm else False, + norm_layer=(None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, None), + ) + else: + block = MBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, True, False) if fewer_norm else False, + norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, act_layer, None), + ) + return block + + +class Stem(nn.Sequential): + def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer): + super().__init__() + self.stride = 2 + + self.add_module( + 'in_conv', + ConvNormAct( + in_chs, out_chs, + kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer, + ) + ) + stem_block = 0 + for _ in range(depth): + self.add_module(f'res{stem_block}', ResidualBlock( + build_local_block( + in_channels=out_chs, + out_channels=out_chs, + stride=1, + expand_ratio=1, + norm_layer=norm_layer, + act_layer=act_layer, + ), + nn.Identity(), + )) + stem_block += 1 + + +class EfficientVitStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth, + norm_layer, + act_layer, + expand_ratio, + head_dim, + vit_stage=False, + ): + super(EfficientVitStage, self).__init__() + blocks = [ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=2, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + fewer_norm=vit_stage, + ), + None, + )] + in_chs = out_chs + + if vit_stage: + # for stage 3, 4 + for _ in range(depth): + blocks.append( + EfficientVitBlock( + in_channels=in_chs, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + ) + ) + else: + # for stage 1, 2 + for i in range(1, depth): + blocks.append(ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=1, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer + ), + nn.Identity(), + )) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + return self.blocks(x) + + +class ClassifierHead(nn.Module): + def __init__( + self, + in_channels, + widths, + n_classes=1000, + dropout=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + global_pool='avg', + ): + super(ClassifierHead, self).__init__() + self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + self.classifier = nn.Sequential( + nn.Linear(widths[0], widths[1], bias=False), + nn.LayerNorm(widths[1]), + act_layer(inplace=True), + nn.Dropout(dropout, inplace=False), + nn.Linear(widths[1], n_classes, bias=True), + ) + + def forward(self, x, pre_logits: bool = False): + x = self.in_conv(x) + x = self.global_pool(x) + if pre_logits: + return x + x = self.classifier(x) + return x + + +class EfficientVit(nn.Module): + def __init__( + self, + in_chans=3, + widths=(), + depths=(), + head_dim=32, + expand_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + global_pool='avg', + head_widths=(), + drop_rate=0.0, + num_classes=1000, + ): + super(EfficientVit, self).__init__() + self.grad_checkpointing = False + self.global_pool = global_pool + self.num_classes = num_classes + + # input stem + self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer) + stride = self.stem.stride + + # stages + self.feature_info = [] + stages = [] + stage_idx = 0 + in_channels = widths[0] + for i, (w, d) in enumerate(zip(widths[1:], depths[1:])): + stages.append(EfficientVitStage( + in_channels, + w, + depth=d, + norm_layer=norm_layer, + act_layer=act_layer, + expand_ratio=expand_ratio, + head_dim=head_dim, + vit_stage=i >= 2, + )) + stride *= 2 + in_channels = w + self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')] + stage_idx += 1 + + self.stages = nn.Sequential(*stages) + self.num_features = in_channels + self.head_widths = head_widths + self.head_dropout = drop_rate + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + self.head_widths, + n_classes=num_classes, + dropout=self.head_dropout, + global_pool=self.global_pool, + ) + else: + if self.global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + self.head = nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None)] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.classifier[-1] + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + self.head_widths, + n_classes=num_classes, + dropout=self.head_dropout, + global_pool=self.global_pool, + ) + else: + if self.global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.in_conv.conv', + 'classifier': 'head.classifier.4', + 'crop_pct': 0.95, + 'input_size': (3, 224, 224), + 'pool_size': (7, 7), + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'efficientvit_b0.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b1.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b1.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b1.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), + 'efficientvit_b2.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b2.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b2.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), + 'efficientvit_b3.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b3.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b3.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), +}) + + +def _create_efficientvit(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + EfficientVit, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +@register_model +def efficientvit_b0(pretrained=False, **kwargs): + model_args = dict( + widths=(8, 16, 32, 64, 128), depths=(1, 2, 2, 2, 2), head_dim=16, head_widths=(1024, 1280)) + return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b1(pretrained=False, **kwargs): + model_args = dict( + widths=(16, 32, 64, 128, 256), depths=(1, 2, 3, 3, 4), head_dim=16, head_widths=(1536, 1600)) + return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b2(pretrained=False, **kwargs): + model_args = dict( + widths=(24, 48, 96, 192, 384), depths=(1, 3, 4, 4, 6), head_dim=32, head_widths=(2304, 2560)) + return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b3(pretrained=False, **kwargs): + model_args = dict( + widths=(32, 64, 128, 256, 512), depths=(1, 4, 6, 6, 9), head_dim=32, head_widths=(2304, 2560)) + return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py new file mode 100644 index 0000000000..8940df0f16 --- /dev/null +++ b/timm/models/efficientvit_msra.py @@ -0,0 +1,652 @@ +""" EfficientViT (by MSRA) + +Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention` + - https://arxiv.org/abs/2305.07027 + +Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT +""" + +__all__ = ['EfficientVitMsra'] +import itertools +from collections import OrderedDict +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +class ConvNorm(torch.nn.Sequential): + def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + super().__init__() + self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False) + self.bn = nn.BatchNorm2d(out_chs) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self.conv, self.bn + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d( + w.size(1) * self.c.groups, w.size(0), w.shape[2:], + stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class NormLinear(torch.nn.Sequential): + def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.): + super().__init__() + self.bn = nn.BatchNorm1d(in_features) + self.drop = nn.Dropout(drop) + self.linear = nn.Linear(in_features, out_features, bias=bias) + + trunc_normal_(self.linear.weight, std=std) + if self.linear.bias is not None: + nn.init.constant_(self.linear.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, linear = self.bn, self.linear + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = linear.weight * w[None, :] + if linear.bias is None: + b = b @ self.linear.weight.T + else: + b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchMerging(torch.nn.Module): + def __init__(self, dim, out_dim): + super().__init__() + hid_dim = int(dim * 4) + self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0) + self.act = torch.nn.ReLU() + self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) + self.se = SqueezeExcite(hid_dim, .25) + self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0) + + def forward(self, x): + x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x)))))) + return x + + +class ResidualDrop(torch.nn.Module): + def __init__(self, m, drop=0.): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class ConvMlp(torch.nn.Module): + def __init__(self, ed, h): + super().__init__() + self.pw1 = ConvNorm(ed, h) + self.act = torch.nn.ReLU() + self.pw2 = ConvNorm(h, ed, bn_weight_init=0) + + def forward(self, x): + x = self.pw2(self.act(self.pw1(x))) + return x + + +class CascadedGroupAttention(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + + r""" Cascaded Group Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution, correspond to the window size. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + kernels=(5, 5, 5, 5), + ): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.val_dim = int(attn_ratio * key_dim) + self.attn_ratio = attn_ratio + + qkvs = [] + dws = [] + for i in range(num_heads): + qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim)) + dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim)) + self.qkvs = torch.nn.ModuleList(qkvs) + self.dws = torch.nn.ModuleList(dws) + self.proj = torch.nn.Sequential( + torch.nn.ReLU(), + ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0) + ) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.attention_bias_cache = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if torch.jit.is_tracing() or self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): + B, C, H, W = x.shape + feats_in = x.chunk(len(self.qkvs), dim=1) + feats_out = [] + feat = feats_in[0] + attn_bias = self.get_attention_biases(x.device) + for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)): + if head_idx > 0: + feat = feat + feats_in[head_idx] + feat = qkv(feat) + q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1) + q = dws(q) + q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) + q = q * self.scale + attn = q.transpose(-2, -1) @ k + attn = attn + attn_bias[head_idx] + attn = attn.softmax(dim=-1) + feat = v @ attn.transpose(-2, -1) + feat = feat.view(B, self.val_dim, H, W) + feats_out.append(feat) + x = self.proj(torch.cat(feats_out, 1)) + return x + + +class LocalWindowAttention(torch.nn.Module): + r""" Local Window Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=(5, 5, 5, 5), + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.resolution = resolution + assert window_resolution > 0, 'window_size must be greater than 0' + self.window_resolution = window_resolution + window_resolution = min(window_resolution, resolution) + self.attn = CascadedGroupAttention( + dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=window_resolution, + kernels=kernels, + ) + + def forward(self, x): + H = W = self.resolution + B, C, H_, W_ = x.shape + # Only check this for classifcation models + _assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}') + _assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}') + if H <= self.window_resolution and W <= self.window_resolution: + x = self.attn(x) + else: + x = x.permute(0, 2, 3, 1) + pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution + pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution + x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_resolution + nW = pW // self.window_resolution + # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw + x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3) + x = x.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C).permute(0, 3, 1, 2) + x = self.attn(x) + # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC + x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C) + x = x.transpose(2, 3).reshape(B, pH, pW, C) + x = x[:, :H, :W].contiguous() + x = x.permute(0, 3, 1, 2) + return x + + +class EfficientVitBlock(torch.nn.Module): + """ A basic EfficientVit building block. + + Args: + dim (int): Number of input channels. + key_dim (int): Dimension for query and key in the token mixer. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + ): + super().__init__() + + self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) + self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2))) + + self.mixer = ResidualDrop( + LocalWindowAttention( + dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=resolution, + window_resolution=window_resolution, + kernels=kernels, + ) + ) + + self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) + self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2))) + + def forward(self, x): + return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) + + +class EfficientVitStage(torch.nn.Module): + def __init__( + self, + in_dim, + out_dim, + key_dim, + downsample=('', 1), + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + depth=1, + ): + super().__init__() + if downsample[0] == 'subsample': + self.resolution = (resolution - 1) // downsample[1] + 1 + down_blocks = [] + down_blocks.append(( + 'res1', + torch.nn.Sequential( + ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim)), + ResidualDrop(ConvMlp(in_dim, int(in_dim * 2))), + ) + )) + down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim))) + down_blocks.append(( + 'res2', + torch.nn.Sequential( + ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim)), + ResidualDrop(ConvMlp(out_dim, int(out_dim * 2))), + ) + )) + self.downsample = nn.Sequential(OrderedDict(down_blocks)) + else: + assert in_dim == out_dim + self.downsample = nn.Identity() + self.resolution = resolution + + blocks = [] + for d in range(depth): + blocks.append(EfficientVitBlock(out_dim, key_dim, num_heads, attn_ratio, self.resolution, window_resolution, kernels)) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class PatchEmbedding(torch.nn.Sequential): + def __init__(self, in_chans, dim): + super().__init__() + self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1)) + self.add_module('relu1', torch.nn.ReLU()) + self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1)) + self.add_module('relu2', torch.nn.ReLU()) + self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1)) + self.add_module('relu3', torch.nn.ReLU()) + self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1)) + self.patch_size = 16 + + +class EfficientVitMsra(nn.Module): + def __init__( + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dim=(64, 128, 192), + key_dim=(16, 16, 16), + depth=(1, 2, 3), + num_heads=(4, 4, 4), + window_size=(7, 7, 7), + kernels=(5, 5, 5, 5), + down_ops=(('', 1), ('subsample', 2), ('subsample', 2)), + global_pool='avg', + drop_rate=0., + ): + super(EfficientVitMsra, self).__init__() + self.grad_checkpointing = False + self.num_classes = num_classes + self.drop_rate = drop_rate + + # Patch embedding + self.patch_embed = PatchEmbedding(in_chans, embed_dim[0]) + stride = self.patch_embed.patch_size + resolution = img_size // self.patch_embed.patch_size + attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] + + # Build EfficientVit blocks + self.feature_info = [] + stages = [] + pre_ed = embed_dim[0] + for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): + stage = EfficientVitStage( + in_dim=pre_ed, + out_dim=ed, + key_dim=kd, + downsample=do, + num_heads=nh, + attn_ratio=ar, + resolution=resolution, + window_resolution=wd, + kernels=kernels, + depth=dpth, + ) + pre_ed = ed + if do[0] == 'subsample' and i != 0: + stride *= do[1] + resolution = stage.resolution + stages.append(stage) + self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + + if global_pool == 'avg': + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + assert num_classes == 0 + self.global_pool = nn.Identity() + self.num_features = embed_dim[-1] + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', + blocks=[(r'^stages\.(\d+)', None)] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + if global_pool == 'avg': + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + assert num_classes == 0 + self.global_pool = nn.Identity() + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +# def checkpoint_filter_fn(state_dict, model): +# if 'model' in state_dict.keys(): +# state_dict = state_dict['model'] +# tmp_dict = {} +# out_dict = {} +# target_keys = model.state_dict().keys() +# target_keys = [k for k in target_keys if k.startswith('stages.')] +# +# for k, v in state_dict.items(): +# if 'attention_bias_idxs' in k: +# continue +# k = k.split('.') +# if k[-2] == 'c': +# k[-2] = 'conv' +# if k[-2] == 'l': +# k[-2] = 'linear' +# k = '.'.join(k) +# tmp_dict[k] = v +# +# for k, v in tmp_dict.items(): +# if k.startswith('patch_embed'): +# k = k.split('.') +# k[1] = 'conv' + str(int(k[1]) // 2 + 1) +# k = '.'.join(k) +# elif k.startswith('blocks'): +# kw = '.'.join(k.split('.')[2:]) +# find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a] +# idx = find_kw.index(k) +# k = [a for a in target_keys if kw in a][idx] +# out_dict[k] = v +# +# return out_dict + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv1.conv', + 'classifier': 'head.linear', + 'fixed_input_size': True, + 'pool_size': (4, 4), + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'efficientvit_m0.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth' + ), + 'efficientvit_m1.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth' + ), + 'efficientvit_m2.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth' + ), + 'efficientvit_m3.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth' + ), + 'efficientvit_m4.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth' + ), + 'efficientvit_m5.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth' + ), +}) + + +def _create_efficientvit_msra(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2)) + model = build_model_with_cfg( + EfficientVitMsra, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +@register_model +def efficientvit_m0(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[64, 128, 192], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5] + ) + return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m1(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 144, 192], + depth=[1, 2, 3], + num_heads=[2, 3, 3], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m2(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 192, 224], + depth=[1, 2, 3], + num_heads=[4, 3, 2], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m3(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 240, 320], + depth=[1, 2, 3], + num_heads=[4, 3, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5] + ) + return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m4(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 256, 384], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m5(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[192, 288, 384], + depth=[1, 3, 4], + num_heads=[3, 3, 4], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs))