-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ADD] Pipeline for fine tuning on DCASE
- Loading branch information
1 parent
4979946
commit 372d845
Showing
5 changed files
with
857 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
########################################### | ||
########################################### | ||
##### CONFIG FOR DCASE CHALLENGE 2024 ##### | ||
########################################### | ||
########################################### | ||
|
||
################################## | ||
# PARAMETERS FOR DATA PROCESSING # | ||
################################## | ||
data: | ||
target_fs: 16000 # used in preprocessing | ||
resample: True # used in preprocessing | ||
denoise: True # used in preprocessing | ||
normalize: True # used in preprocessing | ||
frame_length: 25.0 # used in preprocessing | ||
tensor_length: 128 # used in preprocessing | ||
overlap: 0.5 # used in preprocessing | ||
num_mel_bins: 128 # used in preprocessing | ||
max_segment_length: 1.0 # used in preprocessing | ||
status: validate # used in preprocessing, train or validate or evaluate | ||
set_type: "Validation_Set" | ||
|
||
|
||
################################# | ||
# PARAMETERS FOR MODEL TRAINING # | ||
################################# | ||
# Be sure the parameters match the ones in data processing | ||
# Otherwise the hash of the folders will be different!! | ||
|
||
trainer: | ||
max_epochs: 1 | ||
default_root_dir: /data | ||
accelerator: gpu | ||
gpus: 1 | ||
batch_size: 4 | ||
num_workers: 4 | ||
|
||
model: | ||
lr: 1.0e-05 | ||
model_path: "/data/models/BEATs/BEATs_iter3_plus_AS2M.pt" | ||
specaugment_params: null | ||
# specaugment_params: | ||
# application_ratio: 1.0 | ||
# time_mask: 40 | ||
# freq_mask: 40 | ||
|
||
################################### | ||
# PARAMETERS FOR MODEL PREDICTION # | ||
################################### | ||
predict: | ||
wav_save: False | ||
overwrite: True | ||
tolerance: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import numpy as np | ||
|
||
import torch | ||
from torch import nn, optim | ||
from torch.nn import functional as F | ||
from torch.optim.lr_scheduler import MultiStepLR | ||
from torch.optim.optimizer import Optimizer | ||
from torchmetrics import Accuracy | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.utilities.rank_zero import rank_zero_info | ||
|
||
from BEATs.BEATs import BEATs, BEATsConfig | ||
|
||
class BEATsTransferLearningModel(pl.LightningModule): | ||
def __init__( | ||
self, | ||
num_target_classes: int = 2, | ||
lr: float = 1e-3, | ||
lr_scheduler_gamma: float = 1e-1, | ||
model_path: str = "/model/BEATs_iter3_plus_AS2M.pt", | ||
ft_entire_network: bool = False, # Boolean on whether the classifier layer + BEATs should be fine-tuned | ||
**kwargs, | ||
) -> None: | ||
"""TransferLearningModel. | ||
Args: | ||
lr: Initial learning rate | ||
""" | ||
super().__init__() | ||
self.lr = lr | ||
self.lr_scheduler_gamma = lr_scheduler_gamma | ||
self.num_target_classes = num_target_classes | ||
self.ft_entire_network = ft_entire_network | ||
|
||
# Initialise BEATs model | ||
self.checkpoint = torch.load(model_path) | ||
self.cfg = BEATsConfig( | ||
{ | ||
**self.checkpoint["cfg"], | ||
"predictor_class": self.num_target_classes, | ||
"finetuned_model": False, | ||
} | ||
) | ||
|
||
self._build_model() | ||
|
||
self.train_acc = Accuracy( | ||
task="multiclass", num_classes=self.num_target_classes | ||
) | ||
self.valid_acc = Accuracy( | ||
task="multiclass", num_classes=self.num_target_classes | ||
) | ||
self.save_hyperparameters() | ||
|
||
def _build_model(self): | ||
# 1. Load the pre-trained network | ||
self.beats = BEATs(self.cfg) | ||
|
||
print("LOADING THE PRE-TRAINED WEIGHTS") | ||
self.beats.load_state_dict(self.checkpoint["model"]) | ||
|
||
# 2. Classifier | ||
self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.predictor_class) | ||
|
||
def extract_features(self, x, padding_mask=None): | ||
if padding_mask != None: | ||
x, _ = self.beats.extract_features(x, padding_mask) | ||
else: | ||
x, _ = self.beats.extract_features(x) | ||
return x | ||
|
||
def forward(self, x, padding_mask=None): | ||
"""Forward pass. Return x""" | ||
|
||
# Get the representation | ||
if padding_mask != None: | ||
x, _ = self.beats.extract_features(x, padding_mask) | ||
else: | ||
x, _ = self.beats.extract_features(x) | ||
|
||
# Get the logits | ||
x = self.fc(x) | ||
|
||
# Mean pool the second layer | ||
x = x.mean(dim=1) | ||
|
||
return x | ||
|
||
def loss(self, lprobs, labels): | ||
self.loss_func = nn.CrossEntropyLoss() | ||
return self.loss_func(lprobs, labels) | ||
|
||
def training_step(self, batch, batch_idx): | ||
# 1. Forward pass: | ||
x, y_true = batch | ||
y_probs = self.forward(x) | ||
|
||
# 2. Compute loss | ||
train_loss = self.loss(y_probs, y_true) | ||
|
||
# 3. Compute accuracy: | ||
self.log("train_acc", self.train_acc(y_probs, y_true), prog_bar=True) | ||
|
||
return train_loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
# 1. Forward pass: | ||
x, y_true = batch | ||
y_probs = self.forward(x) | ||
|
||
# 2. Compute loss | ||
self.log("val_loss", self.loss(y_probs, y_true), prog_bar=True) | ||
|
||
# 3. Compute accuracy: | ||
self.log("val_acc", self.valid_acc(y_probs, y_true), prog_bar=True) | ||
|
||
def configure_optimizers(self): | ||
if self.ft_entire_network: | ||
optimizer = optim.AdamW( | ||
[{"params": self.beats.parameters()}, {"params": self.fc.parameters()}], | ||
lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01 | ||
) | ||
else: | ||
optimizer = optim.AdamW( | ||
self.fc.parameters(), | ||
lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01 | ||
) | ||
|
||
return optimizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
from torch.utils.data import Dataset, DataLoader | ||
from pytorch_lightning import LightningDataModule | ||
from sklearn.preprocessing import LabelEncoder | ||
import torch | ||
import pandas as pd | ||
|
||
|
||
class AudioDatasetDCASE(Dataset): | ||
def __init__( | ||
self, | ||
data_frame, | ||
label_dict=None, | ||
): | ||
self.data_frame = data_frame | ||
self.label_encoder = LabelEncoder() | ||
if label_dict is not None: | ||
self.label_encoder.fit(list(label_dict.keys())) | ||
self.label_dict = label_dict | ||
else: | ||
self.label_encoder.fit(self.data_frame["category"]) | ||
self.label_dict = dict( | ||
zip( | ||
self.label_encoder.classes_, | ||
self.label_encoder.transform(self.label_encoder.classes_), | ||
) | ||
) | ||
|
||
def __len__(self): | ||
return len(self.data_frame) | ||
|
||
def get_labels(self): | ||
labels = [] | ||
|
||
for i in range(0, len(self.data_frame)): | ||
label = self.data_frame.iloc[i]["category"] | ||
label = self.label_encoder.transform([label])[0] | ||
labels.append(label) | ||
|
||
return labels | ||
|
||
def __getitem__(self, idx): | ||
input_feature = torch.Tensor(self.data_frame.iloc[idx]["feature"]) | ||
label = self.data_frame.iloc[idx]["category"] | ||
|
||
# Encode label as integer | ||
label = self.label_encoder.transform([label])[0] | ||
|
||
return input_feature, label | ||
|
||
def get_label_dict(self): | ||
return self.label_dict | ||
|
||
class DCASEDataModule(LightningDataModule): | ||
def __init__( | ||
self, | ||
data_frame= pd.DataFrame, | ||
batch_size = 4, | ||
num_workers = 4, | ||
tensor_length = 128, | ||
**kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.data_frame = data_frame | ||
self.batch_size=batch_size | ||
self.num_workers=num_workers | ||
self.tensor_length = tensor_length | ||
self.setup() | ||
|
||
def setup(self, stage=None): | ||
# load data | ||
self.complete_dataset = AudioDatasetDCASE( | ||
data_frame=self.data_frame, | ||
) | ||
|
||
def train_dataloader(self): | ||
train_loader = DataLoader(self.complete_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
pin_memory=False, | ||
collate_fn=self.collate_fn) | ||
return train_loader | ||
|
||
def get_label_dict(self): | ||
label_dic = self.complete_dataset.get_label_dict() | ||
return label_dic | ||
|
||
def collate_fn( | ||
self, input_data | ||
): | ||
true_class_ids = list({x[1] for x in input_data}) | ||
new_input = [] | ||
for x in input_data: | ||
if x[0].shape[1] > self.tensor_length: | ||
rand_start = torch.randint( | ||
0, x[0].shape[1] - self.tensor_length, (1,) | ||
) | ||
new_input.append( | ||
(x[0][:, rand_start : rand_start + self.tensor_length], x[1]) | ||
) | ||
else: | ||
new_input.append(x) | ||
|
||
all_images = torch.cat([x[0].unsqueeze(0) for x in new_input]) | ||
all_labels = (torch.tensor([true_class_ids.index(x[1]) for x in input_data])) | ||
|
||
return (all_images, all_labels) | ||
|
||
|
||
class predictLoader(): | ||
def __init__( | ||
self, | ||
data_frame= pd.DataFrame, | ||
batch_size = 1, | ||
num_workers = 4, | ||
tensor_length = 128 | ||
): | ||
self.data_frame = data_frame | ||
self.batch_size=batch_size | ||
self.num_workers=num_workers | ||
self.tensor_length = tensor_length | ||
self.setup() | ||
|
||
def setup(self, stage=None): | ||
# load data | ||
self.complete_dataset = AudioDatasetDCASE( | ||
data_frame=self.data_frame, | ||
) | ||
|
||
def pred_dataloader(self): | ||
pred_loader = DataLoader(self.complete_dataset, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
pin_memory=False, | ||
collate_fn=self.collate_fn) | ||
return pred_loader | ||
|
||
|
||
def collate_fn( | ||
self, input_data | ||
): | ||
true_class_ids = list({x[1] for x in input_data}) | ||
new_input = [] | ||
for x in input_data: | ||
if x[0].shape[1] > self.tensor_length: | ||
rand_start = torch.randint( | ||
0, x[0].shape[1] - self.tensor_length, (1,) | ||
) | ||
new_input.append( | ||
(x[0][:, rand_start : rand_start + self.tensor_length], x[1]) | ||
) | ||
else: | ||
new_input.append(x) | ||
|
||
all_images = torch.cat([x[0].unsqueeze(0) for x in new_input]) | ||
all_labels = (torch.tensor([true_class_ids.index(x[1]) for x in input_data])) | ||
|
||
return (all_images, all_labels) | ||
|
Oops, something went wrong.