aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/nn.c18
-rw-r--r--src/nn.h22
2 files changed, 33 insertions, 7 deletions
diff --git a/src/nn.c b/src/nn.c
index 6aa7213..1fdb192 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -2,6 +2,24 @@
static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols);
+void nn_layer_backward(
+ double *weights, size_t weigths_shape[2],
+ double *delta, size_t delta_cols,
+ double *out_prev, size_t out_cols,
+ Layer layer, double alpha)
+{
+ assert(out_cols == weigths_shape[0] && "out_cols does not match with weight rows");
+ assert(delta_cols == weigths_shape[1] && "delta_cols does not match with weight cols");
+
+ for (size_t i = 0; i < weigths_shape[0]; i++) {
+ for (size_t j = 0; j < weigths_shape[0]; j++) {
+ size_t index = weigths_shape[1] * i + j;
+ double dcost_w = delta[j] * out_prev[i];
+ weights[index] = layer.weights[index] + alpha * dcost_w;
+ }
+ }
+}
+
void nn_layer_hidden_delta(
double *delta, size_t delta_cols,
double *delta_next, size_t delta_next_cols,
diff --git a/src/nn.h b/src/nn.h
index d6dd1fe..94772c2 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -26,19 +26,27 @@ void nn_layer_map_activation(
void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]);
void nn_layer_backward(
- Layer layer,
- double *weights,
- double *cost_derivative, size_t dcost_shape[2],
- double *out, size_t out_shape[2],
- double *labels, size_t labels_shape[2],
- double *local_gradient); //TODO
+ double *weights, size_t weigths_shape[2],
+ double *delta, size_t dcost_cols,
+ double *out_prev, size_t out_cols,
+ Layer layer, double alpha);
double sigmoid(double x);
double relu(double x);
double identity(double x);
-void nn_forward(double **aout, double **zout, double *input, size_t input_shape[2], Layer network[], size_t network_size);
+void nn_forward(
+ double **aout, double **zout,
+ double *input, size_t input_shape[2],
+ Layer network[], size_t network_size);
+
+void nn_backwad(
+ double **weights,
+ double **zout, double **outs, size_t n_rows,
+ Layer network[], size_t network_size,
+ double (cost_derivative)(double, double));
+
void nn_layer_out_delta(
double *delta, size_t delta_cols,
double *error, size_t error_cols,
Feel free to download, copy and edit any repo