Skip to content

Commit

Permalink
Save and use cache['num_cached']
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Dec 21, 2021
1 parent df89951 commit 9febeec
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key

if exists(rotary_pos_emb):
if using_cache:
rotary_pos_emb = rotary_pos_emb[..., n - 1:, :] # FIXME: Fix rotary index here
rotary_pos_emb = rotary_pos_emb[..., cache['num_cached']:, :]
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))

q = q * self.scale
Expand All @@ -92,7 +92,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
dots.masked_fill_(~mask, mask_value)
del mask

if self.causal and not using_cache: # causality is naturally enforced if we run the cached inference
if self.causal and not using_cache: # causality is naturally enforced for the cached inference
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
Expand Down
11 changes: 6 additions & 5 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def forward(
alpha = 0.1
tokens = tokens * alpha + tokens.detach() * (1 - alpha)

if cache is not None and 'decoding' in cache:
if exists(cache) and cache.get('num_cached'):
tokens = tokens[:, -1:]
out = self.transformer(tokens, cache=cache)

Expand All @@ -598,13 +598,14 @@ def forward(
# mask logits to make sure text predicts text (except last token), and image predicts image

logits_mask = self.logits_mask[:, :seq_len]
if cache is not None:
if 'decoding' in cache:
logits_mask = logits_mask[:, -1:]
cache['decoding'] = True
if exists(cache) and cache.get('num_cached'):
logits_mask = logits_mask[:, -1:]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(logits_mask, max_neg_value)

if exists(cache):
cache['num_cached'] = cache.get('num_cached', 0) + logits.shape[1]

if not return_loss:
return logits

Expand Down

0 comments on commit 9febeec

Please sign in to comment.