diff options
Diffstat (limited to 'src/main.c')
-rw-r--r-- | src/main.c | 78 |
1 files changed, 4 insertions, 74 deletions
@@ -26,91 +26,20 @@ #include <json-c/json.h> #include "util.h" +#include "parse.h" #include "nn.h" #define MAX_FILE_SIZE 536870912 //1<<29; 0.5 GiB -typedef struct Array { - double *data; - size_t shape[2]; -} Array; - #define ARRAY_SIZE(x, type) sizeof(x) / sizeof(type) -static void json_read( - const char *filepath, - Array *input, Array *out, - char *out_keys[], size_t out_keys_size, - char *in_keys[], size_t in_keys_size, - bool read_output); - static void json_write( const char *filepath, Array input, Array out, char *out_keys[], size_t out_keys_size, char *in_keys[], size_t in_keys_size); -void json_read( - const char *filepath, - Array *input, Array *out, - char *out_keys[], size_t n_out_keys, - char *in_keys[], size_t n_input_keys, - bool read_output) -{ - FILE *fp = NULL; - static char fp_buffer[MAX_FILE_SIZE]; - - fp = (!strcmp(filepath, "-")) ? fopen("/dev/stdin", "r") : fopen(filepath, "r"); - - if (fp == NULL) goto json_read_error; - - size_t i = 0; - do { - if (i >= MAX_FILE_SIZE) die("json_read() Error: file size is bigger than '%zu'", i, MAX_FILE_SIZE); - fp_buffer[i] = fgetc(fp); - } while (fp_buffer[i++] != EOF); - - json_object *json_obj; - json_obj = json_tokener_parse(fp_buffer); - size_t json_obj_length = json_object_array_length(json_obj); - - input->shape[0] = (size_t)json_obj_length; - input->shape[1] = n_input_keys; - input->data = calloc(input->shape[0] * input->shape[1], sizeof(input->data[0])); - - out->shape[0] = (size_t)json_obj_length; - out->shape[1] = n_out_keys; - out->data = calloc(out->shape[0] * out->shape[1], sizeof(out->data[0])); - - if (!input->data || !out->data) goto json_read_error; - - for (int i = 0; i < json_object_array_length(json_obj); i++) { - json_object *item = json_object_array_get_idx(json_obj, i); - - for (int j = 0; j < n_input_keys; j++) { - size_t index = n_input_keys * i + j; - input->data[index] = json_object_get_double(json_object_object_get(item, in_keys[j])); - } - - if (!read_output) continue; - - for (int j = 0; j < n_out_keys; j++) { - size_t index = n_out_keys * i + j; - out->data[index] = json_object_get_double(json_object_object_get(item, out_keys[j])); - } - } - - json_object_put(json_obj); - fclose(fp); - - return; - -json_read_error: - perror("json_read() Error"); - exit(1); -} - void json_write( const char *filepath, Array input, Array out, @@ -204,6 +133,7 @@ int main(int argc, char *argv[]) { .alpha = 1e-5, .config_filepath = "utils/settings.cfg", .network_size = 0, + .file_format = NULL, .out_filepath = NULL, }; @@ -223,7 +153,7 @@ int main(int argc, char *argv[]) { Array X, y; if (!strcmp("train", argv[0])) { - json_read(argv[1], &X, &y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys, true); + file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, true, ml_configs.file_format); nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); nn_network_train( network, ml_configs.network_size, @@ -235,7 +165,7 @@ int main(int argc, char *argv[]) { nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { - json_read(argv[1], &X, &y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys, false); + file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, false, ml_configs.file_format); nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); |