aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/nn.c26
-rw-r--r--src/nn.h5
2 files changed, 30 insertions, 1 deletions
diff --git a/src/nn.c b/src/nn.c
index 63dd643..904fffc 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -14,6 +14,30 @@ struct Cost NN_SQUARE = {
.dfunc_out = square_dloss_out
};
+void nn_network_predict(
+ double *output, size_t output_shape[2],
+ double *input, size_t input_shape[2],
+ Layer network[], size_t network_size)
+{
+ double **outs = calloc(network_size, sizeof(double *));
+ double **zouts = calloc(network_size, sizeof(double *));
+ size_t samples = input_shape[0];
+ for (size_t l = 0; l < network_size; l++) {
+ outs[l] = calloc(samples * network[l].neurons, sizeof(double));
+ zouts[l] = calloc(samples * network[l].neurons, sizeof(double));
+ }
+
+ nn_forward(outs, zouts, input, input_shape, network, network_size);
+ memmove(output, outs[network_size - 1], samples * output_shape[1] * sizeof(double));
+
+ for (size_t l = 0; l < network_size; l++) {
+ free(outs[l]);
+ free(zouts[l]);
+ }
+ free(outs);
+ free(zouts);
+}
+
void nn_network_train(
Layer network[], size_t network_size,
double *input, size_t input_shape[2],
@@ -49,7 +73,7 @@ void nn_network_train(
network, network_size,
cost.dfunc_out, alpha);
double *net_out = outs[network_size - 1];
- fprintf(stdout, "epoch: %zu \t loss: %6.2lf\n",
+ fprintf(stdout, "epoch: %zu \t loss: %6.6lf\n",
epoch, get_avg_loss(labels, net_out, labels_shape, cost.func));
}
diff --git a/src/nn.h b/src/nn.h
index 2fcf9be..9005364 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -29,6 +29,11 @@ typedef struct Layer {
void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols);
void nn_network_free_weights(Layer *network, size_t nmemb);
+void nn_network_predict(
+ double *out, size_t out_shape[2],
+ double *input, size_t input_shape[2],
+ Layer network[], size_t network_size);
+
void nn_network_train(
Layer network[], size_t network_size,
double *input, size_t input_shape[2],
Feel free to download, copy and edit any repo