diff options
-rw-r--r-- | src/main.c | 17 | ||||
-rw-r--r-- | src/parse.c | 46 | ||||
-rw-r--r-- | src/parse.h | 22 |
3 files changed, 39 insertions, 46 deletions
@@ -116,10 +116,7 @@ int main(int argc, char *argv[]) { Array X, y; if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) { - file_read(argv[1], &X, &y, - ml_configs.input_keys, ml_configs.n_input_keys, - ml_configs.label_keys, ml_configs.n_label_keys, - true, ml_configs.file_format); + file_read(argv[1], &X, &y, ml_configs, true); if (!strcmp("train", argv[0])) { nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); } else if (!strcmp("retrain", argv[0])) { @@ -130,10 +127,7 @@ int main(int argc, char *argv[]) { nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { - file_read(argv[1], &X, &y, - ml_configs.input_keys, ml_configs.n_input_keys, - ml_configs.label_keys, ml_configs.n_label_keys, - false, ml_configs.file_format); + file_read(argv[1], &X, &y, ml_configs, false); nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); @@ -142,12 +136,7 @@ int main(int argc, char *argv[]) { if (!ml_configs.file_format && !ml_configs.out_filepath) { ml_configs.file_format = file_format_infer(ml_configs.in_filepath); } - - file_write(ml_configs.out_filepath, X, y, - ml_configs.input_keys, ml_configs.n_input_keys, - ml_configs.label_keys, ml_configs.n_label_keys, - !ml_configs.only_out, ml_configs.file_format, - ml_configs.decimal_precision); + file_write(X, y, ml_configs); } else usage(1); nn_network_free_weights(network, ml_configs.network_size); diff --git a/src/parse.c b/src/parse.c index a06f0f3..0533a13 100644 --- a/src/parse.c +++ b/src/parse.c @@ -77,13 +77,17 @@ static void csv_keys2cols(size_t cols[], char *keys[], size_t keys_size); void file_read( char *filepath, Array *input, Array *out, - char *in_keys[], size_t n_in_keys, - char *out_keys[], size_t n_out_keys, - bool read_output, - char *file_format) + struct Configs ml_config, + bool read_output) { FILE *fp; + char **in_keys = ml_config.input_keys; + char **out_keys = ml_config.label_keys; + size_t n_in_keys = ml_config.n_input_keys; + size_t n_out_keys = ml_config.n_label_keys; + char *file_format = ml_config.file_format; + if (filepath != NULL && strcmp(filepath, "-")) { fp = fopen(filepath, "r"); file_format = file_format_infer(filepath); @@ -108,17 +112,19 @@ void file_read( fclose(fp); } -void file_write( - char *filepath, - Array input, Array out, - char *in_keys[], size_t n_in_keys, - char *out_keys[], size_t n_out_keys, - bool write_input, - char *file_format, - int decimal_precision) +void file_write(Array input, Array out, struct Configs ml_config) { FILE *fp; + char *filepath = ml_config.out_filepath; + char **in_keys = ml_config.input_keys; + char **out_keys = ml_config.label_keys; + size_t n_in_keys = ml_config.n_input_keys; + size_t n_out_keys = ml_config.n_label_keys; + bool write_input = !ml_config.only_out; + char *file_format = ml_config.file_format; + char decimal_precision = ml_config.decimal_precision; + if (filepath != NULL && strcmp(filepath, "-")) { fp = fopen(filepath, "w"); @@ -504,7 +510,19 @@ int main(int argc, char *argv[]) { n_in_cols = parse_keys(in_cols, argv[3], keys_buffer); n_out_cols = parse_keys(out_cols, argv[4], keys_buffer); - file_read(in_file, &X, &y, in_cols, n_in_cols, out_cols, n_out_cols, true, format); + struct Configs ml_config = { + .in_filepath = in_file, + .out_filepath = out_file, + .input_keys = in_cols, + .label_keys = out_cols, + .n_input_keys = n_in_cols, + .n_label_keys = n_out_cols, + .file_format = format, + .only_out = false, + .decimal_precision = -1 + }; + + file_read(in_file, &X, &y, ml_config, true); for (i = 0; i < X.shape[0]; i++) { for (j = 0; j < X.shape[1]; j++) { @@ -522,7 +540,7 @@ int main(int argc, char *argv[]) { // use input format if format variable is not defined format = (!format && !strcmp(out_file, "-")) ? file_format_infer(in_file) : format; - file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format, -1); + file_write(X, y, ml_config); for (i = 0; i < n_in_cols; i++) free(in_cols[i]); for (i = 0; i < n_out_cols; i++) free(out_cols[i]); diff --git a/src/parse.h b/src/parse.h index 18130c7..07f740b 100644 --- a/src/parse.h +++ b/src/parse.h @@ -4,29 +4,15 @@ #include <stdio.h> #include <stdbool.h> +#include "util.h" + typedef struct Array { double *data; size_t shape[2]; } Array; -void file_read( - char *filepath, - Array *input, Array *out, - char *in_keys[], size_t n_in_keys, - char *out_keys[], size_t n_out_keys, - bool read_output, - char *file_format - ); - -void file_write( - char *filepath, - Array input, Array out, - char *in_keys[], size_t n_in_keys, - char *out_keys[], size_t n_out_keys, - bool write_input, - char *file_format, - int decimal_precision); - +void file_read(char *filepath, Array *input, Array *out, struct Configs configs, bool read_output); +void file_write(Array input, Array out, struct Configs ml_configs); char * file_format_infer(char *filename); #endif |