-
Notifications
You must be signed in to change notification settings - Fork 0
/
Evaluator.lua
61 lines (51 loc) · 1.85 KB
/
Evaluator.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
local Evaluator = {}
-- Calculates a sequence error rate (aka Levenshtein edit distance)
function Evaluator.sequenceErrorRate(target, prediction)
local d = torch.Tensor(#target + 1, #prediction + 1):zero()
for i = 1, #target + 1 do
for j = 1, #prediction + 1 do
if (i == 1) then
d[1][j] = j - 1
elseif (j == 1) then
d[i][1] = i - 1
end
end
end
for i = 2, #target + 1 do
for j = 2, #prediction + 1 do
if (target[i - 1] == prediction[j - 1]) then
d[i][j] = d[i - 1][j - 1]
else
local substitution = d[i - 1][j - 1] + 1
local insertion = d[i][j - 1] + 1
local deletion = d[i - 1][j] + 1
d[i][j] = torch.min(torch.Tensor({ substitution, insertion, deletion }))
end
end
end
local wer = d[#target + 1][#prediction + 1] / #target
if wer > 1 then return 1 else return wer end
end
function Evaluator.predict2tokens(predictions, mapper)
--[[
Turns the predictions tensor into a list of the most likely tokens
NOTE:
to compute WER we strip the begining and ending spaces
--]]
local tokens = {}
local blankToken = mapper.alphabet2token['$']
local preToken = blankToken
-- The prediction is a sequence of likelihood vectors
local _, maxIndices = torch.max(predictions, 2)
maxIndices = maxIndices:squeeze()
for i=1, maxIndices:size(1) do
local token = maxIndices[i] - 1 -- CTC indexes start from 1, while token starts from 0
-- add token if it's not blank, and is not the same as pre_token
if token ~= blankToken and token ~= preToken then
table.insert(tokens, token)
preToken = token
end
end
return tokens
end
return Evaluator