diff options
Diffstat (limited to 'src/main.c')
-rw-r--r-- | src/main.c | 17 |
1 files changed, 3 insertions, 14 deletions
@@ -116,10 +116,7 @@ int main(int argc, char *argv[]) { Array X, y; if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) { - file_read(argv[1], &X, &y, - ml_configs.input_keys, ml_configs.n_input_keys, - ml_configs.label_keys, ml_configs.n_label_keys, - true, ml_configs.file_format); + file_read(argv[1], &X, &y, ml_configs, true); if (!strcmp("train", argv[0])) { nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); } else if (!strcmp("retrain", argv[0])) { @@ -130,10 +127,7 @@ int main(int argc, char *argv[]) { nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { - file_read(argv[1], &X, &y, - ml_configs.input_keys, ml_configs.n_input_keys, - ml_configs.label_keys, ml_configs.n_label_keys, - false, ml_configs.file_format); + file_read(argv[1], &X, &y, ml_configs, false); nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); @@ -142,12 +136,7 @@ int main(int argc, char *argv[]) { if (!ml_configs.file_format && !ml_configs.out_filepath) { ml_configs.file_format = file_format_infer(ml_configs.in_filepath); } - - file_write(ml_configs.out_filepath, X, y, - ml_configs.input_keys, ml_configs.n_input_keys, - ml_configs.label_keys, ml_configs.n_label_keys, - !ml_configs.only_out, ml_configs.file_format, - ml_configs.decimal_precision); + file_write(X, y, ml_configs); } else usage(1); nn_network_free_weights(network, ml_configs.network_size); |