-
Notifications
You must be signed in to change notification settings - Fork 2
/
model_defn.py
executable file
·95 lines (84 loc) · 3.45 KB
/
model_defn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import tensorflow as tf
import opennmt as onmt
#
# defines the [model] type to train, called in the [train*.sh] script
# tip on overfitting: for single-pair/smaller models dropout=0.3, while for multilingual dropout=[0.1-0.3] shows best result
#
class Transformer(onmt.models.Transformer):
"""Defines a Transformer model as decribed in https://arxiv.org/abs/1706.03762."""
def __init__(self, dtype=tf.float32, share_embeddings=onmt.models.EmbeddingsSharingLevel.NONE):
super(Transformer, self).__init__(
source_inputter=onmt.inputters.WordEmbedder(
vocabulary_file_key="source_words_vocabulary",
embedding_size=512,
dtype=dtype),
target_inputter=onmt.inputters.WordEmbedder(
vocabulary_file_key="target_words_vocabulary",
embedding_size=512,
dtype=dtype),
num_layers=6,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
share_embeddings=share_embeddings)
class TransformerFP16(Transformer):
"""Defines a Transformer model that uses half-precision floating points."""
def __init__(self):
super(TransformerFP16, self).__init__(dtype=tf.float16)
class TransformerSharedEmbd(Transformer):
"""Defines a Transformer model that uses shared encoder-decoder embeddings."""
def __init__(self):
super(TransformerSharedEmbd, self).__init__(
share_embeddings=onmt.models.EmbeddingsSharingLevel.SOURCE_TARGET_INPUT
)
class TransformerMedium(onmt.models.Transformer):
"""Defines a 4 SA layer Transformer model comparable with related works."""
def __init__(self):
super(TransformerMedium, self).__init__(
source_inputter=onmt.inputters.WordEmbedder(
vocabulary_file_key="source_words_vocabulary",
embedding_size=512),
target_inputter=onmt.inputters.WordEmbedder(
vocabulary_file_key="target_words_vocabulary",
embedding_size=512),
num_layers=4,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.3,
attention_dropout=0.1,
relu_dropout=0.1)
'''
class TransformerSA(onmt.models.Transformer):
"""Defines a Transforme:wqr model as decribed in https://arxiv.org/abs/1706.03762."""
def __init__(self, dtype=tf.float32, share_embeddings=onmt.models.EmbeddingsSharingLevel.NONE, dropout=0.1):
super(TransformerSA, self).__init__(
source_inputter=onmt.inputters.WordEmbedder(embedding_size=512),
target_inputter=onmt.inputters.WordEmbedder(embedding_size=512),
num_layers=6,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.1,
attention_dropout=0.1,
ffn_dropout=0.1,
share_embeddings=share_embeddings)
class TransformerShareEmbs(TransformerSA):
"""Defines a Transformer model that uses shared encoder-decoder embeddings."""
def __init__(self):
super(TransformerShareEmbs, self).__init__(
share_embeddings=onmt.models.EmbeddingsSharingLevel.ALL
)
class TransformerShareEmbsDropout(TransformerSA):
"""Defines a Transformer model that uses shared encoder-decoder embeddings."""
def __init__(self):
super(TransformerShareEmbsDropout, self).__init__(
dropout=0.3,
share_embeddings=onmt.models.EmbeddingsSharingLevel.ALL)
'''
# update accordingly
#model = Transformer
model = TransformerMedium