diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/nn.c | 26 | ||||
-rw-r--r-- | src/nn.h | 5 |
2 files changed, 30 insertions, 1 deletions
@@ -14,6 +14,30 @@ struct Cost NN_SQUARE = { .dfunc_out = square_dloss_out }; +void nn_network_predict( + double *output, size_t output_shape[2], + double *input, size_t input_shape[2], + Layer network[], size_t network_size) +{ + double **outs = calloc(network_size, sizeof(double *)); + double **zouts = calloc(network_size, sizeof(double *)); + size_t samples = input_shape[0]; + for (size_t l = 0; l < network_size; l++) { + outs[l] = calloc(samples * network[l].neurons, sizeof(double)); + zouts[l] = calloc(samples * network[l].neurons, sizeof(double)); + } + + nn_forward(outs, zouts, input, input_shape, network, network_size); + memmove(output, outs[network_size - 1], samples * output_shape[1] * sizeof(double)); + + for (size_t l = 0; l < network_size; l++) { + free(outs[l]); + free(zouts[l]); + } + free(outs); + free(zouts); +} + void nn_network_train( Layer network[], size_t network_size, double *input, size_t input_shape[2], @@ -49,7 +73,7 @@ void nn_network_train( network, network_size, cost.dfunc_out, alpha); double *net_out = outs[network_size - 1]; - fprintf(stdout, "epoch: %zu \t loss: %6.2lf\n", + fprintf(stdout, "epoch: %zu \t loss: %6.6lf\n", epoch, get_avg_loss(labels, net_out, labels_shape, cost.func)); } @@ -29,6 +29,11 @@ typedef struct Layer { void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols); void nn_network_free_weights(Layer *network, size_t nmemb); +void nn_network_predict( + double *out, size_t out_shape[2], + double *input, size_t input_shape[2], + Layer network[], size_t network_size); + void nn_network_train( Layer network[], size_t network_size, double *input, size_t input_shape[2], |