diff options
Diffstat (limited to 'src/main.c')
-rw-r--r-- | src/main.c | 61 |
1 files changed, 19 insertions, 42 deletions
@@ -34,45 +34,6 @@ #define ARRAY_SIZE(x, type) sizeof(x) / sizeof(type) -static void json_write( - const char *filepath, - Array input, Array out, - char *out_keys[], size_t out_keys_size, - char *in_keys[], size_t in_keys_size); - -void json_write( - const char *filepath, - Array input, Array out, - char *out_keys[], size_t out_keys_size, - char *in_keys[], size_t in_keys_size) -{ - FILE *fp = (!filepath) ? fopen("/dev/stdout", "w") : fopen(filepath, "w"); - if (!fp) die("json_read() Error:"); - fprintf(fp, "[\n"); - - for (size_t i = 0; i < input.shape[0]; i++) { - fprintf(fp, " {\n"); - - for (size_t j = 0; j < input.shape[1]; j++) { - size_t index = input.shape[1] * i + j; - fprintf(fp, " \"%s\": %lf,\n", in_keys[j], input.data[index]); - } - - for (size_t j = 0; j < out.shape[1]; j++) { - size_t index = out.shape[1] * i + j; - fprintf(fp, " \"%s\": %lf", out_keys[j], out.data[index]); - - if (j == out.shape[1] - 1) fprintf(fp, "\n"); - else fprintf(fp, ",\n"); - } - - if (i == input.shape[0] - 1) fprintf(fp, " }\n"); - else fprintf(fp, " },\n"); - } - fprintf(fp, "]\n"); - fclose(fp); -} - void load_config(struct Configs *cfg, int n_args, ...) { char *filepath; @@ -133,6 +94,7 @@ int main(int argc, char *argv[]) { .alpha = 1e-5, .config_filepath = "utils/settings.cfg", .network_size = 0, + .only_out = false, .file_format = NULL, .out_filepath = NULL, }; @@ -153,7 +115,10 @@ int main(int argc, char *argv[]) { Array X, y; if (!strcmp("train", argv[0])) { - file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, true, ml_configs.file_format); + file_read(argv[1], &X, &y, + ml_configs.input_keys, ml_configs.n_input_keys, + ml_configs.label_keys, ml_configs.n_label_keys, + true, ml_configs.file_format); nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); nn_network_train( network, ml_configs.network_size, @@ -165,11 +130,23 @@ int main(int argc, char *argv[]) { nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { - file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, false, ml_configs.file_format); + file_read(argv[1], &X, &y, + ml_configs.input_keys, ml_configs.n_input_keys, + ml_configs.label_keys, ml_configs.n_label_keys, + false, ml_configs.file_format); nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); - json_write(ml_configs.out_filepath, X, y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys); + + // If neither output and file_format defined use input to define the format + if (!ml_configs.file_format && !ml_configs.out_filepath) { + ml_configs.file_format = file_format_infer(ml_configs.in_filepath); + } + + file_write(ml_configs.out_filepath, X, y, + ml_configs.input_keys, ml_configs.n_input_keys, + ml_configs.label_keys, ml_configs.n_label_keys, + !ml_configs.only_out, ml_configs.file_format); } else usage(1); nn_network_free_weights(network, ml_configs.network_size); |