-
Notifications
You must be signed in to change notification settings - Fork 4
/
search.go
189 lines (172 loc) · 4.85 KB
/
search.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
package watertower
import (
"fmt"
"github.com/future-architect/watertower/nlp"
"golang.org/x/sync/errgroup"
"math"
"sort"
)
// Search searches documents.
//
// searchWord is for title, content. This is processed by NLP logic specified by lang parameter.
//
// tags is for filter search result.
func (wt *WaterTower) Search(searchWord string, tags []string, lang string) ([]*Document, error) {
if lang == "" {
lang = wt.defaultLanguage
}
if lang == "" {
if len(searchWord) < 2 {
lang = "unigram"
} else {
lang = "bigram"
}
}
tokenizer, err := nlp.FindTokenizer(lang)
if err != nil {
return nil, fmt.Errorf("tokenizer for language '%s' is not found: %w", lang, err)
}
searchTokens, _ := tokenizer.TokenizeToMap(searchWord, 0)
if len(searchTokens) == 0 && len(tags) == 0 {
return nil, nil
}
errGroup, ctx := errgroup.WithContext(wt.storage.Context())
var tagDocIDGroups [][]uint32
if len(tags) > 0 {
errGroup.Go(func() error {
findTags, err := wt.FindTagsWithContext(ctx, tags...)
if err != nil {
return err
}
tagDocIDGroups = make([][]uint32, len(findTags))
for i, findTag := range findTags {
tagDocIDGroups[i] = findTag.DocumentIDs
}
return nil
})
}
var foundTokens []*token
var tokenDocIDGroups [][]uint32
var docCount int
errGroup.Go(func() (err error) {
docCount, err = wt.storage.DocCount()
return
})
if len(searchTokens) > 0 {
errGroup.Go(func() (err error) {
tokens := make([]string, 0, len(searchTokens))
for token := range searchTokens {
tokens = append(tokens, token)
}
foundTokens, err = wt.FindTokensWithContext(ctx, tokens...)
for _, token := range foundTokens {
docIDs := make([]uint32, len(token.Postings))
for i, posting := range token.Postings {
docIDs[i] = posting.DocumentID
}
tokenDocIDGroups = append(tokenDocIDGroups, docIDs)
}
return
})
}
err = errGroup.Wait()
if err != nil {
return nil, err
}
var docIDs []uint32
if len(tags) > 0 && len(searchTokens) > 0 {
docIDGroups := append(tagDocIDGroups, tokenDocIDGroups...)
docIDs = intersection(docIDGroups...)
} else if len(searchTokens) > 0 {
docIDs = intersection(tokenDocIDGroups...)
} else {
// len(tags) > 0
docIDs = intersection(tagDocIDGroups...)
}
if len(searchTokens) > 0 {
docIDs, _ = phraseSearchFilter(docIDs, searchTokens, foundTokens)
}
docs, err := wt.FindDocuments(docIDs...)
if err != nil {
return nil, err
}
for i, doc := range docs {
doc.Score = calcScore(foundTokens, docCount, docIDs[i])
}
sort.Slice(docs, func(i, j int) bool {
return docs[i].Score < docs[j].Score
})
return docs, nil
}
func phraseSearchFilter(docIDs []uint32, searchTokens map[string]*nlp.Token, foundTokens []*token) (matchedDocIDs []uint32, foundPositions [][]uint32) {
for _, docID := range docIDs {
tokenPositionMap := convertToTokenPositionMap(foundTokens, docID)
var relativePositionGroups [][]uint32
for word, positionMap := range tokenPositionMap {
relativePositions := findPhraseMatchPositions(searchTokens[word], positionMap)
relativePositionGroups = append(relativePositionGroups, relativePositions)
}
relativePositions := intersection(relativePositionGroups...)
if len(relativePositions) > 0 {
matchedDocIDs = append(matchedDocIDs, docID)
foundPositions = append(foundPositions, relativePositions)
}
}
return
}
func findPhraseMatchPositions(token *nlp.Token, positionMap map[uint32]bool) []uint32 {
firstPos := token.Positions[0]
var result []uint32
for position := range positionMap {
match := true
for i := 1; i < len(token.Positions); i++ {
otherPos := token.Positions[i]
if !positionMap[position-firstPos+otherPos] {
match = false
break
}
}
if match {
result = append(result, position-firstPos)
}
}
sort.Slice(result, func(i, j int) bool {
return result[i] < result[j]
})
return result
}
func convertToTokenPositionMap(foundTokens []*token, docID uint32) map[string]map[uint32]bool {
foundTokenMaps := make(map[string]map[uint32]bool)
for _, foundToken := range foundTokens {
positionMap := make(map[uint32]bool)
for _, posting := range foundToken.Postings {
if posting.DocumentID == docID {
for _, pos := range posting.Positions {
positionMap[pos] = true
}
break
}
}
foundTokenMaps[foundToken.Word] = positionMap
}
return foundTokenMaps
}
func calcScore(foundTokens []*token, docCount int, documentID uint32) float64 {
var totalScore float64
for _, token := range foundTokens {
for _, posting := range token.Postings {
if posting.DocumentID == documentID {
totalScore += tfIdfScore(len(posting.Positions), docCount, len(token.Postings))
}
}
}
return totalScore
}
func tfIdfScore(tokenCount, allDocCount, docCount int) float64 {
var tf float64
if tokenCount > 0 {
tf = 1.0 + math.Log(float64(tokenCount))
}
idf := math.Log(float64(allDocCount) / float64(docCount))
return tf * idf
}