Skip to content

Commit

Permalink
Fixing efficient_vit torchscript, fx, default_cfg issues
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Aug 19, 2023
1 parent 58ea1c0 commit 7d7589e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 58 deletions.
23 changes: 12 additions & 11 deletions timm/models/efficientvit_mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
dilation=1,
groups=1,
bias=False,
dropout=0,
dropout=0.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
Expand Down Expand Up @@ -248,7 +248,7 @@ def forward(self, x):
# lightweight global attention
q = self.kernel_func(q)
k = self.kernel_func(k)
v = F.pad(v, (0, 1), mode="constant", value=1)
v = F.pad(v, (0, 1), mode="constant", value=1.)

kv = k.transpose(-1, -2) @ v
out = q @ kv
Expand Down Expand Up @@ -443,7 +443,7 @@ def __init__(
in_channels,
widths,
n_classes=1000,
dropout=0,
dropout=0.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.Hardswish,
global_pool='avg',
Expand Down Expand Up @@ -547,7 +547,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head.classifier[-1]

def reset_classifier(self, num_classes, global_pool=None, dropout=0):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
Expand All @@ -561,7 +561,7 @@ def reset_classifier(self, num_classes, global_pool=None, dropout=0):
)
else:
if self.global_pool == 'avg':
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
else:
self.head = nn.Identity()

Expand Down Expand Up @@ -592,6 +592,7 @@ def _cfg(url='', **kwargs):
'classifier': 'head.classifier.4',
'crop_pct': 0.95,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
**kwargs,
}

Expand All @@ -605,33 +606,33 @@ def _cfg(url='', **kwargs):
),
'efficientvit_b1.r256_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=1.0,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
),
'efficientvit_b1.r288_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 288, 288), crop_pct=1.0,
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
),
'efficientvit_b2.r224_in1k': _cfg(
hf_hub_id='timm/',
),
'efficientvit_b2.r256_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=1.0,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
),
'efficientvit_b2.r288_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 288, 288), crop_pct=1.0,
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
),
'efficientvit_b3.r224_in1k': _cfg(
hf_hub_id='timm/',
),
'efficientvit_b3.r256_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=1.0,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
),
'efficientvit_b3.r288_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 288, 288), crop_pct=1.0,
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
),
})

Expand Down
103 changes: 56 additions & 47 deletions timm/models/efficientvit_msra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
__all__ = ['EfficientVitMsra']
import itertools
from collections import OrderedDict
from typing import Dict

import torch
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
Expand Down Expand Up @@ -113,6 +114,8 @@ def forward(self, x):


class CascadedGroupAttention(torch.nn.Module):
attention_bias_cache: Dict[str, torch.Tensor]

r""" Cascaded Group Attention.
Args:
Expand All @@ -136,19 +139,19 @@ def __init__(
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.d = int(attn_ratio * key_dim)
self.val_dim = int(attn_ratio * key_dim)
self.attn_ratio = attn_ratio

qkvs = []
dws = []
for i in range(num_heads):
qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.d))
qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim))
dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim))
self.qkvs = torch.nn.ModuleList(qkvs)
self.dws = torch.nn.ModuleList(dws)
self.proj = torch.nn.Sequential(
torch.nn.ReLU(),
ConvNorm(self.d * num_heads, dim, bn_weight_init=0)
ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0)
)

points = list(itertools.product(range(resolution), range(resolution)))
Expand All @@ -161,37 +164,44 @@ def __init__(
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(
torch.zeros(num_heads, len(attention_offsets)))
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
self.attention_bias_cache = {}

@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
if mode and self.attention_bias_cache:
self.attention_bias_cache = {} # clear ab cache

def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if torch.jit.is_tracing() or self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]

def forward(self, x):
B, C, H, W = x.shape
feats_in = x.chunk(len(self.qkvs), dim=1)
feats_out = []
feat = feats_in[0]
for i, qkv in enumerate(self.qkvs):
attn_bias = self.attention_biases[:, self.attention_bias_idxs][i] if self.training else self.ab[i]
if i > 0:
feat = feat + feats_in[i]
attn_bias = self.get_attention_biases(x.device)
for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
if head_idx > 0:
feat = feat + feats_in[head_idx]
feat = qkv(feat)
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1)
q = self.dws[i](q)
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
q = dws(q)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
q = q * self.scale
attn = q.transpose(-2, -1) @ k
attn = attn + attn_bias
attn = attn + attn_bias[head_idx]
attn = attn.softmax(dim=-1)
feat = v @ attn.transpose(-2, -1)
feat = feat.view(B, self.d, H, W)
feat = feat.view(B, self.val_dim, H, W)
feats_out.append(feat)
x = self.proj(torch.cat(feats_out, 1))
return x
Expand Down Expand Up @@ -237,8 +247,8 @@ def forward(self, x):
H = W = self.resolution
B, C, H_, W_ = x.shape
# Only check this for classifcation models
assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_))

_assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
_assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
if H <= self.window_resolution and W <= self.window_resolution:
x = self.attn(x)
else:
Expand Down Expand Up @@ -519,38 +529,37 @@ def _cfg(url='', **kwargs):
'first_conv': 'patch_embed.conv1.conv',
'classifier': 'head.linear',
'fixed_input_size': True,
'pool_size': (4, 4),
**kwargs,
}


default_cfgs = generate_default_cfgs(
{
'efficientvit_m0.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
),
'efficientvit_m1.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
),
'efficientvit_m2.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
),
'efficientvit_m3.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
),
'efficientvit_m4.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
),
'efficientvit_m5.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
),
}
)
default_cfgs = generate_default_cfgs({
'efficientvit_m0.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
),
'efficientvit_m1.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
),
'efficientvit_m2.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
),
'efficientvit_m3.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
),
'efficientvit_m4.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
),
'efficientvit_m5.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
),
})


def _create_efficientvit_msra(variant, pretrained=False, **kwargs):
Expand Down

0 comments on commit 7d7589e

Please sign in to comment.