-
Notifications
You must be signed in to change notification settings - Fork 0
/
AN4CTCTest.lua
28 lines (24 loc) · 1.06 KB
/
AN4CTCTest.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
--[[Calulates the WER using the AN4 Audio database test set.
-- Uses model created by AN4CTCTrain.]]
local Network = require 'Network'
-- Load the network from the saved model.
local networkParams = {
loadModel = true,
saveModel = false,
fileName = arg[1] or "./models/CTCNetwork.t7", -- Rename the evaluated model to CTCNetwork.t7 or pass the file path as an argument.
modelName = 'DeepSpeechModel',
backend = 'cudnn',
nGPU = 1, -- Number of GPUs, set -1 to use CPU
trainingSetLMDBPath = './prepare_an4/train/', -- online loading path
validationSetLMDBPath = './prepare_an4/test/',
logsTrainPath = './logs/TrainingLoss/',
logsValidationPath = './logs/TestScores/',
modelTrainingPath = './models/',
dictionaryPath = './dictionary',
batchSize = 1,
validationBatchSize = 1,
validationIterations = 130 -- batch size 1, goes through 130 samples.
}
Network:init(networkParams)
local wer = Network:testNetwork()
print(string.format('Number of iterations: %d average WER: %2.f%%', networkParams.validationIterations, 100 * wer))