aboutsummaryrefslogtreecommitdiff
path: root/src/nn.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.h')
-rw-r--r--src/nn.h36
1 files changed, 20 insertions, 16 deletions
diff --git a/src/nn.h b/src/nn.h
index 94772c2..c07a943 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -13,6 +13,7 @@
typedef struct Layer {
double *weights, *bias;
double (*activation)(double x);
+ double (*activation_derivative)(double x);
size_t neurons, input_nodes;
} Layer;
@@ -24,12 +25,6 @@ void nn_layer_map_activation(
double *aout, size_t aout_shape[2],
double *zout, size_t zout_shape[2]);
-void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]);
-void nn_layer_backward(
- double *weights, size_t weigths_shape[2],
- double *delta, size_t dcost_cols,
- double *out_prev, size_t out_cols,
- Layer layer, double alpha);
double sigmoid(double x);
double relu(double x);
@@ -41,22 +36,31 @@ void nn_forward(
double *input, size_t input_shape[2],
Layer network[], size_t network_size);
-void nn_backwad(
+void nn_backward(
double **weights,
- double **zout, double **outs, size_t n_rows,
+ double **zout, double **outs,
+ double *input, size_t input_shape[2],
+ double *labels, size_t labels_shape[2],
Layer network[], size_t network_size,
- double (cost_derivative)(double, double));
+ double (cost_derivative)(double, double),
+ double alpha);
+
+void nn_layer_forward(
+ Layer layer,
+ double *out, size_t out_shape[2],
+ double *input, size_t input_shape[2]);
+
+void nn_layer_backward(
+ double *weights, size_t weigths_shape[2],
+ double *delta, double *out_prev,
+ Layer layer, double alpha);
void nn_layer_out_delta(
- double *delta, size_t delta_cols,
- double *error, size_t error_cols,
- double *zout, size_t zout_cols,
+ double *delta, double *dcost_out, double *zout, size_t cols,
double (*activation_derivative)(double));//TODO
void nn_layer_hidden_delta(
- double *delta, size_t delta_cols,
- double *delta_next, size_t delta_next_cols,
- double *weigths_next, size_t weigths_shape[2],
- double *zout, size_t zout_cols,
+ double *delta, double *delta_next, double *zout,
+ double *weights_next, size_t weights_next_shape[2],
double (*activation_derivative)(double));//TODO
#endif
Feel free to download, copy and edit any repo