From a63ee8ba14aad0e364e928399b24196c72a4217f Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 4 Sep 2024 22:15:23 -0500 Subject: fix: ml_network_train() refactored --- src/main.c | 18 +----------------- src/nn.c | 31 ++++++++++++++++++++++--------- src/nn.h | 11 +++++------ 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/src/main.c b/src/main.c index 848d638..63a2fa9 100644 --- a/src/main.c +++ b/src/main.c @@ -79,14 +79,6 @@ Layer * load_network(struct Configs cfg) return network; } -struct Cost load_loss(struct Configs cfg) -{ - extern struct Cost NN_SQUARE; - if (!strcmp("square", cfg.loss)) return NN_SQUARE; - die("load_loss() Error: Unknown '%s' loss function", cfg.loss); - exit(1); -} - int main(int argc, char *argv[]) { char default_config_path[512], *env_config_path; struct Configs ml_configs = { @@ -134,15 +126,7 @@ int main(int argc, char *argv[]) { nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); } - nn_network_train( - network, ml_configs.network_size, - X.data, X.shape, - y.data, y.shape, - load_loss(ml_configs), - ml_configs.epochs, - ml_configs.batch_size, - ml_configs.alpha, - ml_configs.shuffle); + nn_network_train(network, ml_configs, X.data, X.shape, y.data, y.shape); nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { diff --git a/src/nn.c b/src/nn.c index 867819c..19076bb 100644 --- a/src/nn.c +++ b/src/nn.c @@ -29,6 +29,8 @@ #include "util.h" #include "nn.h" + +struct Cost load_loss(struct Configs cfg); static void dataset_shuffle_rows( double *inputs, size_t in_shape[2], double *labels, size_t lbl_shape[2]); @@ -73,15 +75,19 @@ void nn_network_predict( } void nn_network_train( - Layer network[], size_t network_size, + Layer network[], struct Configs ml_configs, double *input, size_t input_shape[2], - double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, - size_t batch_size, double alpha, - bool shuffle) + double *labels, size_t labels_shape[2]) { assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); + size_t epochs = ml_configs.epochs; + size_t batch_size = ml_configs.batch_size; + size_t network_size = ml_configs.network_size; + double alpha = ml_configs.alpha; + bool shuffle = ml_configs.shuffle; + struct Cost cost = load_loss(ml_configs); + double **outs = calloc(network_size, sizeof(double *)); double **zouts = calloc(network_size, sizeof(double *)); double **weights = calloc(network_size, sizeof(double *)); @@ -204,19 +210,19 @@ void nn_backward( double *out_prev = Outs[l - 1] + sample * network[l-1].neurons; double *dcost_out = dcost_outs + sample * network[l].neurons; nn_layer_out_delta(delta, dcost_out, zout, network[l].neurons, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, alpha); } else if (l == 0) { size_t weights_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons}; double *zout = Zout[l] + sample * network[l].neurons; double *input = Input + sample * input_shape[1]; nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weights_next_shape, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, input, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, input, alpha); } else { size_t weights_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons}; double *zout = Zout[l] + sample * network[l].neurons; double *out_prev = Outs[l - 1] + sample * network[l-1].neurons; nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weights_next_shape, network[l].activation.dfunc); - nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, alpha); } memmove(delta_next, delta, weights_shape[1] * sizeof(double)); } @@ -241,7 +247,7 @@ nn_backward_error: void nn_layer_backward( double *weights, double *bias, size_t weights_shape[2], double *delta, double *out_prev, - Layer layer, double alpha) + double alpha) { // W_next = W - alpha * out_prev @ delta.T cblas_dger(CblasRowMajor, weights_shape[0], weights_shape[1], -alpha, @@ -544,6 +550,13 @@ double get_avg_loss( return sum / shape[0]; } +struct Cost load_loss(struct Configs cfg) +{ + if (!strcmp("square", cfg.loss)) return NN_SQUARE; + die("load_loss() Error: Unknown '%s' loss function", cfg.loss); + exit(1); +} + #ifdef NN_TEST /* * compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm diff --git a/src/nn.h b/src/nn.h index 9f8e2a5..105fae7 100644 --- a/src/nn.h +++ b/src/nn.h @@ -22,6 +22,8 @@ #include #include +#include "util.h" + struct Cost { double (*func)(double labels[], double net_out[], size_t shape); double (*dfunc_out)(double labels, double net_out); @@ -49,12 +51,9 @@ void nn_network_predict( Layer network[], size_t network_size); void nn_network_train( - Layer network[], size_t network_size, + Layer network[], struct Configs ml_configs, double *input, size_t input_shape[2], - double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, - size_t batch_size, double alpha, - bool shuffle); + double *labels, size_t labels_shape[2]); void nn_layer_map_activation( double (*activation)(double), @@ -89,7 +88,7 @@ void nn_layer_forward( void nn_layer_backward( double *weights, double *bias, size_t weigths_shape[2], double *delta, double *out_prev, - Layer layer, double alpha); + double alpha); void nn_layer_out_delta( double *delta, double *dcost_out, double *zout, size_t cols, -- cgit v1.2.3-70-g09d2