diff options
Diffstat (limited to 'src/nn.c')
-rw-r--r-- | src/nn.c | 81 |
1 files changed, 79 insertions, 2 deletions
@@ -1,3 +1,14 @@ +#include <stdlib.h> +#include <assert.h> +#include <stdbool.h> +#include <stdio.h> +#include <stdint.h> +#include <string.h> +#include <math.h> +#include <unistd.h> +#include <openblas/cblas.h> + +#include "util.h" #include "nn.h" static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); @@ -267,7 +278,71 @@ void nn_layer_forward(Layer layer, double *zout, size_t zout_shape[2], double *i 1.0, zout, layer.neurons); // beta B } -void nn_network_init_weights(Layer layers[], size_t nmemb, size_t n_inputs) +void nn_network_read_weights(char *filepath, Layer *network, size_t network_size) +{ + FILE *fp = fopen(filepath, "rb"); + if (fp == NULL) die("nn_network_read_weights Error():"); + + size_t net_size, shape[2], ret; + ret = fread(&net_size, sizeof(size_t), 1, fp); + if (net_size != network_size) goto nn_network_read_weights_error; + + for (size_t i = 0; i < network_size; i++) { + fread(shape, sizeof(size_t), 2, fp); + if (shape[0] != network[i].input_nodes + || shape[1] != network[i].neurons) { + goto nn_network_read_weights_error; + } + + if (!network[i].weights || !network[i].bias) { + die("nn_network_read_weights() Error: " + "the weights on layer %zu haven't been initialized", i); + } + + ret = fread(network[i].weights, sizeof(double), shape[0] * shape[1], fp); + if (ret != shape[0] * shape[1]) goto nn_network_read_weights_error; + + ret = fread(network[i].bias, sizeof(double), shape[1], fp); + if (ret != shape[1]) goto nn_network_read_weights_error; + } + + return; + +nn_network_read_weights_error: + die("nn_network_read_weights() Error: " + "number of read objects does not match with expected ones"); +} + +void nn_network_write_weights(char *filepath, Layer *network, size_t network_size) +{ + FILE *fp = fopen(filepath, "wb"); + if (fp == NULL) die("nn_network_write_weights() Error:"); + + fwrite(&network_size, sizeof(size_t), 1, fp); + + size_t ret; + for (size_t i = 0; i < network_size; i++) { + size_t shape[2] = {network[i].input_nodes, network[i].neurons}; + size_t size = shape[0] * shape[1]; + + ret = fwrite(shape, sizeof(size_t), 2, fp); + if (ret != 2) goto nn_network_write_weights_error; + + ret = fwrite(network[i].weights, sizeof(double), size, fp); + if (ret != size) goto nn_network_write_weights_error; + + ret = fwrite(network[i].weights, sizeof(double), network[i].neurons, fp); + if (ret != network[i].neurons) goto nn_network_write_weights_error; + } + fclose(fp); + return; + +nn_network_write_weights_error: + die("nn_network_write_weights() Error: " + "number of written objects does not match with number of objects"); +} + +void nn_network_init_weights(Layer layers[], size_t nmemb, size_t n_inputs, bool fill_random) { int i; size_t prev_size = n_inputs; @@ -280,7 +355,9 @@ void nn_network_init_weights(Layer layers[], size_t nmemb, size_t n_inputs) if (layers[i].weights == NULL || layers[i].bias == NULL) { goto nn_layers_calloc_weights_error; } - fill_random_weights(layers[i].weights, layers[i].bias, prev_size, layers[i].neurons); + + if (fill_random) fill_random_weights(layers[i].weights, layers[i].bias, prev_size, layers[i].neurons); + layers[i].input_nodes = prev_size; prev_size = layers[i].neurons; } |