Skip to content

Commit

Permalink
[POETRY]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 7, 2024
1 parent 1ee60b1 commit 1baa2de
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 70 deletions.
115 changes: 46 additions & 69 deletions audio_flamingo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from einops import rearrange
from torch import einsum, nn
from torch.autograd import Function
from zeta.nn import audio_to_text, Attention
from zeta.nn import audio_to_text, Attention, SwiGLU
from zeta.structs import Transformer, Decoder, AutoregressiveWrapper

# helper functions
Expand Down Expand Up @@ -92,28 +92,6 @@ def backward(ctx, grads):
# they use layernorm without bias, something that pytorch does not offer


class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))

def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)


# residual


class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x


# to latents


Expand All @@ -127,45 +105,6 @@ def forward(self, x):
return F.normalize(latents, dim=-1)


# rotary positional embedding
# https://arxiv.org/abs/2104.09864


class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (
10000 ** (torch.arange(0, dim, 2).float() / dim)
)
self.register_buffer("inv_freq", inv_freq)

def forward(self, max_seq_len, *, device):
seq = torch.arange(
max_seq_len, device=device, dtype=self.inv_freq.dtype
)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)


def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())


# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202


class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x


# parallel attention and feedforward with residual
# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
Expand All @@ -189,9 +128,9 @@ def __init__(
inner_dim = heads * dim_head
context_dim = default(context_dim, dim)

self.norm = LayerNorm(dim)
self.norm = nn.LayerNorm(dim)
self.context_norm = (
LayerNorm(context_dim) if norm_context else nn.Identity()
nn.LayerNorm(context_dim) if norm_context else nn.Identity()
)

self.to_q = nn.Linear(dim, inner_dim, bias=False)
Expand Down Expand Up @@ -326,7 +265,7 @@ def __init__(
self.dense = nn.Linear(dim, dim)

# LayerNorm
self.norm = LayerNorm(dim)
self.norm = nn.LayerNorm(dim)

# Dropout
self.dropout = nn.Dropout(dropout)
Expand Down Expand Up @@ -417,7 +356,7 @@ def __init__(
)

# LayerNorm
self.norm = LayerNorm(dim)
self.norm = nn.LayerNorm(dim)

# Representation transformation layers
self.rpl_layers = nn.ModuleList([])
Expand Down Expand Up @@ -468,6 +407,36 @@ def forward(self, x: Tensor):


class AudioFlamingo(nn.Module):
"""
AudioFlamingo model for audio-text synthesis.
Args:
dim (int): Dimension of the model.
num_tokens (int): Number of tokens in the input sequence.
max_seq_len (int): Maximum sequence length.
heads (int): Number of attention heads.
depth (int): Depth of the model.
dim_head (int): Dimension of each attention head.
dropout (float): Dropout rate.
context_dim (int): Dimension of the context.
*args: Variable length arguments.
**kwargs: Keyword arguments.
Attributes:
dim (int): Dimension of the model.
num_tokens (int): Number of tokens in the input sequence.
heads (int): Number of attention heads.
depth (int): Depth of the model.
dim_head (int): Dimension of each attention head.
dropout (float): Dropout rate.
context_dim (int): Dimension of the context.
transformer (Transformer): Transformer model.
decoder (AutoregressiveWrapper): Autoregressive wrapper for the transformer.
af_blocks (nn.ModuleList): List of AudioFlamingoEncoderBlock layers.
norm (nn.LayerNorm): Layer normalization.
"""

def __init__(
self,
dim: int,
Expand Down Expand Up @@ -523,12 +492,20 @@ def __init__(
)

# LayerNorm
self.norm = LayerNorm(dim)
self.norm = nn.LayerNorm(dim)

def forward(self, text: Tensor, audio: Tensor):
# Text shape - (b, s, d)
# Audio shape - (b, s)
"""
Forward pass of the AudioFlamingo model.
Args:
text (Tensor): Input text tensor of shape (batch_size, seq_len, dim).
audio (Tensor): Input audio tensor of shape (batch_size, seq_len).
Returns:
Tensor: Output tensor of shape (batch_size, seq_len, dim).
"""
# Apply audio blocks to audio
for block in self.af_blocks:
audio = block(audio)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "audio-flamingo"
version = "0.0.2"
version = "0.0.3"
description = "Paper - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]
Expand Down

0 comments on commit 1baa2de

Please sign in to comment.