aboutsummaryrefslogtreecommitdiff
path: root/src/parse.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/parse.c')
-rw-r--r--src/parse.c46
1 files changed, 32 insertions, 14 deletions
diff --git a/src/parse.c b/src/parse.c
index a06f0f3..0533a13 100644
--- a/src/parse.c
+++ b/src/parse.c
@@ -77,13 +77,17 @@ static void csv_keys2cols(size_t cols[], char *keys[], size_t keys_size);
void file_read(
char *filepath,
Array *input, Array *out,
- char *in_keys[], size_t n_in_keys,
- char *out_keys[], size_t n_out_keys,
- bool read_output,
- char *file_format)
+ struct Configs ml_config,
+ bool read_output)
{
FILE *fp;
+ char **in_keys = ml_config.input_keys;
+ char **out_keys = ml_config.label_keys;
+ size_t n_in_keys = ml_config.n_input_keys;
+ size_t n_out_keys = ml_config.n_label_keys;
+ char *file_format = ml_config.file_format;
+
if (filepath != NULL && strcmp(filepath, "-")) {
fp = fopen(filepath, "r");
file_format = file_format_infer(filepath);
@@ -108,17 +112,19 @@ void file_read(
fclose(fp);
}
-void file_write(
- char *filepath,
- Array input, Array out,
- char *in_keys[], size_t n_in_keys,
- char *out_keys[], size_t n_out_keys,
- bool write_input,
- char *file_format,
- int decimal_precision)
+void file_write(Array input, Array out, struct Configs ml_config)
{
FILE *fp;
+ char *filepath = ml_config.out_filepath;
+ char **in_keys = ml_config.input_keys;
+ char **out_keys = ml_config.label_keys;
+ size_t n_in_keys = ml_config.n_input_keys;
+ size_t n_out_keys = ml_config.n_label_keys;
+ bool write_input = !ml_config.only_out;
+ char *file_format = ml_config.file_format;
+ char decimal_precision = ml_config.decimal_precision;
+
if (filepath != NULL && strcmp(filepath, "-")) {
fp = fopen(filepath, "w");
@@ -504,7 +510,19 @@ int main(int argc, char *argv[]) {
n_in_cols = parse_keys(in_cols, argv[3], keys_buffer);
n_out_cols = parse_keys(out_cols, argv[4], keys_buffer);
- file_read(in_file, &X, &y, in_cols, n_in_cols, out_cols, n_out_cols, true, format);
+ struct Configs ml_config = {
+ .in_filepath = in_file,
+ .out_filepath = out_file,
+ .input_keys = in_cols,
+ .label_keys = out_cols,
+ .n_input_keys = n_in_cols,
+ .n_label_keys = n_out_cols,
+ .file_format = format,
+ .only_out = false,
+ .decimal_precision = -1
+ };
+
+ file_read(in_file, &X, &y, ml_config, true);
for (i = 0; i < X.shape[0]; i++) {
for (j = 0; j < X.shape[1]; j++) {
@@ -522,7 +540,7 @@ int main(int argc, char *argv[]) {
// use input format if format variable is not defined
format = (!format && !strcmp(out_file, "-")) ? file_format_infer(in_file) : format;
- file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format, -1);
+ file_write(X, y, ml_config);
for (i = 0; i < n_in_cols; i++) free(in_cols[i]);
for (i = 0; i < n_out_cols; i++) free(out_cols[i]);
Feel free to download, copy and edit any repo