-
Notifications
You must be signed in to change notification settings - Fork 0
/
bpe.go
144 lines (114 loc) · 3.26 KB
/
bpe.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package bpe
import (
"bufio"
"io"
"sort"
"strings"
"github.com/pkg/errors"
)
var defaultTokensCap = 8
type BPE struct {
maxTokenLength int
vocab map[string]struct{} // Set with fast vocab search.
}
type weightedToken struct {
Token *string
Weight int
}
func (b *BPE) Encode(r io.Reader) ([]string, error) {
scanner := bufio.NewScanner(r)
scanner.Split(scanSentences)
tokens := make([]string, 0, defaultTokensCap)
for scanner.Scan() {
sentence := scanner.Text()
b.encodeSentence(&tokens, sentence)
}
if err := scanner.Err(); err != nil && err != io.EOF {
return nil, errors.Wrap(err, "file scan")
}
return tokens, nil
}
// Target is a pointer to slice of tokens because it helps avoid unnecessary memory allocations.
func (b *BPE) encodeSentence(target *[]string, sentence string) {
*target = append(*target, BeginOfSentence)
words := strings.Fields(sentence)
for _, word := range words {
b.encodeWord(target, word)
}
*target = append(*target, EndOfSentence)
}
func (b *BPE) encodeWord(target *[]string, word string) {
word = BeginOfWord + word + EndOfWord // TODO use special tokens from BPE.
tokenStart := 0
tokenLoop:
for tokenStart < len(word) {
tokenEnd := len(word)
if tokenEnd-tokenStart > b.maxTokenLength {
tokenEnd = tokenStart + b.maxTokenLength
}
for ; tokenEnd != tokenStart; tokenEnd-- {
token := word[tokenStart:tokenEnd]
_, ok := b.vocab[token]
if ok {
*target = append(*target, token)
tokenStart += len(token)
continue tokenLoop
}
}
*target = append(*target, UnknownToken)
tokenStart++
}
}
// Decode todo description.
// Error in response added for potential future usages to keep backward compatibility.
func (b *BPE) Decode(tokens []string) (string, error) {
builder := strings.Builder{}
for _, token := range tokens {
// Skip special tokens.
// TODO Use special tokens from BPE.
token = strings.TrimSuffix(token, BeginOfSentence)
token = strings.TrimSuffix(token, EndOfSentence)
token = strings.TrimSuffix(token, EndOfWord)
if strings.HasPrefix(token, BeginOfWord) {
builder.WriteByte(' ')
token = token[len(BeginOfWord):]
}
_, err := builder.WriteString(token)
if err != nil {
return "", err
}
}
return strings.TrimSpace(builder.String()), nil
}
func newModelFromTokensFrequencyTable(tft tokensFrequencyTable, tokensLimit int) *BPE {
tokensListWithWeights := make([]weightedToken, 0, len(tft))
for t, w := range tft {
token := t
tokensListWithWeights = append(tokensListWithWeights, weightedToken{
Token: &token,
Weight: w,
})
}
sort.Slice(tokensListWithWeights, func(i, j int) bool {
return tokensListWithWeights[i].Weight > tokensListWithWeights[j].Weight
})
if len(tokensListWithWeights) > tokensLimit {
tokensListWithWeights = tokensListWithWeights[:tokensLimit]
}
var maxTokenLength int
vocab := make(map[string]struct{}, len(tokensListWithWeights))
for _, t := range tokensListWithWeights {
token := *t.Token
// TODO consider removing it and using value from config.
// Need to check necessity for this change with benchmarks.
tokenLength := len(token)
if len(token) > maxTokenLength {
maxTokenLength = tokenLength
}
vocab[token] = struct{}{}
}
return &BPE{
maxTokenLength: maxTokenLength,
vocab: vocab,
}
}