Skip to content

Commit

Permalink
CHORE: ran black
Browse files Browse the repository at this point in the history
  • Loading branch information
femke-sintef committed Apr 2, 2024
1 parent 58e5809 commit b5d6eda
Show file tree
Hide file tree
Showing 27 changed files with 1,530 additions and 922 deletions.
11 changes: 7 additions & 4 deletions BEATs/BEATs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ def preprocess(
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank

def specaugment(self, fbank, specaugment_params):
# FBG: Add spectral masking
if torch.rand(1) < specaugment_params["application_ratio"]:
masking = ta_transforms.TimeMasking(
time_mask_param=specaugment_params["time_mask"],
time_mask_param=specaugment_params["time_mask"],
)
fbank = masking(fbank)
masking = ta_transforms.FrequencyMasking(
Expand All @@ -183,8 +183,11 @@ def extract_features(

fbank = source.unsqueeze(1)

# FBG: add spectral masking
if hasattr(self.cfg, "specaugment_params") and not self.cfg.specaugment_params is None:
# FBG: add spectral masking
if (
hasattr(self.cfg, "specaugment_params")
and not self.cfg.specaugment_params is None
):
fbank = self.specaugment(fbank, self.cfg.specaugment_params)
# end NOTE FBG
features = self.patch_embedding(fbank)
Expand Down
89 changes: 58 additions & 31 deletions BEATs_on_ESC50/BEATs/BEATs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,41 @@ def __init__(self, cfg=None):
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use

self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
self.layer_wise_gradient_decay_ratio: float = (
1.0 # ratio for layer-wise gradient decay
)
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.deep_norm: bool = False # apply deep_norm first in the transformer

# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
self.activation_dropout: float = (
0.0 # dropout probability after activation in FFN
)
self.encoder_layerdrop: float = (
0.0 # probability of dropping a tarnsformer layer
)
self.dropout_input: float = (
0.0 # dropout to apply to the input (after feat extr)
)

# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
self.conv_pos: int = (
128 # number of filters for convolutional positional embeddings
)
self.conv_pos_groups: int = (
16 # number of groups for convolutional positional embedding
)

# relative position embedding
self.relative_position_embedding: bool = False # apply relative position embedding
self.relative_position_embedding: bool = (
False # apply relative position embedding
)
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 1280 # maximum distance for relative position embedding
self.max_distance: int = (
1280 # maximum distance for relative position embedding
)
self.gru_rel_pos: bool = False # apply gated relative position embedding

# label predictor
Expand All @@ -70,8 +86,8 @@ def update(self, cfg: dict):

class BEATs(nn.Module):
def __init__(
self,
cfg: BEATsConfig,
self,
cfg: BEATsConfig,
) -> None:
super().__init__()
logger.info(f"BEATs Config: {cfg.__dict__}")
Expand All @@ -86,8 +102,13 @@ def __init__(
)

self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
bias=cfg.conv_bias)
self.patch_embedding = nn.Conv2d(
1,
self.embed,
kernel_size=self.input_patch_size,
stride=self.input_patch_size,
bias=cfg.conv_bias,
)

self.dropout_input = nn.Dropout(cfg.dropout_input)

Expand All @@ -102,40 +123,44 @@ def __init__(
self.predictor = None

def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask

def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2 ** 15
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
waveform = waveform.unsqueeze(0) * 2**15
fbank = ta_kaldi.fbank(
waveform,
num_mel_bins=128,
sample_frequency=16000,
frame_length=25,
frame_shift=10,
)
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank

def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)

Expand Down Expand Up @@ -168,12 +193,14 @@ def extract_features(
if padding_mask is not None and padding_mask.any():
logits[padding_mask] = 0
logits = logits.sum(dim=1)
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits
)
else:
logits = logits.mean(dim=1)

lprobs = torch.sigmoid(logits)

return lprobs, padding_mask
else:
return x, padding_mask
return x, padding_mask
93 changes: 60 additions & 33 deletions BEATs_on_ESC50/BEATs/Tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,37 @@ def __init__(self, cfg=None):
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
self.activation_dropout: float = (
0.0 # dropout probability after activation in FFN
)
self.encoder_layerdrop: float = (
0.0 # probability of dropping a tarnsformer layer
)
self.dropout_input: float = (
0.0 # dropout to apply to the input (after feat extr)
)

# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
self.conv_pos: int = (
128 # number of filters for convolutional positional embeddings
)
self.conv_pos_groups: int = (
16 # number of groups for convolutional positional embedding
)

# relative position embedding
self.relative_position_embedding: bool = False # apply relative position embedding
self.relative_position_embedding: bool = (
False # apply relative position embedding
)
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 1280 # maximum distance for relative position embedding
self.max_distance: int = (
1280 # maximum distance for relative position embedding
)
self.gru_rel_pos: bool = False # apply gated relative position embedding

# quantizer
self.quant_n: int = 1024 # codebook number in quantizer
self.quant_dim: int = 256 # codebook dimension in quantizer
self.quant_n: int = 1024 # codebook number in quantizer
self.quant_dim: int = 256 # codebook dimension in quantizer

if cfg is not None:
self.update(cfg)
Expand All @@ -71,8 +85,8 @@ def update(self, cfg: dict):

class Tokenizers(nn.Module):
def __init__(
self,
cfg: TokenizersConfig,
self,
cfg: TokenizersConfig,
) -> None:
super().__init__()
logger.info(f"Tokenizers Config: {cfg.__dict__}")
Expand All @@ -87,8 +101,13 @@ def __init__(
)

self.input_patch_size = cfg.input_patch_size
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
bias=cfg.conv_bias)
self.patch_embedding = nn.Conv2d(
1,
self.embed,
kernel_size=self.input_patch_size,
stride=self.input_patch_size,
bias=cfg.conv_bias,
)

self.dropout_input = nn.Dropout(cfg.dropout_input)

Expand All @@ -97,50 +116,58 @@ def __init__(
self.layer_norm = LayerNorm(self.embed)

self.quantize = NormEMAVectorQuantizer(
n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
n_embed=cfg.quant_n,
embedding_dim=cfg.quant_dim,
beta=1.0,
kmeans_init=True,
decay=0.99,
)
self.quant_n = cfg.quant_n
self.quantize_layer = nn.Sequential(
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
nn.Tanh(),
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim), # for quantize
)

def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask

def preprocess(
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
self,
source: torch.Tensor,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
) -> torch.Tensor:
fbanks = []
for waveform in source:
waveform = waveform.unsqueeze(0) * 2 ** 15
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
waveform = waveform.unsqueeze(0) * 2**15
fbank = ta_kaldi.fbank(
waveform,
num_mel_bins=128,
sample_frequency=16000,
frame_length=25,
frame_shift=10,
)
fbanks.append(fbank)
fbank = torch.stack(fbanks, dim=0)
fbank = (fbank - fbank_mean) / (2 * fbank_std)
return fbank

def extract_labels(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
):
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)

Expand Down Expand Up @@ -169,4 +196,4 @@ def extract_labels(
quantize_input = self.quantize_layer(x)
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)

return embed_ind
return embed_ind
Loading

0 comments on commit b5d6eda

Please sign in to comment.