From a63ee8ba14aad0e364e928399b24196c72a4217f Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 4 Sep 2024 22:15:23 -0500 Subject: fix: ml_network_train() refactored --- src/nn.h | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'src/nn.h') diff --git a/src/nn.h b/src/nn.h index 9f8e2a5..105fae7 100644 --- a/src/nn.h +++ b/src/nn.h @@ -22,6 +22,8 @@ #include #include +#include "util.h" + struct Cost { double (*func)(double labels[], double net_out[], size_t shape); double (*dfunc_out)(double labels, double net_out); @@ -49,12 +51,9 @@ void nn_network_predict( Layer network[], size_t network_size); void nn_network_train( - Layer network[], size_t network_size, + Layer network[], struct Configs ml_configs, double *input, size_t input_shape[2], - double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, - size_t batch_size, double alpha, - bool shuffle); + double *labels, size_t labels_shape[2]); void nn_layer_map_activation( double (*activation)(double), @@ -89,7 +88,7 @@ void nn_layer_forward( void nn_layer_backward( double *weights, double *bias, size_t weigths_shape[2], double *delta, double *out_prev, - Layer layer, double alpha); + double alpha); void nn_layer_out_delta( double *delta, double *dcost_out, double *zout, size_t cols, -- cgit v1.2.3-70-g09d2