diff options
Diffstat (limited to 'src/nn.h')
-rw-r--r-- | src/nn.h | 11 |
1 files changed, 5 insertions, 6 deletions
@@ -22,6 +22,8 @@ #include <stdbool.h> #include <stddef.h> +#include "util.h" + struct Cost { double (*func)(double labels[], double net_out[], size_t shape); double (*dfunc_out)(double labels, double net_out); @@ -49,12 +51,9 @@ void nn_network_predict( Layer network[], size_t network_size); void nn_network_train( - Layer network[], size_t network_size, + Layer network[], struct Configs ml_configs, double *input, size_t input_shape[2], - double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, - size_t batch_size, double alpha, - bool shuffle); + double *labels, size_t labels_shape[2]); void nn_layer_map_activation( double (*activation)(double), @@ -89,7 +88,7 @@ void nn_layer_forward( void nn_layer_backward( double *weights, double *bias, size_t weigths_shape[2], double *delta, double *out_prev, - Layer layer, double alpha); + double alpha); void nn_layer_out_delta( double *delta, double *dcost_out, double *zout, size_t cols, |