-
Notifications
You must be signed in to change notification settings - Fork 1
/
trainMultipleRNN.m
181 lines (158 loc) · 7.16 KB
/
trainMultipleRNN.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
%% ------------------------------------------------------------------------
%% GAIT RECOGNITION BASED ON IMU DATA AND ML ALGORITHM
% Albi Matteo, Cardone Andrea, Oselin Pierfrancesco
%
% Required packages:
% Parallel Computing Toolbox
% Neural Network Toolbox
% Signal Toolbox
% Statistics Toolbox
% -------------------------------------------------------------------------
%% ------------------------------------------------------------------------
%% GOAL OF THE FUNCTION
% Goal of this function is training several networks in order to find the
% one that provides the better performances
% -------------------------------------------------------------------------
clear ;
close all;
clc
addpath("include");
%% DATA IMPORTING
try
file01 = readtable('data/record_walk_7-12-21_caviglia/personaA4kmh.csv', "VariableNamingRule","preserve");
file02 = readtable('data/record_walk_7-12-21_caviglia/personaB4kmh.csv', "VariableNamingRule","preserve");
file03 = readtable('data/record_walk_7-12-21_caviglia/personaC4kmh.csv', "VariableNamingRule","preserve");
file04 = readtable('data/record_walk_7-12-21_caviglia/personaD4kmh.csv', "VariableNamingRule","preserve");
file05 = readtable('data/record_walk_7-12-21_caviglia/personaE4kmh.csv', "VariableNamingRule","preserve");
file06 = readtable('data/record_walk_7-12-21_caviglia/personaA6kmh.csv', "VariableNamingRule","preserve");
file07 = readtable('data/record_walk_7-12-21_caviglia/personaB6kmh.csv', "VariableNamingRule","preserve");
file08 = readtable('data/record_walk_7-12-21_caviglia/personaC5_8kmh.csv', "VariableNamingRule","preserve");
file09 = readtable('data/record_walk_7-12-21_caviglia/personaD6kmh.csv', "VariableNamingRule","preserve");
file10 = readtable('data/record_walk_7-12-21_caviglia/personaE6kmh.csv', "VariableNamingRule","preserve");
%adding cutted lab data
file11 = readtable('data/record_lab_15-12-21/IMU1_1.csv', "VariableNamingRule","preserve");
file12 = readtable('data/record_lab_15-12-21/IMU1_2.csv', "VariableNamingRule","preserve");
file13 = readtable('data/record_lab_15-12-21/IMU2_1.csv', "VariableNamingRule","preserve");
file14 = readtable('data/record_lab_15-12-21/IMU3_1.csv', "VariableNamingRule","preserve");
file15 = readtable('data/record_lab_15-12-21_afternoon/IMU4_1.csv', "VariableNamingRule","preserve");
disp("Data successfully imported");
catch ME
if strcmp(ME.identifier, 'MATLAB:textio:textio:FileNotFound')
disp("ERROR: some data cannot be found");
return;
end
end
train = {file01, file02, file03, file04, file06, file07, file08, file09, file11, file12, file13, file14, file15};
test = {file05, file10};
%data to keep from input dataset
useful_data = [ 'AccX (g)', 'AccY (g)', 'AccZ (g)', ...
'GyroX (deg/s)', 'GyroY (deg/s)', 'GyroZ (deg/s)', ...
'EulerX (deg)', 'EulerY (deg)', 'EulerZ (deg)', ...
'LinAccX (g)', 'LinAccY (g)', 'LinAccZ (g)', ...
'ID'];
%% Labeling and preparing data to train and test the network
[XTrain,YTrain] = dataPreprocessing(train,useful_data);
[XTest,YTest] = dataPreprocessing(test,useful_data);
%% Parameters of the RNN network
% Constant parameters
NumFeatures = height(XTrain{1});
NumClasses = 4;
ExecutionEnvironment = 'cpu';
MiniBatchSize = 1000;
% Variable parameters
NetType = { 'gru', 'lstm'};
NHiddenLayers = [50 100 150];
MaxEpochs = [150 175 200];
GradientThreshold = [1 1.5 2];
% struct were save net, params and accuracies
netData = struct(...
'netType', '', ...
'nHiddenLayers', 0.0, ...
'maxEpochs', 0.0, ...
'gradientThreshold', 0.0, ...
'net', [], ...
'phaseAcc', [], ...
'testAcc', [], ...
'streamAcc', 0.0 ...
);
I = length(NetType);
J = length(NHiddenLayers);
K = length(MaxEpochs);
L = length(GradientThreshold);
results = cell(I, J, K, L);
%% Training
disp("Start training");
layers = []; %net structure
% to compute phase accuracy (same method used in main2.m)
correct = zeros(1,NumClasses); %result from classification
totPhases = zeros(1,NumClasses); %correct label
% to compute test accuracy (same method used in main2.m)
acc = zeros(1,length(XTest));
%define net's layers
for i = 1:I
%net type
netData.netType = NetType{i};
for j = 1:J
% n of hidden layers
netData.nHiddenLayers = NHiddenLayers(j);
% build net
if(strcmp(netData.netType,'gru'))
layers = [sequenceInputLayer(NumFeatures)
gruLayer(netData.nHiddenLayers,'OutputMode','sequence')
fullyConnectedLayer(NumClasses)
softmaxLayer
classificationLayer];
elseif(strcmp(netData.netType,'lstm'))
layers = [sequenceInputLayer(NumFeatures)
lstmLayer(netData.nHiddenLayers,'OutputMode','sequence')
fullyConnectedLayer(NumClasses)
softmaxLayer
classificationLayer];
else
disp('Error in param definition: netType');
return;
end
%define training options
for k = 1:K
% n of epochs
netData.maxEpochs = MaxEpochs(k);
for l = 1:L
% gradient threshold
netData.gradientThreshold = GradientThreshold(l);
% build option struct
options = trainingOptions(...
'adam', ...
'MiniBatchSize',MiniBatchSize, ...
'MaxEpochs',netData.maxEpochs, ...
'GradientThreshold', netData.gradientThreshold, ...
'Verbose', false, ...
'Plots','none', ...
'ExecutionEnvironment', ExecutionEnvironment);
%index of trained net
netNumber = (i-1)*J*K*L+(j-1)*K*L+(k-1)*L+l;
disp("Training net number: "+num2str(netNumber));
%training
netData.net = trainNetwork(XTrain,YTrain,layers,options);
%compute accuracy
for m = 1:length(XTest)
%prediction
YPred = classify(netData.net,XTest{m});
%test accuracy
acc(m) = sum(YPred == YTest{m})/numel(YTest{m});
for n = 1:NumClasses %for each phase
% n of tot labels for each phase
totPhases(n) = totPhases(n) + sum(YTest{m} == categorical(n));
% n of correct prediction for each label
correct(n) = correct(n) + sum(YPred(YTest{m} == categorical(n)) == categorical(n));
end
end
netData.phaseAcc = correct./totPhases; %phase acc
netData.testAcc = acc; %test acc
%stream acc (see simulateStream.m)
netData.streamAcc = simulateStream(netData.net, file10, 0, 0);
results{i,j,k,l} = netData; %saving net
end
end
end
end
save('results.mat', 'results');