From 7796b9e4dc1fd138108b0262ab131e51453d8e66 Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 2 Aug 2023 20:49:21 -0500 Subject: add: layer backward done --- src/nn.c | 18 ++++++++++++++++++ src/nn.h | 22 +++++++++++++++------- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/nn.c b/src/nn.c index 6aa7213..1fdb192 100644 --- a/src/nn.c +++ b/src/nn.c @@ -2,6 +2,24 @@ static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); +void nn_layer_backward( + double *weights, size_t weigths_shape[2], + double *delta, size_t delta_cols, + double *out_prev, size_t out_cols, + Layer layer, double alpha) +{ + assert(out_cols == weigths_shape[0] && "out_cols does not match with weight rows"); + assert(delta_cols == weigths_shape[1] && "delta_cols does not match with weight cols"); + + for (size_t i = 0; i < weigths_shape[0]; i++) { + for (size_t j = 0; j < weigths_shape[0]; j++) { + size_t index = weigths_shape[1] * i + j; + double dcost_w = delta[j] * out_prev[i]; + weights[index] = layer.weights[index] + alpha * dcost_w; + } + } +} + void nn_layer_hidden_delta( double *delta, size_t delta_cols, double *delta_next, size_t delta_next_cols, diff --git a/src/nn.h b/src/nn.h index d6dd1fe..94772c2 100644 --- a/src/nn.h +++ b/src/nn.h @@ -26,19 +26,27 @@ void nn_layer_map_activation( void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]); void nn_layer_backward( - Layer layer, - double *weights, - double *cost_derivative, size_t dcost_shape[2], - double *out, size_t out_shape[2], - double *labels, size_t labels_shape[2], - double *local_gradient); //TODO + double *weights, size_t weigths_shape[2], + double *delta, size_t dcost_cols, + double *out_prev, size_t out_cols, + Layer layer, double alpha); double sigmoid(double x); double relu(double x); double identity(double x); -void nn_forward(double **aout, double **zout, double *input, size_t input_shape[2], Layer network[], size_t network_size); +void nn_forward( + double **aout, double **zout, + double *input, size_t input_shape[2], + Layer network[], size_t network_size); + +void nn_backwad( + double **weights, + double **zout, double **outs, size_t n_rows, + Layer network[], size_t network_size, + double (cost_derivative)(double, double)); + void nn_layer_out_delta( double *delta, size_t delta_cols, double *error, size_t error_cols, -- cgit v1.2.3-70-g09d2