Skip to content

Commit

Permalink
black format + typo fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Oct 20, 2023
1 parent 864bbba commit 417b90e
Show file tree
Hide file tree
Showing 24 changed files with 39 additions and 98 deletions.
28 changes: 1 addition & 27 deletions examples/borzoi_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

# Helper functions (prediction, attribution, visualization)


# Make one-hot coded sequence
def make_seq_1hot(genome_open, chrm, start, end, seq_len):
if start < 0:
Expand All @@ -50,10 +51,8 @@ def make_seq_1hot(genome_open, chrm, start, end, seq_len):

# Predict tracks
def predict_tracks(models, sequence_one_hot):

predicted_tracks = []
for fold_ix in range(len(models)):

yh = models[fold_ix](sequence_one_hot[None, ...])[:, None, ...].astype(
"float16"
)
Expand All @@ -67,7 +66,6 @@ def predict_tracks(models, sequence_one_hot):

# Helper function to get (padded) one-hot
def process_sequence(fasta_open, chrom, start, end, seq_len=524288):

seq_len_actual = end - start

# Pad sequence to input window size
Expand All @@ -81,7 +79,6 @@ def process_sequence(fasta_open, chrom, start, end, seq_len=524288):


def dna_letter_at(letter, x, y, yscale=1, ax=None, color=None, alpha=1.0):

fp = FontProperties(family="DejaVu Sans", weight="bold")

globscale = 1.35
Expand Down Expand Up @@ -146,7 +143,6 @@ def _prediction_input_grad(
prox_bin_index,
dist_bin_index,
):

mean_dist_prox_ratio = None
with tf.GradientTape() as tape:
tape.watch(input_sequence)
Expand Down Expand Up @@ -241,7 +237,6 @@ def get_prediction_gradient_w_rc(
subtract_avg=False,
fold_index=[0, 1, 2, 3],
):

