Skip to content

Commit

Permalink
[ADD] Creation of a figure of the embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Mar 8, 2024
1 parent 2a13350 commit e82a107
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
10 changes: 5 additions & 5 deletions CONFIG_PREDICT.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ data:
# Otherwise the hash of the folders will be different!!

trainer:
max_epochs: 1
max_epochs: 5
default_root_dir: /data
accelerator: gpu
gpus: 1
Expand All @@ -58,10 +58,10 @@ predict:
overwrite: True
n_self_detected_supports: 0
tolerance: 0
filter_by_p_values: True
n_subsample: 1
self_detect_support: True
filter_by_p_values: True # Whether we filter outliers by their pvalues
n_subsample: 1 # Whether each segment should be subsampled
self_detect_support: False # Whether to use the self-training loop

plot:
tsne: False
tsne: True
perplexity: 5
55 changes: 37 additions & 18 deletions evaluate/_utils_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ def write_wav(
for ind, row in result_merged.iterrows():
merged_pred[int(row["Starttime"]*target_fs):int(row["Endtime"]*target_fs)] = 1




# pad with zeros
if len(arr) > len(gt_labels):
gt_labels = np.pad(
Expand Down Expand Up @@ -141,41 +138,63 @@ def plot_2_d_representation(prototypes,
z_neg_supports,
q_embeddings,
labels,
output):
output,
perplexity=5):

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# Create a labels array for all points
# Label for prototypes, positive supports, negative supports, and query embeddings respectively
prototypes_labels = np.array([2] * prototypes.shape[0]) # Assuming 2 is not used in `gt_labels`
pos_supports_labels = np.array([3] * z_pos_supports.shape[0]) # Assuming 3 is not used in `gt_labels`
neg_supports_labels = np.array([4] * z_neg_supports.shape[0]) # Assuming 4 is not used in `gt_labels`
q_embeddings = q_embeddings.detach().numpy()
gt_labels = labels.detach().numpy()
prototypes_labels = np.array([2] * prototypes.shape[0])
pos_supports_labels = np.array([3] * z_pos_supports.shape[0])
neg_supports_labels = np.array([4] * z_neg_supports.shape[0])
q_embeddings = q_embeddings.to("cpu").detach().numpy()
gt_labels = np.squeeze(labels) # already a numpy object

# Concatenate everything into one dataset
feat = np.concatenate([prototypes.to("cpu").detach().numpy(),
z_pos_supports.to("cpu").detach().numpy(),
z_neg_supports.to("cpu").detach().numpy(),
q_embeddings.to("cpu").detach().numpy()])
all_labels = np.concatenate([prototypes_labels, pos_supports_labels, neg_supports_labels, gt_labels])
q_embeddings])
feat = feat[:, -1, :]

all_labels = np.concatenate([prototypes_labels,
pos_supports_labels,
neg_supports_labels,
gt_labels])

# Run t-SNE
tsne = TSNE(n_components=2, perplexity=30)
tsne = TSNE(n_components=2, perplexity=perplexity)
features_2d = tsne.fit_transform(feat)

# Plot
# Define the mapping from numerical labels to descriptive labels
label_descriptions = {
0: "NEG queries",
1: "POS queries",
2: "Prototypes",
3: "POS supports",
4: "NEG supports"
}

# Figure
plt.figure(figsize=(10, 8))

# Define marker for each type of point
markers = {2: "P", 3: "o", 4: "X"} # P for prototypes, o for supports, X for negative supports

for label in np.unique(all_labels):
# Plot each class with its own color and marker
idx = np.where(all_labels == label)
if label in markers: # Prototypes or supports
plt.scatter(features_2d[idx, 0], features_2d[idx, 1], label=label, alpha=1.0, marker=markers[label], s=100) # Larger size
else: # Query embeddings
plt.scatter(features_2d[idx, 0], features_2d[idx, 1], label=label, alpha=0.5, s=50) # Smaller size, more transparent
# Set a larger size for prototypes
size = 150 if label == 2 else 100 if label in markers else 50
alpha = 1.0 if label == 2 else 0.8 if label in markers else 0.25

plt.scatter(features_2d[idx, 0],
features_2d[idx, 1],
label=label_descriptions[label],
alpha=alpha,
marker=markers.get(label, 'o'),
s=size)

plt.legend()
plt.title('t-SNE visualization of embeddings, prototypes, and supports')
Expand Down
4 changes: 3 additions & 1 deletion evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,16 @@ def compute(
# PLOT PROTOTYPES AND EMBEDDINGS IN A 2D SPACE #
################################################
if cfg["plot"]["tsne"]:
print("[INFO] CREATING A FIGURE")
fig_name = os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] + ".png"
output = os.path.join(target_path, fig_name)
plot_2_d_representation(prototypes,
z_pos_supports,
z_neg_supports,
q_embeddings,
labels,
output)
output,
cfg["plot"]["perplexity"])

# Compute the scores for the analysed file -- just as information
acc, recall, precision, f1score = compute_scores(
Expand Down

0 comments on commit e82a107

Please sign in to comment.