Skip to content

Commit

Permalink
fix(plots): remove additional unused comments from posterior predicti…
Browse files Browse the repository at this point in the history
…ve extrapolation

Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Sep 16, 2024
1 parent 44cd705 commit 8174ff9
Showing 1 changed file with 0 additions and 36 deletions.
36 changes: 0 additions & 36 deletions src/pyrovelocity/plots/_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def posterior_curve(
output_fig_objects = []
for figi, gene in enumerate(gene_set):
(index,) = np.where(adata.var_names == gene)
# print(adata.shape, index, posterior_samples["st_mean"].shape)

fig, ax = plt.subplots(3, 4)
fig.set_size_inches(15, 10)
Expand Down Expand Up @@ -90,7 +89,6 @@ def posterior_curve(
grid_cell_time.mean(0).flatten() < t0_sample
] = 0
grid_cell_colors = colors[grid_mask_t0_sample]
# print(grid_time_samples_st.shape)

im = ax[sample].scatter(
posterior_samples["st_mean"][:, index[0]],
Expand Down Expand Up @@ -163,13 +161,6 @@ def posterior_curve(
:, index[0]
].flatten()

##u0 = posterior_samples['u_offset'][sample][:, index[0]].flatten()
##s0 = posterior_samples['s_offset'][sample][:, index[0]].flatten()
##u_inf = posterior_samples['u_inf'][sample][:, index[0]].flatten()
##s_inf = posterior_samples['s_inf'][sample][:, index[0]].flatten()
##switching = posterior_samples['switching'][sample][:, index[0]].flatten()
##dt_switching = posterior_samples['dt_switching'][sample][:, index[0]].flatten()

ax[sample + 4].scatter(
t0_sample,
u0 * uscale,
Expand Down Expand Up @@ -219,20 +210,6 @@ def posterior_curve(
linewidth=0.5,
c="black",
)
# ax[sample].plot(grid_time_samples_st[sample][:, index[0]],
# grid_time_samples_ut[sample][:, index[0]],
# linestyle="--", linewidth=3, color='g')
# if sample == 0:
# print(gene, u0 * uscale, s0)
# print(gene, u_inf * uscale, s_inf)
# print(
# t0_sample,
# dt_switching_sample,
# cell_time_sample_min,
# cell_time_sample_max,
# (cell_time_sample <= t0_sample).sum(),
# )
# print(cell_time_sample.shape)

switching = t0_sample + dt_switching_sample
state0 = (cell_gene_state_grid == 0) & (
Expand Down Expand Up @@ -311,7 +288,6 @@ def extrapolate_prediction_sample_predictive(

posterior_samples_list = []
for tensor_dict in scdl:
# print("--------------------")
u_obs = tensor_dict["U"]
s_obs = tensor_dict["X"]
u_log_library = tensor_dict["u_lib_size"]
Expand Down Expand Up @@ -396,7 +372,6 @@ def extrapolate_prediction_sample_predictive(
axis=-3,
)
)
# ).to("cuda:0")

posterior_samples_new_tmp = Predictive(
pyro.poutine.uncondition(
Expand All @@ -408,7 +383,6 @@ def extrapolate_prediction_sample_predictive(
posterior_samples_new_tmp[key] = posterior_samples[key]
posterior_samples_list.append(posterior_samples_new_tmp)

# print(len(posterior_samples_list))
posterior_samples_new = {}
for key in posterior_samples_list[0].keys():
if posterior_samples_list[0][key].shape[-2] == 1:
Expand All @@ -417,12 +391,6 @@ def extrapolate_prediction_sample_predictive(
posterior_samples_new[key] = torch.concat(
[element[key] for element in posterior_samples_list], axis=-2
)
# posterior_samples_new = model.generate_posterior_samples(
# adata=adata, batch_size=512, num_samples=8
# )

# for key in posterior_samples_new.keys():
# print(posterior_samples_new[key].shape)

grid_time_samples_ut = posterior_samples_new["ut"]
grid_time_samples_st = posterior_samples_new["st"]
Expand All @@ -443,10 +411,6 @@ def extrapolate_prediction_sample_predictive(
grid_time_samples_uscale = np.ones(grid_time_samples_uinf.shape)

grid_time_samples_state = posterior_samples_new["cell_gene_state"]
# print(grid_time_samples_state.shape)
# print(grid_time_samples_uscale.shape)
# print(grid_time_samples_ut.shape)
# print(grid_time_samples_st.shape)
if isinstance(grid_time_samples_state, np.ndarray):
return (
grid_time_samples_ut,
Expand Down

0 comments on commit 8174ff9

Please sign in to comment.