aboutsummaryrefslogtreecommitdiff
path: root/src/nn.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.h')
-rw-r--r--src/nn.h11
1 files changed, 5 insertions, 6 deletions
diff --git a/src/nn.h b/src/nn.h
index 9f8e2a5..105fae7 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -22,6 +22,8 @@
#include <stdbool.h>
#include <stddef.h>
+#include "util.h"
+
struct Cost {
double (*func)(double labels[], double net_out[], size_t shape);
double (*dfunc_out)(double labels, double net_out);
@@ -49,12 +51,9 @@ void nn_network_predict(
Layer network[], size_t network_size);
void nn_network_train(
- Layer network[], size_t network_size,
+ Layer network[], struct Configs ml_configs,
double *input, size_t input_shape[2],
- double *labels, size_t labels_shape[2],
- struct Cost cost, size_t epochs,
- size_t batch_size, double alpha,
- bool shuffle);
+ double *labels, size_t labels_shape[2]);
void nn_layer_map_activation(
double (*activation)(double),
@@ -89,7 +88,7 @@ void nn_layer_forward(
void nn_layer_backward(
double *weights, double *bias, size_t weigths_shape[2],
double *delta, double *out_prev,
- Layer layer, double alpha);
+ double alpha);
void nn_layer_out_delta(
double *delta, double *dcost_out, double *zout, size_t cols,
Feel free to download, copy and edit any repo