Skip to content

Commit

Permalink
[FIX] Decompose evaluate DCASE into simpler functions
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Feb 16, 2024
1 parent fe3b672 commit 64b99f8
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 48 deletions.
154 changes: 154 additions & 0 deletions evaluate/_utils_compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
import pytorch_lightning as pl
import numpy as np
import pandas as pd
import tqdm

from datamodules.TestDCASEDataModule import DCASEDataModule
from prototypicalbeats.prototraining import ProtoBEATsModel
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score

def to_dataframe(features, labels):
# Load the saved array and map the features and labels into a single dataframe
input_features = np.load(features)
labels = np.load(labels)
list_input_features = [input_features[key] for key in input_features.files]
df = pd.DataFrame({"feature": list_input_features, "category": labels})

return df


def train_model(
model_type="pann",
datamodule_class=DCASEDataModule,
max_epochs=1,
enable_model_summary=False,
num_sanity_val_steps=0,
seed=42,
pretrained_model=None,
state=None,
beats_path="/data/model/BEATs/BEATs_iter3_plus_AS2M.pt"
):
# create the lightning trainer object
trainer = pl.Trainer(
max_epochs=max_epochs,
enable_model_summary=enable_model_summary,
num_sanity_val_steps=num_sanity_val_steps,
deterministic=True,
gpus=1,
auto_select_gpus=True,
callbacks=[
pl.callbacks.LearningRateMonitor(logging_interval="step"),
pl.callbacks.EarlyStopping(
monitor="train_acc", mode="max", patience=max_epochs
),
],
default_root_dir="logs/",
enable_checkpointing=False
)

# create the model object
model = ProtoBEATsModel(model_type=model_type)

if pretrained_model:
# Load the pretrained model
try:
pretrained_model = ProtoBEATsModel.load_from_checkpoint(pretrained_model)
except KeyError:
print(
"Failed to load the pretrained model. Please check the checkpoint file."
)
return None

# train the model
trainer.fit(model, datamodule=datamodule_class)

return model


def training(model_type, pretrained_model, state, custom_datamodule, max_epoch, beats_path):

model = train_model(
model_type,
custom_datamodule,
max_epochs=max_epoch,
enable_model_summary=False,
num_sanity_val_steps=0,
seed=42,
pretrained_model=pretrained_model,
state=state,
beats_path=beats_path
)

return model


def get_proto_coordinates(model, model_type, support_data, support_labels, n_way):

if model_type == "beats":
z_supports, _ = model.get_embeddings(support_data, padding_mask=None)
else:
z_supports = model.get_embeddings(support_data, padding_mask=None)

# Get the coordinates of the NEG and POS prototypes
prototypes = model.get_prototypes(
z_support=z_supports, support_labels=support_labels, n_way=n_way
)

# Return the coordinates of the prototypes and the z_supports
return prototypes, z_supports



def compute_z_scores(distance, mean_support, sd_support):
z_score = (distance - mean_support) / sd_support
return z_score


def convert_z_to_p(z_score):
import scipy.stats as stats

p_value = 1 - stats.norm.cdf(z_score)
return p_value


def euclidean_distance(x1, x2):
return torch.sqrt(torch.sum((x1 - x2) ** 2, dim=1))


def calculate_distance(model_type, z_query, z_proto):
# Compute the euclidean distance from queries to prototypes
dists = []
for q in z_query:
q_dists = euclidean_distance(q.unsqueeze(0), z_proto)
dists.append(
q_dists.unsqueeze(0)
) # Contrary to prototraining I need to add a dimension to store the
dists = torch.cat(dists, dim=0)

if model_type == "beats":
# We drop the last dimension without changing the gradients
dists = dists.mean(dim=2).squeeze()

scores = -dists

return scores, dists


def compute_scores(predicted_labels, gt_labels):
acc = accuracy_score(gt_labels, predicted_labels)
recall = recall_score(gt_labels, predicted_labels)
f1score = f1_score(gt_labels, predicted_labels)
precision = precision_score(gt_labels, predicted_labels)
print(f"Accurracy: {acc}")
print(f"Recall: {recall}")
print(f"precision: {precision}")
print(f"F1 score: {f1score}")
return acc, recall, precision, f1score

def merge_preds(df, tolerence, tensor_length):
df["group"] = (
df["Starttime"] > (df["Endtime"] + tolerence * tensor_length).shift().cummax()
).cumsum()
result = df.groupby("group").agg({"Starttime": "min", "Endtime": "max"})
return result
121 changes: 121 additions & 0 deletions evaluate/_utils_writing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import librosa
import numpy as np
import pandas as pd

