aboutsummaryrefslogtreecommitdiff
path: root/src/nn.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.c')
-rw-r--r--src/nn.c81
1 files changed, 79 insertions, 2 deletions
diff --git a/src/nn.c b/src/nn.c
index 904fffc..c7b8ac3 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -1,3 +1,14 @@
+#include <stdlib.h>
+#include <assert.h>
+#include <stdbool.h>
+#include <stdio.h>
+#include <stdint.h>
+#include <string.h>
+#include <math.h>
+#include <unistd.h>
+#include <openblas/cblas.h>
+
+#include "util.h"
#include "nn.h"
static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols);
@@ -267,7 +278,71 @@ void nn_layer_forward(Layer layer, double *zout, size_t zout_shape[2], double *i
1.0, zout, layer.neurons); // beta B
}
-void nn_network_init_weights(Layer layers[], size_t nmemb, size_t n_inputs)
+void nn_network_read_weights(char *filepath, Layer *network, size_t network_size)
+{
+ FILE *fp = fopen(filepath, "rb");
+ if (fp == NULL) die("nn_network_read_weights Error():");
+
+ size_t net_size, shape[2], ret;
+ ret = fread(&net_size, sizeof(size_t), 1, fp);
+ if (net_size != network_size) goto nn_network_read_weights_error;
+
+ for (size_t i = 0; i < network_size; i++) {
+ fread(shape, sizeof(size_t), 2, fp);
+ if (shape[0] != network[i].input_nodes
+ || shape[1] != network[i].neurons) {
+ goto nn_network_read_weights_error;
+ }
+
+ if (!network[i].weights || !network[i].bias) {
+ die("nn_network_read_weights() Error: "
+ "the weights on layer %zu haven't been initialized", i);
+ }
+
+ ret = fread(network[i].weights, sizeof(double), shape[0] * shape[1], fp);
+ if (ret != shape[0] * shape[1]) goto nn_network_read_weights_error;
+
+ ret = fread(network[i].bias, sizeof(double), shape[1], fp);
+ if (ret != shape[1]) goto nn_network_read_weights_error;
+ }
+
+ return;
+
+nn_network_read_weights_error:
+ die("nn_network_read_weights() Error: "
+ "number of read objects does not match with expected ones");
+}
+
+void nn_network_write_weights(char *filepath, Layer *network, size_t network_size)
+{
+ FILE *fp = fopen(filepath, "wb");
+ if (fp == NULL) die("nn_network_write_weights() Error:");
+
+ fwrite(&network_size, sizeof(size_t), 1, fp);
+
+ size_t ret;
+ for (size_t i = 0; i < network_size; i++) {
+ size_t shape[2] = {network[i].input_nodes, network[i].neurons};
+ size_t size = shape[0] * shape[1];
+
+ ret = fwrite(shape, sizeof(size_t), 2, fp);
+ if (ret != 2) goto nn_network_write_weights_error;
+
+ ret = fwrite(network[i].weights, sizeof(double), size, fp);
+ if (ret != size) goto nn_network_write_weights_error;
+
+ ret = fwrite(network[i].weights, sizeof(double), network[i].neurons, fp);
+ if (ret != network[i].neurons) goto nn_network_write_weights_error;
+ }
+ fclose(fp);
+ return;
+
+nn_network_write_weights_error:
+ die("nn_network_write_weights() Error: "
+ "number of written objects does not match with number of objects");
+}
+
+void nn_network_init_weights(Layer layers[], size_t nmemb, size_t n_inputs, bool fill_random)
{
int i;
size_t prev_size = n_inputs;
@@ -280,7 +355,9 @@ void nn_network_init_weights(Layer layers[], size_t nmemb, size_t n_inputs)
if (layers[i].weights == NULL || layers[i].bias == NULL) {
goto nn_layers_calloc_weights_error;
}
- fill_random_weights(layers[i].weights, layers[i].bias, prev_size, layers[i].neurons);
+
+ if (fill_random) fill_random_weights(layers[i].weights, layers[i].bias, prev_size, layers[i].neurons);
+
layers[i].input_nodes = prev_size;
prev_size = layers[i].neurons;
}
Feel free to download, copy and edit any repo