Skip to content

Commit

Permalink
Rename FixCacheKey -> CachedAs
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Dec 21, 2021
1 parent 8e8dea8 commit df89951
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
10 changes: 0 additions & 10 deletions dalle_pytorch/cache.py

This file was deleted.

12 changes: 10 additions & 2 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
from dalle_pytorch.cache import FixCacheKey

from rotary_embedding_torch import RotaryEmbedding, broadcat
from g_mlp_pytorch import gMLPBlock
Expand All @@ -36,6 +35,15 @@ def forward(self, x):
maxes = x.amax(dim = self.dim, keepdim = True)
return x / maxes

class CachedAs(nn.Module):
def __init__(self, cache_key, fn):
super().__init__()
self.cache_key = cache_key
self.fn = fn

def forward(self, x, *, cache=None, **kwargs):
return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)

# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
Expand Down Expand Up @@ -200,7 +208,7 @@ def __init__(
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
shared_ff_layers[ff_id] = ff

attn = FixCacheKey(f'attn_{ind}', attn)
attn = CachedAs(f'attn_{ind}', attn)

if shift_tokens:
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
Expand Down

0 comments on commit df89951

Please sign in to comment.