Skip to content

Commit

Permalink
[ADD] Pipeline for fine tuning on DCASE
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Feb 28, 2024
1 parent 4979946 commit 372d845
Show file tree
Hide file tree
Showing 5 changed files with 857 additions and 0 deletions.
53 changes: 53 additions & 0 deletions dcase_fine_tune/CONFIG.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
###########################################
###########################################
##### CONFIG FOR DCASE CHALLENGE 2024 #####
###########################################
###########################################

##################################
# PARAMETERS FOR DATA PROCESSING #
##################################
data:
target_fs: 16000 # used in preprocessing
resample: True # used in preprocessing
denoise: True # used in preprocessing
normalize: True # used in preprocessing
frame_length: 25.0 # used in preprocessing
tensor_length: 128 # used in preprocessing
overlap: 0.5 # used in preprocessing
num_mel_bins: 128 # used in preprocessing
max_segment_length: 1.0 # used in preprocessing
status: validate # used in preprocessing, train or validate or evaluate
set_type: "Validation_Set"


#################################
# PARAMETERS FOR MODEL TRAINING #
#################################
# Be sure the parameters match the ones in data processing
# Otherwise the hash of the folders will be different!!

trainer:
max_epochs: 1
default_root_dir: /data
accelerator: gpu
gpus: 1
batch_size: 4
num_workers: 4

model:
lr: 1.0e-05
model_path: "/data/models/BEATs/BEATs_iter3_plus_AS2M.pt"
specaugment_params: null
# specaugment_params:
# application_ratio: 1.0
# time_mask: 40
# freq_mask: 40

###################################
# PARAMETERS FOR MODEL PREDICTION #
###################################
predict:
wav_save: False
overwrite: True
tolerance: 0
129 changes: 129 additions & 0 deletions dcase_fine_tune/FTBeats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import numpy as np

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.optimizer import Optimizer
from torchmetrics import Accuracy

import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info

from BEATs.BEATs import BEATs, BEATsConfig

class BEATsTransferLearningModel(pl.LightningModule):
def __init__(
self,
num_target_classes: int = 2,
lr: float = 1e-3,
lr_scheduler_gamma: float = 1e-1,
model_path: str = "/model/BEATs_iter3_plus_AS2M.pt",
ft_entire_network: bool = False, # Boolean on whether the classifier layer + BEATs should be fine-tuned
**kwargs,
) -> None:
"""TransferLearningModel.
Args:
lr: Initial learning rate
"""
super().__init__()
self.lr = lr
self.lr_scheduler_gamma = lr_scheduler_gamma
self.num_target_classes = num_target_classes
self.ft_entire_network = ft_entire_network

# Initialise BEATs model
self.checkpoint = torch.load(model_path)
self.cfg = BEATsConfig(
{
**self.checkpoint["cfg"],
"predictor_class": self.num_target_classes,
"finetuned_model": False,
}
)

self._build_model()

self.train_acc = Accuracy(
task="multiclass", num_classes=self.num_target_classes
)
self.valid_acc = Accuracy(
task="multiclass", num_classes=self.num_target_classes
)
self.save_hyperparameters()

def _build_model(self):
# 1. Load the pre-trained network
self.beats = BEATs(self.cfg)

print("LOADING THE PRE-TRAINED WEIGHTS")
self.beats.load_state_dict(self.checkpoint["model"])

# 2. Classifier
self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.predictor_class)

def extract_features(self, x, padding_mask=None):
if padding_mask != None:
x, _ = self.beats.extract_features(x, padding_mask)
else:
x, _ = self.beats.extract_features(x)
return x

def forward(self, x, padding_mask=None):
"""Forward pass. Return x"""

# Get the representation
if padding_mask != None:
x, _ = self.beats.extract_features(x, padding_mask)
else:
x, _ = self.beats.extract_features(x)

# Get the logits
x = self.fc(x)

# Mean pool the second layer
x = x.mean(dim=1)

return x

def loss(self, lprobs, labels):
self.loss_func = nn.CrossEntropyLoss()
return self.loss_func(lprobs, labels)

def training_step(self, batch, batch_idx):
# 1. Forward pass:
x, y_true = batch
y_probs = self.forward(x)

# 2. Compute loss
train_loss = self.loss(y_probs, y_true)

# 3. Compute accuracy:
self.log("train_acc", self.train_acc(y_probs, y_true), prog_bar=True)

return train_loss

