-
Notifications
You must be signed in to change notification settings - Fork 0
/
01-A2-train-preprocessing.py
186 lines (160 loc) · 7.42 KB
/
01-A2-train-preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Script to read the GISAID preprocessed fasta files and generate country-wise CSVs.
Pipeline:
Read each fasta file.
Get metadata & genome sequence for each strain from fasta file.
Find relevant folder (generated by covseq script) that contains the TSV corresponding to the strain in the fasta file.
Get start and end of each peptide from the TSV and append sequence metadata in CSV for fasta file.
Filter fasta file CSV by countries and generate separate CSV for each country. Append if country CSV already exists.
"""
import os
import glob
import logging
import traceback
import pandas as pd
from Bio import SeqIO
PATH_FASTA_DIR = "data/fasta_preprocessing/processed/" # Contains the preprocessed fasta files
PATH_TSV_FOLDERS = "data/fasta_preprocessing/segmented/" # Contains the folders containing the TSVs
PATH_OUTPUT_DIR = "data/fasta_preprocessing/CSVs/countrywise/"
SEQ_METADATA_COLS = ['Peptide_Name', 'Sequence', 'Sequence_Length', 'DividesBy3', 'Triples_Count', 'Full_Sequence_Length'] # 3mers removed
OTHER_METADATA_COLS = ['Accession_ID', 'Virus_Name', 'Country', 'Collection_Date']
OUTPUT_COLS = ['Accession_ID', 'Virus_Name', 'Country', 'Peptide_Name', 'Sequence', 'Sequence_Length',
'DividesBy3', 'Triples_Count', 'Full_Sequence_Length', 'Collection_Date']
logging.basicConfig(filename='app.log', filemode='w', format='%(asctime)s - %(levelname)s - %(message)s')
def validate_sequence(sequence, full_sequence=False):
if full_sequence:
bases_check = (len(set(sequence) - set("ATGC")) == 0)
if bases_check:
return True
else:
start_codon_check = (sequence[:3] == 'ATG')
end_codon_check = (sequence[-3:] in ['TAA', 'TAG', 'TGA'])
if (start_codon_check and end_codon_check):
return True
return False
def get_fasta_strains(filepath):
strains_for_file = []
fasta_sequences = SeqIO.parse(open(filepath),'fasta')
for fasta in fasta_sequences:
description, sequence = fasta.description, str(fasta.seq)
if not validate_sequence(sequence, full_sequence=True):
continue
splitted_descr = str(description).split('_') # hCoV-19/Mayotte/IPP02391/2021_EPI_ISL_1167000_2021-01-21
virus_name = splitted_descr[0]
country = virus_name.split('/')[1]
accession_id = '_'.join(splitted_descr[1:-1])
collection_date = splitted_descr[-1]
strains_for_file.append([accession_id, virus_name, country, sequence, collection_date])
return strains_for_file
def handle_orf1ab(start, end, full_sequence):
# print(val['Start'], val['End'])
if ',' in start:
vals = start.split(',')
start_a = int(vals[0].strip())
start_b = int(vals[1].strip())
else:
start_a = int(start[:3])
start_b = int(start[3:])
if ',' in end:
vals = end.split(',')
end_a = int(vals[0].strip())
end_b = int(vals[1].strip())
else:
end_a = int(end[:5])
end_b = int(end[5:])
if end_a == start_b:
sequence = full_sequence[start_a-1: end_b]
else:
sequence = full_sequence[start_a-1: end_a] + full_sequence[start_b-1: end_b]
return sequence
def get_amino_sequences(tsv, full_sequence, n_for_nmers=3):
seq_df = pd.DataFrame(columns = SEQ_METADATA_COLS)
if len(tsv) != 12:
raise Exception('TSV length not 12 for the Strain')
for _, val in tsv.iterrows():
peptide_name = val['Product']
if peptide_name == 'orf1ab polyprotein': # Special handling reqd.
sequence = handle_orf1ab(val['Start'], val['End'], full_sequence)
else:
sequence = full_sequence[int(val['Start'])-1: int(val['End'])]
if not validate_sequence(sequence):
raise Exception('Problem in start/end codon')
seq_length = len(sequence) # Ambiguity in TSV lengths
dividesbythree = (seq_length/3).is_integer()
triples_count = seq_length//3
row = pd.Series([peptide_name, sequence, seq_length, dividesbythree, triples_count, len(full_sequence)],
index=SEQ_METADATA_COLS)
seq_df = seq_df.append(row, ignore_index=True)
return seq_df
def find_covseq_tsv(country, accession_id):
tsv_dir_path = glob.glob(f'{PATH_TSV_FOLDERS}/*/*{accession_id}*')
if tsv_dir_path:
tsvs = os.listdir(tsv_dir_path[0])
tsv = [f for f in tsvs if 'orf' in f]
if tsv:
tsv_name = tsv[0]
return f'{tsv_dir_path[0]}/{tsv_name}'
def get_df_for_strain(strain_metadata):
accession_id, virus_name, country, collection_date = strain_metadata[0], strain_metadata[1], strain_metadata[2], strain_metadata[4]
sequence = strain_metadata[3]
tsv_path = find_covseq_tsv(country, accession_id)
print(tsv_path)
if tsv_path:
tsv = pd.read_csv(tsv_path, delimiter="\t")
try:
seq_df = get_amino_sequences(tsv, sequence)
except Exception as e:
raise Exception(f"{str(e)}. ID:{accession_id}; Country:{country}")
other_df = pd.DataFrame([[accession_id, virus_name, country, collection_date]], columns=OTHER_METADATA_COLS)
other_df = pd.concat([other_df]*len(seq_df), ignore_index=True) # Replicate rows for each protein
strain_df = pd.concat([other_df, seq_df], axis=1)
return strain_df
else:
raise Exception(f"TSV does not exist for strain. ID:{accession_id}; Country:{country}")
def export_countrywise_csvs(df_for_fasta):
countries = list(set(df_for_fasta['Country'].values))
for country in countries:
country_df = df_for_fasta[df_for_fasta['Country']==country]
output_path = f'{PATH_OUTPUT_DIR}/{country}.csv'
if os.path.exists(output_path):
df = pd.read_csv(output_path)
out_df = df.append(country_df, ignore_index=True)
else:
out_df = country_df
out_df = out_df.drop_duplicates(subset=['Accession_ID', 'Peptide_Name'])
out_df.to_csv(output_path, index=False)
def clean_output_dir():
result_files = glob.glob(f'{PATH_OUTPUT_DIR}/*')
for file in result_files:
try:
os.remove(file)
except:
pass
print("Cleaned output directory")
def main(cleanup=False):
if cleanup:
clean_output_dir()
fasta_files = os.listdir(PATH_FASTA_DIR)
for file in fasta_files:
print(f"PROCESSING: {file}")
filepath = f'{PATH_FASTA_DIR}/{file}'
try:
strains_for_file = get_fasta_strains(filepath) # [[accession_id, virus_name, country, sequence, collection_date]]
except Exception as e:
print(f"ERROR getting strains from fasta file. Reason: {e}")
logging.error(f"ERROR getting strains from fasta file. Reason: {e}")
# traceback.print_exc(file=sys.stdout)
continue
df_for_fasta = pd.DataFrame(columns = OUTPUT_COLS)
for strain_metadata in strains_for_file:
try:
strain_df = get_df_for_strain(strain_metadata)
df_for_fasta = df_for_fasta.append(strain_df, ignore_index=True)
except Exception as e:
print(f"ERROR processing strain. Reason: {e}")
logging.error(f"ERROR processing strain. Reason: {e}")
# traceback.print_exc(file=sys.stdout)
if not df_for_fasta.empty:
export_countrywise_csvs(df_for_fasta)
if __name__ == "__main__":
main(cleanup=False)