aboutsummaryrefslogtreecommitdiff
path: root/src/nn.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.h')
-rw-r--r--src/nn.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/src/nn.h b/src/nn.h
index 0c794b5..40066e3 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -10,6 +10,11 @@
#include <unistd.h>
#include <openblas/cblas.h>
+struct Cost {
+ double (*func)(double labels, double net_out);
+ double (*dfunc_out)(double labels, double net_out);
+};
+
struct Activation {
double (*func)(double);
double (*dfunc)(double);
@@ -24,6 +29,12 @@ typedef struct Layer {
void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols);
void nn_network_free_weights(Layer *network, size_t nmemb);
+void nn_network_train(
+ Layer network[], size_t network_size,
+ double *input, size_t input_shape[2],
+ double *labels, size_t labels_shape[2],
+ struct Cost cost, size_t epochs, double alpha);
+
void nn_layer_map_activation(
double (*activation)(double),
double *aout, size_t aout_shape[2],
Feel free to download, copy and edit any repo