From ed930338f7936630705c665cad9dd6d562344efc Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 30 Aug 2023 20:59:20 -0500 Subject: add: network read and write done json_read reactored --- src/nn.h | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) (limited to 'src/nn.h') diff --git a/src/nn.h b/src/nn.h index 9005364..5402ffa 100644 --- a/src/nn.h +++ b/src/nn.h @@ -1,14 +1,8 @@ #ifndef __NN__ #define __NN__ -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include struct Cost { double (*func)(double labels[], double net_out[], size_t shape); @@ -26,7 +20,9 @@ typedef struct Layer { size_t neurons, input_nodes; } Layer; -void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols); +void nn_network_write_weights(char *filepath, Layer *network, size_t network_size); +void nn_network_read_weights(char *filepath, Layer *network, size_t network_size); +void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols, bool fill_random); void nn_network_free_weights(Layer *network, size_t nmemb); void nn_network_predict( @@ -77,10 +73,10 @@ void nn_layer_backward( void nn_layer_out_delta( double *delta, double *dcost_out, double *zout, size_t cols, - double (*activation_derivative)(double));//TODO + double (*activation_derivative)(double)); void nn_layer_hidden_delta( double *delta, double *delta_next, double *zout, double *weights_next, size_t weights_next_shape[2], - double (*activation_derivative)(double));//TODO + double (*activation_derivative)(double)); #endif -- cgit v1.2.3-70-g09d2