-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_inferences_benchmark.py
37 lines (30 loc) · 1.13 KB
/
create_inferences_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import random
import json
from utils.data_loader import DataLoader
random.seed(100)
dataset = 'ANATOMY'
data_loader = DataLoader.from_task('inferences')
data, classes, relations = data_loader.load_data(dataset)
folder = f'data/{dataset}/inferences'
nf1_set = set([(l[0].item(), l[1].item()) for l in data['nf1']])
inference_data = []
with open(f'{folder}/inferences.owl', 'r') as f:
for line in f:
if not line.startswith('SubClassOf'):
continue
line = line.strip().replace('SubClassOf(', '').replace(')', '')
class1, class2 = line.split(' ')
if class1 not in classes or class2 not in classes:
print('ERROR: encountered unknown class')
continue
if (classes[class1], classes[class2]) in nf1_set:
continue
inference_data.append((class1, class2))
random.shuffle(inference_data)
num_val = int(0.1 * len(inference_data))
val_data = inference_data[:num_val]
inference_data = inference_data[num_val:]
with open(f'{folder}/inferences.json', 'w+') as f:
json.dump(inference_data, f)
with open(f'{folder}/val.json', 'w+') as f:
json.dump(val_data, f)