Skip to content

Commit

Permalink
[ADD] n_subsample to the support for distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Mar 7, 2024
1 parent 9aa6475 commit 668a2da
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 32 deletions.
16 changes: 8 additions & 8 deletions evaluate/_utils_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ def merge_preds(df, tolerence, tensor_length,frame_shift):
result = df.groupby("group").agg({"Starttime": "min", "Endtime": "max"})
return result

def reshape_support(support_samples, tensor_length=128):
def reshape_support(support_samples, tensor_length=128, n_subsample=1):
new_input = []
for x in support_samples:
#for _ in range(n_subsample):
if x.shape[1] > tensor_length:
rand_start = torch.randint(0, x.shape[1] - tensor_length, (1,))
new_x = torch.tensor(x[:, rand_start : rand_start + tensor_length])
new_input.append(new_x.unsqueeze(0))
else:
new_input.append(torch.tensor(x))
for _ in range(n_subsample):
if x.shape[1] > tensor_length:
rand_start = torch.randint(0, x.shape[1] - tensor_length, (1,))
new_x = torch.tensor(x[:, rand_start : rand_start + tensor_length])
new_input.append(new_x.unsqueeze(0))
else:
new_input.append(torch.tensor(x))
all_supports = torch.cat([x for x in new_input])
return(all_supports)

Expand Down
53 changes: 52 additions & 1 deletion evaluate/_utils_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,55 @@ def write_wav(
result_wav = np.vstack(
(arr, gt_labels, merged_pred, pred_labels , distances_to_pos / 10, z_scores_pos)
)
wavfile.write(output, target_fs, result_wav.T)
wavfile.write(output, target_fs, result_wav.T)

def plot_2_d_representation():
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# Assuming `prototypes`, `z_pos_supports`, `z_neg_supports`, `q_embeddings`, and `labels` are already defined
# Convert tensors to numpy arrays if they are in tensor format
# e.g., z_pos_supports = z_pos_supports.detach().numpy()

# Create a labels array for all points
# Label for prototypes, positive supports, negative supports, and query embeddings respectively
prototypes_labels = np.array([2] * prototypes.shape[0]) # Assuming 2 is not used in `gt_labels`
pos_supports_labels = np.array([3] * z_pos_supports.shape[0]) # Assuming 3 is not used in `gt_labels`
neg_supports_labels = np.array([4] * z_neg_supports.shape[0]) # Assuming 4 is not used in `gt_labels`
q_embeddings = q_embeddings.detach().numpy()
gt_labels = labels.detach().numpy()

# Concatenate everything into one dataset
feat = np.concatenate([prototypes, z_pos_supports, z_neg_supports, q_embeddings])
all_labels = np.concatenate([prototypes_labels, pos_supports_labels, neg_supports_labels, gt_labels])

# Run t-SNE
tsne = TSNE(n_components=2, perplexity=30)
features_2d = tsne.fit_transform(feat)

# Plot
plt.figure(figsize=(10, 8))
# Define marker for each type of point
markers = {2: "P", 3: "o", 4: "X"} # P for prototypes, o for supports, X for negative supports

for label in np.unique(all_labels):
# Plot each class with its own color and marker
idx = np.where(all_labels == label)
if label in markers: # Prototypes or supports
plt.scatter(features_2d[idx, 0], features_2d[idx, 1], label=label, alpha=1.0, marker=markers[label], s=100) # Larger size
else: # Query embeddings
plt.scatter(features_2d[idx, 0], features_2d[idx, 1], label=label, alpha=0.5, s=50) # Smaller size, more transparent

plt.legend()
plt.title('t-SNE visualization of embeddings, prototypes, and supports')
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')
plt.grid(True)

fig_name = os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] + ".png"
output = os.path.join(target_path, fig_name)

# Save the figure
plt.savefig(output, bbox_inches="tight")
plt.show()
31 changes: 8 additions & 23 deletions evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def compute(
# GET EMBEDDINGS FOR THE NEG SAMPLES #
######################################
support_samples_neg = df_support[df_support["category"] == "NEG"]["feature"].to_numpy()
support_samples_neg = reshape_support(support_samples_neg, tensor_length=cfg["data"]["tensor_length"])
support_samples_neg = reshape_support(support_samples_neg,
tensor_length=cfg["data"]["tensor_length"],
n_subsample=cfg["predict"]["n_subsample"])
z_neg_supports, _ = model.get_embeddings(support_samples_neg, padding_mask=None)

### Get the query dataset ###
Expand Down Expand Up @@ -200,28 +202,6 @@ def compute(
################################################
if cfg["plot"]["tsne"]:

from sklearn.manifold import TSNE
import seaborn as sns

prototypes=prototypes.detach().numpy()
z_pos_supports = z_pos_supports.detach().numpy()
z_neg_supports = z_neg_supports.detach().numpy()
q_embeddings = q_embeddings.detach().numpy()
gt_labels = labels
other_labels = np.concatenate(([0,1], np.repeat(1, z_pos_supports.shape(0)), np.repeat(0, z_neg_supports.shape(0))), axis=None)

feat = np.concatenate([q_embeddings, prototypes, z_pos_supports, z_neg_supports])
tsne = TSNE(n_components=2, perplexity=5)
features_2d = tsne.fit_transform(feat)

# Do the figure!
fig = sns.scatterplot(x=features_2d[:, 0], y=features_2d[:, 1], hue=labels)
sns.move_legend(fig, "upper left", bbox_to_anchor=(1, 1))

fig_name = os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] + ".png"
output = os.path.join(target_path, fig_name)
fig.get_figure().savefig(output, bbox_inches="tight")

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
Expand All @@ -235,6 +215,8 @@ def compute(
prototypes_labels = np.array([2] * prototypes.shape[0]) # Assuming 2 is not used in `gt_labels`
pos_supports_labels = np.array([3] * z_pos_supports.shape[0]) # Assuming 3 is not used in `gt_labels`
neg_supports_labels = np.array([4] * z_neg_supports.shape[0]) # Assuming 4 is not used in `gt_labels`
q_embeddings = q_embeddings.detach().numpy()
gt_labels = labels.detach().numpy()

# Concatenate everything into one dataset
feat = np.concatenate([prototypes, z_pos_supports, z_neg_supports, q_embeddings])
Expand Down Expand Up @@ -263,6 +245,9 @@ def compute(
plt.ylabel('Dimension 2')
plt.grid(True)

fig_name = os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] + ".png"
output = os.path.join(target_path, fig_name)

# Save the figure
plt.savefig(output, bbox_inches="tight")
plt.show()
Expand Down

0 comments on commit 668a2da

Please sign in to comment.