-
Notifications
You must be signed in to change notification settings - Fork 0
/
ruthlemm.py
75 lines (64 loc) · 2.54 KB
/
ruthlemm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from itertools import cycle
import random
import argparse
from simpletransformers.seq2seq import Seq2SeqModel
import pandas as pd
random.seed = 42
def load_conllu_dataset(datafile, join=False):
arr = []
with open(datafile, encoding='utf-8') as inp:
strings = inp.readlines()
for s in strings:
if (s[0] != "#" and s.strip()):
split_string = s.split('\t')
if split_string[1] == "(" or split_string[1] == ")" or split_string[1] == "[" or split_string[1] == "]":
form = split_string[1]
else:
form = split_string[1].replace("(", "").replace(")", "").replace("[", "").replace("]", "")
if split_string[3] != "PROPN":
form = form.lower()
else:
form = form.capitalize()
lemma = split_string[2]
if split_string[3] == "PROPN":
lemma = lemma.capitalize()
if join:
inpt = form + " " + split_string[3] + " " + split_string[5]
else:
inpt = form
pos = split_string[3]
arr.append([inpt, lemma, pos])
return pd.DataFrame(arr, columns=["input_text", "target_text", "pos"])
def predict(in_file, out_file, join=False):
if join:
model_name = "Futyn-Maker/RuthLemm-morphology"
else:
model_name = "Futyn-Maker/RuthLemm"
model = Seq2SeqModel(
encoder_decoder_type="bart",
encoder_decoder_name=model_name,
use_cuda=False
)
pred_data = load_conllu_dataset(in_file, join=join)["input_text"].tolist()
predictions = cycle(model.predict(pred_data))
with open(in_file, encoding="utf8") as inp:
strings = inp.readlines()
predicted = []
for s in strings:
if (s[0] != "#" and s.strip()):
split_string = s.split("\t")
split_string[2] = next(predictions)
joined_string = "\t".join(split_string)
predicted.append(joined_string)
continue
predicted.append(s)
with open(out_file, "w", encoding="utf8") as out:
out.write("".join(predicted))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("input_file", type=str, help="Path to the input file")
parser.add_argument("output_file", type=str, help="Path to the output file")
parser.add_argument("--morphology", "-m", action="store_true", help="Use morphology")
args = parser.parse_args()
predict(args.input_file, args.output_file, args.morphology)
print("All done!")