diff options
author | jvech <jmvalenciae@unal.edu.co> | 2024-09-04 22:15:23 -0500 |
---|---|---|
committer | jvech <jmvalenciae@unal.edu.co> | 2024-09-04 22:15:23 -0500 |
commit | a63ee8ba14aad0e364e928399b24196c72a4217f (patch) | |
tree | e9561f3ce87bcca5ea0fc8dcf002fa6024723926 /src/nn.c | |
parent | f39f6d5b0a907d519377e70876b32daad1a676f2 (diff) |
fix: ml_network_train() refactored
Diffstat (limited to 'src/nn.c')
-rw-r--r-- | src/nn.c | 31 |
1 files changed, 22 insertions, 9 deletions
@@ -29,6 +29,8 @@ #include "util.h" #include "nn.h" + +struct Cost load_loss(struct Configs cfg); static void dataset_shuffle_rows( double *inputs, size_t in_shape[2], double *labels, size_t lbl_shape[2]); @@ -73,15 +75,19 @@ void nn_network_predict( } void nn_network_train( - Layer network[], size_t network_size, + Layer network[], struct Configs ml_configs, double *input, size_t input_shape[2], - double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, - size_t batch_size, double alpha, - bool shuffle) + double *labels, size_t labels_shape[2]) { assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); + size_t epochs = ml_configs.epochs; + size_t batch_size = ml_configs.batch_size; + size_t network_size = ml_configs.network_size; + double alpha = ml_configs.alpha; + bool shuffle = ml_configs.shuffle; + struct Cost cost = load_loss(ml_configs); + double **outs = calloc(network_size, sizeof(double *)); double **zouts = calloc(network_size, sizeof(double *)); double **weights = calloc(network_size, sizeof(double *)); @@ -204,19 +210,19 @@ void nn_backward( double *out_prev = Outs[l - 1] + sample * network[l-1].neurons; double *dcost_out = dcost_outs + sample * network[l].neurons; nn_layer_out_delta(delta, dcost_out, zout, network[l].neurons, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, alpha); } else if (l == 0) { size_t weights_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons}; double *zout = Zout[l] + sample * network[l].neurons; double *input = Input + sample * input_shape[1]; nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weights_next_shape, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, input, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, input, alpha); } else { size_t weights_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons}; double *zout = Zout[l] + sample * network[l].neurons; double *out_prev = Outs[l - 1] + sample * network[l-1].neurons; nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weights_next_shape, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, alpha); } memmove(delta_next, delta, weights_shape[1] * sizeof(double)); } @@ -241,7 +247,7 @@ nn_backward_error: void nn_layer_backward( double *weights, double *bias, size_t weights_shape[2], double *delta, double *out_prev, - Layer layer, double alpha) + double alpha) { // W_next = W - alpha * out_prev @ delta.T cblas_dger(CblasRowMajor, weights_shape[0], weights_shape[1], -alpha, @@ -544,6 +550,13 @@ double get_avg_loss( return sum / shape[0]; } +struct Cost load_loss(struct Configs cfg) +{ + if (!strcmp("square", cfg.loss)) return NN_SQUARE; + die("load_loss() Error: Unknown '%s' loss function", cfg.loss); + exit(1); +} + #ifdef NN_TEST /* * compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm |