-
Notifications
You must be signed in to change notification settings - Fork 6
/
event_constraints.py
72 lines (53 loc) · 2.26 KB
/
event_constraints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from vocab import Vocab
from io_utils import read_yaml, read_lines, read_json_lines
from str_utils import capitalize_first_char, normalize_tok, normalize_sent, collapse_role_type
class EventConstraint(object):
'''
This class is used to make sure that (event types, entity types) -> (argument roles) obey event constraints.
'''
def __init__(self, ent_dict, tri_dict, arg_dict):
constraint_file = './data_files/argrole_dict.txt'
self.constraint_list = [] # [(ent_type, tri_type, arg_type)]
for line in read_lines(constraint_file):
line = str(line).lower()
arr = line.split()
arg_type = arr[0]
for pair in arr[1:]:
pair_arr = pair.split(',')
tri_type = pair_arr[0]
ent_type = pair_arr[1]
ent_type = self._replace_ent(ent_type)
self.constraint_list.append((ent_type, tri_type, arg_type))
print('Event constraint size:',len(self.constraint_list))
# { (ent_type, tri_type) : (arg_type1, ...)}
self.ent_tri_to_arg_hash = {}
for cons in self.constraint_list:
ent_id = ent_dict[cons[0]]
tri_id = tri_dict[cons[1]]
arg_id = arg_dict[cons[2]]
# ent_id = cons[0]
# tri_id = cons[1]
# arg_id = cons[2]
if (ent_id, tri_id) not in self.ent_tri_to_arg_hash:
self.ent_tri_to_arg_hash[(ent_id, tri_id)] = set()
self.ent_tri_to_arg_hash[(ent_id, tri_id)].add(arg_id)
#print(self.ent_tri_to_arg_hash)
# single = 0
# for key, val in self.ent_tri_to_arg_hash.items():
# if len(val) == 1:
# single += 1
# print(single)
def _replace_ent(self, ent_type):
if ent_type == 'time':
return 'tim'
if ent_type == 'value':
return 'val'
return ent_type
def check_constraint(self, ent_type, tri_type, arg_type):
if (ent_type, tri_type, arg_type) in self.constraint_list:
return True
else:
return False
def get_constraint_arg_types(self, ent_type_id, tri_type_id):
return self.ent_tri_to_arg_hash.get((ent_type_id, tri_type_id), None)