aboutsummaryrefslogtreecommitdiff
path: root/src/nn.c
diff options
context:
space:
mode:
authorjvech <jmvalenciae@unal.edu.co>2024-09-03 20:08:25 -0500
committerjvech <jmvalenciae@unal.edu.co>2024-09-03 20:08:25 -0500
commitf39f6d5b0a907d519377e70876b32daad1a676f2 (patch)
treee5462ac42b395e2e9938de53ffbfbf6f0380d3de /src/nn.c
parente8624e1ebcabcc831d651e0beefe32df1463c903 (diff)
feat: shuffle dataset on each epoch done
The CLI option to disable it was also added.
Diffstat (limited to 'src/nn.c')
-rw-r--r--src/nn.c123
1 files changed, 120 insertions, 3 deletions
diff --git a/src/nn.c b/src/nn.c
index 56c35fc..867819c 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -29,7 +29,12 @@
#include "util.h"
#include "nn.h"
+static void dataset_shuffle_rows(
+ double *inputs, size_t in_shape[2],
+ double *labels, size_t lbl_shape[2]);
+
static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols);
+
static double get_avg_loss(
double labels[], double outs[], size_t shape[2],
double (*loss)(double *, double *, size_t));
@@ -72,7 +77,8 @@ void nn_network_train(
double *input, size_t input_shape[2],
double *labels, size_t labels_shape[2],
struct Cost cost, size_t epochs,
- size_t batch_size, double alpha)
+ size_t batch_size, double alpha,
+ bool shuffle)
{
assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n");
@@ -83,6 +89,13 @@ void nn_network_train(
if (!outs || !zouts || !weights || !biases) goto nn_network_train_error;
+ double *input_random = calloc(input_shape[0] * input_shape[1], sizeof(double));
+ double *labels_random = calloc(labels_shape[0] * labels_shape[1], sizeof(double));
+
+ if (!input_random || !labels_random) goto nn_network_train_error;
+
+ memcpy(input_random, input, sizeof(double) * input_shape[0] * input_shape[1]);
+ memcpy(labels_random, labels, sizeof(double) * labels_shape[0] * labels_shape[1]);
size_t samples = input_shape[0];
@@ -107,11 +120,17 @@ void nn_network_train(
n_batches++;
}
for (size_t epoch = 0; epoch < epochs; epoch++) {
+
+ if (shuffle) {
+ dataset_shuffle_rows(input_random, input_shape, labels_random, labels_shape);
+ }
+
+
for (size_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
size_t index = batch_size * batch_idx;
- double *input_batch = input + index * input_shape[1];
- double *labels_batch = labels + index * labels_shape[1];
+ double *input_batch = input_random + index * input_shape[1];
+ double *labels_batch = labels_random + index * labels_shape[1];
if (batch_idx == n_batches - 1 && samples % batch_size) {
batch_input_shape[0] = samples % batch_size;
@@ -454,6 +473,52 @@ nn_fill_random_weights_error:
exit(1);
}
+void dataset_shuffle_rows(
+ double *inputs, size_t in_shape[2],
+ double *labels, size_t lbl_shape[2])
+{
+ size_t random_row;
+ size_t in_index, lbl_index;
+ size_t shuffle_in_index, column_in_bytes;
+ size_t shuffle_lbl_index, column_lbl_bytes;
+ double *in_buffer, *lbl_buffer;
+
+ in_buffer = malloc(sizeof(double) * in_shape[1]);
+ lbl_buffer = malloc(sizeof(double) * lbl_shape[1]);
+
+ if (in_buffer == NULL || lbl_buffer == NULL)
+ goto dataset_shuffle_rows_error;
+
+ column_in_bytes = sizeof(double) * in_shape[1];
+ column_lbl_bytes = sizeof(double) * lbl_shape[1];
+ for (size_t row = 0; row < in_shape[0]; row++) {
+ /* Swap actual row with a random row*/
+ random_row = random() % in_shape[0];
+
+ /* Input Swap */
+ in_index = row * in_shape[1];
+ shuffle_in_index = random_row * in_shape[1];
+ memcpy(in_buffer, inputs + in_index, column_in_bytes);
+ memcpy(inputs + in_index, inputs + shuffle_in_index, column_in_bytes);
+ memcpy(inputs + shuffle_in_index, in_buffer, column_in_bytes);
+
+ /* Label Swap */
+ lbl_index = row * lbl_shape[1];
+ shuffle_lbl_index = random_row * lbl_shape[1];
+ memcpy(lbl_buffer, labels + lbl_index, column_lbl_bytes);
+ memcpy(labels + lbl_index, labels + shuffle_lbl_index, column_lbl_bytes);
+ memcpy(labels + shuffle_lbl_index, lbl_buffer, column_lbl_bytes);
+
+ }
+
+ free(in_buffer);
+ free(lbl_buffer);
+ return;
+
+dataset_shuffle_rows_error:
+ die("dataset_shuffle_rows() malloc Error:");
+}
+
double square_loss(double labels[], double net_out[], size_t shape)
{
double sum = 0;
@@ -478,3 +543,55 @@ double get_avg_loss(
}
return sum / shape[0];
}
+
+#ifdef NN_TEST
+/*
+ * compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm
+ */
+int main(void) {
+ /*
+ * array_shuffle_rows() test
+ */
+ srandom(42);
+ double input_array[12] = {
+ 11, 12, 13,
+ 21, 22, 23,
+ 31, 32, 33,
+ 41, 42, 43,
+ };
+ double shuffled_input_array[12] = {
+ 21, 22, 23,
+ 41, 42, 43,
+ 31, 32, 33,
+ 11, 12, 13,
+ };
+ size_t in_shape[2] = {4,3};
+
+ double label_array[4] = {1, 2, 3, 4};
+ double shuffled_label_array[4] = {2, 4, 3, 1};
+ size_t lbl_shape[2] = {4,1};
+
+ dataset_shuffle_rows(input_array, in_shape, label_array, lbl_shape);
+ size_t i, j, index;
+ for (i = 0; i < in_shape[0]; i++) {
+ for (j = 0; j < in_shape[1]; j++) {
+ index = i * in_shape[1] + j;
+ if (input_array[index] != shuffled_input_array[index]) {
+ printf("- array_shuffle_rows() failure: input_array mismatch on (%zu,%zu)\n", i, j);
+ return 1;
+ }
+ }
+
+ for (j = 0; j < lbl_shape[1]; j++) {
+ index = i * lbl_shape[1] + j;
+ if (label_array[index] != shuffled_label_array[index]) {
+ printf("- array_shuffle_rows() failure: label_array mismatch on (%zu,%zu)\n", i, j);
+ return 1;
+ }
+ }
+ }
+ printf("- array_shuffle_rows() success\n");
+
+ return 0;
+}
+#endif //NN_TEST
Feel free to download, copy and edit any repo