-
Notifications
You must be signed in to change notification settings - Fork 33
/
back_translate.py
100 lines (85 loc) · 3.25 KB
/
back_translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Train and evaluate."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from tensor2tensor.bin import t2t_decoder
from tensor2tensor.models import transformer
import decoding
import problems
registry = problems.registry
tf.flags.DEFINE_string(
'from_problem',
'translate_vien_iwslt32k',
'Problem name for source to intermediate language translation.')
tf.flags.DEFINE_string(
'to_problem',
'translate_envi_iwslt32k',
'Problem name for intermediate to source language translation.')
tf.flags.DEFINE_string(
'from_data_dir',
'gs://vien-translation/data/translate_vien_iwslt32k',
'Data directory for source to intermediate language translation.')
tf.flags.DEFINE_string(
'to_data_dir',
'gs://vien-translation/data/translate_envi_iwslt32k',
'Data directory for intermediate to source language translation.')
tf.flags.DEFINE_string(
'from_ckpt',
'gs://vien-translation/checkpoints/translate_vien_iwslt32k_tiny/avg/',
'Pretrain checkpoint directory for source to intermediate language translation.')
tf.flags.DEFINE_string(
'to_ckpt',
'gs://vien-translation/checkpoints/translate_envi_iwslt32k_tiny/avg/',
'Pretrain checkpoint directory for intermediate to source language translation.')
tf.flags.DEFINE_string(
'paraphrase_from_file',
'test_input.vi',
'Input text file to paraphrase.')
tf.flags.DEFINE_string(
'paraphrase_to_file',
'test_output.vi',
'Output text file to paraphrase.')
tf.flags.DEFINE_boolean(
'backtranslate_interactively',
False,
'Whether to back-translate interactively.')
FLAGS = tf.flags.FLAGS
@registry.register_hparams
def transformer_tall9():
hparams = transformer.transformer_big()
hparams.hidden_size = 768
hparams.filter_size = 3072
hparams.num_hidden_layers = 9
hparams.num_heads = 12
return hparams
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
# Convert directory into checkpoints
from_ckpt = FLAGS.from_ckpt
to_ckpt = FLAGS.to_ckpt
if tf.gfile.IsDirectory(FLAGS.from_ckpt):
from_ckpt = tf.train.latest_checkpoint(FLAGS.from_ckpt)
if tf.gfile.IsDirectory(FLAGS.to_ckpt):
to_ckpt = tf.train.latest_checkpoint(FLAGS.to_ckpt)
if FLAGS.backtranslate_interactively:
decoding.backtranslate_interactively(
FLAGS.from_problem, FLAGS.to_problem,
FLAGS.from_data_dir, FLAGS.to_data_dir,
FLAGS.from_ckpt, FLAGS.to_ckpt)
else:
# For back translation from file, we need a temporary file in the other language
# before back-translating into the source language.
tmp_file = os.path.join(
'{}.tmp.txt'.format(FLAGS.paraphrase_from_file)
)
# Step 1: Translating from source language to the other language.
decoding.t2t_decoder(FLAGS.from_problem, FLAGS.from_data_dir,
FLAGS.paraphrase_from_file, tmp_file,
from_ckpt)
# Step 2: Translating from the other language (tmp_file) to source.
decoding.t2t_decoder(FLAGS.to_problem, FLAGS.to_data_dir,
tmp_file, FLAGS.paraphrase_to_file,
to_ckpt)