def validation_step(self, batch, batch_idx):
# 1. Forward pass:
x, y_true = batch
y_probs = self.forward(x)

# 2. Compute loss
self.log("val_loss", self.loss(y_probs, y_true), prog_bar=True)

# 3. Compute accuracy:
self.log("val_acc", self.valid_acc(y_probs, y_true), prog_bar=True)

def configure_optimizers(self):
if self.ft_entire_network:
optimizer = optim.AdamW(
[{"params": self.beats.parameters()}, {"params": self.fc.parameters()}],
lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01
)
else:
optimizer = optim.AdamW(
self.fc.parameters(),
lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01
)

return optimizer
158 changes: 158 additions & 0 deletions dcase_fine_tune/FTDataModule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
from sklearn.preprocessing import LabelEncoder
import torch
import pandas as pd


class AudioDatasetDCASE(Dataset):
def __init__(
self,
data_frame,
label_dict=None,
):
self.data_frame = data_frame
self.label_encoder = LabelEncoder()
if label_dict is not None:
self.label_encoder.fit(list(label_dict.keys()))
self.label_dict = label_dict
else:
self.label_encoder.fit(self.data_frame["category"])
self.label_dict = dict(
zip(
self.label_encoder.classes_,
self.label_encoder.transform(self.label_encoder.classes_),
)
)

def __len__(self):
return len(self.data_frame)

def get_labels(self):
labels = []

for i in range(0, len(self.data_frame)):
label = self.data_frame.iloc[i]["category"]
label = self.label_encoder.transform([label])[0]
labels.append(label)

return labels

def __getitem__(self, idx):
input_feature = torch.Tensor(self.data_frame.iloc[idx]["feature"])
label = self.data_frame.iloc[idx]["category"]

# Encode label as integer
label = self.label_encoder.transform([label])[0]

return input_feature, label

def get_label_dict(self):
return self.label_dict

class DCASEDataModule(LightningDataModule):
def __init__(
self,
data_frame= pd.DataFrame,
batch_size = 4,
num_workers = 4,
tensor_length = 128,
**kwargs
):
super().__init__(**kwargs)
self.data_frame = data_frame
self.batch_size=batch_size
self.num_workers=num_workers
self.tensor_length = tensor_length
self.setup()

def setup(self, stage=None):
# load data
self.complete_dataset = AudioDatasetDCASE(
data_frame=self.data_frame,
)

def train_dataloader(self):
train_loader = DataLoader(self.complete_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=False,
collate_fn=self.collate_fn)
return train_loader

def get_label_dict(self):
label_dic = self.complete_dataset.get_label_dict()
return label_dic

def collate_fn(
self, input_data
):
true_class_ids = list({x[1] for x in input_data})
new_input = []
for x in input_data:
if x[0].shape[1] > self.tensor_length:
rand_start = torch.randint(
0, x[0].shape[1] - self.tensor_length, (1,)
)
new_input.append(
(x[0][:, rand_start : rand_start + self.tensor_length], x[1])
)
else:
new_input.append(x)

all_images = torch.cat([x[0].unsqueeze(0) for x in new_input])
all_labels = (torch.tensor([true_class_ids.index(x[1]) for x in input_data]))

return (all_images, all_labels)


class predictLoader():
def __init__(
self,
data_frame= pd.DataFrame,
batch_size = 1,
num_workers = 4,
tensor_length = 128
):
self.data_frame = data_frame
self.batch_size=batch_size
self.num_workers=num_workers
self.tensor_length = tensor_length
self.setup()

def setup(self, stage=None):
# load data
self.complete_dataset = AudioDatasetDCASE(
data_frame=self.data_frame,
)

def pred_dataloader(self):
pred_loader = DataLoader(self.complete_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=False,
collate_fn=self.collate_fn)
return pred_loader


def collate_fn(
self, input_data
):
true_class_ids = list({x[1] for x in input_data})
new_input = []
for x in input_data:
if x[0].shape[1] > self.tensor_length:
rand_start = torch.randint(
0, x[0].shape[1] - self.tensor_length, (1,)
)
new_input.append(
(x[0][:, rand_start : rand_start + self.tensor_length], x[1])
)
else:
new_input.append(x)

all_images = torch.cat([x[0].unsqueeze(0) for x in new_input])
all_labels = (torch.tensor([true_class_ids.index(x[1]) for x in input_data]))

return (all_images, all_labels)

Loading

0 comments on commit 372d845

Please sign in to comment.