-
Notifications
You must be signed in to change notification settings - Fork 1
/
fever_score_test.py
68 lines (59 loc) · 2.49 KB
/
fever_score_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import json
import sys
from fever_score import fever_score
from prettytable import PrettyTable
parser = argparse.ArgumentParser()
parser.add_argument("--predicted_labels", type=str, default='temp_dev_bert_verifier.json')
parser.add_argument("--predicted_evidence", type=str, default='.json') # your results
parser.add_argument("--actual", default="data/dev_eval.json")
args = parser.parse_args()
predicted_labels =[]
predicted_evidence = []
actual = []
num_label = 0
with open(args.predicted_labels,"r") as predictions_file:
for line in predictions_file:
num_label += 1
print(num_label)
num_label = 0
with open(args.predicted_evidence,"r") as predictions_file:
for line in predictions_file:
num_label += 1
print(num_label)
ids = dict()
with open(args.actual,"r") as predictions_file:
for line in predictions_file:
data = json.loads(line)
if data['id'] not in ids:
ids[json.loads(line)["id"]] = len(predicted_labels)
predicted_labels.append('NOT ENOUGH INFO')
predicted_evidence.append([])
actual.append(0)
with open(args.predicted_labels,"r") as predictions_file:
for line in predictions_file:
predicted_labels[ids[json.loads(line)["id"]]] = json.loads(line)["predicted_label"]
num =0
with open(args.predicted_evidence,"r") as predictions_file:
for line in predictions_file:
evidences = list()
if "predicted_evidence" in json.loads(line):
for evidence in json.loads(line)["predicted_evidence"]:
evidences.append(evidence[:2])
predicted_evidence[ids[json.loads(line)["id"]]] = evidences
if "evidence" in json.loads(line):
for evidence in json.loads(line)["evidence"]:
evidences.append(evidence[:2])
predicted_evidence[ids[json.loads(line)["id"]]] = evidences
with open(args.actual, "r") as actual_file:
for line in actual_file:
actual[ids[json.loads(line)["id"]]] = json.loads(line)
predictions = []
for ev,label in zip(predicted_evidence,predicted_labels):
predictions.append({"predicted_evidence":ev,"predicted_label":label})
score,acc,precision,recall,f1 = fever_score(predictions,actual)
tab = PrettyTable()
tab.field_names = ["FEVER Score", "Label Accuracy", "Evidence Precision", "Evidence Recall", "Evidence F1"]
tab.add_row((round(score,4),round(acc,4),round(precision,4),round(recall,4),round(f1,4)))
print(tab)
print(num)