aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main.c4
-rw-r--r--src/nn.c123
-rw-r--r--src/nn.h3
-rw-r--r--src/util.c7
-rw-r--r--src/util.h1
5 files changed, 132 insertions, 6 deletions
diff --git a/src/main.c b/src/main.c
index 3dc1b5f..848d638 100644
--- a/src/main.c
+++ b/src/main.c
@@ -93,6 +93,7 @@ int main(int argc, char *argv[]) {
.epochs = 100,
.batch_size = 32,
.alpha = 1e-5,
+ .shuffle = true,
.config_filepath = "",
.network_size = 0,
.only_out = false,
@@ -140,7 +141,8 @@ int main(int argc, char *argv[]) {
load_loss(ml_configs),
ml_configs.epochs,
ml_configs.batch_size,
- ml_configs.alpha);
+ ml_configs.alpha,
+ ml_configs.shuffle);
nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath);
} else if (!strcmp("predict", argv[0])) {
diff --git a/src/nn.c b/src/nn.c
index 56c35fc..867819c 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -29,7 +29,12 @@
#include "util.h"
#include "nn.h"
+static void dataset_shuffle_rows(
+ double *inputs, size_t in_shape[2],
+ double *labels, size_t lbl_shape[2]);
+
static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols);
+
static double get_avg_loss(
double labels[], double outs[], size_t shape[2],
double (*loss)(double *, double *, size_t));
@@ -72,7 +77,8 @@ void nn_network_train(
double *input, size_t input_shape[2],
double *labels, size_t labels_shape[2],
struct Cost cost, size_t epochs,
- size_t batch_size, double alpha)
+ size_t batch_size, double alpha,
+ bool shuffle)
{
assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n");
@@ -83,6 +89,13 @@ void nn_network_train(
if (!outs || !zouts || !weights || !biases) goto nn_network_train_error;
+ double *input_random = calloc(input_shape[0] * input_shape[1], sizeof(double));
+ double *labels_random = calloc(labels_shape[0] * labels_shape[1], sizeof(double));
+
+ if (!input_random || !labels_random) goto nn_network_train_error;
+
+ memcpy(input_random, input, sizeof(double) * input_shape[0] * input_shape[1]);
+ memcpy(labels_random, labels, sizeof(double) * labels_shape[0] * labels_shape[1]);
size_t samples = input_shape[0];
@@ -107,11 +120,17 @@ void nn_network_train(
n_batches++;
}
for (size_t epoch = 0; epoch < epochs; epoch++) {
+
+ if (shuffle) {
+ dataset_shuffle_rows(input_random, input_shape, labels_random, labels_shape);
+ }
+
+
for (size_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
size_t index = batch_size * batch_idx;
- double *input_batch = input + index * input_shape[1];
- double *labels_batch = labels + index * labels_shape[1];
+ double *input_batch = input_random + index * input_shape[1];
+ double *labels_batch = labels_random + index * labels_shape[1];
if (batch_idx == n_batches - 1 && samples % batch_size) {
batch_input_shape[0] = samples % batch_size;
@@ -454,6 +473,52 @@ nn_fill_random_weights_error:
exit(1);
}
+void dataset_shuffle_rows(
+ double *inputs, size_t in_shape[2],
+ double *labels, size_t lbl_shape[2])
+{
+ size_t random_row;
+ size_t in_index, lbl_index;
+ size_t shuffle_in_index, column_in_bytes;
+ size_t shuffle_lbl_index, column_lbl_bytes;
+ double *in_buffer, *lbl_buffer;
+
+ in_buffer = malloc(sizeof(double) * in_shape[1]);
+ lbl_buffer = malloc(sizeof(double) * lbl_shape[1]);
+
+ if (in_buffer == NULL || lbl_buffer == NULL)
+ goto dataset_shuffle_rows_error;
+
+ column_in_bytes = sizeof(double) * in_shape[1];
+ column_lbl_bytes = sizeof(double) * lbl_shape[1];
+ for (size_t row = 0; row < in_shape[0]; row++) {
+ /* Swap actual row with a random row*/
+ random_row = random() % in_shape[0];
+
+ /* Input Swap */
+ in_index = row * in_shape[1];
+ shuffle_in_index = random_row * in_shape[1];
+ memcpy(in_buffer, inputs + in_index, column_in_bytes);
+ memcpy(inputs + in_index, inputs + shuffle_in_index, column_in_bytes);
+ memcpy(inputs + shuffle_in_index, in_buffer, column_in_bytes);
+
+ /* Label Swap */
+ lbl_index = row * lbl_shape[1];
+ shuffle_lbl_index = random_row * lbl_shape[1];
+ memcpy(lbl_buffer, labels + lbl_index, column_lbl_bytes);
+ memcpy(labels + lbl_index, labels + shuffle_lbl_index, column_lbl_bytes);
+ memcpy(labels + shuffle_lbl_index, lbl_buffer, column_lbl_bytes);
+
+ }
+
+ free(in_buffer);
+ free(lbl_buffer);
+ return;
+
+dataset_shuffle_rows_error:
+ die("dataset_shuffle_rows() malloc Error:");
+}
+
double square_loss(double labels[], double net_out[], size_t shape)
{
double sum = 0;
@@ -478,3 +543,55 @@ double get_avg_loss(
}
return sum / shape[0];
}
+
+#ifdef NN_TEST
+/*
+ * compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm
+ */
+int main(void) {
+ /*
+ * array_shuffle_rows() test
+ */
+ srandom(42);
+ double input_array[12] = {
+ 11, 12, 13,
+ 21, 22, 23,
+ 31, 32, 33,
+ 41, 42, 43,
+ };
+ double shuffled_input_array[12] = {
+ 21, 22, 23,
+ 41, 42, 43,
+ 31, 32, 33,
+ 11, 12, 13,
+ };
+ size_t in_shape[2] = {4,3};
+
+ double label_array[4] = {1, 2, 3, 4};
+ double shuffled_label_array[4] = {2, 4, 3, 1};
+ size_t lbl_shape[2] = {4,1};
+
+ dataset_shuffle_rows(input_array, in_shape, label_array, lbl_shape);
+ size_t i, j, index;
+ for (i = 0; i < in_shape[0]; i++) {
+ for (j = 0; j < in_shape[1]; j++) {
+ index = i * in_shape[1] + j;
+ if (input_array[index] != shuffled_input_array[index]) {
+ printf("- array_shuffle_rows() failure: input_array mismatch on (%zu,%zu)\n", i, j);
+ return 1;
+ }
+ }
+
+ for (j = 0; j < lbl_shape[1]; j++) {
+ index = i * lbl_shape[1] + j;
+ if (label_array[index] != shuffled_label_array[index]) {
+ printf("- array_shuffle_rows() failure: label_array mismatch on (%zu,%zu)\n", i, j);
+ return 1;
+ }
+ }
+ }
+ printf("- array_shuffle_rows() success\n");
+
+ return 0;
+}
+#endif //NN_TEST
diff --git a/src/nn.h b/src/nn.h
index 5dbb656..9f8e2a5 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -53,7 +53,8 @@ void nn_network_train(
double *input, size_t input_shape[2],
double *labels, size_t labels_shape[2],
struct Cost cost, size_t epochs,
- size_t batch_size, double alpha);
+ size_t batch_size, double alpha,
+ bool shuffle);
void nn_layer_map_activation(
double (*activation)(double),
diff --git a/src/util.c b/src/util.c
index 71cbf8f..81950f1 100644
--- a/src/util.c
+++ b/src/util.c
@@ -105,6 +105,7 @@ void usage(int exit_code)
" -o, --output=FILE Output file (only works with predict)\n"
" -p, --precision=INT Decimals output precision (only works with predict)\n"
" [default=auto]\n"
+ " -S, --no-shuffle Don't shuffle data each epoch (only works with train)\n"
"\n"
);
exit(exit_code);
@@ -120,6 +121,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[])
{"epochs", required_argument, 0, 'e'},
{"batch", required_argument, 0, 'b'},
{"alpha", required_argument, 0, 'a'},
+ {"no-shuffle", no_argument, 0, 'S'},
{"output", required_argument, 0, 'o'},
{"config", required_argument, 0, 'c'},
{"only-out", no_argument, 0, 'O'},
@@ -129,7 +131,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[])
int c;
while (1) {
- c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:b:", long_opts, NULL);
+ c = getopt_long(argc, argv, "hvOSc:e:a:o:i:f:p:b:", long_opts, NULL);
if (c == -1) {
break;
@@ -160,6 +162,9 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[])
if (atoi(optarg) <= 0) die("util_load_cli() Error: batch size must be greater than 0");
ml->batch_size = (size_t)atol(optarg);
break;
+ case 'S':
+ ml->shuffle = false;
+ break;
case 'h':
usage(0);
break;
diff --git a/src/util.h b/src/util.h
index 10aa4aa..69ebfbe 100644
--- a/src/util.h
+++ b/src/util.h
@@ -14,6 +14,7 @@ struct Configs {
size_t n_input_keys, n_label_keys;
char *weights_filepath;
char *config_filepath;
+ bool shuffle;
/* cli cfgs */
char *file_format;
char *in_filepath;
Feel free to download, copy and edit any repo