From a63ee8ba14aad0e364e928399b24196c72a4217f Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 4 Sep 2024 22:15:23 -0500 Subject: fix: ml_network_train() refactored --- src/nn.c | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) (limited to 'src/nn.c') diff --git a/src/nn.c b/src/nn.c index 867819c..19076bb 100644 --- a/src/nn.c +++ b/src/nn.c @@ -29,6 +29,8 @@ #include "util.h" #include "nn.h" + +struct Cost load_loss(struct Configs cfg); static void dataset_shuffle_rows( double *inputs, size_t in_shape[2], double *labels, size_t lbl_shape[2]); @@ -73,15 +75,19 @@ void nn_network_predict( } void nn_network_train( - Layer network[], size_t network_size, + Layer network[], struct Configs ml_configs, double *input, size_t input_shape[2], - double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, - size_t batch_size, double alpha, - bool shuffle) + double *labels, size_t labels_shape[2]) { assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); + size_t epochs = ml_configs.epochs; + size_t batch_size = ml_configs.batch_size; + size_t network_size = ml_configs.network_size; + double alpha = ml_configs.alpha; + bool shuffle = ml_configs.shuffle; + struct Cost cost = load_loss(ml_configs); + double **outs = calloc(network_size, sizeof(double *)); double **zouts = calloc(network_size, sizeof(double *)); double **weights = calloc(network_size, sizeof(double *)); @@ -204,19 +210,19 @@ void nn_backward( double *out_prev = Outs[l - 1] + sample * network[l-1].neurons; double *dcost_out = dcost_outs + sample * network[l].neurons; nn_layer_out_delta(delta, dcost_out, zout, network[l].neurons, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, alpha); } else if (l == 0) { size_t weights_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons}; double *zout = Zout[l] + sample * network[l].neurons; double *input = Input + sample * input_shape[1]; nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weights_next_shape, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, input, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, input, alpha); } else { size_t weights_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons}; double *zout = Zout[l] + sample * network[l].neurons; double *out_prev = Outs[l - 1] + sample * network[l-1].neurons; nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weights_next_shape, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, alpha); } memmove(delta_next, delta, weights_shape[1] * sizeof(double)); } @@ -241,7 +247,7 @@ nn_backward_error: void nn_layer_backward( double *weights, double *bias, size_t weights_shape[2], double *delta, double *out_prev, - Layer layer, double alpha) + double alpha) { // W_next = W - alpha * out_prev @ delta.T cblas_dger(CblasRowMajor, weights_shape[0], weights_shape[1], -alpha, @@ -544,6 +550,13 @@ double get_avg_loss( return sum / shape[0]; } +struct Cost load_loss(struct Configs cfg) +{ + if (!strcmp("square", cfg.loss)) return NN_SQUARE; + die("load_loss() Error: Unknown '%s' loss function", cfg.loss); + exit(1); +} + #ifdef NN_TEST /* * compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm -- cgit v1.2.3-70-g09d2