aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/nn.c44
-rw-r--r--src/nn.h13
2 files changed, 39 insertions, 18 deletions
diff --git a/src/nn.c b/src/nn.c
index e47b56d..14120a5 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -3,7 +3,7 @@
static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols);
void nn_forward(
- double **out,
+ double **out, double **zout,
double *X, size_t X_shape[2],
Layer network[], size_t network_size)
{
@@ -14,25 +14,46 @@ void nn_forward(
for (size_t l = 0; l < network_size; l++) {
out_shape[1] = network[l].neurons;
- nn_layer_forward(network[l], out[l], out_shape, input, in_shape);
+ nn_layer_forward(network[l], zout[l], out_shape, input, in_shape);
+ nn_layer_map_activation(network[l].activation, out[l], out_shape, zout[l], out_shape);
in_shape[1] = out_shape[1];
input = out[l];
}
}
-void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2])
+void nn_layer_map_activation(
+ double (*activation)(double),
+ double *aout, size_t aout_shape[2],
+ double *zout, size_t zout_shape[2])
{
- if (out_shape[0] != input_shape[0] || out_shape[1] != layer.neurons) {
+ if (zout_shape[0] != aout_shape[0] || zout_shape[1] != aout_shape[1]) {
fprintf(stderr,
- "nn_layer_forward() Error: out must have (%zu x %zu) dimensions not (%zu x %zu)\n",
- input_shape[0], layer.neurons, out_shape[0], out_shape[1]);
+ "nn_layer_map_activation() Error: zout must have (%zu x %zu) dimensions not (%zu x %zu)\n",
+ aout_shape[0], aout_shape[1], zout_shape[0], zout_shape[1]);
+ exit(1);
+ }
+
+ for (size_t i = 0; i < aout_shape[0]; i++) {
+ for (size_t j = 0; j < aout_shape[1]; j ++) {
+ size_t index = aout_shape[1] * i + j;
+ aout[index] = activation(zout[index]);
+ }
+ }
+}
+
+void nn_layer_forward(Layer layer, double *zout, size_t zout_shape[2], double *input, size_t input_shape[2])
+{
+ if (zout_shape[0] != input_shape[0] || zout_shape[1] != layer.neurons) {
+ fprintf(stderr,
+ "nn_layer_forward() Error: zout must have (%zu x %zu) dimensions not (%zu x %zu)\n",
+ input_shape[0], layer.neurons, zout_shape[0], zout_shape[1]);
exit(1);
}
for (size_t i = 0; i < input_shape[0]; i++) {
for (size_t j = 0; j < layer.neurons; j++) {
size_t index = layer.neurons * i + j;
- out[index] = layer.bias[j];
+ zout[index] = layer.bias[j];
}
}
@@ -40,14 +61,7 @@ void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *inp
input_shape[0], layer.neurons, layer.input_nodes, // m, n, k
1.0, input, input_shape[1], //alpha X
layer.weights, layer.neurons, // W
- 1.0, out, layer.neurons); // beta B
-
- for (size_t i = 0; i < input_shape[0]; i++) {
- for (size_t j = 0; j < layer.neurons; j ++) {
- size_t index = layer.neurons * i + j;
- out[index] = layer.activation(out[index]);
- }
- }
+ 1.0, zout, layer.neurons); // beta B
}
void nn_network_init_weights(Layer layers[], size_t nmemb, size_t n_inputs)
diff --git a/src/nn.h b/src/nn.h
index a339dfc..62f438d 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -18,13 +18,20 @@ typedef struct Layer {
void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols);
void nn_network_free_weights(Layer *network, size_t nmemb);
-void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]); //TODO
-void nn_layer_backward(Layer *layer, double *out, size_t out_shape[2]); //TODO
+void nn_layer_map_activation(double (*activation)(double), double *aout, size_t aout_shape[2], double *zout, size_t zout_shape[2]);
+void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]);
+void nn_layer_backward(
+ Layer *layer,
+ double *weights,
+ double *out, size_t out_shape[2],
+ double *labels, size_t labels_shape[2],
+ double *local_gradient); //TODO
double sigmoid(double x);
double relu(double x);
double identity(double x);
-void nn_forward(double **out, double *input, size_t input_shape[2], Layer network[], size_t network_size);
+void nn_forward(double **aout, double **zout, double *input, size_t input_shape[2], Layer network[], size_t network_size);
+double nn_layer_out_delta(double error, double (*activation_derivative)(double));
#endif
Feel free to download, copy and edit any repo