# Get gradients for fwd
pred_grads = get_prediction_gradient(
models,
Expand Down Expand Up @@ -334,11 +329,9 @@ def get_prediction_gradient(
subtract_avg=False,
fold_index=[0, 1, 2, 3],
):

pred_grads = np.zeros((len(sequence_one_hots), len(fold_index), 524288, 4))

for fold_i, fold_ix in enumerate(fold_index):

prediction_model = models[fold_ix].model.layers[1]

input_sequence = tf.keras.layers.Input(shape=(524288, 4), name="sequence")
Expand Down Expand Up @@ -413,7 +406,6 @@ def get_prediction_gradient_noisy_w_rc(
n_samples=5,
sample_prob=0.75,
):

# Get gradients for fwd
pred_grads = get_prediction_gradient_noisy(
models,
Expand Down Expand Up @@ -512,11 +504,9 @@ def get_prediction_gradient_noisy(
n_samples=5,
sample_prob=0.75,
):

pred_grads = np.zeros((len(sequence_one_hots), len(fold_index), 524288, 4))

for fold_i, fold_ix in enumerate(fold_index):

print("fold_ix = " + str(fold_ix))

prediction_model = models[fold_ix].model.layers[1]
Expand Down Expand Up @@ -549,13 +539,11 @@ def get_prediction_gradient_noisy(

with tf.device("/cpu:0"):
for example_ix in range(len(sequence_one_hots)):

print("example_ix = " + str(example_ix))

inp = sequence_one_hots[example_ix][None, ...]

for sample_ix in range(n_samples):

print("sample_ix = " + str(sample_ix))

inp_corrupted = np.copy(inp)
Expand Down Expand Up @@ -609,7 +597,6 @@ def _prediction_ism_score(
prox_bin_index,
dist_bin_index,
):

if not use_mean:
if dist_bin_index is None:
mean_dist = np.sum(pred[:, dist_bin_start:dist_bin_end], axis=1)
Expand Down Expand Up @@ -661,13 +648,11 @@ def get_ism(
use_ratio=True,
use_logodds=False,
):

pred_ism = np.zeros((len(sequence_one_hots), len(models), 524288, 4))

bases = [0, 1, 2, 3]

for example_ix in range(len(sequence_one_hots)):

print("example_ix = " + str(example_ix))

sequence_one_hot_wt = sequence_one_hots[example_ix]
Expand Down Expand Up @@ -785,14 +770,12 @@ def get_ism_shuffle(
use_ratio=True,
use_logodds=False,
):

pred_shuffle = np.zeros((len(sequence_one_hots), len(models), 524288, n_samples))
pred_ism = np.zeros((len(sequence_one_hots), len(models), 524288, 4))

bases = [0, 1, 2, 3]

for example_ix in range(len(sequence_one_hots)):

print("example_ix = " + str(example_ix))

sequence_one_hot_wt = sequence_one_hots[example_ix]
Expand Down Expand Up @@ -832,7 +815,6 @@ def get_ism_shuffle(
)

for j in range(ism_start, ism_end):

j_start = j - window_size // 2
j_end = j + window_size // 2 + 1

Expand Down Expand Up @@ -945,7 +927,6 @@ def plot_seq_scores(
save_figs=False,
fig_name="default",
):

importance_scores = importance_scores.T

fig = plt.figure(figsize=figsize)
Expand Down Expand Up @@ -1011,7 +992,6 @@ def plot_seq_scores(
def visualize_input_gradient_pair(
att_grad_wt, att_grad_mut, plot_start=0, plot_end=100, save_figs=False, fig_name=""
):

scores_wt = att_grad_wt[plot_start:plot_end, :]
scores_mut = att_grad_mut[plot_start:plot_end, :]

Expand Down Expand Up @@ -1072,7 +1052,6 @@ def plot_coverage_track_pair_bins(
gene_slice=None,
anno_df=None,
):

plot_start = center_pos - plot_window // 2
plot_end = center_pos + plot_window // 2

Expand Down Expand Up @@ -1104,7 +1083,6 @@ def plot_coverage_track_pair_bins(
for track_name, track_index, track_scale, track_transform, clip_soft in zip(
track_names, track_indices, track_scales, track_transforms, clip_softs
):

# Plot track densities (bins)
y_wt_curr = np.array(np.copy(y_wt), dtype=np.float32)
y_mut_curr = np.array(np.copy(y_mut), dtype=np.float32)
Expand Down Expand Up @@ -1197,7 +1175,6 @@ def plot_coverage_track_pair_bins(
xtick_vals = []

for pas_ix, anno_pos in enumerate(anno_poses):

pas_bin = int((anno_pos - start) // 32) - 16

xtick_vals.append(pas_bin)
Expand Down Expand Up @@ -1279,7 +1256,6 @@ def plot_coverage_track_pair_bins(
def get_coverage_reader(
cov_files, target_length, crop_length, blacklist_bed, blacklist_pct=0.5
):

# open genome coverage files
cov_opens = [CovFace(cov_file) for cov_file in cov_files]

Expand All @@ -1299,14 +1275,12 @@ def _read_coverage(
crop_length=crop_length,
black_chr_trees=black_chr_trees,
):

n_targets = len(cov_opens)

targets = []

# for each targets
for target_i in range(n_targets):

# extract sequence as BED style
if start < 0:
seq_cov_nt = np.concatenate(
Expand Down
1 change: 1 addition & 0 deletions src/scripts/borzoi_bench_crispr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
borzoi_bench_crispr.py
"""


################################################################################
# main
################################################################################
Expand Down
1 change: 1 addition & 0 deletions src/scripts/borzoi_bench_crispr_folds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Benchmark Borzoi model replicates on CRISPR enhancer scoring task.
"""


################################################################################
# main
################################################################################
Expand Down
1 change: 1 addition & 0 deletions src/scripts/borzoi_bench_flowfish_folds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Benchmark Borzoi model replicates on CRISPR enhancer scoring task.
"""


################################################################################
# main
################################################################################
Expand Down
1 change: 1 addition & 0 deletions src/scripts/borzoi_bench_gasperini_folds.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Benchmark Borzoi model replicates on CRISPR enhancer scoring task.
"""


################################################################################
# main
################################################################################
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/borzoi_bench_ipaqtl_folds.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Benchmark Borzoi model replicates on GTEx ipaQTL classification task.
"""


################################################################################
# main
################################################################################
Expand Down Expand Up @@ -529,7 +530,6 @@ def split_sed(it_out_dir, posneg, vcf_dir, sed_stats):
tissue_dir = "%s/%s_%s" % (it_out_dir, tissue_label, posneg)
os.makedirs(tissue_dir, exist_ok=True)
with h5py.File("%s/sed.h5" % tissue_dir, "w") as tissue_h5:

# write SNP indexes
tissue_h5.create_dataset("si", data=np.array(sed_si, dtype="uint32"))

Expand Down
2 changes: 1 addition & 1 deletion src/scripts/borzoi_bench_paqtl_folds.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Benchmark Borzoi model replicates on GTEx paQTL classification task.
"""


################################################################################
# main
################################################################################
Expand Down Expand Up @@ -529,7 +530,6 @@ def split_sed(it_out_dir, posneg, vcf_dir, sed_stats):
tissue_dir = "%s/%s_%s" % (it_out_dir, tissue_label, posneg)
os.makedirs(tissue_dir, exist_ok=True)
with h5py.File("%s/sed.h5" % tissue_dir, "w") as tissue_h5:

# write SNP indexes
tissue_h5.create_dataset("si", data=np.array(sed_si, dtype="uint32"))

Expand Down
2 changes: 1 addition & 1 deletion src/scripts/borzoi_bench_sqtl_folds.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Benchmark Borzoi model replicates on GTEx sQTL classification task.
"""


################################################################################
# main
################################################################################
Expand Down Expand Up @@ -557,7 +558,6 @@ def split_sed(it_out_dir, posneg, vcf_dir, sed_stats):
tissue_dir = "%s/%s_%s" % (it_out_dir, tissue_label, posneg)
os.makedirs(tissue_dir, exist_ok=True)
with h5py.File("%s/sed.h5" % tissue_dir, "w") as tissue_h5:

# write SNP indexes
tissue_h5.create_dataset("si", data=np.array(sed_si, dtype="int32"))

Expand Down
1 change: 1 addition & 0 deletions src/scripts/borzoi_bench_trip_folds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Benchmark Borzoi model replicates on TRIP prediction task.
"""


################################################################################
# main
################################################################################
Expand Down
15 changes: 8 additions & 7 deletions src/scripts/borzoi_satg_gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pysam

from baskerville.dataset import targets_prep_strand
from baskerville import dna_io
from baskerville import dna
from baskerville import gene as bgene
from baskerville import seqnn

Expand All @@ -36,6 +36,7 @@
Perform a gradient saliency analysis for genes specified in a GTF file.
"""


################################################################################
# main
################################################################################
Expand All @@ -44,8 +45,8 @@ def main():
parser = OptionParser(usage)
parser.add_option(
"-f",
dest="genome_fasta",
default="%s/assembly/ucsc/hg38.fa" % os.environ["HG38"],
dest="genome_fasta", ##default="%s/assembly/ucsc/hg38.fa" % os.environ["HG38"],
default=None,
help="Genome FASTA for sequences [Default: %default]",
)
parser.add_option(
Expand Down Expand Up @@ -97,7 +98,7 @@ def main():
help="File specifying target indexes and labels in table format",
)
(options, args) = parser.parse_args()

print(options, args)
if len(args) == 3:
# single worker
params_file = args[0]
Expand Down Expand Up @@ -264,13 +265,13 @@ def main():
else:
grads_ens = []
for shift in options.shifts:
seq_1hot_aug = dna_io.hot1_augment(seq_1hot, shift=shift)
seq_1hot_aug = dna.hot1_augment(seq_1hot, shift=shift)
grads_aug = seqnn_model.gradients(seq_1hot_aug, pos_slice=gene_slice)
grads_aug = unaugment_grads(grads_aug, fwdrc=True, shift=shift)
grads_ens.append(grads_aug)

if options.rc:
seq_1hot_aug = dna_io.hot1_rc(seq_1hot_aug)
seq_1hot_aug = dna.hot1_rc(seq_1hot_aug)
grads_aug = seqnn_model.gradients(
seq_1hot_aug, pos_slice=gene_slice_rc
)
Expand Down Expand Up @@ -340,7 +341,7 @@ def make_seq_1hot(genome_open, chrm, start, end, seq_len):
if len(seq_dna) < seq_len:
seq_dna += "N" * (seq_len - len(seq_dna))

seq_1hot = dna_io.dna_1hot(seq_dna)
seq_1hot = dna.dna_1hot(seq_dna)
return seq_1hot


Expand Down
Loading

0 comments on commit 417b90e

Please sign in to comment.