aboutsummaryrefslogtreecommitdiff
path: root/src/nn.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.c')
-rw-r--r--src/nn.c78
1 files changed, 77 insertions, 1 deletions
diff --git a/src/nn.c b/src/nn.c
index 1f70785..1d08c6b 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -1,6 +1,8 @@
#include "nn.h"
static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols);
+static double get_avg_loss(double labels[], double outs[], size_t shape[2], double (*loss)(double, double));
+
double relu(double x);
double drelu(double x);
@@ -17,6 +19,63 @@ struct Activation NN_SIGMOID = {
.dfunc = dsigmoid
};
+
+void nn_network_train(
+ Layer network[], size_t network_size,
+ double *input, size_t input_shape[2],
+ double *labels, size_t labels_shape[2],
+ struct Cost cost, size_t epochs, double alpha)
+{
+ assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n");
+
+ double **outs = calloc(network_size, sizeof(double *));
+ double **zouts = calloc(network_size, sizeof(double *));
+ double **weights = calloc(network_size, sizeof(double *));
+ double **biases = calloc(network_size, sizeof(double *));
+
+ if (!outs || !zouts || !weights || !biases) goto nn_network_train_error;
+
+
+
+ size_t samples = input_shape[0];
+ for (size_t l = 0; l < network_size; l++) {
+ outs[l] = calloc(samples * network[l].neurons, sizeof(double));
+ zouts[l] = calloc(samples * network[l].neurons, sizeof(double));
+ weights[l] = calloc(network[l].input_nodes * network[l].neurons, sizeof(double));
+ biases[l] = calloc(network[l].neurons, sizeof(double));
+ }
+
+ for (size_t epoch = 0; epoch < epochs; epochs++) {
+ nn_forward(outs, zouts, input, input_shape, network, network_size);
+ nn_backward(
+ weights, biases,
+ zouts, outs,
+ input, input_shape,
+ labels, labels_shape,
+ network, network_size,
+ cost.dfunc_out, alpha);
+ double *net_out = outs[network_size - 1];
+ fprintf(stderr, "epoch: %zu \tavg loss: %6.2lf\n",
+ epoch, get_avg_loss(labels, net_out, labels_shape, cost.func));
+ }
+
+ for (size_t l = 0; l < network_size; l++) {
+ free(outs[l]);
+ free(zouts[l]);
+ free(weights[l]);
+ free(biases[l]);
+ }
+
+ free(zouts);
+ free(outs);
+ free(weights);
+ free(biases);
+
+nn_network_train_error:
+ perror("nn_network_train() Error");
+ exit(1);
+}
+
void nn_backward(
double **weights, double **bias,
double **Zout, double **Outs,
@@ -65,7 +124,12 @@ void nn_backward(
nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weigths_next_shape, network[l].activation.dfunc);
nn_layer_backward(weights[l], bias[l], weigths_shape, delta, out_prev, network[l], alpha);
}
- memcpy(delta_next, delta, weigths_shape[1] * sizeof(double));
+ memmove(delta_next, delta, weigths_shape[1] * sizeof(double));
+ }
+ for (size_t l = network_size - 1; l >= 0; l--) {
+ size_t weigths_shape[2] = {network[l].input_nodes, network[l].neurons};
+ memmove(network[l].weights, weights[l], weigths_shape[0] * weigths_shape[1] * sizeof(double));
+ memmove(network[l].bias, bias[l], weigths_shape[1] * sizeof(double));
}
}
@@ -266,3 +330,15 @@ double relu(double x)
double drelu(double x) {
return (x > 0) ? 1 : 0;
}
+
+double get_avg_loss(double labels[], double outs[], size_t shape[2], double (*loss)(double, double))
+{
+ double sum = 0;
+ for (size_t i = 0; i < shape[0]; i++) {
+ for (size_t j = 0; j < shape[1]; j++) {
+ size_t index = i * shape[1] + j;
+ sum += loss(labels[index], outs[index]);
+ }
+ }
+ return sum / shape[1];
+}
Feel free to download, copy and edit any repo