aboutsummaryrefslogtreecommitdiff
path: root/src/nn.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.h')
-rw-r--r--src/nn.h22
1 files changed, 15 insertions, 7 deletions
diff --git a/src/nn.h b/src/nn.h
index d6dd1fe..94772c2 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -26,19 +26,27 @@ void nn_layer_map_activation(
void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]);
void nn_layer_backward(
- Layer layer,
- double *weights,
- double *cost_derivative, size_t dcost_shape[2],
- double *out, size_t out_shape[2],
- double *labels, size_t labels_shape[2],
- double *local_gradient); //TODO
+ double *weights, size_t weigths_shape[2],
+ double *delta, size_t dcost_cols,
+ double *out_prev, size_t out_cols,
+ Layer layer, double alpha);
double sigmoid(double x);
double relu(double x);
double identity(double x);
-void nn_forward(double **aout, double **zout, double *input, size_t input_shape[2], Layer network[], size_t network_size);
+void nn_forward(
+ double **aout, double **zout,
+ double *input, size_t input_shape[2],
+ Layer network[], size_t network_size);
+
+void nn_backwad(
+ double **weights,
+ double **zout, double **outs, size_t n_rows,
+ Layer network[], size_t network_size,
+ double (cost_derivative)(double, double));
+
void nn_layer_out_delta(
double *delta, size_t delta_cols,
double *error, size_t error_cols,
Feel free to download, copy and edit any repo