diff options
author | jvech <jmvalenciae@unal.edu.co> | 2023-08-02 15:11:27 -0500 |
---|---|---|
committer | jvech <jmvalenciae@unal.edu.co> | 2023-08-02 15:11:27 -0500 |
commit | 2d84b6e5a34fd8fbc62a96c4842665701ab4e9bd (patch) | |
tree | 37e03c0aeb45ca1e6f5f6b52f8497e6ec29549c5 /src/nn.h | |
parent | 525f8398c58cc2ca7f92c416df880068c62abbd5 (diff) |
add: delta calculation for hidden and output layer done
Diffstat (limited to 'src/nn.h')
-rw-r--r-- | src/nn.h | 23 |
1 files changed, 20 insertions, 3 deletions
@@ -2,6 +2,7 @@ #define __NN__ #include <stdlib.h> +#include <assert.h> #include <stdio.h> #include <stdint.h> #include <string.h> @@ -18,11 +19,16 @@ typedef struct Layer { void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols); void nn_network_free_weights(Layer *network, size_t nmemb); -void nn_layer_map_activation(double (*activation)(double), double *aout, size_t aout_shape[2], double *zout, size_t zout_shape[2]); +void nn_layer_map_activation( + double (*activation)(double), + double *aout, size_t aout_shape[2], + double *zout, size_t zout_shape[2]); + void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]); void nn_layer_backward( - Layer *layer, + Layer layer, double *weights, + double *cost_derivative, size_t dcost_shape[2], double *out, size_t out_shape[2], double *labels, size_t labels_shape[2], double *local_gradient); //TODO @@ -33,5 +39,16 @@ double identity(double x); void nn_forward(double **aout, double **zout, double *input, size_t input_shape[2], Layer network[], size_t network_size); -double nn_layer_out_delta(double error, double (*activation_derivative)(double)); +void nn_layer_out_delta( + double *delta, size_t delta_cols, + double *error, size_t error_cols, + double *zout, size_t zout_cols, + double (*activation_derivative)(double));//TODO + +void nn_layer_hidden_delta( + double *delta, size_t delta_cols, + double *delta_next, size_t delta_next_cols, + double *weigths_next, size_t weigths_shape[2], + double *zout, size_t zout_cols, + double (*activation_derivative)(double));//TODO #endif |