-
Notifications
You must be signed in to change notification settings - Fork 4
/
vqaPathEval.py
57 lines (47 loc) · 2.08 KB
/
vqaPathEval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
import json
import random
import os
import argparse
import yaml
from vqaTools.vqa import *
from vqaTools.vqaEval import *
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--quesFile', default='/mnt/sda/lpf/data/vqa/data_PathVQA/pathvqa_test.json')
parser.add_argument('--resFile', default='./output/pathvqa/result/med_pretrain_29_vqa_result_<epoch>.json')
args = parser.parse_args()
all_result_list = []
quesFile = args.quesFile
vqa = VQA(quesFile, quesFile) # question and answer and imgToQA
for i in range(40):
resFile = args.resFile.replace('<epoch>', str(i))
print(resFile)
# create vqa object and vqaRes object
vqaRes = vqa.loadRes(resFile, quesFile)
# create vqaEval object by taking vqa and vqaRes
vqaEval = VQAEval(vqa, vqaRes, n=2) # n is precision of accuracy (number of places after decimal), default is 2
# evaluate results
vqaEval.evaluate()
# print accuracies
acc_dict = {}
print("\n")
print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall']))
acc_dict['Epoch'] = i + 1
acc_dict['Overall'] = vqaEval.accuracy['overall']
print("Per Answer Type Accuracy is the following:")
for ansType in vqaEval.accuracy['perAnswerType']:
acc_dict[ansType] = vqaEval.accuracy['perAnswerType'][ansType]
print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType]))
print("\n")
# save evaluation results to ./results folder
accuracyFile = resFile.replace('.json', '_acc.json')
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
compareFile = resFile.replace('.json', '_compare.json')
json.dump(vqaEval.ansComp, open(compareFile, 'w'))
all_result_list.append(acc_dict)
index = args.resFile.rfind('/')
compareFile = args.resFile[0:index]
compareFile = os.path.join(compareFile, 'all_acc.json')
json.dump(all_result_list, open(compareFile, 'w'))
print('All accurary file saved to: ', compareFile)