diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.c | 4 | ||||
-rw-r--r-- | src/nn.c | 123 | ||||
-rw-r--r-- | src/nn.h | 3 | ||||
-rw-r--r-- | src/util.c | 7 | ||||
-rw-r--r-- | src/util.h | 1 |
5 files changed, 132 insertions, 6 deletions
@@ -93,6 +93,7 @@ int main(int argc, char *argv[]) { .epochs = 100, .batch_size = 32, .alpha = 1e-5, + .shuffle = true, .config_filepath = "", .network_size = 0, .only_out = false, @@ -140,7 +141,8 @@ int main(int argc, char *argv[]) { load_loss(ml_configs), ml_configs.epochs, ml_configs.batch_size, - ml_configs.alpha); + ml_configs.alpha, + ml_configs.shuffle); nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { @@ -29,7 +29,12 @@ #include "util.h" #include "nn.h" +static void dataset_shuffle_rows( + double *inputs, size_t in_shape[2], + double *labels, size_t lbl_shape[2]); + static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); + static double get_avg_loss( double labels[], double outs[], size_t shape[2], double (*loss)(double *, double *, size_t)); @@ -72,7 +77,8 @@ void nn_network_train( double *input, size_t input_shape[2], double *labels, size_t labels_shape[2], struct Cost cost, size_t epochs, - size_t batch_size, double alpha) + size_t batch_size, double alpha, + bool shuffle) { assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); @@ -83,6 +89,13 @@ void nn_network_train( if (!outs || !zouts || !weights || !biases) goto nn_network_train_error; + double *input_random = calloc(input_shape[0] * input_shape[1], sizeof(double)); + double *labels_random = calloc(labels_shape[0] * labels_shape[1], sizeof(double)); + + if (!input_random || !labels_random) goto nn_network_train_error; + + memcpy(input_random, input, sizeof(double) * input_shape[0] * input_shape[1]); + memcpy(labels_random, labels, sizeof(double) * labels_shape[0] * labels_shape[1]); size_t samples = input_shape[0]; @@ -107,11 +120,17 @@ void nn_network_train( n_batches++; } for (size_t epoch = 0; epoch < epochs; epoch++) { + + if (shuffle) { + dataset_shuffle_rows(input_random, input_shape, labels_random, labels_shape); + } + + for (size_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { size_t index = batch_size * batch_idx; - double *input_batch = input + index * input_shape[1]; - double *labels_batch = labels + index * labels_shape[1]; + double *input_batch = input_random + index * input_shape[1]; + double *labels_batch = labels_random + index * labels_shape[1]; if (batch_idx == n_batches - 1 && samples % batch_size) { batch_input_shape[0] = samples % batch_size; @@ -454,6 +473,52 @@ nn_fill_random_weights_error: exit(1); } +void dataset_shuffle_rows( + double *inputs, size_t in_shape[2], + double *labels, size_t lbl_shape[2]) +{ + size_t random_row; + size_t in_index, lbl_index; + size_t shuffle_in_index, column_in_bytes; + size_t shuffle_lbl_index, column_lbl_bytes; + double *in_buffer, *lbl_buffer; + + in_buffer = malloc(sizeof(double) * in_shape[1]); + lbl_buffer = malloc(sizeof(double) * lbl_shape[1]); + + if (in_buffer == NULL || lbl_buffer == NULL) + goto dataset_shuffle_rows_error; + + column_in_bytes = sizeof(double) * in_shape[1]; + column_lbl_bytes = sizeof(double) * lbl_shape[1]; + for (size_t row = 0; row < in_shape[0]; row++) { + /* Swap actual row with a random row*/ + random_row = random() % in_shape[0]; + + /* Input Swap */ + in_index = row * in_shape[1]; + shuffle_in_index = random_row * in_shape[1]; + memcpy(in_buffer, inputs + in_index, column_in_bytes); + memcpy(inputs + in_index, inputs + shuffle_in_index, column_in_bytes); + memcpy(inputs + shuffle_in_index, in_buffer, column_in_bytes); + + /* Label Swap */ + lbl_index = row * lbl_shape[1]; + shuffle_lbl_index = random_row * lbl_shape[1]; + memcpy(lbl_buffer, labels + lbl_index, column_lbl_bytes); + memcpy(labels + lbl_index, labels + shuffle_lbl_index, column_lbl_bytes); + memcpy(labels + shuffle_lbl_index, lbl_buffer, column_lbl_bytes); + + } + + free(in_buffer); + free(lbl_buffer); + return; + +dataset_shuffle_rows_error: + die("dataset_shuffle_rows() malloc Error:"); +} + double square_loss(double labels[], double net_out[], size_t shape) { double sum = 0; @@ -478,3 +543,55 @@ double get_avg_loss( } return sum / shape[0]; } + +#ifdef NN_TEST +/* + * compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm + */ +int main(void) { + /* + * array_shuffle_rows() test + */ + srandom(42); + double input_array[12] = { + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 41, 42, 43, + }; + double shuffled_input_array[12] = { + 21, 22, 23, + 41, 42, 43, + 31, 32, 33, + 11, 12, 13, + }; + size_t in_shape[2] = {4,3}; + + double label_array[4] = {1, 2, 3, 4}; + double shuffled_label_array[4] = {2, 4, 3, 1}; + size_t lbl_shape[2] = {4,1}; + + dataset_shuffle_rows(input_array, in_shape, label_array, lbl_shape); + size_t i, j, index; + for (i = 0; i < in_shape[0]; i++) { + for (j = 0; j < in_shape[1]; j++) { + index = i * in_shape[1] + j; + if (input_array[index] != shuffled_input_array[index]) { + printf("- array_shuffle_rows() failure: input_array mismatch on (%zu,%zu)\n", i, j); + return 1; + } + } + + for (j = 0; j < lbl_shape[1]; j++) { + index = i * lbl_shape[1] + j; + if (label_array[index] != shuffled_label_array[index]) { + printf("- array_shuffle_rows() failure: label_array mismatch on (%zu,%zu)\n", i, j); + return 1; + } + } + } + printf("- array_shuffle_rows() success\n"); + + return 0; +} +#endif //NN_TEST @@ -53,7 +53,8 @@ void nn_network_train( double *input, size_t input_shape[2], double *labels, size_t labels_shape[2], struct Cost cost, size_t epochs, - size_t batch_size, double alpha); + size_t batch_size, double alpha, + bool shuffle); void nn_layer_map_activation( double (*activation)(double), @@ -105,6 +105,7 @@ void usage(int exit_code) " -o, --output=FILE Output file (only works with predict)\n" " -p, --precision=INT Decimals output precision (only works with predict)\n" " [default=auto]\n" + " -S, --no-shuffle Don't shuffle data each epoch (only works with train)\n" "\n" ); exit(exit_code); @@ -120,6 +121,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) {"epochs", required_argument, 0, 'e'}, {"batch", required_argument, 0, 'b'}, {"alpha", required_argument, 0, 'a'}, + {"no-shuffle", no_argument, 0, 'S'}, {"output", required_argument, 0, 'o'}, {"config", required_argument, 0, 'c'}, {"only-out", no_argument, 0, 'O'}, @@ -129,7 +131,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) int c; while (1) { - c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:b:", long_opts, NULL); + c = getopt_long(argc, argv, "hvOSc:e:a:o:i:f:p:b:", long_opts, NULL); if (c == -1) { break; @@ -160,6 +162,9 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) if (atoi(optarg) <= 0) die("util_load_cli() Error: batch size must be greater than 0"); ml->batch_size = (size_t)atol(optarg); break; + case 'S': + ml->shuffle = false; + break; case 'h': usage(0); break; @@ -14,6 +14,7 @@ struct Configs { size_t n_input_keys, n_label_keys; char *weights_filepath; char *config_filepath; + bool shuffle; /* cli cfgs */ char *file_format; char *in_filepath; |