diff options
Diffstat (limited to 'src/main.c')
-rw-r--r-- | src/main.c | 33 |
1 files changed, 21 insertions, 12 deletions
@@ -79,6 +79,7 @@ Layer * load_network(struct Configs cfg) return network; } + int main(int argc, char *argv[]) { char default_config_path[512], *env_config_path; struct Configs ml_configs = { @@ -113,36 +114,44 @@ int main(int argc, char *argv[]) { argv += optind; Layer *network = load_network(ml_configs); - - Array X, y; + Array in, out; + double *X = NULL, *y = NULL; + size_t X_shape[2], y_shape[2]; if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) { - file_read(argv[1], &X, &y, ml_configs, true); + file_read(argv[1], &in, &out, ml_configs, true); + X = data_preprocess(X_shape, in, ml_configs, true, false); + y = data_preprocess(y_shape, out, ml_configs, false, false); if (!strcmp("train", argv[0])) { - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); + nn_network_init_weights(network, ml_configs.network_size, X_shape[1], true); } else if (!strcmp("retrain", argv[0])) { - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); + nn_network_init_weights(network, ml_configs.network_size, X_shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); } - nn_network_train(network, ml_configs, X.data, X.shape, y.data, y.shape); + nn_network_train(network, ml_configs, X, X_shape, y, y_shape); nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { - file_read(argv[1], &X, &y, ml_configs, false); - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); + file_read(argv[1], &in, &out, ml_configs, false); + X = data_preprocess(X_shape, in, ml_configs, true, false); + y = data_preprocess(y_shape, out, ml_configs, false, true); + nn_network_init_weights(network, ml_configs.network_size, X_shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); - nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); + nn_network_predict(y, y_shape, X, X_shape, network, ml_configs.network_size); // If neither output and file_format defined use input to define the output format if (!ml_configs.file_format && !ml_configs.out_filepath) { ml_configs.file_format = file_format_infer(ml_configs.in_filepath); } - file_write(X, y, ml_configs); + data_postprocess(&out, y, y_shape, ml_configs, false); + file_write(in, out, ml_configs); } else usage(1); nn_network_free_weights(network, ml_configs.network_size); free(network); - free(X.data); - free(y.data); + array_free(&in); + array_free(&out); + free(X); + free(y); util_free_config(&ml_configs); return 0; } |