aboutsummaryrefslogtreecommitdiff
path: root/src/main.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.c')
-rw-r--r--src/main.c61
1 files changed, 19 insertions, 42 deletions
diff --git a/src/main.c b/src/main.c
index e692756..216d8d4 100644
--- a/src/main.c
+++ b/src/main.c
@@ -34,45 +34,6 @@
#define ARRAY_SIZE(x, type) sizeof(x) / sizeof(type)
-static void json_write(
- const char *filepath,
- Array input, Array out,
- char *out_keys[], size_t out_keys_size,
- char *in_keys[], size_t in_keys_size);
-
-void json_write(
- const char *filepath,
- Array input, Array out,
- char *out_keys[], size_t out_keys_size,
- char *in_keys[], size_t in_keys_size)
-{
- FILE *fp = (!filepath) ? fopen("/dev/stdout", "w") : fopen(filepath, "w");
- if (!fp) die("json_read() Error:");
- fprintf(fp, "[\n");
-
- for (size_t i = 0; i < input.shape[0]; i++) {
- fprintf(fp, " {\n");
-
- for (size_t j = 0; j < input.shape[1]; j++) {
- size_t index = input.shape[1] * i + j;
- fprintf(fp, " \"%s\": %lf,\n", in_keys[j], input.data[index]);
- }
-
- for (size_t j = 0; j < out.shape[1]; j++) {
- size_t index = out.shape[1] * i + j;
- fprintf(fp, " \"%s\": %lf", out_keys[j], out.data[index]);
-
- if (j == out.shape[1] - 1) fprintf(fp, "\n");
- else fprintf(fp, ",\n");
- }
-
- if (i == input.shape[0] - 1) fprintf(fp, " }\n");
- else fprintf(fp, " },\n");
- }
- fprintf(fp, "]\n");
- fclose(fp);
-}
-
void load_config(struct Configs *cfg, int n_args, ...)
{
char *filepath;
@@ -133,6 +94,7 @@ int main(int argc, char *argv[]) {
.alpha = 1e-5,
.config_filepath = "utils/settings.cfg",
.network_size = 0,
+ .only_out = false,
.file_format = NULL,
.out_filepath = NULL,
};
@@ -153,7 +115,10 @@ int main(int argc, char *argv[]) {
Array X, y;
if (!strcmp("train", argv[0])) {
- file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, true, ml_configs.file_format);
+ file_read(argv[1], &X, &y,
+ ml_configs.input_keys, ml_configs.n_input_keys,
+ ml_configs.label_keys, ml_configs.n_label_keys,
+ true, ml_configs.file_format);
nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
nn_network_train(
network, ml_configs.network_size,
@@ -165,11 +130,23 @@ int main(int argc, char *argv[]) {
nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath);
} else if (!strcmp("predict", argv[0])) {
- file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, false, ml_configs.file_format);
+ file_read(argv[1], &X, &y,
+ ml_configs.input_keys, ml_configs.n_input_keys,
+ ml_configs.label_keys, ml_configs.n_label_keys,
+ false, ml_configs.file_format);
nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size);
- json_write(ml_configs.out_filepath, X, y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys);
+
+ // If neither output and file_format defined use input to define the format
+ if (!ml_configs.file_format && !ml_configs.out_filepath) {
+ ml_configs.file_format = file_format_infer(ml_configs.in_filepath);
+ }
+
+ file_write(ml_configs.out_filepath, X, y,
+ ml_configs.input_keys, ml_configs.n_input_keys,
+ ml_configs.label_keys, ml_configs.n_label_keys,
+ !ml_configs.only_out, ml_configs.file_format);
} else usage(1);
nn_network_free_weights(network, ml_configs.network_size);
Feel free to download, copy and edit any repo