From df8995172f64e532d4ae2c768df8f0f4fc35cb51 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Tue, 21 Dec 2021 01:24:30 +0000 Subject: [PATCH] Rename FixCacheKey -> CachedAs --- dalle_pytorch/cache.py | 10 ---------- dalle_pytorch/transformer.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 12 deletions(-) delete mode 100644 dalle_pytorch/cache.py diff --git a/dalle_pytorch/cache.py b/dalle_pytorch/cache.py deleted file mode 100644 index 524b1153..00000000 --- a/dalle_pytorch/cache.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch.nn as nn - -class FixCacheKey(nn.Module): - def __init__(self, cache_key, fn): - super().__init__() - self.cache_key = cache_key - self.fn = fn - - def forward(self, x, *, cache=None, **kwargs): - return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs) diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index eb57cadc..78f695ba 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -9,7 +9,6 @@ from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention -from dalle_pytorch.cache import FixCacheKey from rotary_embedding_torch import RotaryEmbedding, broadcat from g_mlp_pytorch import gMLPBlock @@ -36,6 +35,15 @@ def forward(self, x): maxes = x.amax(dim = self.dim, keepdim = True) return x / maxes +class CachedAs(nn.Module): + def __init__(self, cache_key, fn): + super().__init__() + self.cache_key = cache_key + self.fn = fn + + def forward(self, x, *, cache=None, **kwargs): + return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs) + # https://arxiv.org/abs/2103.17239 class LayerScale(nn.Module): def __init__(self, dim, depth, fn): @@ -200,7 +208,7 @@ def __init__( ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) shared_ff_layers[ff_id] = ff - attn = FixCacheKey(f'attn_{ind}', attn) + attn = CachedAs(f'attn_{ind}', attn) if shift_tokens: attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))