diff --git a/evaluate/_utils_compute.py b/evaluate/_utils_compute.py new file mode 100644 index 0000000..96fcfff --- /dev/null +++ b/evaluate/_utils_compute.py @@ -0,0 +1,154 @@ +import torch +import pytorch_lightning as pl +import numpy as np +import pandas as pd +import tqdm + +from datamodules.TestDCASEDataModule import DCASEDataModule +from prototypicalbeats.prototraining import ProtoBEATsModel +from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score + +def to_dataframe(features, labels): + # Load the saved array and map the features and labels into a single dataframe + input_features = np.load(features) + labels = np.load(labels) + list_input_features = [input_features[key] for key in input_features.files] + df = pd.DataFrame({"feature": list_input_features, "category": labels}) + + return df + + +def train_model( + model_type="pann", + datamodule_class=DCASEDataModule, + max_epochs=1, + enable_model_summary=False, + num_sanity_val_steps=0, + seed=42, + pretrained_model=None, + state=None, + beats_path="/data/model/BEATs/BEATs_iter3_plus_AS2M.pt" +): + # create the lightning trainer object + trainer = pl.Trainer( + max_epochs=max_epochs, + enable_model_summary=enable_model_summary, + num_sanity_val_steps=num_sanity_val_steps, + deterministic=True, + gpus=1, + auto_select_gpus=True, + callbacks=[ + pl.callbacks.LearningRateMonitor(logging_interval="step"), + pl.callbacks.EarlyStopping( + monitor="train_acc", mode="max", patience=max_epochs + ), + ], + default_root_dir="logs/", + enable_checkpointing=False + ) + + # create the model object + model = ProtoBEATsModel(model_type=model_type) + + if pretrained_model: + # Load the pretrained model + try: + pretrained_model = ProtoBEATsModel.load_from_checkpoint(pretrained_model) + except KeyError: + print( + "Failed to load the pretrained model. Please check the checkpoint file." + ) + return None + + # train the model + trainer.fit(model, datamodule=datamodule_class) + + return model + + +def training(model_type, pretrained_model, state, custom_datamodule, max_epoch, beats_path): + + model = train_model( + model_type, + custom_datamodule, + max_epochs=max_epoch, + enable_model_summary=False, + num_sanity_val_steps=0, + seed=42, + pretrained_model=pretrained_model, + state=state, + beats_path=beats_path + ) + + return model + + +def get_proto_coordinates(model, model_type, support_data, support_labels, n_way): + + if model_type == "beats": + z_supports, _ = model.get_embeddings(support_data, padding_mask=None) + else: + z_supports = model.get_embeddings(support_data, padding_mask=None) + + # Get the coordinates of the NEG and POS prototypes + prototypes = model.get_prototypes( + z_support=z_supports, support_labels=support_labels, n_way=n_way + ) + + # Return the coordinates of the prototypes and the z_supports + return prototypes, z_supports + + + +def compute_z_scores(distance, mean_support, sd_support): + z_score = (distance - mean_support) / sd_support + return z_score + + +def convert_z_to_p(z_score): + import scipy.stats as stats + + p_value = 1 - stats.norm.cdf(z_score) + return p_value + + +def euclidean_distance(x1, x2): + return torch.sqrt(torch.sum((x1 - x2) ** 2, dim=1)) + + +def calculate_distance(model_type, z_query, z_proto): + # Compute the euclidean distance from queries to prototypes + dists = [] + for q in z_query: + q_dists = euclidean_distance(q.unsqueeze(0), z_proto) + dists.append( + q_dists.unsqueeze(0) + ) # Contrary to prototraining I need to add a dimension to store the + dists = torch.cat(dists, dim=0) + + if model_type == "beats": + # We drop the last dimension without changing the gradients + dists = dists.mean(dim=2).squeeze() + + scores = -dists + + return scores, dists + + +def compute_scores(predicted_labels, gt_labels): + acc = accuracy_score(gt_labels, predicted_labels) + recall = recall_score(gt_labels, predicted_labels) + f1score = f1_score(gt_labels, predicted_labels) + precision = precision_score(gt_labels, predicted_labels) + print(f"Accurracy: {acc}") + print(f"Recall: {recall}") + print(f"precision: {precision}") + print(f"F1 score: {f1score}") + return acc, recall, precision, f1score + +def merge_preds(df, tolerence, tensor_length): + df["group"] = ( + df["Starttime"] > (df["Endtime"] + tolerence * tensor_length).shift().cummax() + ).cumsum() + result = df.groupby("group").agg({"Starttime": "min", "Endtime": "max"}) + return result \ No newline at end of file diff --git a/evaluate/_utils_writing.py b/evaluate/_utils_writing.py new file mode 100644 index 0000000..9b9a684 --- /dev/null +++ b/evaluate/_utils_writing.py @@ -0,0 +1,121 @@ +import os +import librosa +import numpy as np +import pandas as pd + +def write_results(predicted_labels, begins, ends): + df_out = pd.DataFrame( + { + "Starttime": begins, + "Endtime": ends, + "PredLabels": predicted_labels, + } + ) + + return df_out + +def write_wav( + files, + cfg, + gt_labels, + pred_labels, + distances_to_pos, + z_scores_pos, + target_fs=16000, + target_path=None, + frame_shift=1, + support_spectrograms=None +): + from scipy.io import wavfile + + # Some path management + filename = ( + os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] + ".wav" + ) + # Return the final product + output = os.path.join(target_path, filename) + + # Find the filepath for the file being analysed + for f in files: + if os.path.basename(f) == filename: + print(os.path.basename(f)) + print(filename) + arr, _ = librosa.load(f, sr=target_fs, mono=True) + break + + print(len(arr)) + print(len(gt_labels)) + print(len(pred_labels)) + + # Expand the dimensions + gt_labels = np.repeat( + np.squeeze(gt_labels, axis=1).T, + int( + cfg["data"]["tensor_length"] + * cfg["data"]["overlap"] + * target_fs + * frame_shift + / 1000 + ), + ) + pred_labels = np.repeat( + pred_labels.T, + int( + cfg["data"]["tensor_length"] + * cfg["data"]["overlap"] + * target_fs + * frame_shift + / 1000 + ), + ) + distances_to_pos = np.repeat( + distances_to_pos.T, + int( + cfg["data"]["tensor_length"] + * cfg["data"]["overlap"] + * target_fs + * frame_shift + / 1000 + ), + ) + z_scores_pos = np.repeat( + z_scores_pos.T, + int( + cfg["data"]["tensor_length"] + * cfg["data"]["overlap"] + * target_fs + * frame_shift + / 1000 + ), + ) + + arr = np.pad( + arr, (0, len(gt_labels) - len(arr)), "constant", constant_values=(0,) + ) + + + # pad with zeros + #gt_labels = np.pad( + # gt_labels, (0, len(gt_labels) - len(arr)), "constant", constant_values=(0,) + #) + #pred_labels = np.pad( + # pred_labels, (0, len(pred_labels) - len(arr) ), "constant", constant_values=(0,) + #) + #distances_to_pos = np.pad( + # distances_to_pos, + # (0, len(distances_to_pos) - len(arr)), + # "constant", + # constant_values=(0,), + #) + #z_scores_pos = np.pad( + # z_scores_pos, + # (0, len(z_scores_pos) - len(arr)), + # "constant", + # constant_values=(0,), + #) + + # Write the results + result_wav = np.vstack( + (arr, gt_labels, pred_labels, distances_to_pos / 10, z_scores_pos) + ) + wavfile.write(output, target_fs, result_wav.T) \ No newline at end of file diff --git a/evaluate/evaluateDCASE.py b/evaluate/evaluateDCASE.py index f2c4ab6..c0c492f 100644 --- a/evaluate/evaluateDCASE.py +++ b/evaluate/evaluateDCASE.py @@ -488,6 +488,7 @@ def write_wav( target_fs=16000, target_path=None, frame_shift=1, + support_spectrograms=None ): from scipy.io import wavfile @@ -731,6 +732,7 @@ def main(cfg: DictConfig): target_fs=cfg["data"]["target_fs"], target_path=target_path, frame_shift=meta_df.loc[filename, "frame_shift"], + support_spectrograms=support_spectrograms ) # Return the final product @@ -753,54 +755,7 @@ def main(cfg: DictConfig): if __name__ == "__main__": - #parser = argparse.ArgumentParser() - - #parser.add_argument( - # "--config", - # help="Path to the config file", - # required=False, - # default="./CONFIG_PREDICT.yaml", - # type=str, - #) - -# parser.add_argument( -# "--wav_save", -# help="Should the results be also saved as a .wav file?", -# default=False, -# required=False, -# action="store_true", -# ) - -# parser.add_argument( -# "--overwrite", -# help="Remove earlier obtained results at start", -# default=False, -# required=False, -# action="store_true", -# ) -# -# parser.add_argument( -# "--n_self_detected_supports", -# help="Remove earlier obtained results at start", -# default=0, -# required=False, -# type=int, -# ) - -# parser.add_argument( -# "--tolerance", -# help="How many non detection in detection still counts for a detection", -# default=0, -# required=False, -# type=int, -# ) - - #cli_args = parser.parse_args() - main(), - #cli_args.overwrite, - #cli_args.tolerance, - #cli_args.n_self_detected_supports, - #cli_args.wav_save) +