aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/nn.c35
-rw-r--r--src/nn.h23
2 files changed, 55 insertions, 3 deletions
diff --git a/src/nn.c b/src/nn.c
index 14120a5..6aa7213 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -2,6 +2,41 @@
static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols);
+void nn_layer_hidden_delta(
+ double *delta, size_t delta_cols,
+ double *delta_next, size_t delta_next_cols,
+ double *weigths_next, size_t weigths_shape[2],
+ double *zout, size_t zout_cols,
+ double (*activation_derivative)(double))
+{
+ assert(delta_cols == zout_cols);
+ assert(delta_cols == weigths_shape[0]);
+ assert(delta_next_cols == weigths_shape[1]);
+
+ for (size_t j = 0; j < delta_cols; j++) {
+ double sum = 0;
+ for (size_t k = 0; k < delta_next_cols; k++) {
+ size_t index = j * delta_cols + k;
+ sum += delta_next[k] * weigths_next[index];
+ }
+ delta[j] = sum * activation_derivative(zout[j]);
+ }
+}
+
+void nn_layer_out_delta(
+ double *delta, size_t delta_cols,
+ double *error, size_t error_cols,
+ double *zout, size_t zout_cols,
+ double (*activation_derivative)(double))
+{
+ assert(delta_cols == error_cols);
+ assert(zout_cols == error_cols);
+
+ for (size_t i = 0; i < delta_cols; i++) {
+ delta[i] = error[i] * activation_derivative(zout[i]);
+ }
+}
+
void nn_forward(
double **out, double **zout,
double *X, size_t X_shape[2],
diff --git a/src/nn.h b/src/nn.h
index 62f438d..d6dd1fe 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -2,6 +2,7 @@
#define __NN__
#include <stdlib.h>
+#include <assert.h>
#include <stdio.h>
#include <stdint.h>
#include <string.h>
@@ -18,11 +19,16 @@ typedef struct Layer {
void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols);
void nn_network_free_weights(Layer *network, size_t nmemb);
-void nn_layer_map_activation(double (*activation)(double), double *aout, size_t aout_shape[2], double *zout, size_t zout_shape[2]);
+void nn_layer_map_activation(
+ double (*activation)(double),
+ double *aout, size_t aout_shape[2],
+ double *zout, size_t zout_shape[2]);
+
void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]);
void nn_layer_backward(
- Layer *layer,
+ Layer layer,
double *weights,
+ double *cost_derivative, size_t dcost_shape[2],
double *out, size_t out_shape[2],
double *labels, size_t labels_shape[2],
double *local_gradient); //TODO
@@ -33,5 +39,16 @@ double identity(double x);
void nn_forward(double **aout, double **zout, double *input, size_t input_shape[2], Layer network[], size_t network_size);
-double nn_layer_out_delta(double error, double (*activation_derivative)(double));
+void nn_layer_out_delta(
+ double *delta, size_t delta_cols,
+ double *error, size_t error_cols,
+ double *zout, size_t zout_cols,
+ double (*activation_derivative)(double));//TODO
+
+void nn_layer_hidden_delta(
+ double *delta, size_t delta_cols,
+ double *delta_next, size_t delta_next_cols,
+ double *weigths_next, size_t weigths_shape[2],
+ double *zout, size_t zout_cols,
+ double (*activation_derivative)(double));//TODO
#endif
Feel free to download, copy and edit any repo