-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FIX] Decompose evaluate DCASE into simpler functions
- Loading branch information
1 parent
fe3b672
commit 64b99f8
Showing
3 changed files
with
278 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import torch | ||
import pytorch_lightning as pl | ||
import numpy as np | ||
import pandas as pd | ||
import tqdm | ||
|
||
from datamodules.TestDCASEDataModule import DCASEDataModule | ||
from prototypicalbeats.prototraining import ProtoBEATsModel | ||
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score | ||
|
||
def to_dataframe(features, labels): | ||
# Load the saved array and map the features and labels into a single dataframe | ||
input_features = np.load(features) | ||
labels = np.load(labels) | ||
list_input_features = [input_features[key] for key in input_features.files] | ||
df = pd.DataFrame({"feature": list_input_features, "category": labels}) | ||
|
||
return df | ||
|
||
|
||
def train_model( | ||
model_type="pann", | ||
datamodule_class=DCASEDataModule, | ||
max_epochs=1, | ||
enable_model_summary=False, | ||
num_sanity_val_steps=0, | ||
seed=42, | ||
pretrained_model=None, | ||
state=None, | ||
beats_path="/data/model/BEATs/BEATs_iter3_plus_AS2M.pt" | ||
): | ||
# create the lightning trainer object | ||
trainer = pl.Trainer( | ||
max_epochs=max_epochs, | ||
enable_model_summary=enable_model_summary, | ||
num_sanity_val_steps=num_sanity_val_steps, | ||
deterministic=True, | ||
gpus=1, | ||
auto_select_gpus=True, | ||
callbacks=[ | ||
pl.callbacks.LearningRateMonitor(logging_interval="step"), | ||
pl.callbacks.EarlyStopping( | ||
monitor="train_acc", mode="max", patience=max_epochs | ||
), | ||
], | ||
default_root_dir="logs/", | ||
enable_checkpointing=False | ||
) | ||
|
||
# create the model object | ||
model = ProtoBEATsModel(model_type=model_type) | ||
|
||
if pretrained_model: | ||
# Load the pretrained model | ||
try: | ||
pretrained_model = ProtoBEATsModel.load_from_checkpoint(pretrained_model) | ||
except KeyError: | ||
print( | ||
"Failed to load the pretrained model. Please check the checkpoint file." | ||
) | ||
return None | ||
|
||
# train the model | ||
trainer.fit(model, datamodule=datamodule_class) | ||
|
||
return model | ||
|
||
|
||
def training(model_type, pretrained_model, state, custom_datamodule, max_epoch, beats_path): | ||
|
||
model = train_model( | ||
model_type, | ||
custom_datamodule, | ||
max_epochs=max_epoch, | ||
enable_model_summary=False, | ||
num_sanity_val_steps=0, | ||
seed=42, | ||
pretrained_model=pretrained_model, | ||
state=state, | ||
beats_path=beats_path | ||
) | ||
|
||
return model | ||
|
||
|
||
def get_proto_coordinates(model, model_type, support_data, support_labels, n_way): | ||
|
||
if model_type == "beats": | ||
z_supports, _ = model.get_embeddings(support_data, padding_mask=None) | ||
else: | ||
z_supports = model.get_embeddings(support_data, padding_mask=None) | ||
|
||
# Get the coordinates of the NEG and POS prototypes | ||
prototypes = model.get_prototypes( | ||
z_support=z_supports, support_labels=support_labels, n_way=n_way | ||
) | ||
|
||
# Return the coordinates of the prototypes and the z_supports | ||
return prototypes, z_supports | ||
|
||
|
||
|
||
def compute_z_scores(distance, mean_support, sd_support): | ||
z_score = (distance - mean_support) / sd_support | ||
return z_score | ||
|
||
|
||
def convert_z_to_p(z_score): | ||
import scipy.stats as stats | ||
|
||
p_value = 1 - stats.norm.cdf(z_score) | ||
return p_value | ||
|
||
|
||
def euclidean_distance(x1, x2): | ||
return torch.sqrt(torch.sum((x1 - x2) ** 2, dim=1)) | ||
|
||
|
||
def calculate_distance(model_type, z_query, z_proto): | ||
# Compute the euclidean distance from queries to prototypes | ||
dists = [] | ||
for q in z_query: | ||
q_dists = euclidean_distance(q.unsqueeze(0), z_proto) | ||
dists.append( | ||
q_dists.unsqueeze(0) | ||
) # Contrary to prototraining I need to add a dimension to store the | ||
dists = torch.cat(dists, dim=0) | ||
|
||
if model_type == "beats": | ||
# We drop the last dimension without changing the gradients | ||
dists = dists.mean(dim=2).squeeze() | ||
|
||
scores = -dists | ||
|
||
return scores, dists | ||
|
||
|
||
def compute_scores(predicted_labels, gt_labels): | ||
acc = accuracy_score(gt_labels, predicted_labels) | ||
recall = recall_score(gt_labels, predicted_labels) | ||
f1score = f1_score(gt_labels, predicted_labels) | ||
precision = precision_score(gt_labels, predicted_labels) | ||
print(f"Accurracy: {acc}") | ||
print(f"Recall: {recall}") | ||
print(f"precision: {precision}") | ||
print(f"F1 score: {f1score}") | ||
return acc, recall, precision, f1score | ||
|
||
def merge_preds(df, tolerence, tensor_length): | ||
df["group"] = ( | ||
df["Starttime"] > (df["Endtime"] + tolerence * tensor_length).shift().cummax() | ||
).cumsum() | ||
result = df.groupby("group").agg({"Starttime": "min", "Endtime": "max"}) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import os | ||
import librosa | ||
import numpy as np | ||
import pandas as pd | ||
|
||
def write_results(predicted_labels, begins, ends): | ||
df_out = pd.DataFrame( | ||
{ | ||
"Starttime": begins, | ||
"Endtime": ends, | ||
"PredLabels": predicted_labels, | ||
} | ||
) | ||
|
||
return df_out | ||
|
||
def write_wav( | ||
files, | ||
cfg, | ||
gt_labels, | ||
pred_labels, | ||
distances_to_pos, | ||
z_scores_pos, | ||
target_fs=16000, | ||
target_path=None, | ||
frame_shift=1, | ||
support_spectrograms=None | ||
): | ||
from scipy.io import wavfile | ||
|
||
# Some path management | ||
filename = ( | ||
os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] + ".wav" | ||
) | ||
# Return the final product | ||
output = os.path.join(target_path, filename) | ||
|
||
# Find the filepath for the file being analysed | ||
for f in files: | ||
if os.path.basename(f) == filename: | ||
print(os.path.basename(f)) | ||
print(filename) | ||
arr, _ = librosa.load(f, sr=target_fs, mono=True) | ||
break | ||
|
||
print(len(arr)) | ||
print(len(gt_labels)) | ||
print(len(pred_labels)) | ||
|
||
# Expand the dimensions | ||
gt_labels = np.repeat( | ||
np.squeeze(gt_labels, axis=1).T, | ||
int( | ||
cfg["data"]["tensor_length"] | ||
* cfg["data"]["overlap"] | ||
* target_fs | ||
* frame_shift | ||
/ 1000 | ||
), | ||
) | ||
pred_labels = np.repeat( | ||
pred_labels.T, | ||
int( | ||
cfg["data"]["tensor_length"] | ||
* cfg["data"]["overlap"] | ||
* target_fs | ||
* frame_shift | ||
/ 1000 | ||
), | ||
) | ||
distances_to_pos = np.repeat( | ||
distances_to_pos.T, | ||
int( | ||
cfg["data"]["tensor_length"] | ||
* cfg["data"]["overlap"] | ||
* target_fs | ||
* frame_shift | ||
/ 1000 | ||
), | ||
) | ||
z_scores_pos = np.repeat( | ||
z_scores_pos.T, | ||
int( | ||
cfg["data"]["tensor_length"] | ||
* cfg["data"]["overlap"] | ||
* target_fs | ||
* frame_shift | ||
/ 1000 | ||
), | ||
) | ||
|
||
arr = np.pad( | ||
arr, (0, len(gt_labels) - len(arr)), "constant", constant_values=(0,) | ||
) | ||
|
||
|
||
# pad with zeros | ||
#gt_labels = np.pad( | ||
# gt_labels, (0, len(gt_labels) - len(arr)), "constant", constant_values=(0,) | ||
#) | ||
#pred_labels = np.pad( | ||
# pred_labels, (0, len(pred_labels) - len(arr) ), "constant", constant_values=(0,) | ||
#) | ||
#distances_to_pos = np.pad( | ||
# distances_to_pos, | ||
# (0, len(distances_to_pos) - len(arr)), | ||
# "constant", | ||
# constant_values=(0,), | ||
#) | ||
#z_scores_pos = np.pad( | ||
# z_scores_pos, | ||
# (0, len(z_scores_pos) - len(arr)), | ||
# "constant", | ||
# constant_values=(0,), | ||
#) | ||
|
||
# Write the results | ||
result_wav = np.vstack( | ||
(arr, gt_labels, pred_labels, distances_to_pos / 10, z_scores_pos) | ||
) | ||
wavfile.write(output, target_fs, result_wav.T) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters