diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6c39545 --- /dev/null +++ b/.gitignore @@ -0,0 +1,136 @@ +# Custom +.vscode +wandb +outputs +tmp* +slurm-logs + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +.github + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Lightning /research +test_tube_exp/ +tests/tests_tt_dir/ +tests/save_dir +default/ +data/ +test_tube_logs/ +test_tube_data/ +datasets/ +model_weights/ +tests/save_dir +tests/tests_tt_dir/ +processed/ +raw/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# IDEs +.idea +.vscode + +# seed project +lightning_logs/ +MNIST +.DS_Store diff --git a/README.md b/README.md new file mode 100644 index 0000000..476f182 --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +
+ +## Do You Even Need Attention? +### A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet + +[![Paper](http://img.shields.io/badge/Paper-B31B1B.svg)]() +
+ +### TL;DR +We replace the attention layer in a vision transformer with a feed-forward layer and find that it still works quite well on ImageNet. + +### Abstract +Recent research in architecture design for computer vision has shown that transformers applied to sequences of unrolled image patches make for strong image classifiers. Much of this research focuses on modifying the transformer's attention layer, either to make it more efficient or better suited to the spacial structure of images. In this short report, we ask the question: is the attention layer even necessary? Specifically, we replace the attention layer with a second feed-forward layer over the patch features, resulting in an architecture is simply a series of feed-forward networks applied over the patch and feature dimensions. In experiments on ImageNet, we show that a ViT-base-sized model obtains 74.896\% top-1 accuracy, rivaling a ResNet-50 (albeit with more parameters). Apart from its simplicity, this architecture does not appear to have any immediate practical advantages over a vision transformer with attention---it performs slightly worse and still requires $O(n^2)$ memory---but we hope the computer vision community will find it interesting nonetheless. + +### Note +This is concurrent research with [MLP-Mixer](https://arxiv.org/abs/2105.01601) from Google Research. The ideas are exacty the same, with the one difference being that they use (a lot) more compute. + +### How to train + +The model definition in `vision_transformer_linear.py` is designed to be run with the repo from DeiT, which is itself based on the wonderful `timm` package. + +Steps: + * Clone the DeiT repo + ```bash + + ``` + +#### Pretrained models and logs + +Here is a Weights and Biases report with the expected training trajectory: [W&B](https://wandb.ai/lukemelas2/deit-experiments/reports/Do-You-Even-Need-Attention---Vmlldzo2NjUxMzI?accessToken=8kebvweue0gd1s6qiav2orco97v85glogsi8i83576j42bb1g39e59px56lkk4zu) + +| name | acc@1 | #params | url | +| --- | --- | --- | --- | +| FF-tiny | 61.4 | 7.7M | [model](https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth) | +| FF-base | 74.9 | 62M | [model](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth) | +| FF-large | 71.4 | 206M | [model](https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth) | + + +#### Citation +``` +@article{melaskyriazi2021doyoueven, + title={Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet}, + author={Luke Melas-Kyriazi}, + journal=arxiv, + year=2021 +} +``` diff --git a/vision_transformer_linear.py b/vision_transformer_linear.py new file mode 100644 index 0000000..7c98edb --- /dev/null +++ b/vision_transformer_linear.py @@ -0,0 +1,212 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg, Mlp + + +def requires_grad(module, requires_grad): + for p in module.parameters(): + p.requires_grad = requires_grad + + +class LinearBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_tokens=197): + super().__init__() + + # First stage + self.mlp1 = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm1 = norm_layer(dim) + + # Second stage + self.mlp2 = Mlp(in_features=num_tokens, hidden_features=int( + num_tokens * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(num_tokens) + + # Dropout (or a variant) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.mlp1(self.norm1(x))) + x = x.transpose(-2, -1) + x = x + self.drop_path(self.mlp2(self.norm2(x))) + x = x.transpose(-2, -1) + return x + + +class PatchEmbed(nn.Module): + """ Wraps a convolution """ + + def __init__(self, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.proj(x) + return x + + +class LearnedPositionalEncoding(nn.Module): + """ Learned positional encoding with dynamic interpolation at runtime """ + + def __init__(self, height, width, embed_dim): + super().__init__() + self.height = height + self.width = width + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, height, width)) + self.cls_pos_embed = nn.Parameter(torch.zeros(1, 1, embed_dim)) + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_pos_embed, std=.02) + + def forward(self, x): + B, C, H, W = x.shape + if H == self.height and W == self.width: + pos_embed = self.pos_embed + else: + pos_embed = F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False) + return self.cls_pos_embed, pos_embed + + +class LinearVisionTransformer(nn.Module): + """ + Basically the same as the standard Vision Transformer, but with support for resizable + or sinusoidal positional embeddings. + """ + + def __init__(self, *, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, + positional_encoding='learned', learned_positional_encoding_size=(14, 14), block_cls=LinearBlock): + super().__init__() + + # Config + self.num_classes = num_classes + self.patch_size = patch_size + self.num_features = self.embed_dim = embed_dim + + # Patch embedding + self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + # Class token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # Positional encoding + if positional_encoding == 'learned': + height, width = self.learned_positional_encoding_size = learned_positional_encoding_size + self.pos_encoding = LearnedPositionalEncoding(height, width, embed_dim) + else: + raise NotImplementedError('Unsupposed positional encoding') + self.pos_drop = nn.Dropout(p=drop_rate) + + # Stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = nn.ModuleList([ + block_cls(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, num_tokens=1 + (224 // patch_size)**2) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + # Init + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + + # Patch embedding + B, C, H, W = x.shape # B x C x H x W + x = self.patch_embed(x) # B x E x H//p x W//p + + # Positional encoding + # NOTE: cls_pos_embed for compatibility with pretrained models + cls_pos_embed, pos_embed = self.pos_encoding(x) + + # Flatten image, append class token, add positional encoding + cls_tokens = self.cls_token.expand(B, -1, -1) + x = x.flatten(2).transpose(1, 2) # flatten + x = torch.cat((cls_tokens, x), dim=1) # class token + pos_embed = pos_embed.flatten(2).transpose(1, 2) # flatten + pos_embed = torch.cat([cls_pos_embed, pos_embed], dim=1) # class pos emb + x = x + pos_embed + x = self.pos_drop(x) + + # Transformer + for blk in self.blocks: + x = blk(x) + + # Final layernorm + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +@register_model +def linear_tiny(pretrained=False, **kwargs): + model = LinearVisionTransformer( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + + +@register_model +def linear_base(pretrained=False, **kwargs): + model = LinearVisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + + +@register_model +def linear_large(pretrained=False, **kwargs): + model = LinearVisionTransformer( + patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + + +if __name__ == '__main__': + + # Test + x = torch.randn(2, 3, 224, 224) + m = linear_tiny() + out = m(x) + print('-----') + print(f'num params: {sum(p.numel() for p in m.parameters())}') + print(out.shape) + loss = out.sum() + loss.backward() + print('Single iteration completed successfully')