diff options
Diffstat (limited to 'src/parse.c')
-rw-r--r-- | src/parse.c | 353 |
1 files changed, 280 insertions, 73 deletions
diff --git a/src/parse.c b/src/parse.c index 0533a13..c3fe2a6 100644 --- a/src/parse.c +++ b/src/parse.c @@ -29,8 +29,7 @@ static void json_read( FILE *fp, Array *input, Array *out, - char *in_keys[], size_t in_keys_size, - char *out_keys[], size_t out_keys_size, + struct Configs cfgs, bool read_output ); @@ -47,11 +46,7 @@ static void csv_read( static void json_write( FILE *fp, Array input, Array out, - char *in_keys[], size_t in_keys_size, - char *out_keys[], size_t out_keys_size, - bool write_input, - int decimal_precision - ); + struct Configs cfgs); static void csv_write( FILE *fp, @@ -73,7 +68,6 @@ static void csv_readline_values( static void csv_keys2cols(size_t cols[], char *keys[], size_t keys_size); - void file_read( char *filepath, Array *input, Array *out, @@ -82,10 +76,6 @@ void file_read( { FILE *fp; - char **in_keys = ml_config.input_keys; - char **out_keys = ml_config.label_keys; - size_t n_in_keys = ml_config.n_input_keys; - size_t n_out_keys = ml_config.n_label_keys; char *file_format = ml_config.file_format; if (filepath != NULL && strcmp(filepath, "-")) { @@ -102,9 +92,11 @@ void file_read( file_format = file_format_infer(filepath); } + /* if (!strcmp(file_format, "csv")) csv_read(fp, input, out, in_keys, n_in_keys, out_keys, n_out_keys, read_output, false, ','); else if (!strcmp(file_format, "tsv")) csv_read(fp, input, out, in_keys, n_in_keys, out_keys, n_out_keys, read_output, false, '\t'); - else if (!strcmp(file_format, "json")) json_read(fp, input, out, in_keys, n_in_keys, out_keys, n_out_keys, read_output); + */ + if (!strcmp(file_format, "json")) json_read(fp, input, out, ml_config, read_output); else { die("file_read() Error: unable to parse %s files", file_format); } @@ -117,14 +109,7 @@ void file_write(Array input, Array out, struct Configs ml_config) FILE *fp; char *filepath = ml_config.out_filepath; - char **in_keys = ml_config.input_keys; - char **out_keys = ml_config.label_keys; - size_t n_in_keys = ml_config.n_input_keys; - size_t n_out_keys = ml_config.n_label_keys; - bool write_input = !ml_config.only_out; char *file_format = ml_config.file_format; - char decimal_precision = ml_config.decimal_precision; - if (filepath != NULL && strcmp(filepath, "-")) { fp = fopen(filepath, "w"); @@ -138,20 +123,165 @@ void file_write(Array input, Array out, struct Configs ml_config) if (fp == NULL) die("file_write() Error:"); - if (!strcmp(file_format, "json")) json_write(fp, input, out, in_keys, n_in_keys, out_keys, n_out_keys, write_input, decimal_precision); + if (!strcmp(file_format, "json")) json_write(fp, input, out, ml_config); + /* else if (!strcmp(file_format, "csv")) csv_write(fp, input, out, write_input, ',', decimal_precision); else if (!strcmp(file_format, "tsv")) csv_write(fp, input, out, write_input, '\t', decimal_precision); + */ else { die("file_write() Error: unable to write %s files", file_format); } fclose(fp); } +void data_postprocess( + Array *out, + double *data, size_t data_shape[2], + struct Configs cfgs, + bool is_input) +{ + char **keys = (is_input) ? cfgs.input_keys : cfgs.label_keys; + size_t n_keys = (is_input) ? cfgs.n_input_keys : cfgs.n_label_keys; + + char **categorical_keys = cfgs.categorical_keys; + size_t n_categorical_keys = cfgs.n_categorical_keys; + + char ***categorical_values = cfgs.categorical_values; + size_t *n_categorical_values = cfgs.n_categorical_values; + + size_t i, j, data_j; + for (data_j = j = 0; j < n_keys; j++) { + int k; + switch (out->type[j]) { + case ARRAY_NUMERICAL: + for (i = 0; i < data_shape[0]; i++) { + size_t data_index = i * data_shape[1] + data_j; + size_t index = i * out->shape[1] + j; + out->data[index].numeric = data[data_index]; + } + data_j++; + break; + case ARRAY_ONEHOT: + k = util_get_key_index(keys[j], categorical_keys, n_categorical_keys); + if (k == -1) { + die("data_postprocess() Error: field '%s' is not registered as categorical", + keys[j]); + } + for (i = 0; i < data_shape[0]; i++) { + size_t index = i * out->shape[1] + j; + size_t data_index = i * data_shape[1] + data_j; + int onehot_i = util_argmax(data + data_index, n_categorical_values[k]); + out->data[index].categorical = e_strdup(categorical_values[k][onehot_i]); + } + data_j += n_categorical_values[k]; + break; + default: + die("data_postprocess() Error: unexpected type received on '%s' field", keys[j]); + } + } +} + +double * data_preprocess( + size_t out_shape[2], + Array data, + struct Configs cfgs, + bool is_input, + bool only_allocate) +{ + double *out; + + char **keys = (is_input) ? cfgs.input_keys : cfgs.label_keys; + size_t n_keys = (is_input) ? cfgs.n_input_keys : cfgs.n_label_keys; + + char **categorical_keys = cfgs.categorical_keys; + size_t n_categorical_keys = cfgs.n_categorical_keys; + + char ***categorical_values = cfgs.categorical_values; + size_t *n_categorical_values = cfgs.n_categorical_values; + + size_t i, j, out_j; + + out_shape[0] = data.shape[0]; + out_shape[1] = 0; + for (i = 0; i < n_keys; i++) { + int n; + switch (data.type[i]) { + case ARRAY_NUMERICAL: + out_shape[1]++; + break; + case ARRAY_ONEHOT: + n = util_get_key_index(keys[i], categorical_keys, n_categorical_keys); + if (n == -1) die("data_preprocess() Error: field '%s' is not marked as categorical", keys[i]); + out_shape[1] += n_categorical_values[n]; + break; + default: + die("data_preprocess() Error: field '%s' has an unknown type", keys[i]); + break; + } + } + + out = ecalloc(out_shape[0] * out_shape[1], sizeof(double)); + if (only_allocate) return out; + + for (out_j = j = 0; j < data.shape[1]; j++) { + switch (data.type[j]) { + int k; + case ARRAY_NUMERICAL: + for (i = 0; i < out_shape[0]; i++) { + size_t index = i * data.shape[1] + j; + size_t out_index = i * out_shape[1] + out_j; + out[out_index] = data.data[index].numeric; + } + out_j++; + break; + case ARRAY_ONEHOT: + k = util_get_key_index(keys[j], categorical_keys, n_categorical_keys); + for (i = 0; i < out_shape[0]; i++) { + int onehot_i; + size_t index = i * data.shape[1] + j; + onehot_i = util_get_key_index(data.data[index].categorical, + categorical_values[k], + n_categorical_values[k]); + if (onehot_i == -1) { + die("data_preprocess() Error: unexpected '%s' value found", + data.data[index].categorical); + } + size_t out_index = i * out_shape[1] + out_j + onehot_i; + out[out_index] = 1.0; + } + out_j += n_categorical_values[k]; + break; + default: + die("data_preprocess() Error: field '%s' has an unknown type", keys[j]); + } + } + + return out; +} + +void array_free(Array *x) { + size_t i, j, index; + for (j = 1; j < x->shape[1]; j++) { + switch (x->type[j]) { + case ARRAY_ORDINAL: + case ARRAY_ONEHOT: + for (i = 0; i < x->shape[0]; i++) { + index = x->shape[1] * i + j; + free(x->data[index].categorical); + } + break; + default: + break; + } + } + free(x->type); + free(x->data); +} + void json_read( FILE *fp, Array *input, Array *out, - char *in_keys[], size_t n_input_keys, - char *out_keys[], size_t n_out_keys, + struct Configs cfgs, bool read_output) { static char fp_buffer[MAX_FILE_SIZE]; @@ -159,8 +289,15 @@ void json_read( json_object *json_obj, *item, *value; json_type obj_type; + char **in_keys = cfgs.input_keys; + char **out_keys = cfgs.label_keys; + char **onehot_keys = cfgs.onehot_keys; + size_t n_input_keys = cfgs.n_input_keys; + size_t n_out_keys = cfgs.n_label_keys; + size_t n_onehot_keys = cfgs.n_onehot_keys; - if (fp == NULL) goto json_read_error; + + if (fp == NULL) die("json_read() Error:"); i = 0; do { @@ -178,13 +315,28 @@ void json_read( input->shape[0] = (size_t)json_obj_length; input->shape[1] = n_input_keys; - input->data = calloc(input->shape[0] * input->shape[1], sizeof(input->data[0])); + input->type = ecalloc(input->shape[1], sizeof(enum ArrayType)); + input->data = ecalloc(input->shape[0] * input->shape[1], sizeof(input->data[0])); out->shape[0] = (size_t)json_obj_length; out->shape[1] = n_out_keys; - out->data = calloc(out->shape[0] * out->shape[1], sizeof(out->data[0])); + out->type = ecalloc(out->shape[1], sizeof(enum ArrayType)); + out->data = ecalloc(out->shape[0] * out->shape[1], sizeof(out->data[0])); + + for (i = 0; i < n_onehot_keys; i++) { + for (j = 0; j < n_input_keys; j++) { + if (!strcmp(onehot_keys[i], in_keys[j])) { + input->type[j] = ARRAY_ONEHOT; + } + } + + for (j = 0; j < n_out_keys; j++) { + if (!strcmp(onehot_keys[i], out_keys[j])) { + out->type[j] = ARRAY_ONEHOT; + } + } + } - if (!input->data || !out->data) goto json_read_error; for (i = 0; i < json_object_array_length(json_obj); i++) { item = json_object_array_get_idx(json_obj, i); @@ -201,15 +353,30 @@ void json_read( for (j = 0; j < n_input_keys; j++) { value = json_object_object_get(item, in_keys[j]); obj_type = json_object_get_type(value); - switch (obj_type) { - case json_type_double: - case json_type_int: - index = n_input_keys * i + j; - input->data[index] = json_object_get_double(value); + index = n_input_keys * i + j; + switch (input->type[j]) { + case ARRAY_NUMERICAL: + switch (obj_type) { + case json_type_int: + case json_type_double: + input->data[index].numeric = json_object_get_double(value); + break; + default: + die("json_read() Error: unexpected JSON data received, expecting a number"); + } break; - default: - die("json_read() Error: unexpected JSON data received, expecting a number"); + case ARRAY_ONEHOT: + switch (obj_type) { + case json_type_int: + case json_type_string: + input->data[index].categorical = e_strdup(json_object_get_string(value)); + break; + default: + die("json_read() Error: unexpected JSON data received, expecting a string or integer"); + } break; + default: + die("json_read() Error: preprocess field type '%s' is not implemented", in_keys[j]); } } @@ -218,28 +385,40 @@ void json_read( for (j = 0; j < n_out_keys; j++) { value = json_object_object_get(item, out_keys[j]); obj_type = json_object_get_type(value); - switch (obj_type) { - case json_type_double: - case json_type_int: - index = n_out_keys * i + j; - out->data[index] = json_object_get_double(value); + index = n_out_keys * i + j; + switch (out->type[j]) { + case ARRAY_NUMERICAL: + switch (obj_type) { + case json_type_int: + case json_type_double: + out->data[index].numeric = json_object_get_double(value); + break; + default: + die("json_read() Error: unexpected JSON data received, expecting a number"); + } break; - default: - die("json_read() Error: unexpected JSON data received, expecting a number"); + case ARRAY_ONEHOT: + switch (obj_type) { + case json_type_int: + case json_type_string: + out->data[index].categorical = e_strdup(json_object_get_string(value)); + break; + default: + die("json_read() Error: unexpected JSON data received, expecting string or integer"); + } break; + default: + die("json_read() Error: preprocess field type '%s' is not implemented", out_keys[j]); } } } json_object_put(json_obj); return; - -json_read_error: - perror("json_read() Error"); - exit(1); } +/* void csv_read( FILE *fp, Array *input, Array *out, @@ -304,51 +483,78 @@ void csv_read( free(out_cols); return; } +*/ void json_write( FILE *fp, Array input, Array out, - char *in_keys[], size_t in_keys_size, - char *out_keys[], size_t out_keys_size, - bool write_input, - int decimal_precision) + struct Configs cfgs) { - fprintf(fp, "[\n"); - - if (in_keys_size != input.shape[1] && write_input) { - die("json_write() Error: there are more keys (%zu) than input columns (%zu)", - in_keys_size, input.shape[1]); + char **in_keys = cfgs.input_keys; + char **out_keys = cfgs.label_keys; + size_t n_in_keys = cfgs.n_input_keys; + size_t n_out_keys = cfgs.n_label_keys; + bool write_input = !cfgs.only_out; + int decimal_precision = cfgs.decimal_precision; + + json_object *root = json_object_new_array(); + if (!root) { + die("json_write() Error: Unable to create json_data"); } - if (out_keys_size != out.shape[1]) { - die("json_write() Error: there are more keys (%zu) than output columns (%zu)", - out_keys_size, out.shape[1]); - } + if (n_in_keys != input.shape[1]) + die("json_write() Error: input keys and data columns have different sizes"); + if (n_out_keys != out.shape[1]) + die("json_write() Error: output keys and data columns have different sizes"); - for (size_t i = 0; i < input.shape[0]; i++) { - fprintf(fp, " {\n"); + size_t i, j; + for (i = 0; i < input.shape[0]; i++) { + json_object *obj = json_object_new_object(); if (write_input) { - for (size_t j = 0; j < input.shape[1]; j++) { - size_t index = input.shape[1] * i + j; - fprintf(fp, " \"%s\": %g,\n", in_keys[j], input.data[index]); + for (j = 0; j < input.shape[1]; j++) { + char buffer[128]; + size_t index = i * input.shape[1] + j; + switch (input.type[j]) { + case ARRAY_NUMERICAL: + sprintf(buffer, "%g", input.data[index].numeric); + json_object_object_add(obj, in_keys[j], json_object_new_double_s(input.data[index].numeric, buffer)); + break; + case ARRAY_ONEHOT: + json_object_object_add(obj, in_keys[j], json_object_new_string(input.data[index].categorical)); + break; + default: + die("json_write(): Unexpected value received"); + } } } - for (size_t j = 0; j < out.shape[1]; j++) { - size_t index = out.shape[1] * i + j; - fprintf(fp, " \"%s\": %.*g", out_keys[j], decimal_precision, out.data[index]); - if (j == out.shape[1] - 1) fprintf(fp, "\n"); - else fprintf(fp, ",\n"); + for (j = 0; j < out.shape[1]; j++) { + size_t index = i * out.shape[1] + j; + char buffer[32]; + switch (out.type[j]) { + case ARRAY_NUMERICAL: + sprintf(buffer, "%.*g", decimal_precision, out.data[index].numeric); + json_object_object_add(obj, out_keys[j], json_object_new_double_s(out.data[index].numeric, buffer)); + break; + case ARRAY_ONEHOT: + json_object_object_add(obj, out_keys[j], json_object_new_string(out.data[index].categorical)); + break; + default: + die("json_write(): Unexpected value received"); + } } - - if (i == input.shape[0] - 1) fprintf(fp, " }\n"); - else fprintf(fp, " },\n"); + json_object_array_add(root, obj); + } + int ret = fprintf(fp, "%s", json_object_to_json_string_ext(root, JSON_C_TO_STRING_PRETTY | JSON_C_TO_STRING_SPACED)); + if (ret == -1) { + die("json_write() Error: unable to write json data"); } - fprintf(fp, "]\n"); + json_object_put(root); } +/* void csv_write( FILE *fp, Array input, Array out, @@ -373,6 +579,7 @@ void csv_write( fprintf(fp, "\n"); } } +*/ void csv_columns_select( double *dst_row, double *src_row, @@ -457,7 +664,6 @@ char * file_format_infer(char *filename) return file_format; } - #ifdef PARSE_TEST #include <assert.h> #include <string.h> @@ -546,4 +752,5 @@ int main(int argc, char *argv[]) { return 0; } + #endif |