-
Notifications
You must be signed in to change notification settings - Fork 1
/
NeuralNetwork.java
154 lines (128 loc) · 3.92 KB
/
NeuralNetwork.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import java.util.Random;
import java.lang.Math;
import Jama.Matrix;
public class NeuralNetwork{
//nodes at each layer.
private final int inputNodes,
hiddenNodes,
outputNodes;
//how much the network can change during training.
private double learningRate;
private Matrix weightInputHidden,
weightHiddenOutput;
private Random random = new Random();
/**
* creates NeuralNetwork.
*
*@param int inputNodes
*@param int hiddenNodes
*@param int outputNodes
*@param double learningRate
*/
public NeuralNetwork(int inputNodes, int hiddenNodes, int outputNodes, double learningRate){
this.inputNodes = inputNodes;
this.hiddenNodes = hiddenNodes;
this.outputNodes = outputNodes;
this.learningRate = learningRate;
weightInputHidden = randomWeightMatrix(hiddenNodes, inputNodes);
weightHiddenOutput = randomWeightMatrix(outputNodes, hiddenNodes);
}
/**
* trains NeuralNetwork.
*
*@param Matrix inputs : one line of data.
*@param Matrix targets : target matrix (what output should look like).
*/
public void train(Matrix inputs, Matrix targets){
Matrix hiddenOutput = generateHiddenOutput(inputs);
Matrix outputOutput = generateOutputOutput(hiddenOutput);
//calulate output error. (targets - actuall)
Matrix outputError = targets.minus(outputOutput);
// apply weightHiddenOutput to output error.
Matrix hiddenError = weightHiddenOutput.transpose().times(outputError);
backwardPropagation(weightHiddenOutput, outputError, outputOutput, hiddenOutput);
backwardPropagation(weightInputHidden, hiddenError, hiddenOutput, inputs);
}
/**
* feeds errors through layers backward to update layer weights.
*
*@param Matrix weightMatrix : Matrix to be updated.
*@param Matrix error
*@param Matrix from : starting layer.
*@param Matrix to : ending layer.
*/
private void backwardPropagation(Matrix weightMatrix, Matrix error, Matrix from, Matrix to){
weightMatrix.plusEquals(
(
error.arrayTimes(from)
.arrayTimes(
new Matrix(error.getRowDimension(), error.getColumnDimension(), 1)
.minus(from))
)
.times(to.transpose())
.times(this.learningRate)
);
}
/**
* feeds inputs through hidden and output layers to get guess.
*
*@param Matrix inputs
*@return Matrix
*/
public Matrix generateGuess(Matrix inputs){
Matrix hiddenOutput = generateHiddenOutput(inputs);
return generateOutputOutput(hiddenOutput);
}
/**
* calculate outputs of output layer.
*
*@param input : input to be normalized
*@return Matrix.
*/
private Matrix generateHiddenOutput(Matrix inputs){
//feed inputs into hidden layer.
Matrix hiddenInput = weightInputHidden.times(inputs);
return activationFunction(hiddenInput);
}
/**
* calculate outputs of output layer.
*
*@param input : input to be normalized
*@return Matrix
*/
private Matrix generateOutputOutput(Matrix hiddenOutput){
//feed hidden layer outputs into output layer.
Matrix outputInput = weightHiddenOutput.times(hiddenOutput);
return activationFunction(outputInput);
}
/**
* activationFunction (sigmoid)
*
*@param input : input to be normalized
*@return Matrix : input normalized
*/
private Matrix activationFunction(Matrix input){
for(int i = 0; i < input.getRowDimension(); i++){
for(int j = 0; j < input.getColumnDimension(); j++){
input.set(i, j, 1/(1 + Math.exp(- input.get(i, j))));
}
}
return input;
}
/**
* randomWeightMatrix
*
*@param row
*@param col
*@return row by col matrix
*/
private Matrix randomWeightMatrix(int row, int col){
Matrix retMatrix = new Matrix(row, col);
for(int i = 0; i < retMatrix.getRowDimension(); i++){
for(int j = 0; j < retMatrix.getColumnDimension(); j++){
retMatrix.set(i, j, random.nextGaussian() * Math.pow(hiddenNodes, -0.5));
}
}
return retMatrix;
}
}