-
Notifications
You must be signed in to change notification settings - Fork 3
/
gen_train_captions.py
57 lines (45 loc) · 1.63 KB
/
gen_train_captions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python
# -*- coding=utf-8 -*-
import joblib
SOS_TOKEN = 'zsosz'
EOS_TOKEN = 'zeosz'
TH_WORD_COUNT = 10
def to_vocabulary(words):
vocab = set()
[vocab.add(word) for word in words]
return vocab
def add_sos_eos(captions):
for k in captions.keys():
captions[k] = list(map(lambda x: SOS_TOKEN + ' ' + x + ' ' + EOS_TOKEN, captions[k]))
return captions
def filter_by_count(captions):
word_counts = {}
for k in captions.keys():
for v in captions[k]:
for w in v.split(' '):
word_counts[w] = word_counts.get(w, 0) + 1
return [w for w in word_counts if word_counts[w] >= TH_WORD_COUNT]
if __name__ == "__main__":
with open('flickr8k/cleaned_captions.pkl', 'rb') as f:
cleaned_captions = joblib.load(f)
# Add start and end token
cleaned_captions = add_sos_eos(cleaned_captions)
train_images = []
with open('flickr8k/Flickr_8k.trainImages.txt', 'r') as f:
for line in f:
train_images.append(line.strip())
train_captions = {}
for train_image in train_images:
if train_image in cleaned_captions:
train_captions[train_image] = cleaned_captions[train_image]
# Save train captions
with open('flickr8k/train_captions.pkl', 'wb') as f:
joblib.dump(train_captions, f, compress=3)
# Save vocabulary
train_words = filter_by_count(train_captions)
vocabulary = to_vocabulary(train_words)
with open('flickr8k/vocabulary.txt', 'w') as f:
for vocab in vocabulary:
f.writelines(vocab)
f.writelines('\n')
print('Vocabulary size:', len(vocabulary))