aboutsummaryrefslogtreecommitdiff
path: root/src/main.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.c')
-rw-r--r--src/main.c33
1 files changed, 21 insertions, 12 deletions
diff --git a/src/main.c b/src/main.c
index e779e2d..7918706 100644
--- a/src/main.c
+++ b/src/main.c
@@ -79,6 +79,7 @@ Layer * load_network(struct Configs cfg)
return network;
}
+
int main(int argc, char *argv[]) {
char default_config_path[512], *env_config_path;
struct Configs ml_configs = {
@@ -113,36 +114,44 @@ int main(int argc, char *argv[]) {
argv += optind;
Layer *network = load_network(ml_configs);
-
- Array X, y;
+ Array in, out;
+ double *X = NULL, *y = NULL;
+ size_t X_shape[2], y_shape[2];
if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) {
- file_read(argv[1], &X, &y, ml_configs, true);
+ file_read(argv[1], &in, &out, ml_configs, true);
+ X = data_preprocess(X_shape, in, ml_configs, true, false);
+ y = data_preprocess(y_shape, out, ml_configs, false, false);
if (!strcmp("train", argv[0])) {
- nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
+ nn_network_init_weights(network, ml_configs.network_size, X_shape[1], true);
} else if (!strcmp("retrain", argv[0])) {
- nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
+ nn_network_init_weights(network, ml_configs.network_size, X_shape[1], false);
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
}
- nn_network_train(network, ml_configs, X.data, X.shape, y.data, y.shape);
+ nn_network_train(network, ml_configs, X, X_shape, y, y_shape);
nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath);
} else if (!strcmp("predict", argv[0])) {
- file_read(argv[1], &X, &y, ml_configs, false);
- nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
+ file_read(argv[1], &in, &out, ml_configs, false);
+ X = data_preprocess(X_shape, in, ml_configs, true, false);
+ y = data_preprocess(y_shape, out, ml_configs, false, true);
+ nn_network_init_weights(network, ml_configs.network_size, X_shape[1], false);
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
- nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size);
+ nn_network_predict(y, y_shape, X, X_shape, network, ml_configs.network_size);
// If neither output and file_format defined use input to define the output format
if (!ml_configs.file_format && !ml_configs.out_filepath) {
ml_configs.file_format = file_format_infer(ml_configs.in_filepath);
}
- file_write(X, y, ml_configs);
+ data_postprocess(&out, y, y_shape, ml_configs, false);
+ file_write(in, out, ml_configs);
} else usage(1);
nn_network_free_weights(network, ml_configs.network_size);
free(network);
- free(X.data);
- free(y.data);
+ array_free(&in);
+ array_free(&out);
+ free(X);
+ free(y);
util_free_config(&ml_configs);
return 0;
}
Feel free to download, copy and edit any repo