diff options
Diffstat (limited to 'src/nn.c')
-rw-r--r-- | src/nn.c | 123 |
1 files changed, 120 insertions, 3 deletions
@@ -29,7 +29,12 @@ #include "util.h" #include "nn.h" +static void dataset_shuffle_rows( + double *inputs, size_t in_shape[2], + double *labels, size_t lbl_shape[2]); + static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); + static double get_avg_loss( double labels[], double outs[], size_t shape[2], double (*loss)(double *, double *, size_t)); @@ -72,7 +77,8 @@ void nn_network_train( double *input, size_t input_shape[2], double *labels, size_t labels_shape[2], struct Cost cost, size_t epochs, - size_t batch_size, double alpha) + size_t batch_size, double alpha, + bool shuffle) { assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); @@ -83,6 +89,13 @@ void nn_network_train( if (!outs || !zouts || !weights || !biases) goto nn_network_train_error; + double *input_random = calloc(input_shape[0] * input_shape[1], sizeof(double)); + double *labels_random = calloc(labels_shape[0] * labels_shape[1], sizeof(double)); + + if (!input_random || !labels_random) goto nn_network_train_error; + + memcpy(input_random, input, sizeof(double) * input_shape[0] * input_shape[1]); + memcpy(labels_random, labels, sizeof(double) * labels_shape[0] * labels_shape[1]); size_t samples = input_shape[0]; @@ -107,11 +120,17 @@ void nn_network_train( n_batches++; } for (size_t epoch = 0; epoch < epochs; epoch++) { + + if (shuffle) { + dataset_shuffle_rows(input_random, input_shape, labels_random, labels_shape); + } + + for (size_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { size_t index = batch_size * batch_idx; - double *input_batch = input + index * input_shape[1]; - double *labels_batch = labels + index * labels_shape[1]; + double *input_batch = input_random + index * input_shape[1]; + double *labels_batch = labels_random + index * labels_shape[1]; if (batch_idx == n_batches - 1 && samples % batch_size) { batch_input_shape[0] = samples % batch_size; @@ -454,6 +473,52 @@ nn_fill_random_weights_error: exit(1); } +void dataset_shuffle_rows( + double *inputs, size_t in_shape[2], + double *labels, size_t lbl_shape[2]) +{ + size_t random_row; + size_t in_index, lbl_index; + size_t shuffle_in_index, column_in_bytes; + size_t shuffle_lbl_index, column_lbl_bytes; + double *in_buffer, *lbl_buffer; + + in_buffer = malloc(sizeof(double) * in_shape[1]); + lbl_buffer = malloc(sizeof(double) * lbl_shape[1]); + + if (in_buffer == NULL || lbl_buffer == NULL) + goto dataset_shuffle_rows_error; + + column_in_bytes = sizeof(double) * in_shape[1]; + column_lbl_bytes = sizeof(double) * lbl_shape[1]; + for (size_t row = 0; row < in_shape[0]; row++) { + /* Swap actual row with a random row*/ + random_row = random() % in_shape[0]; + + /* Input Swap */ + in_index = row * in_shape[1]; + shuffle_in_index = random_row * in_shape[1]; + memcpy(in_buffer, inputs + in_index, column_in_bytes); + memcpy(inputs + in_index, inputs + shuffle_in_index, column_in_bytes); + memcpy(inputs + shuffle_in_index, in_buffer, column_in_bytes); + + /* Label Swap */ + lbl_index = row * lbl_shape[1]; + shuffle_lbl_index = random_row * lbl_shape[1]; + memcpy(lbl_buffer, labels + lbl_index, column_lbl_bytes); + memcpy(labels + lbl_index, labels + shuffle_lbl_index, column_lbl_bytes); + memcpy(labels + shuffle_lbl_index, lbl_buffer, column_lbl_bytes); + + } + + free(in_buffer); + free(lbl_buffer); + return; + +dataset_shuffle_rows_error: + die("dataset_shuffle_rows() malloc Error:"); +} + double square_loss(double labels[], double net_out[], size_t shape) { double sum = 0; @@ -478,3 +543,55 @@ double get_avg_loss( } return sum / shape[0]; } + +#ifdef NN_TEST +/* + * compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm + */ +int main(void) { + /* + * array_shuffle_rows() test + */ + srandom(42); + double input_array[12] = { + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 41, 42, 43, + }; + double shuffled_input_array[12] = { + 21, 22, 23, + 41, 42, 43, + 31, 32, 33, + 11, 12, 13, + }; + size_t in_shape[2] = {4,3}; + + double label_array[4] = {1, 2, 3, 4}; + double shuffled_label_array[4] = {2, 4, 3, 1}; + size_t lbl_shape[2] = {4,1}; + + dataset_shuffle_rows(input_array, in_shape, label_array, lbl_shape); + size_t i, j, index; + for (i = 0; i < in_shape[0]; i++) { + for (j = 0; j < in_shape[1]; j++) { + index = i * in_shape[1] + j; + if (input_array[index] != shuffled_input_array[index]) { + printf("- array_shuffle_rows() failure: input_array mismatch on (%zu,%zu)\n", i, j); + return 1; + } + } + + for (j = 0; j < lbl_shape[1]; j++) { + index = i * lbl_shape[1] + j; + if (label_array[index] != shuffled_label_array[index]) { + printf("- array_shuffle_rows() failure: label_array mismatch on (%zu,%zu)\n", i, j); + return 1; + } + } + } + printf("- array_shuffle_rows() success\n"); + + return 0; +} +#endif //NN_TEST |