aboutsummaryrefslogtreecommitdiff
path: root/src/main.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.c')
-rw-r--r--src/main.c91
1 files changed, 75 insertions, 16 deletions
diff --git a/src/main.c b/src/main.c
index 30aca1e..aac7e94 100644
--- a/src/main.c
+++ b/src/main.c
@@ -1,5 +1,6 @@
#include <stdio.h>
#include <stdint.h>
+#include <stdbool.h>
#include <string.h>
#include <stdarg.h>
#include <errno.h>
@@ -19,17 +20,25 @@ typedef struct Array {
#define ARRAY_SIZE(x, type) sizeof(x) / sizeof(type)
-static void json_read(const char *filepath,
- Array *input, Array *out,
- char *out_key,
- char *in_keys[],
- size_t in_keys_size);
-
-void json_read(const char *filepath,
- Array *input, Array *out,
- char *out_key,
- char *in_keys[],
- size_t n_input_keys)
+static void json_read(
+ const char *filepath,
+ Array *input, Array *out,
+ char *out_keys[], size_t out_keys_size,
+ char *in_keys[], size_t in_keys_size,
+ bool read_output);
+
+static void json_write(
+ const char *filepath,
+ Array input, Array out,
+ char *out_keys[], size_t out_keys_size,
+ char *in_keys[], size_t in_keys_size);
+
+void json_read(
+ const char *filepath,
+ Array *input, Array *out,
+ char *out_keys[], size_t n_out_keys,
+ char *in_keys[], size_t n_input_keys,
+ bool read_output)
{
FILE *fp = NULL;
char *fp_buffer = NULL;
@@ -68,15 +77,24 @@ void json_read(const char *filepath,
input->data = calloc(input->shape[0] * input->shape[1], sizeof(input->data[0]));
out->shape[0] = (size_t)json_obj_length;
- out->shape[1] = 1;
+ out->shape[1] = n_out_keys;
out->data = calloc(out->shape[0] * out->shape[1], sizeof(out->data[0]));
+ if (!input->data || !out->data) goto json_read_error;
+
for (int i = 0; i < json_object_array_length(json_obj); i++) {
json_object *item = json_object_array_get_idx(json_obj, i);
- out->data[i] = json_object_get_double(json_object_object_get(item, out_key));
for (int j = 0; j < n_input_keys; j++) {
- input->data[n_input_keys * i + j] = json_object_get_double(json_object_object_get(item, in_keys[j]));
+ size_t index = n_input_keys * i + j;
+ input->data[index] = json_object_get_double(json_object_object_get(item, in_keys[j]));
+ }
+
+ if (!read_output) continue;
+
+ for (int j = 0; j < n_out_keys; j++) {
+ size_t index = n_out_keys * i + j;
+ out->data[index] = json_object_get_double(json_object_object_get(item, out_keys[j]));
}
}
@@ -90,6 +108,39 @@ json_read_error:
exit(1);
}
+void json_write(
+ const char *filepath,
+ Array input, Array out,
+ char *out_keys[], size_t out_keys_size,
+ char *in_keys[], size_t in_keys_size)
+{
+ FILE *fp = (!filepath) ? fopen("/dev/stdout", "w") : fopen(filepath, "w");
+ if (!fp) die("json_read() Error:");
+ fprintf(fp, "[\n");
+
+ for (size_t i = 0; i < input.shape[0]; i++) {
+ fprintf(fp, " {\n");
+
+ for (size_t j = 0; j < input.shape[1]; j++) {
+ size_t index = input.shape[1] * i + j;
+ fprintf(fp, " \"%s\": %lf,\n", in_keys[j], input.data[index]);
+ }
+
+ for (size_t j = 0; j < out.shape[1]; j++) {
+ size_t index = out.shape[1] * i + j;
+ fprintf(fp, " \"%s\": %lf", out_keys[j], out.data[index]);
+
+ if (j == out.shape[1] - 1) fprintf(fp, "\n");
+ else fprintf(fp, ",\n");
+ }
+
+ if (i == input.shape[0] - 1) fprintf(fp, " }\n");
+ else fprintf(fp, " },\n");
+ }
+ fprintf(fp, "]\n");
+ fclose(fp);
+}
+
void load_config(struct Configs *cfg, int n_args, ...)
{
char *filepath;
@@ -145,6 +196,7 @@ int main(int argc, char *argv[]) {
.alpha = 1e-5,
.config_filepath = "utils/settings.cfg",
.network_size = 0,
+ .out_filepath = NULL,
};
// Try different config paths
@@ -156,8 +208,8 @@ int main(int argc, char *argv[]) {
Array X, y;
if (!strcmp("train", argv[0])) {
- json_read(argv[1], &X, &y, ml_configs.label_keys[0], ml_configs.input_keys, ml_configs.n_input_keys);
- nn_network_init_weights(network, ml_configs.network_size, X.shape[1]);
+ json_read(argv[1], &X, &y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys, true);
+ nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
nn_network_train(
network, ml_configs.network_size,
X.data, X.shape,
@@ -165,7 +217,14 @@ int main(int argc, char *argv[]) {
load_loss(ml_configs),
ml_configs.epochs,
ml_configs.alpha);
+ nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
+ fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath);
} else if (!strcmp("predict", argv[0])) {
+ json_read(argv[1], &X, &y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys, false);
+ nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
+ nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
+ nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size);
+ json_write(ml_configs.out_filepath, X, y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys);
} else usage(1);
nn_network_free_weights(network, ml_configs.network_size);
Feel free to download, copy and edit any repo