diff --git a/Makefile.am b/Makefile.am index a004dc62..f2185608 100644 --- a/Makefile.am +++ b/Makefile.am @@ -22,6 +22,7 @@ librnnoise_la_SOURCES = \ src/denoise.c \ src/rnn.c \ src/rnn_data.c \ + src/rnn_reader.c \ src/pitch.c \ src/kiss_fft.c \ src/celt_lpc.c @@ -35,7 +36,7 @@ noinst_PROGRAMS = examples/rnnoise_demo endif examples_rnnoise_demo_SOURCES = examples/rnnoise_demo.c -examples_rnnoise_demo_LDADD = librnnoise.la +examples_rnnoise_demo_LDADD = librnnoise.la $(LIBM) pkgconfigdir = $(libdir)/pkgconfig pkgconfig_DATA = rnnoise.pc diff --git a/README b/README index 27b18fa9..03697801 100644 --- a/README +++ b/README @@ -1,4 +1,4 @@ -RNNoise is a noise suppression library based on a recurrent neural network +RNNoise is a noise suppression library based on a recurrent neural network. To compile, just type: % ./autogen.sh @@ -12,6 +12,6 @@ While it is meant to be used as a library, a simple command-line tool is provided as an example. It operates on RAW 16-bit (machine endian) mono PCM files sampled at 48 kHz. It can be used as: -./examples/rnnoise_demo input.pcm output.pcm +./examples/rnnoise_demo [model file] < input.raw > output.raw The output is also a 16-bit raw PCM file. diff --git a/TRAINING b/TRAINING new file mode 100644 index 00000000..86c5a4eb --- /dev/null +++ b/TRAINING @@ -0,0 +1,11 @@ +(1) cd src ; ./compile.sh + +(2) ./denoise_training signal.raw noise.raw count > training.f32 + + (note the matrix size and replace 500000 87 below) + +(3) cd training ; ./bin2hdf5.py ../src/training.f32 500000 87 training.h5 + +(4) ./rnn_train.py + +(5) ./dump_rnn.py weights.hdf5 ../src/rnn_data.c ../src/rnn_data.h diff --git a/examples/rnnoise_demo.c b/examples/rnnoise_demo.c index e1e239a2..95c2be45 100644 --- a/examples/rnnoise_demo.c +++ b/examples/rnnoise_demo.c @@ -1,4 +1,5 @@ -/* Copyright (c) 2017 Mozilla */ +/* Copyright (c) 2018 Gregor Richards + * Copyright (c) 2017 Mozilla */ /* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions @@ -24,36 +25,82 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#include #include #include "rnnoise.h" +#include +#include +#include "rnnoise.h" #define FRAME_SIZE 480 int main(int argc, char **argv) { - int i; + int i, ci; int first = 1; + int channels; float x[FRAME_SIZE]; - FILE *f1, *fout; - DenoiseState *st; - st = rnnoise_create(); - if (argc!=3) { - fprintf(stderr, "usage: %s \n", argv[0]); + short *tmp; + RNNModel *model = NULL; + DenoiseState **sts; + float max_attenuation; + if (argc < 3) { + fprintf(stderr, "usage: %s [model file]\n", argv[0]); + return 1; + } + + channels = atoi(argv[1]); + if (channels < 1) channels = 1; + max_attenuation = pow(10, -atof(argv[2])/10); + + if (argc >= 4) { + FILE *model_file = fopen(argv[3], "r"); + if (!model_file) { + perror(argv[3]); + return 1; + } + model = rnnoise_model_from_file(model_file); + fprintf(stderr, "\n\n\n%p\n\n\n", model); + if (!model) { + perror(argv[3]); + return 1; + } + fclose(model_file); + } + + sts = malloc(channels * sizeof(DenoiseState *)); + if (!sts) { + perror("malloc"); return 1; } - f1 = fopen(argv[1], "r"); - fout = fopen(argv[2], "w"); + tmp = malloc(channels * FRAME_SIZE * sizeof(short)); + if (!tmp) { + perror("malloc"); + return 1; + } + for (i = 0; i < channels; i++) { + sts[i] = rnnoise_create(model); + rnnoise_set_param(sts[i], RNNOISE_PARAM_MAX_ATTENUATION, max_attenuation); + } + while (1) { - short tmp[FRAME_SIZE]; - fread(tmp, sizeof(short), FRAME_SIZE, f1); - if (feof(f1)) break; - for (i=0;i + + #ifndef RNNOISE_EXPORT # if defined(WIN32) # if defined(RNNOISE_BUILD) && defined(DLL_EXPORT) @@ -38,15 +45,26 @@ # endif #endif - typedef struct DenoiseState DenoiseState; +typedef struct RNNModel RNNModel; RNNOISE_EXPORT int rnnoise_get_size(); -RNNOISE_EXPORT int rnnoise_init(DenoiseState *st); +RNNOISE_EXPORT int rnnoise_init(DenoiseState *st, RNNModel *model); -RNNOISE_EXPORT DenoiseState *rnnoise_create(); +RNNOISE_EXPORT DenoiseState *rnnoise_create(RNNModel *model); RNNOISE_EXPORT void rnnoise_destroy(DenoiseState *st); RNNOISE_EXPORT float rnnoise_process_frame(DenoiseState *st, float *out, const float *in); + +RNNOISE_EXPORT RNNModel *rnnoise_model_from_file(FILE *f); + +RNNOISE_EXPORT void rnnoise_model_free(RNNModel *model); + +/* Parameters to a denoise state */ +#define RNNOISE_PARAM_MAX_ATTENUATION 1 + +RNNOISE_EXPORT void rnnoise_set_param(DenoiseState *st, int param, float value); + +#endif diff --git a/src/denoise.c b/src/denoise.c index 128cd999..b1f03716 100644 --- a/src/denoise.c +++ b/src/denoise.c @@ -1,4 +1,5 @@ -/* Copyright (c) 2017 Mozilla */ +/* Copyright (c) 2018 Gregor Richards + * Copyright (c) 2017 Mozilla */ /* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions @@ -52,24 +53,26 @@ #define SQUARE(x) ((x)*(x)) -#define SMOOTH_BANDS 1 - -#if SMOOTH_BANDS #define NB_BANDS 22 -#else -#define NB_BANDS 21 -#endif #define CEPS_MEM 8 #define NB_DELTA_CEPS 6 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2) +/* We don't allow max attenuation to be more than 60dB */ +#define MIN_MAX_ATTENUATION 0.000001f + #ifndef TRAINING #define TRAINING 0 #endif + +/* The built-in model, used if no file is given as input */ +extern const struct RNNModel rnnoise_model_orig; + + static const opus_int16 eband5ms[] = { /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/ 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100 @@ -95,9 +98,10 @@ struct DenoiseState { float mem_hp_x[2]; float lastg[NB_BANDS]; RNNState rnn; + + float max_attenuation; }; -#if SMOOTH_BANDS void compute_band_energy(float *bandE, const kiss_fft_cpx *X) { int i; float sum[NB_BANDS] = {0}; @@ -162,32 +166,6 @@ void interp_band_gain(float *g, const float *bandE) { } } } -#else -void compute_band_energy(float *bandE, const kiss_fft_cpx *X) { - int i; - for (i=0;irnn.model = model; + else + st->rnn.model = &rnnoise_model_orig; + st->rnn.vad_gru_state = calloc(sizeof(float), st->rnn.model->vad_gru_size); + st->rnn.noise_gru_state = calloc(sizeof(float), st->rnn.model->noise_gru_size); + st->rnn.denoise_gru_state = calloc(sizeof(float), st->rnn.model->denoise_gru_size); return 0; } -DenoiseState *rnnoise_create() { +DenoiseState *rnnoise_create(RNNModel *model) { DenoiseState *st; st = malloc(rnnoise_get_size()); - rnnoise_init(st); + rnnoise_init(st, model); return st; } void rnnoise_destroy(DenoiseState *st) { + free(st->rnn.vad_gru_state); + free(st->rnn.noise_gru_state); + free(st->rnn.denoise_gru_state); free(st); } @@ -493,6 +481,25 @@ float rnnoise_process_frame(DenoiseState *st, float *out, const float *in) { g[i] = MAX16(g[i], alpha*st->lastg[i]); st->lastg[i] = g[i]; } + + /* Apply maximum attenuation (minimum value) */ + if (st->max_attenuation) { + float min = 1, mult; + for (i=0;imax_attenuation) { + if (min < MIN_MAX_ATTENUATION) + min = MIN_MAX_ATTENUATION; + mult = (1.0f-st->max_attenuation) / (1.0f-min); + for (i=0;ilastg[i] = g[i]; + } + } + } + interp_band_gain(gf, g); #if 1 for (i=0;i MIN_MAX_ATTENUATION && value <= 1) || value == 0) + st->max_attenuation = value; + else + st->max_attenuation = MIN_MAX_ATTENUATION; + break; + } +} + #if TRAINING static float uni_rand() { @@ -538,20 +557,21 @@ int main(int argc, char **argv) { int vad_cnt=0; int gain_change_count=0; float speech_gain = 1, noise_gain = 1; - FILE *f1, *f2, *fout; + FILE *f1, *f2; + int maxCount; DenoiseState *st; DenoiseState *noise_state; DenoiseState *noisy; - st = rnnoise_create(); - noise_state = rnnoise_create(); - noisy = rnnoise_create(); + st = rnnoise_create(NULL); + noise_state = rnnoise_create(NULL); + noisy = rnnoise_create(NULL); if (argc!=4) { - fprintf(stderr, "usage: %s \n", argv[0]); + fprintf(stderr, "usage: %s \n", argv[0]); return 1; } f1 = fopen(argv[1], "r"); f2 = fopen(argv[2], "r"); - fout = fopen(argv[3], "w"); + maxCount = atoi(argv[3]); for(i=0;i<150;i++) { short tmp[FRAME_SIZE]; fread(tmp, sizeof(short), FRAME_SIZE, f2); @@ -563,12 +583,11 @@ int main(int argc, char **argv) { float Ln[NB_BANDS]; float features[NB_FEATURES]; float g[NB_BANDS]; - float gf[FREQ_SIZE]={1}; short tmp[FRAME_SIZE]; float vad=0; - float vad_prob; float E=0; - if (count==50000000) break; + if (count==maxCount) break; + if ((count%1000)==0) fprintf(stderr, "%d\r", count); if (++gain_change_count > 2821) { speech_gain = pow(10., (-40+(rand()%60))/20.); noise_gain = pow(10., (-30+(rand()%50))/20.); @@ -643,37 +662,16 @@ int main(int argc, char **argv) { if (vad==0 && noise_gain==0) g[i] = -1; } count++; -#if 0 - for (i=0;irnn, g, &vad_prob, features); - interp_band_gain(gf, g); -#if 1 - for (i=0;ivad_gru_state, dense_out); - compute_dense(&vad_output, vad, rnn->vad_gru_state); - for (i=0;ivad_gru_state[i]; - for (i=0;inoise_gru_state, noise_input); + compute_dense(rnn->model->input_dense, dense_out, input); + compute_gru(rnn->model->vad_gru, rnn->vad_gru_state, dense_out); + compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state); + for (i=0;imodel->input_dense_size;i++) noise_input[i] = dense_out[i]; + for (i=0;imodel->vad_gru_size;i++) noise_input[i+rnn->model->input_dense_size] = rnn->vad_gru_state[i]; + for (i=0;imodel->input_dense_size+rnn->model->vad_gru_size] = input[i]; + compute_gru(rnn->model->noise_gru, rnn->noise_gru_state, noise_input); - for (i=0;ivad_gru_state[i]; - for (i=0;inoise_gru_state[i]; - for (i=0;idenoise_gru_state, denoise_input); - compute_dense(&denoise_output, gains, rnn->denoise_gru_state); + for (i=0;imodel->vad_gru_size;i++) denoise_input[i] = rnn->vad_gru_state[i]; + for (i=0;imodel->noise_gru_size;i++) denoise_input[i+rnn->model->vad_gru_size] = rnn->noise_gru_state[i]; + for (i=0;imodel->vad_gru_size+rnn->model->noise_gru_size] = input[i]; + compute_gru(rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input); + compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state); } diff --git a/src/rnn.h b/src/rnn.h index 9e08b44a..10329f55 100644 --- a/src/rnn.h +++ b/src/rnn.h @@ -27,6 +27,8 @@ #ifndef RNN_H_ #define RNN_H_ +#include "rnnoise.h" + #include "opus_types.h" #define WEIGHTS_SCALE (1.f/256) diff --git a/src/rnn_data.c b/src/rnn_data.c index 8f6c99bb..22c53165 100644 --- a/src/rnn_data.c +++ b/src/rnn_data.c @@ -5,6 +5,7 @@ #endif #include "rnn.h" +#include "rnn_data.h" static const rnn_weight input_dense_weights[1008] = { -10, 0, -3, 1, -8, -6, 3, -13, @@ -141,7 +142,7 @@ static const rnn_weight input_dense_bias[24] = { -126, 28, 127, 125, -30, 127, -89, -20 }; -const DenseLayer input_dense = { +static const DenseLayer input_dense = { input_dense_bias, input_dense_weights, 42, 24, ACTIVATION_TANH @@ -597,7 +598,7 @@ static const rnn_weight vad_gru_bias[72] = { -29, 127, 34, -66, 49, 53, 27, 62 }; -const GRULayer vad_gru = { +static const GRULayer vad_gru = { vad_gru_bias, vad_gru_weights, vad_gru_recurrent_weights, @@ -3115,7 +3116,7 @@ static const rnn_weight noise_gru_bias[144] = { -23, -64, 31, 86, -50, 2, -38, 7 }; -const GRULayer noise_gru = { +static const GRULayer noise_gru = { noise_gru_bias, noise_gru_weights, noise_gru_recurrent_weights, @@ -10727,7 +10728,7 @@ static const rnn_weight denoise_gru_bias[288] = { -21, 25, 18, -58, 25, 126, -84, 127 }; -const GRULayer denoise_gru = { +static const GRULayer denoise_gru = { denoise_gru_bias, denoise_gru_weights, denoise_gru_recurrent_weights, @@ -11007,7 +11008,7 @@ static const rnn_weight denoise_output_bias[22] = { -126, -105, -53, -49, -18, -9 }; -const DenseLayer denoise_output = { +static const DenseLayer denoise_output = { denoise_output_bias, denoise_output_weights, 96, 22, ACTIVATION_SIGMOID @@ -11023,9 +11024,28 @@ static const rnn_weight vad_output_bias[1] = { -50 }; -const DenseLayer vad_output = { +static const DenseLayer vad_output = { vad_output_bias, vad_output_weights, 24, 1, ACTIVATION_SIGMOID }; +const struct RNNModel rnnoise_model_orig = { + 24, + &input_dense, + + 24, + &vad_gru, + + 48, + &noise_gru, + + 96, + &denoise_gru, + + 22, + &denoise_output, + + 1, + &vad_output +}; diff --git a/src/rnn_data.h b/src/rnn_data.h index 56109804..f2186fe0 100644 --- a/src/rnn_data.h +++ b/src/rnn_data.h @@ -1,32 +1,33 @@ -/*This file is automatically generated from a Keras model*/ - #ifndef RNN_DATA_H #define RNN_DATA_H #include "rnn.h" -#define INPUT_DENSE_SIZE 24 -extern const DenseLayer input_dense; +struct RNNModel { + int input_dense_size; + const DenseLayer *input_dense; -#define VAD_GRU_SIZE 24 -extern const GRULayer vad_gru; + int vad_gru_size; + const GRULayer *vad_gru; -#define NOISE_GRU_SIZE 48 -extern const GRULayer noise_gru; + int noise_gru_size; + const GRULayer *noise_gru; -#define DENOISE_GRU_SIZE 96 -extern const GRULayer denoise_gru; + int denoise_gru_size; + const GRULayer *denoise_gru; -#define DENOISE_OUTPUT_SIZE 22 -extern const DenseLayer denoise_output; + int denoise_output_size; + const DenseLayer *denoise_output; -#define VAD_OUTPUT_SIZE 1 -extern const DenseLayer vad_output; + int vad_output_size; + const DenseLayer *vad_output; +}; struct RNNState { - float vad_gru_state[VAD_GRU_SIZE]; - float noise_gru_state[NOISE_GRU_SIZE]; - float denoise_gru_state[DENOISE_GRU_SIZE]; + const RNNModel *model; + float *vad_gru_state; + float *noise_gru_state; + float *denoise_gru_state; }; diff --git a/src/rnn_reader.c b/src/rnn_reader.c new file mode 100644 index 00000000..2a031db1 --- /dev/null +++ b/src/rnn_reader.c @@ -0,0 +1,168 @@ +/* Copyright (c) 2018 Gregor Richards */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include +#include +#include + +#include "rnn.h" +#include "rnn_data.h" +#include "rnnoise.h" + +/* Although these values are the same as in rnn.h, we make them separate to + * avoid accidentally burning internal values into a file format */ +#define F_ACTIVATION_TANH 0 +#define F_ACTIVATION_SIGMOID 1 +#define F_ACTIVATION_RELU 2 + +RNNModel *rnnoise_model_from_file(FILE *f) +{ + int i, in; + + if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1) + return NULL; + + RNNModel *ret = calloc(1, sizeof(RNNModel)); + if (!ret) + return NULL; + +#define ALLOC_LAYER(type, name) \ + type *name; \ + name = calloc(1, sizeof(type)); \ + if (!name) { \ + rnnoise_model_free(ret); \ + return NULL; \ + } \ + ret->name = name + + ALLOC_LAYER(DenseLayer, input_dense); + ALLOC_LAYER(GRULayer, vad_gru); + ALLOC_LAYER(GRULayer, noise_gru); + ALLOC_LAYER(GRULayer, denoise_gru); + ALLOC_LAYER(DenseLayer, denoise_output); + ALLOC_LAYER(DenseLayer, vad_output); + +#define INPUT_VAL(name) do { \ + if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \ + rnnoise_model_free(ret); \ + return NULL; \ + } \ + name = in; \ + } while (0) + +#define INPUT_ACTIVATION(name) do { \ + int activation; \ + INPUT_VAL(activation); \ + switch (activation) { \ + case F_ACTIVATION_SIGMOID: \ + name = ACTIVATION_SIGMOID; \ + break; \ + case F_ACTIVATION_RELU: \ + name = ACTIVATION_RELU; \ + break; \ + default: \ + name = ACTIVATION_TANH; \ + } \ + } while (0) + +#define INPUT_ARRAY(name, len) do { \ + rnn_weight *values = malloc((len) * sizeof(rnn_weight)); \ + if (!values) { \ + rnnoise_model_free(ret); \ + return NULL; \ + } \ + name = values; \ + for (i = 0; i < (len); i++) { \ + if (fscanf(f, "%d", &in) != 1) { \ + rnnoise_model_free(ret); \ + return NULL; \ + } \ + values[i] = in; \ + } \ + } while (0) + +#define INPUT_DENSE(name) do { \ + INPUT_VAL(name->nb_inputs); \ + INPUT_VAL(name->nb_neurons); \ + ret->name ## _size = name->nb_neurons; \ + INPUT_ACTIVATION(name->activation); \ + INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \ + INPUT_ARRAY(name->bias, name->nb_neurons); \ + } while (0) + +#define INPUT_GRU(name) do { \ + INPUT_VAL(name->nb_inputs); \ + INPUT_VAL(name->nb_neurons); \ + ret->name ## _size = name->nb_neurons; \ + INPUT_ACTIVATION(name->activation); \ + INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons * 3); \ + INPUT_ARRAY(name->recurrent_weights, name->nb_neurons * name->nb_neurons * 3); \ + INPUT_ARRAY(name->bias, name->nb_neurons * 3); \ + } while (0) + + INPUT_DENSE(input_dense); + INPUT_GRU(vad_gru); + INPUT_GRU(noise_gru); + INPUT_GRU(denoise_gru); + INPUT_DENSE(denoise_output); + INPUT_DENSE(vad_output); + + return ret; +} + +void rnnoise_model_free(RNNModel *model) +{ +#define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0) +#define FREE_DENSE(name) do { \ + if (model->name) { \ + free((void *) model->name->input_weights); \ + free((void *) model->name->bias); \ + free((void *) model->name); \ + } \ + } while (0) +#define FREE_GRU(name) do { \ + if (model->name) { \ + free((void *) model->name->input_weights); \ + free((void *) model->name->recurrent_weights); \ + free((void *) model->name->bias); \ + free((void *) model->name); \ + } \ + } while (0) + + if (!model) + return; + FREE_DENSE(input_dense); + FREE_GRU(vad_gru); + FREE_GRU(noise_gru); + FREE_GRU(denoise_gru); + FREE_DENSE(denoise_output); + FREE_DENSE(vad_output); + free(model); +} diff --git a/training/dump_rnn.py b/training/dump_rnn.py index 9f267a7e..2f04359d 100755 --- a/training/dump_rnn.py +++ b/training/dump_rnn.py @@ -12,46 +12,64 @@ import re import numpy as np -def printVector(f, vector, name): +def printVector(f, ft, vector, name): v = np.reshape(vector, (-1)); #print('static const float ', name, '[', len(v), '] = \n', file=f) f.write('static const rnn_weight {}[{}] = {{\n '.format(name, len(v))) for i in range(0, len(v)): f.write('{}'.format(min(127, int(round(256*v[i]))))) + ft.write('{}'.format(min(127, int(round(256*v[i]))))) if (i!=len(v)-1): f.write(',') else: break; + ft.write(" ") if (i%8==7): f.write("\n ") else: f.write(" ") #print(v, file=f) f.write('\n};\n\n') + ft.write("\n") return; -def printLayer(f, hf, layer): +def printLayer(f, ft, layer): weights = layer.get_weights() - printVector(f, weights[0], layer.name + '_weights') + activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() + if len(weights) > 2: + ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]/3)) + else: + ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1])) + if activation == 'SIGMOID': + ft.write('1\n') + elif activation == 'RELU': + ft.write('2\n') + else: + ft.write('0\n') + printVector(f, ft, weights[0], layer.name + '_weights') if len(weights) > 2: - printVector(f, weights[1], layer.name + '_recurrent_weights') - printVector(f, weights[-1], layer.name + '_bias') + printVector(f, ft, weights[1], layer.name + '_recurrent_weights') + printVector(f, ft, weights[-1], layer.name + '_bias') name = layer.name - activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() if len(weights) > 2: - f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' + f.write('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation)) - hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1]/3)) - hf.write('extern const GRULayer {};\n\n'.format(name)); else: - f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' + f.write('static const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation)) - hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1])) - hf.write('extern const DenseLayer {};\n\n'.format(name)); + +def structLayer(f, layer): + weights = layer.get_weights() + name = layer.name + if len(weights) > 2: + f.write(' {},\n'.format(weights[0].shape[1]/3)) + else: + f.write(' {},\n'.format(weights[0].shape[1])) + f.write(' &{},\n'.format(name)) def foo(c, name): - return 1 + return None def mean_squared_sqrt_error(y_true, y_pred): return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) @@ -62,27 +80,28 @@ def mean_squared_sqrt_error(y_true, y_pred): weights = model.get_weights() f = open(sys.argv[2], 'w') -hf = open(sys.argv[3], 'w') +ft = open(sys.argv[3], 'w') f.write('/*This file is automatically generated from a Keras model*/\n\n') -f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n\n') - -hf.write('/*This file is automatically generated from a Keras model*/\n\n') -hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "rnn.h"\n\n') +f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n#include "rnn_data.h"\n\n') +ft.write('rnnoise-nu model file version 1\n') layer_list = [] for i, layer in enumerate(model.layers): if len(layer.get_weights()) > 0: - printLayer(f, hf, layer) + printLayer(f, ft, layer) if len(layer.get_weights()) > 2: layer_list.append(layer.name) -hf.write('struct RNNState {\n') -for i, name in enumerate(layer_list): - hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) -hf.write('};\n') +f.write('const struct RNNModel rnnoise_model_{} = {{\n'.format(sys.argv[4])) +for i, layer in enumerate(model.layers): + if len(layer.get_weights()) > 0: + structLayer(f, layer) +f.write('};\n') -hf.write('\n\n#endif\n') +#hf.write('struct RNNState {\n') +#for i, name in enumerate(layer_list): +# hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) +#hf.write('};\n') f.close() -hf.close() diff --git a/training/rnn_train.py b/training/rnn_train.py index bb53f89b..06d7e1a4 100755 --- a/training/rnn_train.py +++ b/training/rnn_train.py @@ -82,7 +82,7 @@ def get_config(self): batch_size = 32 print('Loading data...') -with h5py.File('denoise_data9.h5', 'r') as hf: +with h5py.File('training.h5', 'r') as hf: all_data = hf['data'][:] print('done.') @@ -113,4 +113,4 @@ def get_config(self): batch_size=batch_size, epochs=120, validation_split=0.1) -model.save("newweights9i.hdf5") +model.save("weights.hdf5")