-
Notifications
You must be signed in to change notification settings - Fork 16
/
story_visualization.py
86 lines (68 loc) · 3.72 KB
/
story_visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from glob import glob
from PIL import Image
import cv2
import textwrap
import os
class StoryPlot:
def __init__(self, stories_data_set_path='./dataset/vist_dataset/validate_data/val.story-in-sequence.json',
images_root_folder_path='./dataset/vist_dataset/validate_data/images/val'):
self.story_dataset_path = stories_data_set_path
self.images_root_folder_path = images_root_folder_path
self.annotations = json.load(open(stories_data_set_path))['annotations']
def visualize_story(self, story_id, decoded_sentences):
story = []
for annotation_data in self.annotations:
annotation = annotation_data[0]
if annotation['story_id'] == story_id:
story.append(annotation)
print('found the story')
story = sorted(story, key=lambda k: k['worker_arranged_photo_order'])
story_image_filenames = [''] * len(story)
image_paths = [y for x in os.walk(self.images_root_folder_path) for y in glob(os.path.join(x[0], "*.jpg"))]
for filename in image_paths:
for i in range(len(story)):
if story[i]['photo_flickr_id'] in filename:
story_image_filenames[i] = filename
fig = plt.figure()
plt.subplots_adjust(bottom=0.6)
wrapper = textwrap.TextWrapper(width=30)
for i in range(len(story_image_filenames)):
im = cv2.imread(story_image_filenames[i])
im = cv2.resize(im, (500, 500))
original_text = story[i]['text']
decoded_text = decoded_sentences[i]
a = fig.add_subplot(1, len(story_image_filenames), i + 1)
a.axis("off")
a.text(0, 550, "\n".join(wrapper.wrap(original_text)), ha='left', va="top")
a.text(0, 700, "\n".join(wrapper.wrap(decoded_text)), ha='left', va="top")
plt.imshow(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
plt.axis("off")
plt.show()
def get_story_data(self, story_id):
story = []
for annotation_data in self.annotations:
annotation = annotation_data[0]
if annotation['story_id'] == story_id:
story.append(annotation)
story = sorted(story, key=lambda k: k['worker_arranged_photo_order'])
story_image_filenames = [''] * len(story)
image_paths = [y for x in os.walk(self.images_root_folder_path) for y in glob(os.path.join(x[0], "*.jpg"))]
for filename in image_paths:
for i in range(len(story)):
if story[i]['photo_flickr_id'] in filename:
story_image_filenames[i] = filename
original_sentences = map(lambda x: x['text'], story)
return {'image_filenames': story_image_filenames, "original_sentences": original_sentences}
# story_plot = StoryPlot(stories_data_set_path='./dataset/vist_sis/train.story-in-sequence.json',
# images_root_folder_path='./dataset/sample_images')
# story_plot.visualize_story("11053", [
# "Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard",
# "Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard",
# "Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard",
# "Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard",
# "Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard"])