aboutsummaryrefslogtreecommitdiff
path: root/src/nn.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.c')
-rw-r--r--src/nn.c26
1 files changed, 25 insertions, 1 deletions
diff --git a/src/nn.c b/src/nn.c
index 63dd643..904fffc 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -14,6 +14,30 @@ struct Cost NN_SQUARE = {
.dfunc_out = square_dloss_out
};
+void nn_network_predict(
+ double *output, size_t output_shape[2],
+ double *input, size_t input_shape[2],
+ Layer network[], size_t network_size)
+{
+ double **outs = calloc(network_size, sizeof(double *));
+ double **zouts = calloc(network_size, sizeof(double *));
+ size_t samples = input_shape[0];
+ for (size_t l = 0; l < network_size; l++) {
+ outs[l] = calloc(samples * network[l].neurons, sizeof(double));
+ zouts[l] = calloc(samples * network[l].neurons, sizeof(double));
+ }
+
+ nn_forward(outs, zouts, input, input_shape, network, network_size);
+ memmove(output, outs[network_size - 1], samples * output_shape[1] * sizeof(double));
+
+ for (size_t l = 0; l < network_size; l++) {
+ free(outs[l]);
+ free(zouts[l]);
+ }
+ free(outs);
+ free(zouts);
+}
+
void nn_network_train(
Layer network[], size_t network_size,
double *input, size_t input_shape[2],
@@ -49,7 +73,7 @@ void nn_network_train(
network, network_size,
cost.dfunc_out, alpha);
double *net_out = outs[network_size - 1];
- fprintf(stdout, "epoch: %zu \t loss: %6.2lf\n",
+ fprintf(stdout, "epoch: %zu \t loss: %6.6lf\n",
epoch, get_avg_loss(labels, net_out, labels_shape, cost.func));
}
Feel free to download, copy and edit any repo