-
Notifications
You must be signed in to change notification settings - Fork 24
/
manual_run.yaml
142 lines (133 loc) · 5.29 KB
/
manual_run.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Config for manual run
seml:
executable: chemCPA/experiments_run.py
name: manual_run_chemCPA
output_dir: project_folder/logs
conda_environment: chemical_CPA
project_root_dir: .
fixed:
profiling.run_profiler: False
profiling.outdir: "./"
training.checkpoint_freq: 25 # checkpoint frequency to run evaluate, and maybe save checkpoint
training.num_epochs: 101 # maximum epochs for training. One epoch updates either autoencoder, or adversary, depending on adversary_steps.
training.max_minutes: 1200 # maximum computation time
training.full_eval_during_train: False
training.run_eval_disentangle: True # whether to calc the disentanglement loss when running the full eval
training.run_eval_r2: True
training.run_eval_r2_sc: False
training.run_eval_logfold: False
training.save_checkpoints: True # checkpoints tend to be ~250MB large for LINCS.
training.save_dir: project_folder/checkpoints
dataset.dataset_type: trapnell
# dataset.data_params.perturbation_key: condition # stores name of the drug
dataset.data_params.perturbation_key: condition # stores name of the drug
dataset.data_params.pert_category: cov_drug_dose_name # stores celltype_drugname_drugdose
dataset.data_params.dose_key: pert_dose # stores drug dose as a float
dataset.data_params.covariate_keys: cell_type # necessary field for cell types. Fill it with a dummy variable if no celltypes present.
# dataset.data_params.smiles_key: SMILES
dataset.data_params.smiles_key: canonical_smiles
dataset.data_params.use_drugs_idx: True # If false, will use One-hot encoding instead
# dataset.data_params.split_key: split_ood_multi_task
dataset.data_params.split_key: random_split
model.pretrained_model_path: project_folder/checkpoints
model.pretrained_model_hashes: null
model.additional_params.patience: 50 # patience for early stopping. Effective epochs: patience * checkpoint_freq.
model.additional_params.decoder_activation: ReLU
model.additional_params.doser_type: amortized
model.additional_params.multi_task: false
model.embedding.directory: embeddings
model.additional_params.seed: 1337
model.hparams.dim: 32
model.hparams.dropout: 0.262378
model.hparams.autoencoder_width: 256
model.hparams.autoencoder_depth: 4
model.hparams.reg_multi_task: 0
random:
samples: 1
seed: 42
model.hparams.batch_size:
type: choice
options:
- 128
model.hparams.autoencoder_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.autoencoder_wd:
type: loguniform
min: 1e-8
max: 1e-5
model.hparams.adversary_width:
type: choice
options:
- 64
- 128
- 256
model.hparams.adversary_depth:
type: choice
options:
- 2
- 3
- 4
model.hparams.adversary_lr:
type: loguniform
min: 5e-5
max: 1e-2
model.hparams.adversary_wd:
type: loguniform
min: 1e-8
max: 1e-3
model.hparams.adversary_steps: # every X steps, update the adversary INSTEAD OF the autoencoder.
type: choice
options:
- 2
- 3
model.hparams.reg_adversary:
type: loguniform
min: 1
max: 40
model.hparams.reg_adversary_cov:
type: loguniform
min: 5
max: 50
model.hparams.penalty_adversary:
type: loguniform
min: 0.05
max: 2
model.hparams.dosers_lr:
type: loguniform
min: 1e-4
max: 1e-2
model.hparams.dosers_wd:
type: loguniform
min: 1e-8
max: 1e-5
grid:
model.load_pretrained:
type: choice
options:
# - True
- False
rdkit_model:
fixed:
model.embedding.model: rdkit # Alternative: grover_base, jtvae
model.hparams.dosers_width: 64 # should be larger for jt-vae and grover
model.hparams.dosers_depth: 3
model.hparams.step_size_lr: 50 # this applies to all optimizers (AE, ADV, DRUG)
model.hparams.embedding_encoder_width: 128 # should be larger for jtvae and grover
model.hparams.embedding_encoder_depth: 4
#____________________________________________________________________________________________________
#_Shared_gene_set_(setting_1)________________________________________________________________________
# model.append_ae_layer: False
# dataset.data_params.dataset_path: project_folder/datasets/sciplex_complete_middle_subset_lincs_genes.h5ad
# dataset.data_params.degs_key: lincs_DEGs # `uns` column name denoting the DEGs for each perturbation
#____________________________________________________________________________________________________
#_Extended_gene_set_(setting_2)______________________________________________________________________
# model.append_ae_layer: True
model.append_ae_layer: False
model.enable_cpa_mode: False
# dataset.data_params.dataset_path: project_folder/datasets/sciplex_complete_middle_subset.h5ad
# dataset.data_params.degs_key: all_DEGs
dataset.data_params.dataset_path: project_folder/datasets/lincs_full_smiles_sciplex_genes.h5ad
dataset.data_params.degs_key: rank_genes_groups_cov
#____________________________________________________________________________________________________