forked from prokia/drugVQA
-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
161 lines (153 loc) · 6.42 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import numpy as np
import re
import torch
from torch.autograd import Variable
# from torch.utils.data import Dataset, DataLoader
def create_variable(tensor):
# Do cuda() before wrapping with variable
if torch.cuda.is_available():
return Variable(tensor.cuda())
else:
return Variable(tensor)
def replace_halogen(string):
"""Regex to replace Br and Cl with single letters"""
br = re.compile('Br')
cl = re.compile('Cl')
string = br.sub('R', string)
string = cl.sub('L', string)
return string
# Create necessary variables, lengths, and target
def make_variables(lines, properties,letters):
sequence_and_length = [line2voc_arr(line,letters) for line in lines]
vectorized_seqs = [sl[0] for sl in sequence_and_length]
seq_lengths = torch.LongTensor([sl[1] for sl in sequence_and_length])
return pad_sequences(vectorized_seqs, seq_lengths, properties)
def make_variables_seq(lines,letters):
sequence_and_length = [line2voc_arr(line,letters) for line in lines]
vectorized_seqs = [sl[0] for sl in sequence_and_length]
seq_lengths = torch.LongTensor([sl[1] for sl in sequence_and_length])
return pad_sequences_seq(vectorized_seqs, seq_lengths)
def line2voc_arr(line,letters):
arr = []
regex = '(\[[^\[\]]{1,10}\])'
line = replace_halogen(line)
char_list = re.split(regex, line)
for li, char in enumerate(char_list):
if char.startswith('['):
arr.append(letterToIndex(char,letters))
else:
chars = [unit for unit in char]
for i, unit in enumerate(chars):
arr.append(letterToIndex(unit,letters))
return arr, len(arr)
def letterToIndex(letter,smiles_letters):
return smiles_letters.index(letter)
# pad sequences and sort the tensor
def pad_sequences(vectorized_seqs, seq_lengths, properties):
seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long()
for idx, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)):
seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
# Sort tensors by their length
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# Also sort the target (countries) in the same order
target = properties.double()
if len(properties):
target = target[perm_idx]
# Return variables
# DataParallel requires everything to be a Variable
return create_variable(seq_tensor),create_variable(seq_lengths),create_variable(target)
def pad_sequences_seq(vectorized_seqs, seq_lengths):
seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long()
for idx, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)):
seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
# Sort tensors by their length
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
# print(seq_tensor)
seq_tensor = seq_tensor[perm_idx]
# Return variables
# DataParallel requires everything to be a Variable
return create_variable(seq_tensor), create_variable(seq_lengths)
def construct_vocabulary(smiles_list,fname):
"""Returns all the characters present in a SMILES file.
Uses regex to find characters/tokens of the format '[x]'."""
add_chars = set()
for i, smiles in enumerate(smiles_list):
regex = '(\[[^\[\]]{1,10}\])'
smiles = ds.replace_halogen(smiles)
char_list = re.split(regex, smiles)
for char in char_list:
if char.startswith('['):
add_chars.add(char)
else:
chars = [unit for unit in char]
[add_chars.add(unit) for unit in chars]
print("Number of characters: {}".format(len(add_chars)))
with open(fname, 'w') as f:
f.write('<pad>' + "\n")
for char in add_chars:
f.write(char + "\n")
return add_chars
def readLinesStrip(lines):
for i in range(len(lines)):
lines[i] = lines[i].rstrip('\n')
return lines
def getProteinSeq(path,contactMapName):
proteins = open(path+"/"+contactMapName).readlines()
proteins = readLinesStrip(proteins)
seq = proteins[1]
return seq
def getProtein(path,contactMapName,contactMap = True):
proteins = open(path+"/"+contactMapName).readlines()
proteins = readLinesStrip(proteins)
seq = proteins[1]
if(contactMap):
contactMap = []
for i in range(2,len(proteins)):
contactMap.append(proteins[i])
return seq,contactMap
else:
return seq
def getTrainDataSet(trainFoldPath):
with open(trainFoldPath, 'r') as f:
trainCpi_list = f.read().strip().split('\n')
trainDataSet = [cpi.strip().split() for cpi in trainCpi_list]
return trainDataSet#[[smiles, sequence, interaction],.....]
def getTestProteinList(testFoldPath):
testProteinList = readLinesStrip(open(testFoldPath).readlines())[0].split()
return testProteinList#['kpcb_2i0eA_full','fabp4_2nnqA_full',....]
def getSeqContactDict(contactPath,contactDictPath):# make a seq-contactMap dict
contactDict = open(contactDictPath).readlines()
seqContactDict = {}
for data in contactDict:
_,contactMapName = data.strip().split(':')
seq,contactMap = getProtein(contactPath,contactMapName)
contactmap_np = [list(map(float, x.strip(' ').split(' '))) for x in contactMap]
feature2D = np.expand_dims(contactmap_np, axis=0)
feature2D = torch.FloatTensor(feature2D)
seqContactDict[seq] = feature2D
return seqContactDict
def getLetters(path):
with open(path, 'r') as f:
chars = f.read().split()
return chars
def getDataDict(testProteinList,activePath,decoyPath,contactPath):
dataDict = {}
for x in testProteinList:#'xiap_2jk7A_full'
xData = []
protein = x.split('_')[0]
print(protein)
proteinActPath = activePath+"/"+protein+"_actives_final.ism"
proteinDecPath = decoyPath+"/"+protein+"_decoys_final.ism"
act = open(proteinActPath,'r').readlines()
dec = open(proteinDecPath,'r').readlines()
actives = [[x.split(' ')[0],1] for x in act] ######
decoys = [[x.split(' ')[0],0] for x in dec]# test
seq = getProtein(contactPath,x,contactMap = False)
for i in range(len(actives)):
xData.append([actives[i][0],seq,actives[i][1]])
for i in range(len(decoys)):
xData.append([decoys[i][0],seq,decoys[i][1]])
print(len(xData))
dataDict[x] = xData
return dataDict