aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main.c17
-rw-r--r--src/parse.c46
-rw-r--r--src/parse.h22
3 files changed, 39 insertions, 46 deletions
diff --git a/src/main.c b/src/main.c
index 01ada0a..e779e2d 100644
--- a/src/main.c
+++ b/src/main.c
@@ -116,10 +116,7 @@ int main(int argc, char *argv[]) {
Array X, y;
if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) {
- file_read(argv[1], &X, &y,
- ml_configs.input_keys, ml_configs.n_input_keys,
- ml_configs.label_keys, ml_configs.n_label_keys,
- true, ml_configs.file_format);
+ file_read(argv[1], &X, &y, ml_configs, true);
if (!strcmp("train", argv[0])) {
nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
} else if (!strcmp("retrain", argv[0])) {
@@ -130,10 +127,7 @@ int main(int argc, char *argv[]) {
nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath);
} else if (!strcmp("predict", argv[0])) {
- file_read(argv[1], &X, &y,
- ml_configs.input_keys, ml_configs.n_input_keys,
- ml_configs.label_keys, ml_configs.n_label_keys,
- false, ml_configs.file_format);
+ file_read(argv[1], &X, &y, ml_configs, false);
nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size);
@@ -142,12 +136,7 @@ int main(int argc, char *argv[]) {
if (!ml_configs.file_format && !ml_configs.out_filepath) {
ml_configs.file_format = file_format_infer(ml_configs.in_filepath);
}
-
- file_write(ml_configs.out_filepath, X, y,
- ml_configs.input_keys, ml_configs.n_input_keys,
- ml_configs.label_keys, ml_configs.n_label_keys,
- !ml_configs.only_out, ml_configs.file_format,
- ml_configs.decimal_precision);
+ file_write(X, y, ml_configs);
} else usage(1);
nn_network_free_weights(network, ml_configs.network_size);
diff --git a/src/parse.c b/src/parse.c
index a06f0f3..0533a13 100644
--- a/src/parse.c
+++ b/src/parse.c
@@ -77,13 +77,17 @@ static void csv_keys2cols(size_t cols[], char *keys[], size_t keys_size);
void file_read(
char *filepath,
Array *input, Array *out,
- char *in_keys[], size_t n_in_keys,
- char *out_keys[], size_t n_out_keys,
- bool read_output,
- char *file_format)
+ struct Configs ml_config,
+ bool read_output)
{
FILE *fp;
+ char **in_keys = ml_config.input_keys;
+ char **out_keys = ml_config.label_keys;
+ size_t n_in_keys = ml_config.n_input_keys;
+ size_t n_out_keys = ml_config.n_label_keys;
+ char *file_format = ml_config.file_format;
+
if (filepath != NULL && strcmp(filepath, "-")) {
fp = fopen(filepath, "r");
file_format = file_format_infer(filepath);
@@ -108,17 +112,19 @@ void file_read(
fclose(fp);
}
-void file_write(
- char *filepath,
- Array input, Array out,
- char *in_keys[], size_t n_in_keys,
- char *out_keys[], size_t n_out_keys,
- bool write_input,
- char *file_format,
- int decimal_precision)
+void file_write(Array input, Array out, struct Configs ml_config)
{
FILE *fp;
+ char *filepath = ml_config.out_filepath;
+ char **in_keys = ml_config.input_keys;
+ char **out_keys = ml_config.label_keys;
+ size_t n_in_keys = ml_config.n_input_keys;
+ size_t n_out_keys = ml_config.n_label_keys;
+ bool write_input = !ml_config.only_out;
+ char *file_format = ml_config.file_format;
+ char decimal_precision = ml_config.decimal_precision;
+
if (filepath != NULL && strcmp(filepath, "-")) {
fp = fopen(filepath, "w");
@@ -504,7 +510,19 @@ int main(int argc, char *argv[]) {
n_in_cols = parse_keys(in_cols, argv[3], keys_buffer);
n_out_cols = parse_keys(out_cols, argv[4], keys_buffer);
- file_read(in_file, &X, &y, in_cols, n_in_cols, out_cols, n_out_cols, true, format);
+ struct Configs ml_config = {
+ .in_filepath = in_file,
+ .out_filepath = out_file,
+ .input_keys = in_cols,
+ .label_keys = out_cols,
+ .n_input_keys = n_in_cols,
+ .n_label_keys = n_out_cols,
+ .file_format = format,
+ .only_out = false,
+ .decimal_precision = -1
+ };
+
+ file_read(in_file, &X, &y, ml_config, true);
for (i = 0; i < X.shape[0]; i++) {
for (j = 0; j < X.shape[1]; j++) {
@@ -522,7 +540,7 @@ int main(int argc, char *argv[]) {
// use input format if format variable is not defined
format = (!format && !strcmp(out_file, "-")) ? file_format_infer(in_file) : format;
- file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format, -1);
+ file_write(X, y, ml_config);
for (i = 0; i < n_in_cols; i++) free(in_cols[i]);
for (i = 0; i < n_out_cols; i++) free(out_cols[i]);
diff --git a/src/parse.h b/src/parse.h
index 18130c7..07f740b 100644
--- a/src/parse.h
+++ b/src/parse.h
@@ -4,29 +4,15 @@
#include <stdio.h>
#include <stdbool.h>
+#include "util.h"
+
typedef struct Array {
double *data;
size_t shape[2];
} Array;
-void file_read(
- char *filepath,
- Array *input, Array *out,
- char *in_keys[], size_t n_in_keys,
- char *out_keys[], size_t n_out_keys,
- bool read_output,
- char *file_format
- );
-
-void file_write(
- char *filepath,
- Array input, Array out,
- char *in_keys[], size_t n_in_keys,
- char *out_keys[], size_t n_out_keys,
- bool write_input,
- char *file_format,
- int decimal_precision);
-
+void file_read(char *filepath, Array *input, Array *out, struct Configs configs, bool read_output);
+void file_write(Array input, Array out, struct Configs ml_configs);
char * file_format_infer(char *filename);
#endif
Feel free to download, copy and edit any repo