def write_results(predicted_labels, begins, ends):
df_out = pd.DataFrame(
{
"Starttime": begins,
"Endtime": ends,
"PredLabels": predicted_labels,
}
)

return df_out

def write_wav(
files,
cfg,
gt_labels,
pred_labels,
distances_to_pos,
z_scores_pos,
target_fs=16000,
target_path=None,
frame_shift=1,
support_spectrograms=None
):
from scipy.io import wavfile

# Some path management
filename = (
os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] + ".wav"
)
# Return the final product
output = os.path.join(target_path, filename)

# Find the filepath for the file being analysed
for f in files:
if os.path.basename(f) == filename:
print(os.path.basename(f))
print(filename)
arr, _ = librosa.load(f, sr=target_fs, mono=True)
break

print(len(arr))
print(len(gt_labels))
print(len(pred_labels))

# Expand the dimensions
gt_labels = np.repeat(
np.squeeze(gt_labels, axis=1).T,
int(
cfg["data"]["tensor_length"]
* cfg["data"]["overlap"]
* target_fs
* frame_shift
/ 1000
),
)
pred_labels = np.repeat(
pred_labels.T,
int(
cfg["data"]["tensor_length"]
* cfg["data"]["overlap"]
* target_fs
* frame_shift
/ 1000
),
)
distances_to_pos = np.repeat(
distances_to_pos.T,
int(
cfg["data"]["tensor_length"]
* cfg["data"]["overlap"]
* target_fs
* frame_shift
/ 1000
),
)
z_scores_pos = np.repeat(
z_scores_pos.T,
int(
cfg["data"]["tensor_length"]
* cfg["data"]["overlap"]
* target_fs
* frame_shift
/ 1000
),
)

arr = np.pad(
arr, (0, len(gt_labels) - len(arr)), "constant", constant_values=(0,)
)


# pad with zeros
#gt_labels = np.pad(
# gt_labels, (0, len(gt_labels) - len(arr)), "constant", constant_values=(0,)
#)
#pred_labels = np.pad(
# pred_labels, (0, len(pred_labels) - len(arr) ), "constant", constant_values=(0,)
#)
#distances_to_pos = np.pad(
# distances_to_pos,
# (0, len(distances_to_pos) - len(arr)),
# "constant",
# constant_values=(0,),
#)
#z_scores_pos = np.pad(
# z_scores_pos,
# (0, len(z_scores_pos) - len(arr)),
# "constant",
# constant_values=(0,),
#)

# Write the results
result_wav = np.vstack(
(arr, gt_labels, pred_labels, distances_to_pos / 10, z_scores_pos)
)
wavfile.write(output, target_fs, result_wav.T)
51 changes: 3 additions & 48 deletions evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def write_wav(
target_fs=16000,
target_path=None,
frame_shift=1,
support_spectrograms=None
):
from scipy.io import wavfile

Expand Down Expand Up @@ -731,6 +732,7 @@ def main(cfg: DictConfig):
target_fs=cfg["data"]["target_fs"],
target_path=target_path,
frame_shift=meta_df.loc[filename, "frame_shift"],
support_spectrograms=support_spectrograms
)

# Return the final product
Expand All @@ -753,54 +755,7 @@ def main(cfg: DictConfig):


if __name__ == "__main__":
#parser = argparse.ArgumentParser()

#parser.add_argument(
# "--config",
# help="Path to the config file",
# required=False,
# default="./CONFIG_PREDICT.yaml",
# type=str,
#)

# parser.add_argument(
# "--wav_save",
# help="Should the results be also saved as a .wav file?",
# default=False,
# required=False,
# action="store_true",
# )

# parser.add_argument(
# "--overwrite",
# help="Remove earlier obtained results at start",
# default=False,
# required=False,
# action="store_true",
# )
#
# parser.add_argument(
# "--n_self_detected_supports",
# help="Remove earlier obtained results at start",
# default=0,
# required=False,
# type=int,
# )

# parser.add_argument(
# "--tolerance",
# help="How many non detection in detection still counts for a detection",
# default=0,
# required=False,
# type=int,
# )

#cli_args = parser.parse_args()

main(),
#cli_args.overwrite,
#cli_args.tolerance,
#cli_args.n_self_detected_supports,
#cli_args.wav_save)



0 comments on commit 64b99f8

Please sign in to comment.