From ed930338f7936630705c665cad9dd6d562344efc Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 30 Aug 2023 20:59:20 -0500 Subject: add: network read and write done json_read reactored --- src/main.c | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 16 deletions(-) (limited to 'src/main.c') diff --git a/src/main.c b/src/main.c index 30aca1e..aac7e94 100644 --- a/src/main.c +++ b/src/main.c @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -19,17 +20,25 @@ typedef struct Array { #define ARRAY_SIZE(x, type) sizeof(x) / sizeof(type) -static void json_read(const char *filepath, - Array *input, Array *out, - char *out_key, - char *in_keys[], - size_t in_keys_size); - -void json_read(const char *filepath, - Array *input, Array *out, - char *out_key, - char *in_keys[], - size_t n_input_keys) +static void json_read( + const char *filepath, + Array *input, Array *out, + char *out_keys[], size_t out_keys_size, + char *in_keys[], size_t in_keys_size, + bool read_output); + +static void json_write( + const char *filepath, + Array input, Array out, + char *out_keys[], size_t out_keys_size, + char *in_keys[], size_t in_keys_size); + +void json_read( + const char *filepath, + Array *input, Array *out, + char *out_keys[], size_t n_out_keys, + char *in_keys[], size_t n_input_keys, + bool read_output) { FILE *fp = NULL; char *fp_buffer = NULL; @@ -68,15 +77,24 @@ void json_read(const char *filepath, input->data = calloc(input->shape[0] * input->shape[1], sizeof(input->data[0])); out->shape[0] = (size_t)json_obj_length; - out->shape[1] = 1; + out->shape[1] = n_out_keys; out->data = calloc(out->shape[0] * out->shape[1], sizeof(out->data[0])); + if (!input->data || !out->data) goto json_read_error; + for (int i = 0; i < json_object_array_length(json_obj); i++) { json_object *item = json_object_array_get_idx(json_obj, i); - out->data[i] = json_object_get_double(json_object_object_get(item, out_key)); for (int j = 0; j < n_input_keys; j++) { - input->data[n_input_keys * i + j] = json_object_get_double(json_object_object_get(item, in_keys[j])); + size_t index = n_input_keys * i + j; + input->data[index] = json_object_get_double(json_object_object_get(item, in_keys[j])); + } + + if (!read_output) continue; + + for (int j = 0; j < n_out_keys; j++) { + size_t index = n_out_keys * i + j; + out->data[index] = json_object_get_double(json_object_object_get(item, out_keys[j])); } } @@ -90,6 +108,39 @@ json_read_error: exit(1); } +void json_write( + const char *filepath, + Array input, Array out, + char *out_keys[], size_t out_keys_size, + char *in_keys[], size_t in_keys_size) +{ + FILE *fp = (!filepath) ? fopen("/dev/stdout", "w") : fopen(filepath, "w"); + if (!fp) die("json_read() Error:"); + fprintf(fp, "[\n"); + + for (size_t i = 0; i < input.shape[0]; i++) { + fprintf(fp, " {\n"); + + for (size_t j = 0; j < input.shape[1]; j++) { + size_t index = input.shape[1] * i + j; + fprintf(fp, " \"%s\": %lf,\n", in_keys[j], input.data[index]); + } + + for (size_t j = 0; j < out.shape[1]; j++) { + size_t index = out.shape[1] * i + j; + fprintf(fp, " \"%s\": %lf", out_keys[j], out.data[index]); + + if (j == out.shape[1] - 1) fprintf(fp, "\n"); + else fprintf(fp, ",\n"); + } + + if (i == input.shape[0] - 1) fprintf(fp, " }\n"); + else fprintf(fp, " },\n"); + } + fprintf(fp, "]\n"); + fclose(fp); +} + void load_config(struct Configs *cfg, int n_args, ...) { char *filepath; @@ -145,6 +196,7 @@ int main(int argc, char *argv[]) { .alpha = 1e-5, .config_filepath = "utils/settings.cfg", .network_size = 0, + .out_filepath = NULL, }; // Try different config paths @@ -156,8 +208,8 @@ int main(int argc, char *argv[]) { Array X, y; if (!strcmp("train", argv[0])) { - json_read(argv[1], &X, &y, ml_configs.label_keys[0], ml_configs.input_keys, ml_configs.n_input_keys); - nn_network_init_weights(network, ml_configs.network_size, X.shape[1]); + json_read(argv[1], &X, &y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys, true); + nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); nn_network_train( network, ml_configs.network_size, X.data, X.shape, @@ -165,7 +217,14 @@ int main(int argc, char *argv[]) { load_loss(ml_configs), ml_configs.epochs, ml_configs.alpha); + nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); + fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { + json_read(argv[1], &X, &y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys, false); + nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); + nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); + nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); + json_write(ml_configs.out_filepath, X, y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys); } else usage(1); nn_network_free_weights(network, ml_configs.network_size); -- cgit v1.2.3-70-g09d2