Skip to content

Commit

Permalink
[ADD] class imbalance + stratified train_test_split
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Mar 1, 2024
1 parent ee2fb7f commit de13bca
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 15 deletions.
6 changes: 4 additions & 2 deletions dcase_fine_tune/CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ data:
# Otherwise the hash of the folders will be different!!

trainer:
max_epochs: 500
max_epochs: 10000
default_root_dir: /data/lightning_logs/BEATs
accelerator: gpu
gpus: 1
batch_size: 32
batch_size: 64
num_workers: 4
patience: 20
min_sample_per_category: 10

model:
lr: 1.0e-05
Expand Down
2 changes: 1 addition & 1 deletion dcase_fine_tune/FTBeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
lr: float = 1e-3,
lr_scheduler_gamma: float = 1e-1,
model_path: str = None,
ft_entire_network: bool = False, # Boolean on whether the classifier layer + BEATs should be fine-tuned
ft_entire_network: bool = True, # Boolean on whether the classifier layer + BEATs should be fine-tuned
**kwargs,
) -> None:
"""TransferLearningModel.
Expand Down
71 changes: 62 additions & 9 deletions dcase_fine_tune/FTDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from sklearn.model_selection import train_test_split
import torch
import pandas as pd
import numpy as np

from torch.utils.data import WeightedRandomSampler


class AudioDatasetDCASE(Dataset):
Expand Down Expand Up @@ -51,6 +54,30 @@ def __getitem__(self, idx):
def get_label_dict(self):
return self.label_dict

class AudioDatasetDCASEV2(Dataset):
def __init__(
self,
data_frame,
):
self.data_frame = data_frame

def __len__(self):
return len(self.data_frame)

def get_labels(self):
labels = []

for i in range(0, len(self.data_frame)):
label = self.data_frame.iloc[i]["category"]
labels.append(label)

return labels

def __getitem__(self, idx):
input_feature = torch.Tensor(self.data_frame.iloc[idx]["feature"])
label = self.data_frame.iloc[idx]["category"]

return input_feature, label

class DCASEDataModule(LightningDataModule):
def __init__(
Expand All @@ -60,6 +87,7 @@ def __init__(
num_workers = 4,
tensor_length = 128,
test_size = 0,
min_sample_per_category = 5,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -68,33 +96,58 @@ def __init__(
self.num_workers=num_workers
self.tensor_length = tensor_length
self.test_size = test_size
self.min_sample_per_category = min_sample_per_category

self.label_encoder = LabelEncoder()
self.label_encoder.fit(self.data_frame["category"])
self.label_dict = dict(
zip(
self.label_encoder.classes_,
self.label_encoder.transform(self.label_encoder.classes_),
)
)

self.setup()
self.divide_train_val()

def setup(self, stage=None):
# load data
self.complete_dataset = AudioDatasetDCASE(
data_frame=self.data_frame,
)
self.data_frame["category"] = self.label_encoder.fit_transform(self.data_frame["category"])
self.complete_dataset = AudioDatasetDCASEV2(data_frame=self.data_frame)

def divide_train_val(self):

value_counts = self.data_frame["category"].value_counts()
self.num_target_classes = len(self.data_frame["category"].unique())

# Separate into training and validation set
train_indices, validation_indices, _, _ = train_test_split(
range(len(self.complete_dataset)),
self.complete_dataset.get_labels(),
test_size=self.test_size,
random_state=42,
random_state=1,
stratify=self.data_frame["category"]
)

data_frame_train = self.data_frame.loc[train_indices]
data_frame_train.reset_index(drop=True, inplace=True)

# deal with class imbalance
value_counts = data_frame_train["category"].value_counts()
weight = 1. / value_counts
samples_weight = np.array([weight[t] for t in data_frame_train["category"]])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
self.sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

data_frame_validation = self.data_frame.loc[validation_indices]
data_frame_validation.reset_index(drop=True, inplace=True)

# generate subset based on indices
self.train_set = AudioDatasetDCASE(
self.train_set = AudioDatasetDCASEV2(
data_frame=data_frame_train,
)
self.val_set = AudioDatasetDCASE(
self.val_set = AudioDatasetDCASEV2(
data_frame=data_frame_validation,
)

Expand All @@ -103,16 +156,16 @@ def train_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=False,
shuffle=True,
collate_fn=self.collate_fn)
collate_fn=self.collate_fn,
sampler=self.sampler
)
return train_loader

def val_dataloader(self):
val_loader = DataLoader(self.val_set,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=False,
shuffle=True,
collate_fn=self.collate_fn)
return val_loader

Expand Down
14 changes: 11 additions & 3 deletions dcase_fine_tune/FTtrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def train_model(
model,
datamodule_class,
max_epochs,
patience,
num_sanity_val_steps=0,
root_dir="logs/"
):
Expand All @@ -29,7 +30,7 @@ def train_model(
auto_select_gpus=True,
callbacks=[
pl.callbacks.LearningRateMonitor(logging_interval="step"),
pl.callbacks.EarlyStopping(monitor="train_acc", mode="max", patience=max_epochs),
pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=patience),
],
default_root_dir=root_dir,
enable_checkpointing=True
Expand Down Expand Up @@ -80,15 +81,22 @@ def main(cfg: DictConfig):
batch_size=cfg["trainer"]["batch_size"],
num_workers=cfg["trainer"]["num_workers"],
tensor_length=cfg["data"]["tensor_length"],
test_size=0.2)
test_size=0.2,
min_sample_per_category=cfg["trainer"]["min_sample_per_category"])

# create the model object
num_target_classes = len(df["category"].unique())
print(num_target_classes)

model = BEATsTransferLearningModel(model_path=cfg["model"]["model_path"],
num_target_classes=num_target_classes,
lr=cfg["model"]["lr"])

train_model(model, Loader, cfg["trainer"]["max_epochs"], root_dir=cfg["trainer"]["default_root_dir"])
train_model(model,
Loader,
cfg["trainer"]["max_epochs"],
patience=cfg["trainer"]["patience"],
root_dir=cfg["trainer"]["default_root_dir"])

if __name__ == "__main__":
main()

0 comments on commit de13bca

Please sign in to comment.