Skip to content

Commit

Permalink
improve perfomance by using blas like primitives (eg. faxpy aka fma)
Browse files Browse the repository at this point in the history
  • Loading branch information
sleepybishop committed Aug 30, 2020
1 parent 4a34847 commit d4c1c30
Showing 1 changed file with 57 additions and 49 deletions.
106 changes: 57 additions & 49 deletions src/rnn.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,35 +76,38 @@ static OPUS_INLINE float relu(float x)
return x < 0 ? 0 : x;
}

void faxpy(float *restrict a, const rnn_weight *restrict b, int k, float u)
{
if (u == 0.0) return;
for (int idx = 0; idx < k; idx++)
a[idx] += b[idx] * u;
}

void compute_dense(const DenseLayer *layer, float *output, const float *input)
{
int i, j;
int N, M;
int stride;
M = layer->nb_inputs;
N = layer->nb_neurons;
stride = N;
for (i=0;i<N;i++)
{
/* Compute update gate. */
float sum = layer->bias[i];
for (j=0;j<M;j++)
sum += layer->input_weights[j*stride + i]*input[j];
output[i] = WEIGHTS_SCALE*sum;
}
const rnn_weight *ip = layer->input_weights;
/* Compute update gate. */
for(i = 0; i < N; i++)
output[i] = layer->bias[i];
for (j=0;j<M;j++,ip+=N)
faxpy(output, ip, N, input[j]);
switch (layer->activation) {
case ACTIVATION_SIGMOID:
for (i=0;i<N;i++)
output[i] = sigmoid_approx(output[i]);
output[i] = sigmoid_approx(WEIGHTS_SCALE * output[i]);
break;
case ACTIVATION_TANH:
for (i=0;i<N;i++)
output[i] = tansig_approx(output[i]);
output[i] = tansig_approx(WEIGHTS_SCALE * output[i]);
break;
default:
case ACTIVATION_RELU:
for (i=0;i<N;i++)
output[i] = relu(output[i]);
output[i] = relu(WEIGHTS_SCALE * output[i]);
break;
}
}
Expand All @@ -120,44 +123,49 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
M = gru->nb_inputs;
N = gru->nb_neurons;
stride = 3*N;
for (i=0;i<N;i++)
{
/* Compute update gate. */
float sum = gru->bias[i];
for (j=0;j<M;j++)
sum += gru->input_weights[j*stride + i]*input[j];
for (j=0;j<N;j++)
sum += gru->recurrent_weights[j*stride + i]*state[j];
z[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
}
for (i=0;i<N;i++)
{
/* Compute reset gate. */
float sum = gru->bias[N + i];
for (j=0;j<M;j++)
sum += gru->input_weights[N + j*stride + i]*input[j];
for (j=0;j<N;j++)
sum += gru->recurrent_weights[N + j*stride + i]*state[j];
r[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
const rnn_weight *ip = gru->input_weights;
const rnn_weight *rp = gru->recurrent_weights;
/* Compute update gate. */
for(i = 0; i < N; i++)
z[i] = gru->bias[i];
for (j=0;j<M;j++,ip+=stride)
faxpy(z, ip, N, input[j]);
for (j=0;j<N;j++,rp+=stride)
faxpy(z, rp, N, state[j]);
for(i = 0; i < N; i++)
z[i] = sigmoid_approx(WEIGHTS_SCALE*z[i]);
/* Compute reset gate. */
for(i = 0; i < N; i++)
r[i] = gru->bias[N+i];
ip = gru->input_weights + N;
rp = gru->recurrent_weights + N;
for (j=0;j<M;j++,ip+=stride)
faxpy(r, ip, N, input[j]);
for (j=0;j<N;j++,rp+=stride)
faxpy(r, rp, N, state[j]);
for(i = 0; i < N; i++)
r[i] = sigmoid_approx(WEIGHTS_SCALE*r[i]);

/* Compute output. */
for(i = 0; i < N; i++)
h[i] = gru->bias[2*N+i];
ip = gru->input_weights + 2*N;
rp = gru->recurrent_weights + 2*N;
for (j=0;j<M;j++,ip+=stride)
faxpy(h, ip, N, input[j]);
for (j=0;j<N;j++,rp+=stride)
faxpy(h, rp, N, r[j]*state[j]);
for (i=0;i<N;i++) {
switch (gru->activation) {
case ACTIVATION_SIGMOID: h[i] = sigmoid_approx(WEIGHTS_SCALE*h[i]);break;
case ACTIVATION_TANH: h[i] = tansig_approx(WEIGHTS_SCALE*h[i]); break;
default:
case ACTIVATION_RELU: h[i] = relu(WEIGHTS_SCALE*h[i]); break;
}
h[i] = z[i]*state[i] + (1-z[i])*h[i];
}
for (i=0;i<N;i++)
{
/* Compute output. */
float sum = gru->bias[2*N + i];
for (j=0;j<M;j++)
sum += gru->input_weights[2*N + j*stride + i]*input[j];
for (j=0;j<N;j++)
sum += gru->recurrent_weights[2*N + j*stride + i]*state[j]*r[j];
switch (gru->activation) {
case ACTIVATION_SIGMOID: sum = sigmoid_approx(WEIGHTS_SCALE*sum);break;
case ACTIVATION_TANH: sum = tansig_approx(WEIGHTS_SCALE*sum); break;
default:
case ACTIVATION_RELU: sum = relu(WEIGHTS_SCALE*sum); break;
}
h[i] = z[i]*state[i] + (1-z[i])*sum;
}
for (i=0;i<N;i++)
state[i] = h[i];
state[i] = h[i];
}

#define INPUT_SIZE 42
Expand Down

0 comments on commit d4c1c30

Please sign in to comment.