Skip to content

Commit

Permalink
fixed final flake8 errors in esm2 utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jessicaw9910 committed Nov 12, 2024
1 parent a0d0eb4 commit bd7df08
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion missense_kinase_toolkit/ml/src/esm2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset
Expand Down Expand Up @@ -310,6 +311,7 @@ def parse_stats_dataframes(
def plot_label_histogram(
val_df: pd.DataFrame,
bool_orig: bool = True,
labels: list[float] | None = None,
path: str = "/data1/tanseyw/projects/whitej/esm_km_atp/",
):
"""Plot histograms of labels for validation set.
Expand All @@ -320,6 +322,8 @@ def plot_label_histogram(
Validation dataframe from trainer state log.
bool_orig : bool
If True, plot labels in original scale.
labels : list[float] | None
List of labels for original scale.
Returns
-------
Expand All @@ -329,7 +333,7 @@ def plot_label_histogram(
list_replace = [f"Fold: {i}\n(n = {sum(val_df["fold"] == i)})" for i in list_fold]
val_df["fold_label"] = val_df["fold"].map(dict(zip(list_fold, list_replace)))

if bool_orig:
if bool_orig and labels is not None:
val_df["orig_label"] = invert_zscore(val_df["label"], labels)
val_df["orig_label"] = val_df["orig_label"].apply(lambda x: 10**x)

Expand Down

0 comments on commit bd7df08

Please sign in to comment.