diff --git a/dalle_pytorch/attention.py b/dalle_pytorch/attention.py index 39e807a6..9c5bbea8 100644 --- a/dalle_pytorch/attention.py +++ b/dalle_pytorch/attention.py @@ -37,7 +37,8 @@ def apply_pos_emb(pos_emb, qkv): # classes class Attention(nn.Module): - def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): + def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False, + static_mask = None): super().__init__() inner_dim = dim_head * heads self.heads = heads @@ -46,6 +47,7 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou self.stable = stable self.causal = causal + self.register_buffer('static_mask', static_mask, persistent=False) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( @@ -53,19 +55,27 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou nn.Dropout(dropout) ) - def forward(self, x, mask = None, rotary_pos_emb = None): + def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None): b, n, _, h, device = *x.shape, self.heads, x.device softmax = torch.softmax if not self.stable else stable_softmax + offset = cache.get('offset', 0) if exists(cache) else 0 qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) if exists(rotary_pos_emb): - q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) + q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v)) q = q * self.scale - dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) + if offset > 0: + k_top, v_top = cache[cache_key] + k = torch.cat([k_top, k], dim=-2) + v = torch.cat([v_top, v], dim=-2) + if exists(cache): + cache[cache_key] = k, v + + dots = q @ k.swapaxes(-1, -2) mask_value = max_neg_value(dots) if exists(mask): @@ -73,14 +83,17 @@ def forward(self, x, mask = None, rotary_pos_emb = None): dots.masked_fill_(~mask, mask_value) del mask - if self.causal: + if self.causal and offset == 0: # causality is naturally enforced for the cached inference i, j = dots.shape[-2:] mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() dots.masked_fill_(mask, mask_value) + if exists(self.static_mask): + dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value) + attn = softmax(dots, dim=-1) - out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = attn @ v out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) return out @@ -109,7 +122,13 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, nn.Dropout(dropout) ) - def forward(self, x, mask = None, rotary_pos_emb = None): + def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None): + n0 = x.shape[1] + if exists(cache): + if cache_key in cache: + x = torch.cat([cache[cache_key], x], dim=-2) + cache[cache_key] = x + b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device softmax = torch.softmax if not self.stable else stable_softmax @@ -204,7 +223,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None): out = rearrange(out, '(b h) n d -> b n (h d)', h = h) out = self.to_out(out) - return out[:, :n] + return out[:, n - n0:n] # sparse axial causal attention @@ -229,7 +248,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head nn.Dropout(dropout) ) - def forward(self, x, mask = None, rotary_pos_emb = None): + def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None): b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device softmax = torch.softmax if not self.stable else stable_softmax diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 76c2e254..cadb07bd 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -344,6 +344,7 @@ def __init__( shared_attn_ids = None, shared_ff_ids = None, share_input_output_emb = False, + use_static_masks = False, ): super().__init__() assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE' @@ -391,6 +392,7 @@ def __init__( rotary_emb = rotary_emb, shared_attn_ids = shared_attn_ids, shared_ff_ids = shared_ff_ids, + use_static_masks = use_static_masks, ) self.stable = stable @@ -484,7 +486,8 @@ def generate_images( filter_thres = 0.5, temperature = 1., img = None, - num_init_img_tokens = None + num_init_img_tokens = None, + use_cache = False, ): vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens total_len = text_seq_len + image_seq_len @@ -503,12 +506,13 @@ def generate_images( indices = indices[:, :num_img_tokens] out = torch.cat((out, indices), dim = -1) + cache = {} if use_cache else None for cur_len in range(out.shape[1], total_len): is_image = cur_len >= text_seq_len text, image = out[:, :text_seq_len], out[:, text_seq_len:] - logits = self(text, image, mask = mask)[:, -1, :] + logits = self(text, image, mask = mask, cache = cache)[:, -1, :] filtered_logits = top_k(logits, thres = filter_thres) probs = F.softmax(filtered_logits / temperature, dim = -1) @@ -536,6 +540,7 @@ def forward( text, image = None, mask = None, + cache = None, return_loss = False ): assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})' @@ -584,7 +589,9 @@ def forward( alpha = 0.1 tokens = tokens * alpha + tokens.detach() * (1 - alpha) - out = self.transformer(tokens) + if exists(cache) and cache.get('offset'): + tokens = tokens[:, -1:] + out = self.transformer(tokens, cache=cache) if self.stable: out = self.norm_by_max(out) @@ -594,9 +601,14 @@ def forward( # mask logits to make sure text predicts text (except last token), and image predicts image logits_mask = self.logits_mask[:, :seq_len] + if exists(cache) and cache.get('offset'): + logits_mask = logits_mask[:, -1:] max_neg_value = -torch.finfo(logits.dtype).max logits.masked_fill_(logits_mask, max_neg_value) + if exists(cache): + cache['offset'] = cache.get('offset', 0) + logits.shape[1] + if not return_loss: return logits diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index c7322a4c..395a3ce4 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -1,3 +1,4 @@ +from collections import deque from collections.abc import Iterable from functools import partial from itertools import islice, cycle @@ -35,6 +36,15 @@ def forward(self, x): maxes = x.amax(dim = self.dim, keepdim = True) return x / maxes +class CachedAs(nn.Module): + def __init__(self, cache_key, fn): + super().__init__() + self.cache_key = cache_key + self.fn = fn + + def forward(self, x, *, cache=None, **kwargs): + return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs) + # https://arxiv.org/abs/2103.17239 class LayerScale(nn.Module): def __init__(self, dim, depth, fn): @@ -83,7 +93,7 @@ def __init__(self, dim, dropout = 0., mult = 4.): nn.Linear(dim * mult, dim) ) - def forward(self, x): + def forward(self, x, cache=None, cache_key=None): return self.net(x) # token shift classes @@ -94,12 +104,30 @@ def __init__(self, fn, image_size, seq_len): self.fn = fn self.image_size = image_size self.seq_len = seq_len + self.img_seq_len = image_size ** 2 + self.text_len = seq_len - self.img_seq_len + 1 + + def forward(self, x, cache=None, cache_key=None, **kwargs): + seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len + + if exists(cache) and cache_key in cache: + offset = cache['offset'] + assert offset >= text_len, "cached inference for text is not supported" + q = cache[cache_key] + assert isinstance(q, deque) and len(q) == image_size + + x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1) + + q.append((x_top, x_left)) + x_top = q.popleft()[0] + x_left = q[-2][1] + if (offset - text_len) % image_size == 0: + x_left = torch.zeros_like(x_left) + + x = torch.cat((x_top, x_left, *x_pass), dim=-1) + return self.fn(x[:, None], cache=cache, **kwargs) - def forward(self, x, **kwargs): n = x.shape[1] - seq_len, image_size = self.seq_len, self.image_size - img_seq_len = image_size ** 2 - text_len = seq_len - img_seq_len + 1 padding = seq_len - n + 1 # get text and image tokens @@ -124,8 +152,22 @@ def forward(self, x, **kwargs): # merge text and image sequence back together x_img = rearrange(x_img, 'b h w d -> b (h w) d') - x = torch.cat((x_text, x_img[:, :-padding]), dim = 1) - return self.fn(x, **kwargs) + x_img = x_img[:, :-padding] + x = torch.cat((x_text, x_img), dim = 1) + + if exists(cache): + dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1) + dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(dummy_left) + + q = deque() + x_img = x_img[:, -image_size:] + for _ in range(image_size - x_img.shape[1]): + q.append((dummy_top, dummy_left)) + for i in range(x_img.shape[1]): + q.append(x_img[:, i].chunk(4, dim=-1)[:2]) + cache[cache_key] = q + + return self.fn(x, cache=cache, **kwargs) # main transformer class @@ -152,11 +194,15 @@ def __init__( rotary_emb = True, shared_attn_ids = None, shared_ff_ids = None, + use_static_masks = False, ): super().__init__() layers = nn.ModuleList([]) sparse_layer = cast_tuple(sparse_attn, depth) + self.seq_len = seq_len + self.image_fmap_size = image_fmap_size + attn_types = default(attn_types, ('full',)) attn_types = cast_tuple(attn_types) attn_type_layer = islice(cycle(attn_types), depth) @@ -173,9 +219,15 @@ def __init__( elif attn_type == 'sparse': attn_class = SparseAttention elif attn_type == 'axial_row': - attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) + if use_static_masks: + attn_class = partial(Attention, stable = stable, static_mask = self._get_static_mask(attn_type)) + else: + attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) elif attn_type == 'axial_col': - attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) + if use_static_masks: + attn_class = partial(Attention, stable = stable, static_mask = self._get_static_mask(attn_type)) + else: + attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) elif attn_type == 'conv_like': attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable) elif attn_type == 'mlp': @@ -199,8 +251,11 @@ def __init__( ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) shared_ff_layers[ff_id] = ff + attn = CachedAs(f'attn_{ind}', attn) + if shift_tokens: - attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff)) + attn = CachedAs(f'preshift_attn_{ind}', PreShiftToken(attn, image_size = image_fmap_size, seq_len = seq_len)) + ff = CachedAs(f'preshift_ff_{ind}', PreShiftToken(ff, image_size = image_fmap_size, seq_len = seq_len)) layers.append(nn.ModuleList([ LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), @@ -209,7 +264,9 @@ def __init__( execute_type = ReversibleSequence if reversible else SequentialSequence route_attn = ((True, False),) * depth - attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn} + route_all = ((True, True),) * depth + attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn, + 'cache': route_all} self.layers = execute_type(layers, args_route = attn_route_map) @@ -245,3 +302,27 @@ def __init__( def forward(self, x, **kwargs): return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs) + + def _get_static_mask(self, attn_type): + # In case of attn_type = "axial_{row,col}", + # the sparse implementation is most efficient for training, + # but the full attention with a static mask is most efficient for inference + # since caching is implemented in this case. + + img_seq_len = self.image_fmap_size ** 2 + text_len = self.seq_len + 1 - img_seq_len + + static_mask = torch.zeros(self.seq_len, self.seq_len, dtype=torch.bool) + static_mask[:, :text_len] = True + if attn_type == 'axial_row': + for row in range(self.image_fmap_size): + begin = text_len + row * self.image_fmap_size + end = text_len + (row + 1) * self.image_fmap_size + static_mask[begin:end, begin:end] = True + elif attn_type == 'axial_col': + for col in range(self.image_fmap_size): + begin = text_len + col + static_mask[begin::self.image_fmap_size, begin::self.image_fmap_size] = True + else: + raise ValueError(f'attention type "{attn_type}" can\'t be simulated with a static mask') + return static_mask