Implementation of Gated Slot Attention in Pytorch from scratch in one file from the paper Gated Slot Attention for Efficient Linear-Time Sequence Modeling
pip3 install -U gated-slot-attention
For full usage, use your own tokenizer and vocab size.
import torch
from gated_slot_attention.model import GSATransformer
model = GSATransformer(
dim=512,
heads=8,
m=64,
tau=0.1,
depth=1,
vocab_size=10000,
max_seq_len=1024,
)
x = torch.randint(0, 10000, (1, 1024))
out = model(x)
print(out.shape)
MIT
@misc{zhang2024gatedslotattentionefficient,
title={Gated Slot Attention for Efficient Linear-Time Sequence Modeling},
author={Yu Zhang and Songlin Yang and Ruijie Zhu and Yue Zhang and Leyang Cui and Yiqiao Wang and Bolun Wang and Freda Shi and Bailin Wang and Wei Bi and Peng Zhou and Guohong Fu},
year={2024},
eprint={2409.07146},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2409.07146},
}