From f39f6d5b0a907d519377e70876b32daad1a676f2 Mon Sep 17 00:00:00 2001 From: jvech Date: Tue, 3 Sep 2024 20:08:25 -0500 Subject: feat: shuffle dataset on each epoch done The CLI option to disable it was also added. --- src/main.c | 4 +- src/nn.c | 123 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- src/nn.h | 3 +- src/util.c | 7 +++- src/util.h | 1 + 5 files changed, 132 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/main.c b/src/main.c index 3dc1b5f..848d638 100644 --- a/src/main.c +++ b/src/main.c @@ -93,6 +93,7 @@ int main(int argc, char *argv[]) { .epochs = 100, .batch_size = 32, .alpha = 1e-5, + .shuffle = true, .config_filepath = "", .network_size = 0, .only_out = false, @@ -140,7 +141,8 @@ int main(int argc, char *argv[]) { load_loss(ml_configs), ml_configs.epochs, ml_configs.batch_size, - ml_configs.alpha); + ml_configs.alpha, + ml_configs.shuffle); nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { diff --git a/src/nn.c b/src/nn.c index 56c35fc..867819c 100644 --- a/src/nn.c +++ b/src/nn.c @@ -29,7 +29,12 @@ #include "util.h" #include "nn.h" +static void dataset_shuffle_rows( + double *inputs, size_t in_shape[2], + double *labels, size_t lbl_shape[2]); + static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); + static double get_avg_loss( double labels[], double outs[], size_t shape[2], double (*loss)(double *, double *, size_t)); @@ -72,7 +77,8 @@ void nn_network_train( double *input, size_t input_shape[2], double *labels, size_t labels_shape[2], struct Cost cost, size_t epochs, - size_t batch_size, double alpha) + size_t batch_size, double alpha, + bool shuffle) { assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); @@ -83,6 +89,13 @@ void nn_network_train( if (!outs || !zouts || !weights || !biases) goto nn_network_train_error; + double *input_random = calloc(input_shape[0] * input_shape[1], sizeof(double)); + double *labels_random = calloc(labels_shape[0] * labels_shape[1], sizeof(double)); + + if (!input_random || !labels_random) goto nn_network_train_error; + + memcpy(input_random, input, sizeof(double) * input_shape[0] * input_shape[1]); + memcpy(labels_random, labels, sizeof(double) * labels_shape[0] * labels_shape[1]); size_t samples = input_shape[0]; @@ -107,11 +120,17 @@ void nn_network_train( n_batches++; } for (size_t epoch = 0; epoch < epochs; epoch++) { + + if (shuffle) { + dataset_shuffle_rows(input_random, input_shape, labels_random, labels_shape); + } + + for (size_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { size_t index = batch_size * batch_idx; - double *input_batch = input + index * input_shape[1]; - double *labels_batch = labels + index * labels_shape[1]; + double *input_batch = input_random + index * input_shape[1]; + double *labels_batch = labels_random + index * labels_shape[1]; if (batch_idx == n_batches - 1 && samples % batch_size) { batch_input_shape[0] = samples % batch_size; @@ -454,6 +473,52 @@ nn_fill_random_weights_error: exit(1); } +void dataset_shuffle_rows( + double *inputs, size_t in_shape[2], + double *labels, size_t lbl_shape[2]) +{ + size_t random_row; + size_t in_index, lbl_index; + size_t shuffle_in_index, column_in_bytes; + size_t shuffle_lbl_index, column_lbl_bytes; + double *in_buffer, *lbl_buffer; + + in_buffer = malloc(sizeof(double) * in_shape[1]); + lbl_buffer = malloc(sizeof(double) * lbl_shape[1]); + + if (in_buffer == NULL || lbl_buffer == NULL) + goto dataset_shuffle_rows_error; + + column_in_bytes = sizeof(double) * in_shape[1]; + column_lbl_bytes = sizeof(double) * lbl_shape[1]; + for (size_t row = 0; row < in_shape[0]; row++) { + /* Swap actual row with a random row*/ + random_row = random() % in_shape[0]; + + /* Input Swap */ + in_index = row * in_shape[1]; + shuffle_in_index = random_row * in_shape[1]; + memcpy(in_buffer, inputs + in_index, column_in_bytes); + memcpy(inputs + in_index, inputs + shuffle_in_index, column_in_bytes); + memcpy(inputs + shuffle_in_index, in_buffer, column_in_bytes); + + /* Label Swap */ + lbl_index = row * lbl_shape[1]; + shuffle_lbl_index = random_row * lbl_shape[1]; + memcpy(lbl_buffer, labels + lbl_index, column_lbl_bytes); + memcpy(labels + lbl_index, labels + shuffle_lbl_index, column_lbl_bytes); + memcpy(labels + shuffle_lbl_index, lbl_buffer, column_lbl_bytes); + + } + + free(in_buffer); + free(lbl_buffer); + return; + +dataset_shuffle_rows_error: + die("dataset_shuffle_rows() malloc Error:"); +} + double square_loss(double labels[], double net_out[], size_t shape) { double sum = 0; @@ -478,3 +543,55 @@ double get_avg_loss( } return sum / shape[0]; } + +#ifdef NN_TEST +/* + * compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm + */ +int main(void) { + /* + * array_shuffle_rows() test + */ + srandom(42); + double input_array[12] = { + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 41, 42, 43, + }; + double shuffled_input_array[12] = { + 21, 22, 23, + 41, 42, 43, + 31, 32, 33, + 11, 12, 13, + }; + size_t in_shape[2] = {4,3}; + + double label_array[4] = {1, 2, 3, 4}; + double shuffled_label_array[4] = {2, 4, 3, 1}; + size_t lbl_shape[2] = {4,1}; + + dataset_shuffle_rows(input_array, in_shape, label_array, lbl_shape); + size_t i, j, index; + for (i = 0; i < in_shape[0]; i++) { + for (j = 0; j < in_shape[1]; j++) { + index = i * in_shape[1] + j; + if (input_array[index] != shuffled_input_array[index]) { + printf("- array_shuffle_rows() failure: input_array mismatch on (%zu,%zu)\n", i, j); + return 1; + } + } + + for (j = 0; j < lbl_shape[1]; j++) { + index = i * lbl_shape[1] + j; + if (label_array[index] != shuffled_label_array[index]) { + printf("- array_shuffle_rows() failure: label_array mismatch on (%zu,%zu)\n", i, j); + return 1; + } + } + } + printf("- array_shuffle_rows() success\n"); + + return 0; +} +#endif //NN_TEST diff --git a/src/nn.h b/src/nn.h index 5dbb656..9f8e2a5 100644 --- a/src/nn.h +++ b/src/nn.h @@ -53,7 +53,8 @@ void nn_network_train( double *input, size_t input_shape[2], double *labels, size_t labels_shape[2], struct Cost cost, size_t epochs, - size_t batch_size, double alpha); + size_t batch_size, double alpha, + bool shuffle); void nn_layer_map_activation( double (*activation)(double), diff --git a/src/util.c b/src/util.c index 71cbf8f..81950f1 100644 --- a/src/util.c +++ b/src/util.c @@ -105,6 +105,7 @@ void usage(int exit_code) " -o, --output=FILE Output file (only works with predict)\n" " -p, --precision=INT Decimals output precision (only works with predict)\n" " [default=auto]\n" + " -S, --no-shuffle Don't shuffle data each epoch (only works with train)\n" "\n" ); exit(exit_code); @@ -120,6 +121,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) {"epochs", required_argument, 0, 'e'}, {"batch", required_argument, 0, 'b'}, {"alpha", required_argument, 0, 'a'}, + {"no-shuffle", no_argument, 0, 'S'}, {"output", required_argument, 0, 'o'}, {"config", required_argument, 0, 'c'}, {"only-out", no_argument, 0, 'O'}, @@ -129,7 +131,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) int c; while (1) { - c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:b:", long_opts, NULL); + c = getopt_long(argc, argv, "hvOSc:e:a:o:i:f:p:b:", long_opts, NULL); if (c == -1) { break; @@ -160,6 +162,9 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) if (atoi(optarg) <= 0) die("util_load_cli() Error: batch size must be greater than 0"); ml->batch_size = (size_t)atol(optarg); break; + case 'S': + ml->shuffle = false; + break; case 'h': usage(0); break; diff --git a/src/util.h b/src/util.h index 10aa4aa..69ebfbe 100644 --- a/src/util.h +++ b/src/util.h @@ -14,6 +14,7 @@ struct Configs { size_t n_input_keys, n_label_keys; char *weights_filepath; char *config_filepath; + bool shuffle; /* cli cfgs */ char *file_format; char *in_filepath; -- cgit v1.2.3-70-g09d2