diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/nn.c | 78 | ||||
-rw-r--r-- | src/nn.h | 11 |
2 files changed, 88 insertions, 1 deletions
@@ -1,6 +1,8 @@ #include "nn.h" static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); +static double get_avg_loss(double labels[], double outs[], size_t shape[2], double (*loss)(double, double)); + double relu(double x); double drelu(double x); @@ -17,6 +19,63 @@ struct Activation NN_SIGMOID = { .dfunc = dsigmoid }; + +void nn_network_train( + Layer network[], size_t network_size, + double *input, size_t input_shape[2], + double *labels, size_t labels_shape[2], + struct Cost cost, size_t epochs, double alpha) +{ + assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); + + double **outs = calloc(network_size, sizeof(double *)); + double **zouts = calloc(network_size, sizeof(double *)); + double **weights = calloc(network_size, sizeof(double *)); + double **biases = calloc(network_size, sizeof(double *)); + + if (!outs || !zouts || !weights || !biases) goto nn_network_train_error; + + + + size_t samples = input_shape[0]; + for (size_t l = 0; l < network_size; l++) { + outs[l] = calloc(samples * network[l].neurons, sizeof(double)); + zouts[l] = calloc(samples * network[l].neurons, sizeof(double)); + weights[l] = calloc(network[l].input_nodes * network[l].neurons, sizeof(double)); + biases[l] = calloc(network[l].neurons, sizeof(double)); + } + + for (size_t epoch = 0; epoch < epochs; epochs++) { + nn_forward(outs, zouts, input, input_shape, network, network_size); + nn_backward( + weights, biases, + zouts, outs, + input, input_shape, + labels, labels_shape, + network, network_size, + cost.dfunc_out, alpha); + double *net_out = outs[network_size - 1]; + fprintf(stderr, "epoch: %zu \tavg loss: %6.2lf\n", + epoch, get_avg_loss(labels, net_out, labels_shape, cost.func)); + } + + for (size_t l = 0; l < network_size; l++) { + free(outs[l]); + free(zouts[l]); + free(weights[l]); + free(biases[l]); + } + + free(zouts); + free(outs); + free(weights); + free(biases); + +nn_network_train_error: + perror("nn_network_train() Error"); + exit(1); +} + void nn_backward( double **weights, double **bias, double **Zout, double **Outs, @@ -65,7 +124,12 @@ void nn_backward( nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weigths_next_shape, network[l].activation.dfunc); nn_layer_backward(weights[l], bias[l], weigths_shape, delta, out_prev, network[l], alpha); } - memcpy(delta_next, delta, weigths_shape[1] * sizeof(double)); + memmove(delta_next, delta, weigths_shape[1] * sizeof(double)); + } + for (size_t l = network_size - 1; l >= 0; l--) { + size_t weigths_shape[2] = {network[l].input_nodes, network[l].neurons}; + memmove(network[l].weights, weights[l], weigths_shape[0] * weigths_shape[1] * sizeof(double)); + memmove(network[l].bias, bias[l], weigths_shape[1] * sizeof(double)); } } @@ -266,3 +330,15 @@ double relu(double x) double drelu(double x) { return (x > 0) ? 1 : 0; } + +double get_avg_loss(double labels[], double outs[], size_t shape[2], double (*loss)(double, double)) +{ + double sum = 0; + for (size_t i = 0; i < shape[0]; i++) { + for (size_t j = 0; j < shape[1]; j++) { + size_t index = i * shape[1] + j; + sum += loss(labels[index], outs[index]); + } + } + return sum / shape[1]; +} @@ -10,6 +10,11 @@ #include <unistd.h> #include <openblas/cblas.h> +struct Cost { + double (*func)(double labels, double net_out); + double (*dfunc_out)(double labels, double net_out); +}; + struct Activation { double (*func)(double); double (*dfunc)(double); @@ -24,6 +29,12 @@ typedef struct Layer { void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols); void nn_network_free_weights(Layer *network, size_t nmemb); +void nn_network_train( + Layer network[], size_t network_size, + double *input, size_t input_shape[2], + double *labels, size_t labels_shape[2], + struct Cost cost, size_t epochs, double alpha); + void nn_layer_map_activation( double (*activation)(double), double *aout, size_t aout_shape[2], |