From 2d84b6e5a34fd8fbc62a96c4842665701ab4e9bd Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 2 Aug 2023 15:11:27 -0500 Subject: add: delta calculation for hidden and output layer done --- src/nn.c | 35 +++++++++++++++++++++++++++++++++++ src/nn.h | 23 ++++++++++++++++++++--- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/nn.c b/src/nn.c index 14120a5..6aa7213 100644 --- a/src/nn.c +++ b/src/nn.c @@ -2,6 +2,41 @@ static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); +void nn_layer_hidden_delta( + double *delta, size_t delta_cols, + double *delta_next, size_t delta_next_cols, + double *weigths_next, size_t weigths_shape[2], + double *zout, size_t zout_cols, + double (*activation_derivative)(double)) +{ + assert(delta_cols == zout_cols); + assert(delta_cols == weigths_shape[0]); + assert(delta_next_cols == weigths_shape[1]); + + for (size_t j = 0; j < delta_cols; j++) { + double sum = 0; + for (size_t k = 0; k < delta_next_cols; k++) { + size_t index = j * delta_cols + k; + sum += delta_next[k] * weigths_next[index]; + } + delta[j] = sum * activation_derivative(zout[j]); + } +} + +void nn_layer_out_delta( + double *delta, size_t delta_cols, + double *error, size_t error_cols, + double *zout, size_t zout_cols, + double (*activation_derivative)(double)) +{ + assert(delta_cols == error_cols); + assert(zout_cols == error_cols); + + for (size_t i = 0; i < delta_cols; i++) { + delta[i] = error[i] * activation_derivative(zout[i]); + } +} + void nn_forward( double **out, double **zout, double *X, size_t X_shape[2], diff --git a/src/nn.h b/src/nn.h index 62f438d..d6dd1fe 100644 --- a/src/nn.h +++ b/src/nn.h @@ -2,6 +2,7 @@ #define __NN__ #include +#include #include #include #include @@ -18,11 +19,16 @@ typedef struct Layer { void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols); void nn_network_free_weights(Layer *network, size_t nmemb); -void nn_layer_map_activation(double (*activation)(double), double *aout, size_t aout_shape[2], double *zout, size_t zout_shape[2]); +void nn_layer_map_activation( + double (*activation)(double), + double *aout, size_t aout_shape[2], + double *zout, size_t zout_shape[2]); + void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]); void nn_layer_backward( - Layer *layer, + Layer layer, double *weights, + double *cost_derivative, size_t dcost_shape[2], double *out, size_t out_shape[2], double *labels, size_t labels_shape[2], double *local_gradient); //TODO @@ -33,5 +39,16 @@ double identity(double x); void nn_forward(double **aout, double **zout, double *input, size_t input_shape[2], Layer network[], size_t network_size); -double nn_layer_out_delta(double error, double (*activation_derivative)(double)); +void nn_layer_out_delta( + double *delta, size_t delta_cols, + double *error, size_t error_cols, + double *zout, size_t zout_cols, + double (*activation_derivative)(double));//TODO + +void nn_layer_hidden_delta( + double *delta, size_t delta_cols, + double *delta_next, size_t delta_next_cols, + double *weigths_next, size_t weigths_shape[2], + double *zout, size_t zout_cols, + double (*activation_derivative)(double));//TODO #endif -- cgit v1.2.3-70-g09d2