From 21a570b6d98497835785eccf28fc7f16e57ab197 Mon Sep 17 00:00:00 2001 From: jvech Date: Fri, 4 Aug 2023 18:40:41 -0500 Subject: add: nn_backward implemented It needs to be tested and some backward layer functions were redefined to improve readability --- src/nn.h | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) (limited to 'src/nn.h') diff --git a/src/nn.h b/src/nn.h index 94772c2..c07a943 100644 --- a/src/nn.h +++ b/src/nn.h @@ -13,6 +13,7 @@ typedef struct Layer { double *weights, *bias; double (*activation)(double x); + double (*activation_derivative)(double x); size_t neurons, input_nodes; } Layer; @@ -24,12 +25,6 @@ void nn_layer_map_activation( double *aout, size_t aout_shape[2], double *zout, size_t zout_shape[2]); -void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]); -void nn_layer_backward( - double *weights, size_t weigths_shape[2], - double *delta, size_t dcost_cols, - double *out_prev, size_t out_cols, - Layer layer, double alpha); double sigmoid(double x); double relu(double x); @@ -41,22 +36,31 @@ void nn_forward( double *input, size_t input_shape[2], Layer network[], size_t network_size); -void nn_backwad( +void nn_backward( double **weights, - double **zout, double **outs, size_t n_rows, + double **zout, double **outs, + double *input, size_t input_shape[2], + double *labels, size_t labels_shape[2], Layer network[], size_t network_size, - double (cost_derivative)(double, double)); + double (cost_derivative)(double, double), + double alpha); + +void nn_layer_forward( + Layer layer, + double *out, size_t out_shape[2], + double *input, size_t input_shape[2]); + +void nn_layer_backward( + double *weights, size_t weigths_shape[2], + double *delta, double *out_prev, + Layer layer, double alpha); void nn_layer_out_delta( - double *delta, size_t delta_cols, - double *error, size_t error_cols, - double *zout, size_t zout_cols, + double *delta, double *dcost_out, double *zout, size_t cols, double (*activation_derivative)(double));//TODO void nn_layer_hidden_delta( - double *delta, size_t delta_cols, - double *delta_next, size_t delta_next_cols, - double *weigths_next, size_t weigths_shape[2], - double *zout, size_t zout_cols, + double *delta, double *delta_next, double *zout, + double *weights_next, size_t weights_next_shape[2], double (*activation_derivative)(double));//TODO #endif -- cgit v1.2.3-70-g09d2