-
Notifications
You must be signed in to change notification settings - Fork 3
/
predictions.py
181 lines (132 loc) · 5 KB
/
predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/env python
# coding: utf-8
import numpy as np
import pandas as pd
import cv2
import pytesseract
from glob import glob
import spacy
import re
import string
import warnings
warnings.filterwarnings('ignore')
### Load NER model
model_ner = spacy.load('./output/model-best/')
def cleanText(txt):
whitespace = string.whitespace
punctuation = "!#$%&\'()*+:;<=>?[\\]^`{|}~"
tableWhitespace = str.maketrans('','',whitespace)
tablePunctuation = str.maketrans('','',punctuation)
text = str(txt)
text = text.lower()
removewhitespace = text.translate(tableWhitespace)
removepunctuation = removewhitespace.translate(tablePunctuation)
return str(removepunctuation)
# group the label
class groupgen():
def __init__(self):
self.id = 0
self.text = ''
def getgroup(self,text):
if self.text == text:
return self.id
else:
self.id +=1
self.text = text
return self.id
def parser(text,label):
if label == 'PHONE':
text = text.lower()
text = re.sub(r'\D','',text)
elif label == 'EMAIL':
text = text.lower()
allow_special_char = '@_.\-'
text = re.sub(r'[^A-Za-z0-9{} ]'.format(allow_special_char),'',text)
elif label == 'WEB':
text = text.lower()
allow_special_char = ':/.%#\-'
text = re.sub(r'[^A-Za-z0-9{} ]'.format(allow_special_char),'',text)
elif label in ('NAME', 'DES'):
text = text.lower()
text = re.sub(r'[^a-z ]','',text)
text = text.title()
elif label == 'ORG':
text = text.lower()
text = re.sub(r'[^a-z0-9 ]','',text)
text = text.title()
return text
grp_gen = groupgen()
def getPredictions(image):
# extract data using Pytesseract
tessData = pytesseract.image_to_data(image)
# convert into dataframe
tessList = list(map(lambda x:x.split('\t'), tessData.split('\n')))
df = pd.DataFrame(tessList[1:],columns=tessList[0])
df.dropna(inplace=True) # drop missing values
df['text'] = df['text'].apply(cleanText)
# convet data into content
df_clean = df.query('text != "" ')
content = " ".join([w for w in df_clean['text']])
print(content)
# get prediction from NER model
doc = model_ner(content)
# convert doc in json
docjson = doc.to_json()
doc_text = docjson['text']
# creating tokens
datafram_tokens = pd.DataFrame(docjson['tokens'])
datafram_tokens['token'] = datafram_tokens[['start','end']].apply(
lambda x:doc_text[x[0]:x[1]] , axis = 1)
right_table = pd.DataFrame(docjson['ents'])[['start','label']]
datafram_tokens = pd.merge(datafram_tokens,right_table,how='left',on='start')
datafram_tokens.fillna('O',inplace=True)
# join lable to df_clean dataframe
df_clean['end'] = df_clean['text'].apply(lambda x: len(x)+1).cumsum() - 1
df_clean['start'] = df_clean[['text','end']].apply(lambda x: x[1] - len(x[0]),axis=1)
# inner join with start
dataframe_info = pd.merge(df_clean,datafram_tokens[['start','token','label']],how='inner',on='start')
# Bounding Box
bb_df = dataframe_info.query("label != 'O' ")
bb_df['label'] = bb_df['label'].apply(lambda x: x[2:])
bb_df['group'] = bb_df['label'].apply(grp_gen.getgroup)
# right and bottom of bounding box
bb_df[['left','top','width','height']] = bb_df[['left','top','width','height']].astype(int)
bb_df['right'] = bb_df['left'] + bb_df['width']
bb_df['bottom'] = bb_df['top'] + bb_df['height']
# tagging: groupby group
col_group = ['left','top','right','bottom','label','token','group']
group_tag_img = bb_df[col_group].groupby(by='group')
img_tagging = group_tag_img.agg({
'left':min,
'right':max,
'top':min,
'bottom':max,
'label':np.unique,
'token':lambda x: " ".join(x)
})
img_bb = image.copy()
for l,r,t,b,label,token in img_tagging.values:
cv2.rectangle(img_bb,(l,t),(r,b),(0,255,0),2)
label = str(label)
cv2.putText(img_bb,label,(l,t),cv2.FONT_HERSHEY_PLAIN,1,(255,0,255),2)
# Entities
info_array = dataframe_info[['token','label']].values
entities = dict(NAME=[],ORG=[],DES=[],PHONE=[],EMAIL=[],WEB=[])
previous = 'O'
for token, label in info_array:
bio_tag = label[0]
label_tag = label[2:]
text = parser(token,label_tag)
if bio_tag in ('B','I'):
if previous != label_tag:
entities[label_tag].append(text)
else:
if bio_tag == "B":
entities[label_tag].append(text)
else:
if label_tag in ("NAME",'ORG','DES'):
entities[label_tag][-1] = entities[label_tag][-1] + " " + text
else:
entities[label_tag][-1] = entities[label_tag][-1] + text
previous = label_tag
return img_bb, entities