From 30dd1e327571c3ba9de8ee8740c607dcc0ece584 Mon Sep 17 00:00:00 2001 From: jvech Date: Sat, 12 Aug 2023 07:49:19 -0500 Subject: add: nn_network_predict done --- src/nn.c | 26 +++++++++++++++++++++++++- src/nn.h | 5 +++++ 2 files changed, 30 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/nn.c b/src/nn.c index 63dd643..904fffc 100644 --- a/src/nn.c +++ b/src/nn.c @@ -14,6 +14,30 @@ struct Cost NN_SQUARE = { .dfunc_out = square_dloss_out }; +void nn_network_predict( + double *output, size_t output_shape[2], + double *input, size_t input_shape[2], + Layer network[], size_t network_size) +{ + double **outs = calloc(network_size, sizeof(double *)); + double **zouts = calloc(network_size, sizeof(double *)); + size_t samples = input_shape[0]; + for (size_t l = 0; l < network_size; l++) { + outs[l] = calloc(samples * network[l].neurons, sizeof(double)); + zouts[l] = calloc(samples * network[l].neurons, sizeof(double)); + } + + nn_forward(outs, zouts, input, input_shape, network, network_size); + memmove(output, outs[network_size - 1], samples * output_shape[1] * sizeof(double)); + + for (size_t l = 0; l < network_size; l++) { + free(outs[l]); + free(zouts[l]); + } + free(outs); + free(zouts); +} + void nn_network_train( Layer network[], size_t network_size, double *input, size_t input_shape[2], @@ -49,7 +73,7 @@ void nn_network_train( network, network_size, cost.dfunc_out, alpha); double *net_out = outs[network_size - 1]; - fprintf(stdout, "epoch: %zu \t loss: %6.2lf\n", + fprintf(stdout, "epoch: %zu \t loss: %6.6lf\n", epoch, get_avg_loss(labels, net_out, labels_shape, cost.func)); } diff --git a/src/nn.h b/src/nn.h index 2fcf9be..9005364 100644 --- a/src/nn.h +++ b/src/nn.h @@ -29,6 +29,11 @@ typedef struct Layer { void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols); void nn_network_free_weights(Layer *network, size_t nmemb); +void nn_network_predict( + double *out, size_t out_shape[2], + double *input, size_t input_shape[2], + Layer network[], size_t network_size); + void nn_network_train( Layer network[], size_t network_size, double *input, size_t input_shape[2], -- cgit v1.2.3-70-g09d2