-
Notifications
You must be signed in to change notification settings - Fork 0
/
scTour_infer.py
144 lines (111 loc) · 5.5 KB
/
scTour_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
## Running using scGPT environment
## "/N/slate/xuexiao/scgpt_env"
import sctour as sct
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
n_hvg = 2000
n_bins = 15
## Load data
adata = sc.read('/N/slate/xuexiao/GSE165897/Patient_EOC_with_ptime.h5ad')
#adata.obs['celltype'] = 'cell_subtype'
print(adata.shape)
adata.X = adata.X.astype(np.float32)
# Display the first few rows of the metadata
print(adata.obs.head())
# Display the columns in the metadata
print(adata.obs.columns)
# Summary of the metadata
print(adata.obs.info())
# Load model
tnode = sct.predict.load_model('/N/slate/xuexiao/GSE165897/scTour_model_GSE165897_EOC.pth')
## Infer cellular dynamics
# Pseudotime
adata.obs['ptime'] = tnode.get_time()
# Latent space
#zs represents the latent z from variational inference, and pred_zs represents the latent z from ODE solver
#mix_zs represents the weighted combination of the two, which is used for downstream analysis
mix_zs, zs, pred_zs = tnode.get_latentsp(alpha_z=0.5, alpha_predz=0.5)
adata.obsm['X_TNODE'] = mix_zs
# Vector field
adata.obsm['X_VF'] = tnode.get_vector_field(adata.obs['ptime'].values, adata.obsm['X_TNODE'])
# Save data
#adata.write("/N/slate/xuexiao/GSE165897/Patient_EOC_with_ptime.h5ad")
## Visualization
adata = adata[np.argsort(adata.obs['ptime'].values), :]
sc.pp.neighbors(adata, use_rep='X_TNODE', n_neighbors=15)
sc.tl.umap(adata, min_dist=0.1)
# Create subplots
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(20, 15))
# Plot UMAP colored by cell type
sc.pl.umap(adata, color='cell_subtype', ax=axs[0, 0], legend_loc='on data', show=False, frameon=False)
axs[0, 0].set_title("UMAP by Cell Type")
fig.savefig("scTour_model_GSE165897_EOC_umap_celltype.png")
# Plot UMAP colored by sample batch
sc.pl.umap(adata, color='patient_id', ax=axs[0, 1], show=False, frameon=False)
axs[0, 1].set_title("UMAP by Sample Batch")
fig.savefig("scTour_model_GSE165897_EOC_umap_sample_batch.png")
# Plot UMAP colored by chemo status
sc.pl.umap(adata, color='treatment_phase', ax=axs[0, 2], show=False, frameon=False)
axs[0, 2].set_title("UMAP by Treatment Phase")
fig.savefig("scTour_model_GSE165897_EOC_umap_treatment_phase.png")
# Plot UMAP colored by pseudotime
sc.pl.umap(adata, color='ptime', ax=axs[1, 0], show=False, frameon=False)
axs[1, 0].set_title("UMAP by Pseudotime")
fig.savefig("scTour_model_GSE165897_EOC_umap_pseudotime.png")
# Plot vector field
sct.vf.plot_vector_field(adata, zs_key='X_TNODE', vf_key='X_VF', use_rep_neigh='X_TNODE', color='cell_subtype', show=False, ax=axs[1, 1], legend_loc='none', frameon=False, size=100, alpha=0.2)
axs[1, 1].set_title("Vector Field")
fig.savefig("scTour_model_GSE165897_EOC_vector_field.png")
# Save the entire figure
plt.tight_layout()
plt.savefig("scTour_model_GSE165897_EOC_combined_plots.png")
plt.show()
## Use model to predict other dataset
# Load new dataset
adata_test1 = sc.read('/N/slate/xuexiao/combine_all_3/OC_cell_in_house.h5ad')
print(adata_test1.shape)
adata.X = adata_test1.X.astype(np.float32)
# Display the first few rows of the metadata
print(adata_test1.obs.head())
# Display the columns in the metadata
print(adata_test1.obs.columns)
# Summary of the metadata
print(adata_test1.obs.info())
#test dataset 1
#the first parameter is the trained model
adata_test1.obs['ptime'] = sct.predict.predict_time(tnode, adata_test1)
mix_zs, zs, pred_zs = sct.predict.predict_latentsp(tnode, adata_test1, alpha_z=0.4, alpha_predz=0.6, mode='coarse')
adata_test1.obsm['X_TNODE'] = mix_zs
adata_test1.obsm['X_VF'] = sct.predict.predict_vector_field(tnode, adata_test1.obs['ptime'].values, adata_test1.obsm['X_TNODE'])
# View result
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(15, 10))
# Plot UMAP colored by cell type
sc.pl.umap(adata_test1, color='cell_subtype', legend_loc='on data', show=False, ax=axs[0, 0], frameon=False)
axs[0, 0].set_title("UMAP by Cell Type")
fig.savefig("scTour_model_GSE165897_EOC_umap_celltype_test1.png")
# Plot UMAP colored by sample batch
sc.pl.umap(adata_test1, color='orig.ident', show=False, ax=axs[0, 1], frameon=False)
axs[0, 1].set_title("UMAP by Sample Batch")
fig.savefig("scTour_model_GSE165897_EOC_umap_sample_batch_test1_in_house.png")
# Plot UMAP colored by chemo status
sc.pl.umap(adata_test1, color='CytoTRACE2_Potency', ax=axs[0, 2], show=False, frameon=False)
axs[0, 2].set_title("UMAP by Treatment Phase")
fig.savefig("scTour_model_GSE165897_EOC_umap_CytoTRACE2_Potency_test1_in_house.png")
# Plot UMAP colored by pseudotime
sc.pl.umap(adata_test1, color='ptime', show=False, ax=axs[1, 0], frameon=False)
axs[1, 0].set_title("UMAP by Pseudotime")
fig.savefig("scTour_model_GSE165897_EOC_umap_pseudotime_test1_in_house.png")
# Plot vector field with cell type color
sct.vf.plot_vector_field(adata_test1, zs_key='X_TNODE', vf_key='X_VF', use_rep_neigh='X_TNODE', show=False, ax=axs[1, 1], color='cell_subtype', t_key='ptime', frameon=False, size=80, alpha=0.05)
axs[1, 1].set_title("Vector Field by Cell Type")
fig.savefig("scTour_model_GSE165897_EOC_vector_field_celltype_test1_in_house.png")
# Plot vector field with treatment phase color
sct.vf.plot_vector_field(adata_test1, zs_key='X_TNODE', vf_key='X_VF', use_rep_neigh='X_TNODE', show=False, ax=axs[1, 2], color='treatment_phase', t_key='ptime', frameon=False, size=80, alpha=0.05)
axs[1, 2].set_title("Vector Field by Treatment Phase")
fig.savefig("scTour_model_GSE165897_EOC_vector_field_treatment_phase_test1_in_house.png")
# Save the entire figure
plt.tight_layout()
plt.savefig("scTour_model_GSE165897_EOC_combined_plots_test1_in_house.png")
plt.show()