Skip to content

Commit

Permalink
[ADD] filter by p-values after self learning
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Mar 25, 2024
1 parent 0354add commit 58e5809
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 16 deletions.
4 changes: 2 additions & 2 deletions CONFIG_PREDICT.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ predict:
overwrite: True
n_self_detected_supports: 0
tolerance: 0
filter_by_p_values: False # Whether we filter outliers by their pvalues
n_subsample: 1 # Whether each segment should be subsampled
self_detect_support: True # Whether to use the self-training loop
self_detect_support: False # Whether to use the self-training loop
filter_by_p_value: False # Whether we filter outliers by their pvalues
threshold_p_value: 0.1

plot:
Expand Down
19 changes: 7 additions & 12 deletions evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,15 @@ def compute(
# GET THE PVALUES
p_values_pos = 1 - ecdf(distances_to_pos)

if cfg["predict"]["filter_by_p_values"]:
predicted_labels = filter_outliers_by_p_values(predicted_labels,
p_values_pos,
target_class=1,
upper_threshold=0.05)

if cfg["predict"]["self_detect_support"]:
print("[INFO] SELF DETECTING SUPPORT SAMPLES")
#########################################################
# SELF DETECT POS AND NEG SAMPLES AND APPEND TO DATASET #
#########################################################

# Detect POS samples
detected_pos_indices = np.where(p_values_pos == cfg["predict"]["threshold_p_value"])[0]
print(f"[INFO] SELF DETECTED {detected_pos_indices} POS SAMPLES")
detected_pos_indices = np.where(p_values_pos == 1)[0] # We need to be sure that it is POS samples
print(f"[INFO] SELF DETECTED {len(detected_pos_indices)} POS SAMPLES")

# BECAUSE CUDA ERROR WHEN RESAMPLING TOO MANY SAMPLES
if len(detected_pos_indices) > 40:
Expand Down Expand Up @@ -286,10 +280,11 @@ def compute(
p_values_pos = 1 - ecdf(distances_to_pos)

# Filter by pvalues
predicted_labels = filter_outliers_by_p_values(predicted_labels,
p_values_pos,
target_class=1,
upper_threshold=cfg["predict"]["threshold_p_value"])
if cfg["predict"]["filter_by_p_value"]:
predicted_labels = filter_outliers_by_p_values(predicted_labels,
p_values_pos,
target_class=1,
upper_threshold=cfg["predict"]["threshold_p_value"])

################################################
# PLOT PROTOTYPES AND EMBEDDINGS IN A 2D SPACE #
Expand Down
2 changes: 1 addition & 1 deletion evaluate/evaluation_metrics/evaluation_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada

if __name__ == "__main__":

all_files = glob.glob("/data/DCASEfewshot/validate/d8f698b184e75c3ef4e830f9da4f148071fb4c56/results/beats/models/BEATS_SELF_LEARNING_PTHR=02/**/eval_out.csv",
all_files = glob.glob("/data/DCASEfewshot/validate/d8f698b184e75c3ef4e830f9da4f148071fb4c56/results/baseline/version_0/**/eval_out.csv",
recursive=True)

l_fscores = []
Expand Down
1 change: 0 additions & 1 deletion shell_scripts/log.txt

This file was deleted.

0 comments on commit 58e5809

Please sign in to comment.