aboutsummaryrefslogtreecommitdiff
path: root/src/main.c
diff options
context:
space:
mode:
authorjvech <jmvalenciae@unal.edu.co>2024-07-24 15:31:02 -0500
committerjvech <jmvalenciae@unal.edu.co>2024-07-24 15:31:02 -0500
commitd45581c0b067b9526ce88ba9d3a1bd861f4ff7cc (patch)
treea907346b2b282437537d7f4f6b138b3efddcce22 /src/main.c
parentb9deaf6ec1ba587f2b81a63c75b696c6def33436 (diff)
add: file_read() and format integraded on main program
things implemented: - read output in false bug was solved. - Make generic rule added to build test executables - format option added to the CLI
Diffstat (limited to 'src/main.c')
-rw-r--r--src/main.c78
1 files changed, 4 insertions, 74 deletions
diff --git a/src/main.c b/src/main.c
index dab8bd9..e692756 100644
--- a/src/main.c
+++ b/src/main.c
@@ -26,91 +26,20 @@
#include <json-c/json.h>
#include "util.h"
+#include "parse.h"
#include "nn.h"
#define MAX_FILE_SIZE 536870912 //1<<29; 0.5 GiB
-typedef struct Array {
- double *data;
- size_t shape[2];
-} Array;
-
#define ARRAY_SIZE(x, type) sizeof(x) / sizeof(type)
-static void json_read(
- const char *filepath,
- Array *input, Array *out,
- char *out_keys[], size_t out_keys_size,
- char *in_keys[], size_t in_keys_size,
- bool read_output);
-
static void json_write(
const char *filepath,
Array input, Array out,
char *out_keys[], size_t out_keys_size,
char *in_keys[], size_t in_keys_size);
-void json_read(
- const char *filepath,
- Array *input, Array *out,
- char *out_keys[], size_t n_out_keys,
- char *in_keys[], size_t n_input_keys,
- bool read_output)
-{
- FILE *fp = NULL;
- static char fp_buffer[MAX_FILE_SIZE];
-
- fp = (!strcmp(filepath, "-")) ? fopen("/dev/stdin", "r") : fopen(filepath, "r");
-
- if (fp == NULL) goto json_read_error;
-
- size_t i = 0;
- do {
- if (i >= MAX_FILE_SIZE) die("json_read() Error: file size is bigger than '%zu'", i, MAX_FILE_SIZE);
- fp_buffer[i] = fgetc(fp);
- } while (fp_buffer[i++] != EOF);
-
- json_object *json_obj;
- json_obj = json_tokener_parse(fp_buffer);
- size_t json_obj_length = json_object_array_length(json_obj);
-
- input->shape[0] = (size_t)json_obj_length;
- input->shape[1] = n_input_keys;
- input->data = calloc(input->shape[0] * input->shape[1], sizeof(input->data[0]));
-
- out->shape[0] = (size_t)json_obj_length;
- out->shape[1] = n_out_keys;
- out->data = calloc(out->shape[0] * out->shape[1], sizeof(out->data[0]));
-
- if (!input->data || !out->data) goto json_read_error;
-
- for (int i = 0; i < json_object_array_length(json_obj); i++) {
- json_object *item = json_object_array_get_idx(json_obj, i);
-
- for (int j = 0; j < n_input_keys; j++) {
- size_t index = n_input_keys * i + j;
- input->data[index] = json_object_get_double(json_object_object_get(item, in_keys[j]));
- }
-
- if (!read_output) continue;
-
- for (int j = 0; j < n_out_keys; j++) {
- size_t index = n_out_keys * i + j;
- out->data[index] = json_object_get_double(json_object_object_get(item, out_keys[j]));
- }
- }
-
- json_object_put(json_obj);
- fclose(fp);
-
- return;
-
-json_read_error:
- perror("json_read() Error");
- exit(1);
-}
-
void json_write(
const char *filepath,
Array input, Array out,
@@ -204,6 +133,7 @@ int main(int argc, char *argv[]) {
.alpha = 1e-5,
.config_filepath = "utils/settings.cfg",
.network_size = 0,
+ .file_format = NULL,
.out_filepath = NULL,
};
@@ -223,7 +153,7 @@ int main(int argc, char *argv[]) {
Array X, y;
if (!strcmp("train", argv[0])) {
- json_read(argv[1], &X, &y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys, true);
+ file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, true, ml_configs.file_format);
nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
nn_network_train(
network, ml_configs.network_size,
@@ -235,7 +165,7 @@ int main(int argc, char *argv[]) {
nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath);
} else if (!strcmp("predict", argv[0])) {
- json_read(argv[1], &X, &y, ml_configs.label_keys, ml_configs.n_label_keys, ml_configs.input_keys, ml_configs.n_input_keys, false);
+ file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, false, ml_configs.file_format);
nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size);
Feel free to download, copy and edit any repo